import uuid from fastapi import Body from langchain.memory import ( CombinedMemory, ConversationBufferMemory, ConversationSummaryMemory, ConversationBufferWindowMemory ) from sse_starlette.sse import EventSourceResponse from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN from server.utils import wrap_done, get_ChatOpenAI from langchain.chains import LLMChain, ConversationChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable import asyncio import json from langchain.prompts.chat import ChatPromptTemplate from typing import List, Optional, Union from server.chat.utils import History from langchain.prompts import PromptTemplate from server.utils import get_prompt_template, get_format_template from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory from server.db.repository import add_message_to_db from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler from datetime import datetime from langchain_core.messages import SystemMessage import time as t from configs.basic_config import * async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), conversation_id: str = Body("", description="对话框ID"), history_len: int = Body(-1, description="从数据库中取历史消息的数量"), history: Union[int, List[History]] = Body([], description="历史对话,设为一个整数可以从数据库中读取历史消息", examples=[[ {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, {"role": "assistant", "content": "虎头虎脑"}]] ), stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): async def chat_iterator() -> AsyncIterable[str]: nonlocal history, max_tokens callback = AsyncIteratorCallbackHandler() callbacks = [callback] memory = None time = datetime.now().strftime("%Y年%m月%d日") # 负责保存llm response到message db message_id = str(uuid.uuid1())+"q" conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id, chat_type="llm_chat", query=query) callbacks.append(conversation_callback) logger.info(f"智能对话的入参信息:query:{query},conversation_id:{conversation_id},history:{history},stream:{stream},model_name:{model_name},temperature:{temperature},max_tokens:{max_tokens}prompt_name:{prompt_name}") if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None if prompt_name == "Search Summary": model = get_ChatOpenAI( model_name=LLM_MODELS[0], temperature=temperature, max_tokens=max_tokens, callbacks=callbacks, ) # print ("model info >>>", LLM_MODELS[0]) else: model = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, callbacks=callbacks, ) logger.info(f"当前使用的模型为:{model_name}") if history and prompt_name not in ["Search Summary", "get_policy_time"]: history = [History.from_data(h) for h in history] if prompt_name == "default": prompt_template = get_prompt_template("llm_chat", "default_with_history") if prompt_name == "Policy History Assistant": prompt_template = get_prompt_template("llm_chat", "Policy History Assistant_with_history") if prompt_name == "Topic Recommend Assistant": prompt_template = get_prompt_template("llm_chat", "Topic Recommend Assistant_with_history") if prompt_name == "Abstract Assistant": prompt_template = get_prompt_template("llm_chat", "Abstract Assistant_with_history") # input_prompt = History(role="system", content=prompt_template).to_msg_template(False) # chat_prompt = ChatPromptTemplate.from_messages([input_prompt]) chat_prompt = PromptTemplate.from_template(prompt_template) # 把history转成memory buff_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="history", input_key="input") for message in history: # 检查消息的角色 if message.role == 'user': # 添加用户消息 buff_memory.chat_memory.add_user_message(message.content) elif message.role == 'assistant': # 添加AI消息 buff_memory.chat_memory.add_ai_message(message.content) background_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="time", input_key="input") message = SystemMessage(content = f'当前的时间是:{time}') background_memory.chat_memory.add_message(message) memory = CombinedMemory(memories=[background_memory, buff_memory]) chain = ConversationChain(llm=model, verbose=True, memory=memory, prompt=chat_prompt) # elif conversation_id and history_len > 0: # 前端要求从数据库取历史消息 # # 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量 # prompt = get_prompt_template("llm_chat", "with_history") # chat_prompt = PromptTemplate.from_template(prompt) # # 根据conversation_id 获取message 列表进而拼凑 memory # memory = ConversationBufferDBMemory(conversation_id=conversation_id, # llm=model, # message_limit=history_len) else: prompt_template = get_prompt_template("llm_chat", prompt_name) input_prompt = History(role="system", content=prompt_template).to_msg_template(False) # input_msg = History(role="user", content=query).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages([input_prompt]) chain = LLMChain(prompt=chat_prompt, llm=model) # print(f'智能对话的chain>>>\n{chain}\n') task = asyncio.create_task(wrap_done( chain.acall({"input": query, "time": time}), callback.done), ) if stream: async for token in callback.aiter(): # Use server-sent-events to stream the response yield json.dumps( {"text": token, "message_id": message_id}, ensure_ascii=False) else: answer = "" async for token in callback.aiter(): answer += token yield json.dumps( {"text": answer, "message_id": message_id}, ensure_ascii=False) await task return EventSourceResponse(chat_iterator())