from langchain.memory import ConversationBufferWindowMemory from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent from server.agent.tools_select import tools, tool_names, search_tool_names from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status from langchain.agents import LLMSingleActionAgent, AgentExecutor from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate from fastapi import Body from sse_starlette.sse import EventSourceResponse from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template from langchain.chains import LLMChain from typing import AsyncIterable, Optional import asyncio from typing import List from server.chat.utils import History import json from server.agent import model_container from server.knowledge_base.kb_service.base import get_kb_details import ast import re from configs.basic_config import * async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), history: List[History] = Body([], description="历史对话", examples=[[ {"role": "user", "content": "请使用天气查询工具查询今天北京天气"}, {"role": "assistant", "content": "今天是2024年3月22日,受冷空气影响,白天有3、4级偏北风,阵风6、7" "级,西部山区阵风相对明显,局地伴有扬沙。白天晴,局地有扬沙,偏北风,1级转3、4级,阵风6、7级," "最高气温22℃。夜间晴间多云,偏北风,1、2级,最低气温6℃。"}]] ), stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): history = [History.from_data(h) for h in history] query = "帮我搜索一下:" + query logger.info(f"agent query: {query}") async def agent_chat_iterator( query: str, history: Optional[List[History]], model_name: str = LLM_MODELS[0], prompt_name: str = prompt_name, ) -> AsyncIterable[str]: nonlocal max_tokens callback = CustomAsyncIteratorCallbackHandler() if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None model = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, callbacks=[callback], ) ## 传入全局变量来实现agent调用 kb_list = {x["kb_name"]: x for x in get_kb_details()} model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()} if Agent_MODEL: ## 如果有指定使用Agent模型来完成任务 model_agent = get_ChatOpenAI( model_name=Agent_MODEL, temperature=temperature, max_tokens=max_tokens, callbacks=[callback], ) model_container.MODEL = model_agent else: model_container.MODEL = model prompt_template = get_prompt_template("agent_chat", prompt_name) type(prompt_template) prompt_template_agent = CustomPromptTemplate( template=prompt_template, tools=tools, input_variables=["input", "intermediate_steps", "history"] ) output_parser = CustomOutputParser() llm_chain = LLMChain(llm=model, prompt=prompt_template_agent) # 把history转成agent的memory memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2) for message in history: # 检查消息的角色 if message.role == 'user': # 添加用户消息 memory.chat_memory.add_user_message(message.content) else: # 添加AI消息 memory.chat_memory.add_ai_message(message.content) if "chatglm3" in model_container.MODEL.model_name: agent_executor = initialize_glm3_agent( llm=model, tools=tools, callback_manager=None, # Langchain Prompt is not constructed directly here, it is constructed inside the GLM3 agent. prompt=prompt_template, input_variables=["input", "intermediate_steps", "history"], memory=memory, verbose=True, ) else: agent = LLMSingleActionAgent( llm_chain=llm_chain, output_parser=output_parser, stop=["\nObservation:", "Observation"], allowed_tools=tool_names, ) agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory, ) while True: try: task = asyncio.create_task(wrap_done( agent_executor.acall(query, callbacks=[callback], include_run_info=True), callback.done)) break except: pass if stream: search_answer = "" policy_answer = "" async for chunk in callback.aiter(): tools_use = [] # Use server-sent-events to stream the response data = json.loads(chunk) if data["status"] == Status.start or data["status"] == Status.complete: continue elif data["status"] == Status.error: tools_use.append("\n```\n") tools_use.append("工具名称: " + data["tool_name"]) tools_use.append("工具状态: " + "调用失败") tools_use.append("错误信息: " + data["error"]) tools_use.append("重新开始尝试") tools_use.append("\n```\n") yield json.dumps({"tools": tools_use}, ensure_ascii=False) elif data["status"] == Status.tool_finish: tools_use.append("\n```\n") tools_use.append("工具名称: " + data["tool_name"]) tools_use.append("工具状态: " + "调用成功") tools_use.append("工具输入: " + data["input_str"]) if data["tool_name"] == "联网思索": if "政策类资料" in data["output_str"]: try: # 尝试使用ast.literal_eval来安全地解析字符串为列表 output_arr = ast.literal_eval(data["output_str"]) except ValueError: # 如果解析失败,处理错误 print("无法解析字符串为列表") output_arr = [] policy_content = ''.join(output_arr[:5]) policy_answer = ''.join(output_arr[5:10]) # search_content = output_arr[-2] search_answer_str = str(output_arr[-1]) if "暂未找到相关资料" in search_answer_str: search_answer = "\n知识中心资料: 暂无" else: search_answer_arr = search_answer_str[2: len(search_answer_str) - 2].replace("\\n", "") search_answer = '\n'.join(search_answer_arr.split("\', \'")) data["output_str"] = ''.join(policy_content + policy_answer) print("<<<工具输出>>>\n", data["output_str"]) elif "暂未找到相关资料" in data["output_str"]: print("output_str", data["output_str"]) try: output_arr = ast.literal_eval(data["output_str"]) except ValueError: # 如果解析失败,处理错误 print("无法解析字符串为列表") output_arr = [] # 计算列表中的列表和字符串数量 search_content = output_arr[:0] search_answer = str(output_arr[-1])[2: len(output_arr[-1]) - 3] data["output_str"] = str(search_content) print("<<<工具输出>>>\n", data["output_str"]) else: search_output_str = data["output_str"][2: len(data["output_str"]) - 3].replace("\\n", "") search_output_arr = search_output_str.split("\', [\'") search_content = str(search_output_arr[0]) search_output_str1 = str(search_output_arr[1]) search_answer = '\n'.join(search_output_str1.split("\', \'")) data["output_str"] = search_content print("<<<工具输出>>>\n", data["output_str"]) if data["tool_name"] == "policy_knowledgebase": # policy_output_str = data["output_str"][2: len(data["output_str"]) - 2].replace("\n", "") policy_output_str = ast.literal_eval((data["output_str"].replace("\n", "\\n"))) id_str = policy_output_str[0] processed_lines = [line.strip() + '\n' for line in policy_output_str[1]] policy_answer = id_str + '\n\n' + ''.join(processed_lines) # print("policy_output_str: ", policy_output_str) # policy_answer = '\n'.join(policy_output_str.split("\', \'")) # policy_answer += policy_output print("policy_answer:", policy_answer) tools_use.append("工具输出: " + data["output_str"]) tools_use.append("\n```\n") # 格式化工具的输出 yield json.dumps({"tools": tools_use}, ensure_ascii=False) elif data["status"] == Status.agent_finish and search_answer: # yield json.dumps({"final_answer": data["final_answer"] + "\n\n参考资料:\n\n" + search_answer.replace("\\n", '\n').replace("\', \'", '').replace("\'], [\'",'')}, ensure_ascii=False) if policy_answer: yield json.dumps({"final_answer": data["final_answer"] + "\n\n参考资料:\n\n" + str(policy_answer) + str(search_answer)}, ensure_ascii=False) else: yield json.dumps({"final_answer": data["final_answer"] + "\n\n参考资料:\n\n" + str(search_answer)}, ensure_ascii=False) print("search_answer_output:", search_answer) elif data["status"] == Status.agent_finish and policy_answer: # yield json.dumps({"final_answer": policy_answer.replace("\\n", '\n').replace("\', \'", '\n').replace("\', [\'",'\n\n')}, ensure_ascii=False) yield json.dumps({"final_answer": policy_answer}, ensure_ascii=False) print("policy_answer_output:", policy_answer) elif data["status"] == Status.agent_finish: yield json.dumps({"final_answer": data["final_answer"]}, ensure_ascii=False) else: yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False) else: answer = "" final_answer = "" async for chunk in callback.aiter(): # Use server-sent-events to stream the response data = json.loads(chunk) if data["status"] == Status.start or data["status"] == Status.complete: continue if data["status"] == Status.error: answer += "\n```\n" answer += "工具名称: " + data["tool_name"] + "\n" answer += "工具状态: " + "调用失败" + "\n" answer += "错误信息: " + data["error"] + "\n" answer += "\n```\n" if data["status"] == Status.tool_finish: answer += "\n```\n" answer += "工具名称: " + data["tool_name"] + "\n" answer += "工具状态: " + "调用成功" + "\n" answer += "工具输入: " + data["input_str"] + "\n" answer += "工具输出: " + data["output_str"] + "\n" answer += "\n```\n" if data["status"] == Status.agent_finish: final_answer = data["final_answer"] else: answer += data["llm_token"] yield json.dumps({"answer": answer, "final_answer": final_answer}, ensure_ascii=False) await task return EventSourceResponse(agent_chat_iterator(query=query, history=history, model_name=model_name, prompt_name=prompt_name), )