Files
gangyan/langchain-chat/server/chat/solve_problem.py

249 lines
14 KiB
Python

import uuid
from fastapi import Body
from langchain.memory import (
CombinedMemory,
ConversationBufferMemory,
ConversationSummaryMemory,
ConversationBufferWindowMemory
)
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN
from server.utils import wrap_done, get_ChatOpenAI
from langchain.chains import LLMChain, ConversationChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
import json
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional, Union
from server.chat.utils import History
from langchain.prompts import PromptTemplate
from server.utils import get_prompt_template, get_format_template
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
from datetime import datetime
from langchain_core.messages import SystemMessage
import time as t
from server.utils import replace_variables
from configs.basic_config import *
from configs.outline_config import outlines
MAX_RETRIES = 2
RETRY_DELAY = 1
async def solve_problem(
user_prompt_name: Optional[str] = Body(None, description="用户输入"),
query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
conversation_id: str = Body("", description="对话框ID"),
history: Union[int, List[History]] = Body([], description="历史对话"),
model_name: str = Body("default_model", description="LLM 模型名称。"),
temperature: float = Body(0.7, description="LLM 采样温度", ge=0.0, le=2.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量"),
prompt_name: str = Body("default", description="使用的prompt模板名称"),
stream: bool = Body(False, description="流式输出")
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
callbacks = [callback]
time = datetime.now().strftime("%Y年%m月%d")
message_id = str(uuid.uuid1())+"q"
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
if prompt_name == "solve_problem":
kwargs = {}
kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": True}}
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=callbacks,
**kwargs
)
else:
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=callbacks,
)
if history and not prompt_name == "history_route":
if prompt_name == "solve_problem":
user_prompt = get_prompt_template("llm_chat", user_prompt_name+"_with_history")
prompt_template = get_prompt_template("llm_chat", "solve_problem_history")
prompt_template = replace_variables(prompt_template, replace_content=user_prompt, replace_param="{user_prompt}")
elif prompt_name == "solve_problem_outline":
prompt_template = get_prompt_template("llm_chat", "solve_problem_outline_history")
elif prompt_name == "outlines_route":
prompt_template = get_prompt_template("llm_chat", "outlines_route_with_history")
outline_detail = [f"\"index\": \"{outline['index']}\", \"title\": \"{outline['title']}\", \"summary\": \"{outline['summary']}\""for outline in outlines[:-1]]
prompt_template = replace_variables(prompt_template, replace_content=str(outline_detail), replace_param="{outlines}")
else:
prompt_template = get_prompt_template("llm_chat", "think_route_history")
if user_prompt_name:
user_prompt = get_prompt_template("llm_chat", user_prompt_name+"_with_history")
prompt_template = replace_variables(prompt_template, replace_content=user_prompt, replace_param="{user_prompt}")
history = [History.from_data(h) for h in history]
chat_prompt = PromptTemplate.from_template(prompt_template)
# 把history转成memory
buff_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="history", input_key="input")
for message in history:
# 检查消息的角色
if message.role == 'user':
# 添加用户消息
buff_memory.chat_memory.add_user_message(message.content)
elif message.role == 'assistant':
# 添加AI消息
buff_memory.chat_memory.add_ai_message(message.content)
background_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="time", input_key="input")
message = SystemMessage(content = f'当前的时间是:{time}')
background_memory.chat_memory.add_message(message)
memory = CombinedMemory(memories=[background_memory, buff_memory])
chain = ConversationChain(llm=model, verbose=True, memory=memory, prompt=chat_prompt)
else:
prompt_template = get_prompt_template("llm_chat", prompt_name)
if user_prompt_name and prompt_name == "think_route":
user_prompt = get_prompt_template("llm_chat", user_prompt_name)
prompt_template = replace_variables(prompt_template, replace_content=user_prompt, replace_param="{user_prompt}")
if prompt_name == "history_route":
history = [History.from_data(h) for h in history]
buff_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="history", input_key="input")
background_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="time", input_key="input")
message = SystemMessage(content = f'当前的时间是:{time}')
background_memory.chat_memory.add_message(message)
memory = CombinedMemory(memories=[background_memory, buff_memory])
prompt_template = replace_variables(prompt_template, replace_content=str(history), replace_param="{history_summary}")
if prompt_name == "outlines_route":
outline_detail = [f"\"index\": \"{outline['index']}\", \"title\": \"{outline['title']}\", \"summary\": \"{outline['summary']}\""for outline in outlines[:-1]]
prompt_template = replace_variables(prompt_template, replace_content=str(outline_detail), replace_param="{outlines}")
if prompt_name == "solve_problem_outline":
prompt_template = get_prompt_template("llm_chat", "solve_problem_outline")
prompt_template = replace_variables(prompt_template, replace_content=datetime.now().strftime("%Y"), replace_param="{year}")
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
chain = LLMChain(prompt=chat_prompt, llm=model,verbose=True)
# 保存创建 chain 所需的信息,用于重试
chain_kwargs = {
"model_name": model_name,
"temperature": temperature,
"max_tokens": max_tokens,
"prompt_name": prompt_name,
"user_prompt_name": user_prompt_name,
}
# 判断是否有历史对话
has_history = history and not prompt_name == "history_route"
use_conversation_chain = has_history
answer = ""
retry_count = 0
while retry_count <= MAX_RETRIES:
try:
# 重新创建 callback 和 model
callback = AsyncIteratorCallbackHandler()
callbacks = [callback]
if prompt_name == "solve_problem":
kwargs = {}
kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": True}}
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=callbacks,
**kwargs
)
else:
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=callbacks,
)
# 重新创建 chain
if use_conversation_chain:
if prompt_name == "solve_problem":
user_prompt = get_prompt_template("llm_chat", user_prompt_name+"_with_history")
prompt_template = get_prompt_template("llm_chat", "solve_problem_history")
prompt_template = replace_variables(prompt_template, replace_content=user_prompt, replace_param="{user_prompt}")
elif prompt_name == "solve_problem_outline":
prompt_template = get_prompt_template("llm_chat", "solve_problem_outline_history")
elif prompt_name == "outlines_route":
prompt_template = get_prompt_template("llm_chat", "outlines_route_with_history")
outline_detail = [f"\"index\": \"{outline['index']}\", \"title\": \"{outline['title']}\", \"summary\": \"{outline['summary']}\""for outline in outlines[:-1]]
prompt_template = replace_variables(prompt_template, replace_content=str(outline_detail), replace_param="{outlines}")
else:
prompt_template = get_prompt_template("llm_chat", "think_route_history")
if user_prompt_name:
user_prompt = get_prompt_template("llm_chat", user_prompt_name+"_with_history")
prompt_template = replace_variables(prompt_template, replace_content=user_prompt, replace_param="{user_prompt}")
chat_prompt = PromptTemplate.from_template(prompt_template)
buff_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="history", input_key="input")
for message in history:
if message.role == 'user':
buff_memory.chat_memory.add_user_message(message.content)
elif message.role == 'assistant':
buff_memory.chat_memory.add_ai_message(message.content)
background_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="time", input_key="input")
message = SystemMessage(content=f'当前的时间是:{time}')
background_memory.chat_memory.add_message(message)
memory = CombinedMemory(memories=[background_memory, buff_memory])
chain = ConversationChain(llm=model, verbose=True, memory=memory, prompt=chat_prompt)
task = asyncio.create_task(wrap_done(
chain.acall({"input": query}),
callback.done),
)
else:
prompt_template = get_prompt_template("llm_chat", prompt_name)
if user_prompt_name and prompt_name == "think_route":
user_prompt = get_prompt_template("llm_chat", user_prompt_name)
prompt_template = replace_variables(prompt_template, replace_content=user_prompt, replace_param="{user_prompt}")
if prompt_name == "history_route":
buff_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="history", input_key="input")
background_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="time", input_key="input")
message = SystemMessage(content=f'当前的时间是:{time}')
background_memory.chat_memory.add_message(message)
memory = CombinedMemory(memories=[background_memory, buff_memory])
prompt_template = replace_variables(prompt_template, replace_content=str(history), replace_param="{history_summary}")
if prompt_name == "outlines_route":
outline_detail = [f"\"index\": \"{outline['index']}\", \"title\": \"{outline['title']}\", \"summary\": \"{outline['summary']}\""for outline in outlines[:-1]]
prompt_template = replace_variables(prompt_template, replace_content=str(outline_detail), replace_param="{outlines}")
if prompt_name == "solve_problem_outline":
prompt_template = get_prompt_template("llm_chat", "solve_problem_outline")
prompt_template = replace_variables(prompt_template, replace_content=datetime.now().strftime("%Y"), replace_param="{year}")
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
chain = LLMChain(prompt=chat_prompt, llm=model, verbose=True)
task = asyncio.create_task(wrap_done(
chain.acall({"input": query, "time": time}),
callback.done),
)
async for token in callback.aiter():
if stream:
yield json.dumps({"text": token}, ensure_ascii=False)
else:
answer += token
logger.info(f'solve_problem: {str(answer)}')
await task
break
except Exception as e:
retry_count += 1
if retry_count > MAX_RETRIES:
logger.error(f"流式传输失败,已达到最大重试次数 {MAX_RETRIES}: {e}")
raise
logger.warning(f"流式传输第 {retry_count} 次失败,{RETRY_DELAY}秒后重试: {e}")
await asyncio.sleep(RETRY_DELAY)
if stream:
return
else:
yield json.dumps(
{"text": answer, "message_id": message_id},
ensure_ascii=False
)