Files
gangyan/langchain-chat/server/chat/agent_write_test.py

254 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime
import uuid
from langchain.memory import ConversationBufferMemory
from server.agent.agent import Agent
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
from server.agent.tools_select import tools, tool_names, search_tool_names
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
from langchain.agents import LLMSingleActionAgent, AgentExecutor
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
from fastapi import Body
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
from server.chat import utils
from server.chat.knowledge_base_name import KnowledgeBase
from server.utils import replace_variables, wrap_done, get_ChatOpenAI, get_prompt_template
from langchain.chains import LLMChain
from typing import AsyncIterable, Optional
from server.agent.tools import rag_search
import asyncio
from typing import List
from server.chat.utils import History, remove_after_and_including, remove_after_and_includings
import json
from server.agent import model_container
from server.knowledge_base.kb_service.base import get_kb_details
import ast
import re
import asyncio
from concurrent.futures import ThreadPoolExecutor
from server.chat.policy_fun_iast import get_llm_model_response, get_llm_model_response_async, get_llm_model_response_stream_openai
from configs import kb_config
from configs.basic_config import *
_executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=8)
async def run_sync(func, /, *args, **kwargs):
loop = asyncio.get_running_loop()
return await loop.run_in_executor(_executor, lambda: func(*args, **kwargs))
async def agent_write_test(
user_prompt_name: Optional[str] = Body(None, description="用户输入", examples=[""]),
style: Optional[str] = Body("默认风格", description="语言风格", examples=[""]),
query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
think_content: str = Body(..., description="思考过程", examples=[""]),
uuid: str=Body(..., description="用户输入", examples=["恼羞成怒"]),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "请使用天气查询工具查询今天北京天气"},
{"role": "assistant",
"content": "今天是2024年3月22日受冷空气影响白天有3、4级偏北风阵风6、7"
"级,西部山区阵风相对明显,局地伴有扬沙。白天晴,局地有扬沙,偏北风,1级转3、4级阵风6、7级,"
"最高气温22℃。夜间晴间多云,偏北风,1、2级,最低气温6℃。"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(0.1, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
) -> AsyncIterable[str]:
agent = Agent()
agent.step = "初始步骤"
finish_tools = "已经调用过的工具名称:"
if user_prompt_name == "complete_outline":
user_prompt = get_prompt_template("llm_chat",user_prompt_name+"_with_history")
user_prompt = replace_variables(user_prompt, replace_content=style, replace_param="{style}")
else:
user_prompt = get_prompt_template("llm_chat",user_prompt_name)
user_prompt = user_prompt.replace("{{time}}",datetime.now().strftime("%Y%m%d")).replace("{{input}}",query)
# 假设 Tool 有 .name 和 .func 属性
tool_map = {tool.name: tool.func for tool in tools}
if history:
history_detail = str(history)
else:
history_detail = ""
all_tool = "\n".join([f'【工具名称】{tool.name}: 【工具描述】{tool.description}' for tool in tools])
rewrite = 0
wrong_num = 0
while not agent.step == "answer":
if rewrite == 0:
step_res = await run_sync(
get_llm_model_response,
strategy_name="query rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="get_next_write_tip",
prompt_param_dict={
"time": datetime.now().strftime("%Y%m%d"),
"step": agent.step,
"question": query,
"user_prompt": user_prompt,
"history": history_detail,
"tools": all_tool,
"res": agent.res,
"finish_tools": finish_tools,
"think_content": think_content,
},
temperature=0.01,
max_tokens=512,
)
next_step_pattern = r'<step>(.*?)</step>'
next_step_content = re.search(next_step_pattern, step_res, re.DOTALL)
agent.step = next_step_content.group(1).strip() if next_step_content else ""
agent.res += (
step_res.replace("<step>", "")
.replace("</step>", "")
.replace(agent.step, "")
)
match agent.step:
case "thinking":
thinking = await run_sync(
get_llm_model_response,
strategy_name="query rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="agent_write_think",
prompt_param_dict={
"time": datetime.now().strftime("%Y%m%d"),
"user_prompt": user_prompt,
"think_content": think_content,
"input": query,
"history": history_detail,
"res": agent.res,
"tools": all_tool,
"finish_tools": finish_tools,
},
temperature=0.01,
max_tokens=512,
)
agent.res = agent.res+"<step>thinking</step>"+thinking
print("thinking")
case "select_tool":
tool_desc = await run_sync(
get_llm_model_response,
strategy_name="query rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="tool_write_select",
prompt_param_dict={
"time": datetime.now().strftime("%Y%m%d"),
"user_prompt": user_prompt,
"think_content": think_content,
"input": query,
"history": history_detail,
"res": agent.res,
"tools": tools,
"finish_tools": finish_tools
},
temperature=0.01,
max_tokens=512,
)
agent.res += "<step>select_tool</step>"
try:
try:
tool_name_pattern = r'<tool>(.*?)</tool>'
toolname = re.search(tool_name_pattern, tool_desc.replace("\n","")).group(1).strip()
tool_input_pattern = r'<tool_input>(.*?)</tool_input>'
toolinput=re.search(tool_input_pattern, tool_desc.replace("\n","")).group(1).strip()
except Exception as e:
if wrong_num > 1:
tool_desc += f"请重新修正。以修正失败{wrong_num}"
tool_desc = await run_sync(
get_llm_model_response,
strategy_name="query rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="agent_rewrite",
prompt_param_dict={"input": tool_desc, "format": all_tool},
temperature=0.01,
max_tokens=512,
)
try:
tool_name_pattern = r'<tool>(.*?)</tool>'
toolname = re.search(tool_name_pattern, tool_desc.replace("\n","")).group(1).strip()
tool_input_pattern = r'<tool_input>(.*?)</tool_input>'
toolinput=re.search(tool_input_pattern, tool_desc.replace("\n","")).group(1).strip()
except Exception as e:
print(e)
wrong_num += 1
rewrite = 1
if wrong_num > 3:
rewrite = 0
finish_tools += tool_desc
agent.res+=f"使用{tool_desc}工具出现了异常,请使用其他工具或用自身能力回答,禁止虚拟链接"
finish_tools = finish_tools+ toolname +","
rewrite = 0
agent.res+="使用"+toolname+"工具的输入:【"+toolinput+""
toolinput+="{\"uuid\":\""+uuid+"\"}"
tool_func = tool_map[toolname]
# ---------- 调用工具(阻塞 → 线程池) ----------
if asyncio.iscoroutinefunction(tool_func):
result = await tool_func(toolinput)
else:
result = await run_sync(tool_func, toolinput)
if result == None and toolname == "统计数据查询":
agent.res+="使用"+toolname+"工具没有搜索到结果,推荐用联网思索替换统计数据查询"
else:
agent.res+="使用"+toolname+"工具的结果:【"+result+""
rewrite = 0
if toolname == "知识库联想":
source = utils.get_shared_variable(uuid)
source_docs = source["source_docs"]
try:
docs_string = "\n"+"\n".join(f"{str(doc)}\n" for doc in source_docs)
except:
print("知识库联想存在异常")
pass
yield json.dumps({"docs": docs_string}, ensure_ascii=False)
if toolname == "联网思索":
source = utils.get_shared_variable(uuid)
source_docs = source["source_docs"]
try:
docs_string = "\n"+"\n".join(f"{str(doc)}\n" for doc in source_docs)
except:
print("联网思索存在异常")
pass
yield json.dumps({"docs": docs_string}, ensure_ascii=False)
# if toolname == "天气工具":
# source = utils.get_shared_variable(uuid)
# source_docs = source["source_docs"]
# try:
# docs_string = "\n"+"\n".join(f"{str(doc)}\n" for doc in source_docs)
# except:
# print("知识库联想存在异常")
# pass
# yield json.dumps({"docs": docs_string}, ensure_ascii=False)
print("结果:", result)
print("select_tool")
rewrite = 0
except Exception as e:
rewrite = 1
wrong_num +=0.5
if wrong_num > 3:
rewrite = 0
finish_tools += tool_desc
agent.res+=f"使用{tool_desc}工具出现了异常,请使用其他工具或用自身能力回答,禁止虚拟链接"
await asyncio.sleep(0)
async for chunk in get_llm_model_response_stream_openai(
type=1,
strategy_name="query rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="agent_write_answer",
prompt_param_dict={"time": datetime.now().strftime("%Y%m%d"),"user_prompt":user_prompt,"think_content":think_content,"input":query,"history":history_detail,"res":agent.res,"tools":all_tool,"finish_tools":finish_tools},
temperature=0.01,
max_tokens=20000
):
yield json.dumps(
{"answer": chunk},
ensure_ascii=False)