Files
gangyan/langchain-chat/server/chat/agent_chat.py

258 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from langchain.memory import ConversationBufferWindowMemory
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
from server.agent.tools_select import tools, tool_names, search_tool_names
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
from langchain.agents import LLMSingleActionAgent, AgentExecutor
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
from fastapi import Body
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from langchain.chains import LLMChain
from typing import AsyncIterable, Optional
import asyncio
from typing import List
from server.chat.utils import History
import json
from server.agent import model_container
from server.knowledge_base.kb_service.base import get_kb_details
import ast
import re
from configs.basic_config import *
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "请使用天气查询工具查询今天北京天气"},
{"role": "assistant",
"content": "今天是2024年3月22日受冷空气影响白天有3、4级偏北风阵风6、7"
"级,西部山区阵风相对明显,局地伴有扬沙。白天晴,局地有扬沙,偏北风,1级转3、4级阵风6、7级,"
"最高气温22℃。夜间晴间多云,偏北风,1、2级,最低气温6℃。"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
history = [History.from_data(h) for h in history]
query = "帮我搜索一下:" + query
logger.info(f"agent query: {query}")
async def agent_chat_iterator(
query: str,
history: Optional[List[History]],
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = CustomAsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
## 传入全局变量来实现agent调用
kb_list = {x["kb_name"]: x for x in get_kb_details()}
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
if Agent_MODEL:
## 如果有指定使用Agent模型来完成任务
model_agent = get_ChatOpenAI(
model_name=Agent_MODEL,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
model_container.MODEL = model_agent
else:
model_container.MODEL = model
prompt_template = get_prompt_template("agent_chat", prompt_name)
type(prompt_template)
prompt_template_agent = CustomPromptTemplate(
template=prompt_template,
tools=tools,
input_variables=["input", "intermediate_steps", "history"]
)
output_parser = CustomOutputParser()
llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
# 把history转成agent的memory
memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
for message in history:
# 检查消息的角色
if message.role == 'user':
# 添加用户消息
memory.chat_memory.add_user_message(message.content)
else:
# 添加AI消息
memory.chat_memory.add_ai_message(message.content)
if "chatglm3" in model_container.MODEL.model_name:
agent_executor = initialize_glm3_agent(
llm=model,
tools=tools,
callback_manager=None,
# Langchain Prompt is not constructed directly here, it is constructed inside the GLM3 agent.
prompt=prompt_template,
input_variables=["input", "intermediate_steps", "history"],
memory=memory,
verbose=True,
)
else:
agent = LLMSingleActionAgent(
llm_chain=llm_chain,
output_parser=output_parser,
stop=["\nObservation:", "Observation"],
allowed_tools=tool_names,
)
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent,
tools=tools,
verbose=True,
memory=memory,
)
while True:
try:
task = asyncio.create_task(wrap_done(
agent_executor.acall(query, callbacks=[callback], include_run_info=True),
callback.done))
break
except:
pass
if stream:
search_answer = ""
policy_answer = ""
async for chunk in callback.aiter():
tools_use = []
# Use server-sent-events to stream the response
data = json.loads(chunk)
if data["status"] == Status.start or data["status"] == Status.complete:
continue
elif data["status"] == Status.error:
tools_use.append("\n```\n")
tools_use.append("工具名称: " + data["tool_name"])
tools_use.append("工具状态: " + "调用失败")
tools_use.append("错误信息: " + data["error"])
tools_use.append("重新开始尝试")
tools_use.append("\n```\n")
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
elif data["status"] == Status.tool_finish:
tools_use.append("\n```\n")
tools_use.append("工具名称: " + data["tool_name"])
tools_use.append("工具状态: " + "调用成功")
tools_use.append("工具输入: " + data["input_str"])
if data["tool_name"] == "联网思索":
if "政策类资料" in data["output_str"]:
try:
# 尝试使用ast.literal_eval来安全地解析字符串为列表
output_arr = ast.literal_eval(data["output_str"])
except ValueError:
# 如果解析失败,处理错误
print("无法解析字符串为列表")
output_arr = []
policy_content = ''.join(output_arr[:5])
policy_answer = ''.join(output_arr[5:10])
# search_content = output_arr[-2]
search_answer_str = str(output_arr[-1])
if "暂未找到相关资料" in search_answer_str:
search_answer = "\n知识中心资料: 暂无"
else:
search_answer_arr = search_answer_str[2: len(search_answer_str) - 2].replace("\\n", "")
search_answer = '\n'.join(search_answer_arr.split("\', \'"))
data["output_str"] = ''.join(policy_content + policy_answer)
print("<<<工具输出>>>\n", data["output_str"])
elif "暂未找到相关资料" in data["output_str"]:
print("output_str", data["output_str"])
try:
output_arr = ast.literal_eval(data["output_str"])
except ValueError:
# 如果解析失败,处理错误
print("无法解析字符串为列表")
output_arr = []
# 计算列表中的列表和字符串数量
search_content = output_arr[:0]
search_answer = str(output_arr[-1])[2: len(output_arr[-1]) - 3]
data["output_str"] = str(search_content)
print("<<<工具输出>>>\n", data["output_str"])
else:
search_output_str = data["output_str"][2: len(data["output_str"]) - 3].replace("\\n", "")
search_output_arr = search_output_str.split("\', [\'")
search_content = str(search_output_arr[0])
search_output_str1 = str(search_output_arr[1])
search_answer = '\n'.join(search_output_str1.split("\', \'"))
data["output_str"] = search_content
print("<<<工具输出>>>\n", data["output_str"])
if data["tool_name"] == "policy_knowledgebase":
# policy_output_str = data["output_str"][2: len(data["output_str"]) - 2].replace("\n", "")
policy_output_str = ast.literal_eval((data["output_str"].replace("\n", "\\n")))
id_str = policy_output_str[0]
processed_lines = [line.strip() + '\n' for line in policy_output_str[1]]
policy_answer = id_str + '\n\n' + ''.join(processed_lines)
# print("policy_output_str: ", policy_output_str)
# policy_answer = '\n'.join(policy_output_str.split("\', \'"))
# policy_answer += policy_output
print("policy_answer", policy_answer)
tools_use.append("工具输出: " + data["output_str"])
tools_use.append("\n```\n")
# 格式化工具的输出
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
elif data["status"] == Status.agent_finish and search_answer:
# yield json.dumps({"final_answer": data["final_answer"] + "\n\n参考资料\n\n" + search_answer.replace("\\n", '\n').replace("\', \'", '').replace("\'], [\'",'')}, ensure_ascii=False)
if policy_answer:
yield json.dumps({"final_answer": data["final_answer"] + "\n\n参考资料:\n\n" + str(policy_answer) + str(search_answer)}, ensure_ascii=False)
else:
yield json.dumps({"final_answer": data["final_answer"] + "\n\n参考资料:\n\n" + str(search_answer)}, ensure_ascii=False)
print("search_answer_output", search_answer)
elif data["status"] == Status.agent_finish and policy_answer:
# yield json.dumps({"final_answer": policy_answer.replace("\\n", '\n').replace("\', \'", '\n').replace("\', [\'",'\n\n')}, ensure_ascii=False)
yield json.dumps({"final_answer": policy_answer}, ensure_ascii=False)
print("policy_answer_output", policy_answer)
elif data["status"] == Status.agent_finish:
yield json.dumps({"final_answer": data["final_answer"]}, ensure_ascii=False)
else:
yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False)
else:
answer = ""
final_answer = ""
async for chunk in callback.aiter():
# Use server-sent-events to stream the response
data = json.loads(chunk)
if data["status"] == Status.start or data["status"] == Status.complete:
continue
if data["status"] == Status.error:
answer += "\n```\n"
answer += "工具名称: " + data["tool_name"] + "\n"
answer += "工具状态: " + "调用失败" + "\n"
answer += "错误信息: " + data["error"] + "\n"
answer += "\n```\n"
if data["status"] == Status.tool_finish:
answer += "\n```\n"
answer += "工具名称: " + data["tool_name"] + "\n"
answer += "工具状态: " + "调用成功" + "\n"
answer += "工具输入: " + data["input_str"] + "\n"
answer += "工具输出: " + data["output_str"] + "\n"
answer += "\n```\n"
if data["status"] == Status.agent_finish:
final_answer = data["final_answer"]
else:
answer += data["llm_token"]
yield json.dumps({"answer": answer, "final_answer": final_answer}, ensure_ascii=False)
await task
return EventSourceResponse(agent_chat_iterator(query=query,
history=history,
model_name=model_name,
prompt_name=prompt_name),
)