import uuid from fastapi import Body from langchain.memory import ( CombinedMemory, ConversationBufferMemory, ConversationSummaryMemory, ConversationBufferWindowMemory ) from sse_starlette.sse import EventSourceResponse from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN from server.utils import wrap_done, get_ChatOpenAI from langchain.chains import LLMChain, ConversationChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable import asyncio import json from langchain.prompts.chat import ChatPromptTemplate from typing import List, Optional, Union from server.chat.utils import History from langchain.prompts import PromptTemplate from server.utils import get_prompt_template, get_format_template from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory from server.db.repository import add_message_to_db from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler from datetime import datetime from langchain_core.messages import SystemMessage import time as t from server.utils import replace_variables from configs.basic_config import * from configs.outline_config import outlines MAX_RETRIES = 2 RETRY_DELAY = 1 async def solve_problem( user_prompt_name: Optional[str] = Body(None, description="用户输入"), query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), conversation_id: str = Body("", description="对话框ID"), history: Union[int, List[History]] = Body([], description="历史对话"), model_name: str = Body("default_model", description="LLM 模型名称。"), temperature: float = Body(0.7, description="LLM 采样温度", ge=0.0, le=2.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量"), prompt_name: str = Body("default", description="使用的prompt模板名称"), stream: bool = Body(False, description="流式输出") ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() callbacks = [callback] time = datetime.now().strftime("%Y年%m月%d日") message_id = str(uuid.uuid1())+"q" if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None if prompt_name == "solve_problem": kwargs = {} kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": True}} model = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, callbacks=callbacks, **kwargs ) else: model = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, callbacks=callbacks, ) if history and not prompt_name == "history_route": if prompt_name == "solve_problem": user_prompt = get_prompt_template("llm_chat", user_prompt_name+"_with_history") prompt_template = get_prompt_template("llm_chat", "solve_problem_history") prompt_template = replace_variables(prompt_template, replace_content=user_prompt, replace_param="{user_prompt}") elif prompt_name == "solve_problem_outline": prompt_template = get_prompt_template("llm_chat", "solve_problem_outline_history") elif prompt_name == "outlines_route": prompt_template = get_prompt_template("llm_chat", "outlines_route_with_history") outline_detail = [f"\"index\": \"{outline['index']}\", \"title\": \"{outline['title']}\", \"summary\": \"{outline['summary']}\""for outline in outlines[:-1]] prompt_template = replace_variables(prompt_template, replace_content=str(outline_detail), replace_param="{outlines}") else: prompt_template = get_prompt_template("llm_chat", "think_route_history") if user_prompt_name: user_prompt = get_prompt_template("llm_chat", user_prompt_name+"_with_history") prompt_template = replace_variables(prompt_template, replace_content=user_prompt, replace_param="{user_prompt}") history = [History.from_data(h) for h in history] chat_prompt = PromptTemplate.from_template(prompt_template) # 把history转成memory buff_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="history", input_key="input") for message in history: # 检查消息的角色 if message.role == 'user': # 添加用户消息 buff_memory.chat_memory.add_user_message(message.content) elif message.role == 'assistant': # 添加AI消息 buff_memory.chat_memory.add_ai_message(message.content) background_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="time", input_key="input") message = SystemMessage(content = f'当前的时间是:{time}') background_memory.chat_memory.add_message(message) memory = CombinedMemory(memories=[background_memory, buff_memory]) chain = ConversationChain(llm=model, verbose=True, memory=memory, prompt=chat_prompt) else: prompt_template = get_prompt_template("llm_chat", prompt_name) if user_prompt_name and prompt_name == "think_route": user_prompt = get_prompt_template("llm_chat", user_prompt_name) prompt_template = replace_variables(prompt_template, replace_content=user_prompt, replace_param="{user_prompt}") if prompt_name == "history_route": history = [History.from_data(h) for h in history] buff_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="history", input_key="input") background_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="time", input_key="input") message = SystemMessage(content = f'当前的时间是:{time}') background_memory.chat_memory.add_message(message) memory = CombinedMemory(memories=[background_memory, buff_memory]) prompt_template = replace_variables(prompt_template, replace_content=str(history), replace_param="{history_summary}") if prompt_name == "outlines_route": outline_detail = [f"\"index\": \"{outline['index']}\", \"title\": \"{outline['title']}\", \"summary\": \"{outline['summary']}\""for outline in outlines[:-1]] prompt_template = replace_variables(prompt_template, replace_content=str(outline_detail), replace_param="{outlines}") if prompt_name == "solve_problem_outline": prompt_template = get_prompt_template("llm_chat", "solve_problem_outline") prompt_template = replace_variables(prompt_template, replace_content=datetime.now().strftime("%Y"), replace_param="{year}") input_prompt = History(role="system", content=prompt_template).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages([input_prompt]) chain = LLMChain(prompt=chat_prompt, llm=model,verbose=True) # 保存创建 chain 所需的信息,用于重试 chain_kwargs = { "model_name": model_name, "temperature": temperature, "max_tokens": max_tokens, "prompt_name": prompt_name, "user_prompt_name": user_prompt_name, } # 判断是否有历史对话 has_history = history and not prompt_name == "history_route" use_conversation_chain = has_history answer = "" retry_count = 0 while retry_count <= MAX_RETRIES: try: # 重新创建 callback 和 model callback = AsyncIteratorCallbackHandler() callbacks = [callback] if prompt_name == "solve_problem": kwargs = {} kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": True}} model = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, callbacks=callbacks, **kwargs ) else: model = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, callbacks=callbacks, ) # 重新创建 chain if use_conversation_chain: if prompt_name == "solve_problem": user_prompt = get_prompt_template("llm_chat", user_prompt_name+"_with_history") prompt_template = get_prompt_template("llm_chat", "solve_problem_history") prompt_template = replace_variables(prompt_template, replace_content=user_prompt, replace_param="{user_prompt}") elif prompt_name == "solve_problem_outline": prompt_template = get_prompt_template("llm_chat", "solve_problem_outline_history") elif prompt_name == "outlines_route": prompt_template = get_prompt_template("llm_chat", "outlines_route_with_history") outline_detail = [f"\"index\": \"{outline['index']}\", \"title\": \"{outline['title']}\", \"summary\": \"{outline['summary']}\""for outline in outlines[:-1]] prompt_template = replace_variables(prompt_template, replace_content=str(outline_detail), replace_param="{outlines}") else: prompt_template = get_prompt_template("llm_chat", "think_route_history") if user_prompt_name: user_prompt = get_prompt_template("llm_chat", user_prompt_name+"_with_history") prompt_template = replace_variables(prompt_template, replace_content=user_prompt, replace_param="{user_prompt}") chat_prompt = PromptTemplate.from_template(prompt_template) buff_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="history", input_key="input") for message in history: if message.role == 'user': buff_memory.chat_memory.add_user_message(message.content) elif message.role == 'assistant': buff_memory.chat_memory.add_ai_message(message.content) background_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="time", input_key="input") message = SystemMessage(content=f'当前的时间是:{time}') background_memory.chat_memory.add_message(message) memory = CombinedMemory(memories=[background_memory, buff_memory]) chain = ConversationChain(llm=model, verbose=True, memory=memory, prompt=chat_prompt) task = asyncio.create_task(wrap_done( chain.acall({"input": query}), callback.done), ) else: prompt_template = get_prompt_template("llm_chat", prompt_name) if user_prompt_name and prompt_name == "think_route": user_prompt = get_prompt_template("llm_chat", user_prompt_name) prompt_template = replace_variables(prompt_template, replace_content=user_prompt, replace_param="{user_prompt}") if prompt_name == "history_route": buff_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="history", input_key="input") background_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="time", input_key="input") message = SystemMessage(content=f'当前的时间是:{time}') background_memory.chat_memory.add_message(message) memory = CombinedMemory(memories=[background_memory, buff_memory]) prompt_template = replace_variables(prompt_template, replace_content=str(history), replace_param="{history_summary}") if prompt_name == "outlines_route": outline_detail = [f"\"index\": \"{outline['index']}\", \"title\": \"{outline['title']}\", \"summary\": \"{outline['summary']}\""for outline in outlines[:-1]] prompt_template = replace_variables(prompt_template, replace_content=str(outline_detail), replace_param="{outlines}") if prompt_name == "solve_problem_outline": prompt_template = get_prompt_template("llm_chat", "solve_problem_outline") prompt_template = replace_variables(prompt_template, replace_content=datetime.now().strftime("%Y"), replace_param="{year}") input_prompt = History(role="system", content=prompt_template).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages([input_prompt]) chain = LLMChain(prompt=chat_prompt, llm=model, verbose=True) task = asyncio.create_task(wrap_done( chain.acall({"input": query, "time": time}), callback.done), ) async for token in callback.aiter(): if stream: yield json.dumps({"text": token}, ensure_ascii=False) else: answer += token logger.info(f'solve_problem: {str(answer)}') await task break except Exception as e: retry_count += 1 if retry_count > MAX_RETRIES: logger.error(f"流式传输失败,已达到最大重试次数 {MAX_RETRIES}: {e}") raise logger.warning(f"流式传输第 {retry_count} 次失败,{RETRY_DELAY}秒后重试: {e}") await asyncio.sleep(RETRY_DELAY) if stream: return else: yield json.dumps( {"text": answer, "message_id": message_id}, ensure_ascii=False )