Files
gangyan/langchain-chat/server/chat/agent_chat.py

258 lines
14 KiB
Python
Raw Permalink Normal View History

from langchain.memory import ConversationBufferWindowMemory
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
from server.agent.tools_select import tools, tool_names, search_tool_names
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
from langchain.agents import LLMSingleActionAgent, AgentExecutor
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
from fastapi import Body
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from langchain.chains import LLMChain
from typing import AsyncIterable, Optional
import asyncio
from typing import List
from server.chat.utils import History
import json
from server.agent import model_container
from server.knowledge_base.kb_service.base import get_kb_details
import ast
import re
from configs.basic_config import *
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "请使用天气查询工具查询今天北京天气"},
{"role": "assistant",
"content": "今天是2024年3月22日受冷空气影响白天有3、4级偏北风阵风6、7"
"级,西部山区阵风相对明显,局地伴有扬沙。白天晴,局地有扬沙,偏北风,1级转3、4级阵风6、7级,"
"最高气温22℃。夜间晴间多云,偏北风,1、2级,最低气温6℃。"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
history = [History.from_data(h) for h in history]
query = "帮我搜索一下:" + query
logger.info(f"agent query: {query}")
async def agent_chat_iterator(
query: str,
history: Optional[List[History]],
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = CustomAsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
## 传入全局变量来实现agent调用
kb_list = {x["kb_name"]: x for x in get_kb_details()}
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
if Agent_MODEL:
## 如果有指定使用Agent模型来完成任务
model_agent = get_ChatOpenAI(
model_name=Agent_MODEL,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
model_container.MODEL = model_agent
else:
model_container.MODEL = model
prompt_template = get_prompt_template("agent_chat", prompt_name)
type(prompt_template)
prompt_template_agent = CustomPromptTemplate(
template=prompt_template,
tools=tools,
input_variables=["input", "intermediate_steps", "history"]
)
output_parser = CustomOutputParser()
llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
# 把history转成agent的memory
memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
for message in history:
# 检查消息的角色
if message.role == 'user':
# 添加用户消息
memory.chat_memory.add_user_message(message.content)
else:
# 添加AI消息
memory.chat_memory.add_ai_message(message.content)
if "chatglm3" in model_container.MODEL.model_name:
agent_executor = initialize_glm3_agent(
llm=model,
tools=tools,
callback_manager=None,
# Langchain Prompt is not constructed directly here, it is constructed inside the GLM3 agent.
prompt=prompt_template,
input_variables=["input", "intermediate_steps", "history"],
memory=memory,
verbose=True,
)
else:
agent = LLMSingleActionAgent(
llm_chain=llm_chain,
output_parser=output_parser,
stop=["\nObservation:", "Observation"],
allowed_tools=tool_names,
)
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent,
tools=tools,
verbose=True,
memory=memory,
)
while True:
try:
task = asyncio.create_task(wrap_done(
agent_executor.acall(query, callbacks=[callback], include_run_info=True),
callback.done))
break
except:
pass
if stream:
search_answer = ""
policy_answer = ""
async for chunk in callback.aiter():
tools_use = []
# Use server-sent-events to stream the response
data = json.loads(chunk)
if data["status"] == Status.start or data["status"] == Status.complete:
continue
elif data["status"] == Status.error:
tools_use.append("\n```\n")
tools_use.append("工具名称: " + data["tool_name"])
tools_use.append("工具状态: " + "调用失败")
tools_use.append("错误信息: " + data["error"])
tools_use.append("重新开始尝试")
tools_use.append("\n```\n")
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
elif data["status"] == Status.tool_finish:
tools_use.append("\n```\n")
tools_use.append("工具名称: " + data["tool_name"])
tools_use.append("工具状态: " + "调用成功")
tools_use.append("工具输入: " + data["input_str"])
if data["tool_name"] == "联网思索":
if "政策类资料" in data["output_str"]:
try:
# 尝试使用ast.literal_eval来安全地解析字符串为列表
output_arr = ast.literal_eval(data["output_str"])
except ValueError:
# 如果解析失败,处理错误
print("无法解析字符串为列表")
output_arr = []
policy_content = ''.join(output_arr[:5])
policy_answer = ''.join(output_arr[5:10])
# search_content = output_arr[-2]
search_answer_str = str(output_arr[-1])
if "暂未找到相关资料" in search_answer_str:
search_answer = "\n知识中心资料: 暂无"
else:
search_answer_arr = search_answer_str[2: len(search_answer_str) - 2].replace("\\n", "")
search_answer = '\n'.join(search_answer_arr.split("\', \'"))
data["output_str"] = ''.join(policy_content + policy_answer)
print("<<<工具输出>>>\n", data["output_str"])
elif "暂未找到相关资料" in data["output_str"]:
print("output_str", data["output_str"])
try:
output_arr = ast.literal_eval(data["output_str"])
except ValueError:
# 如果解析失败,处理错误
print("无法解析字符串为列表")
output_arr = []
# 计算列表中的列表和字符串数量
search_content = output_arr[:0]
search_answer = str(output_arr[-1])[2: len(output_arr[-1]) - 3]
data["output_str"] = str(search_content)
print("<<<工具输出>>>\n", data["output_str"])
else:
search_output_str = data["output_str"][2: len(data["output_str"]) - 3].replace("\\n", "")
search_output_arr = search_output_str.split("\', [\'")
search_content = str(search_output_arr[0])
search_output_str1 = str(search_output_arr[1])
search_answer = '\n'.join(search_output_str1.split("\', \'"))
data["output_str"] = search_content
print("<<<工具输出>>>\n", data["output_str"])
if data["tool_name"] == "policy_knowledgebase":
# policy_output_str = data["output_str"][2: len(data["output_str"]) - 2].replace("\n", "")
policy_output_str = ast.literal_eval((data["output_str"].replace("\n", "\\n")))
id_str = policy_output_str[0]
processed_lines = [line.strip() + '\n' for line in policy_output_str[1]]
policy_answer = id_str + '\n\n' + ''.join(processed_lines)
# print("policy_output_str: ", policy_output_str)
# policy_answer = '\n'.join(policy_output_str.split("\', \'"))
# policy_answer += policy_output
print("policy_answer", policy_answer)
tools_use.append("工具输出: " + data["output_str"])
tools_use.append("\n```\n")
# 格式化工具的输出
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
elif data["status"] == Status.agent_finish and search_answer:
# yield json.dumps({"final_answer": data["final_answer"] + "\n\n参考资料\n\n" + search_answer.replace("\\n", '\n').replace("\', \'", '').replace("\'], [\'",'')}, ensure_ascii=False)
if policy_answer:
yield json.dumps({"final_answer": data["final_answer"] + "\n\n参考资料:\n\n" + str(policy_answer) + str(search_answer)}, ensure_ascii=False)
else:
yield json.dumps({"final_answer": data["final_answer"] + "\n\n参考资料:\n\n" + str(search_answer)}, ensure_ascii=False)
print("search_answer_output", search_answer)
elif data["status"] == Status.agent_finish and policy_answer:
# yield json.dumps({"final_answer": policy_answer.replace("\\n", '\n').replace("\', \'", '\n').replace("\', [\'",'\n\n')}, ensure_ascii=False)
yield json.dumps({"final_answer": policy_answer}, ensure_ascii=False)
print("policy_answer_output", policy_answer)
elif data["status"] == Status.agent_finish:
yield json.dumps({"final_answer": data["final_answer"]}, ensure_ascii=False)
else:
yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False)
else:
answer = ""
final_answer = ""
async for chunk in callback.aiter():
# Use server-sent-events to stream the response
data = json.loads(chunk)
if data["status"] == Status.start or data["status"] == Status.complete:
continue
if data["status"] == Status.error:
answer += "\n```\n"
answer += "工具名称: " + data["tool_name"] + "\n"
answer += "工具状态: " + "调用失败" + "\n"
answer += "错误信息: " + data["error"] + "\n"
answer += "\n```\n"
if data["status"] == Status.tool_finish:
answer += "\n```\n"
answer += "工具名称: " + data["tool_name"] + "\n"
answer += "工具状态: " + "调用成功" + "\n"
answer += "工具输入: " + data["input_str"] + "\n"
answer += "工具输出: " + data["output_str"] + "\n"
answer += "\n```\n"
if data["status"] == Status.agent_finish:
final_answer = data["final_answer"]
else:
answer += data["llm_token"]
yield json.dumps({"answer": answer, "final_answer": final_answer}, ensure_ascii=False)
await task
return EventSourceResponse(agent_chat_iterator(query=query,
history=history,
model_name=model_name,
prompt_name=prompt_name),
)