[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
253
langchain-chat/server/chat/agent_write_test.py
Normal file
253
langchain-chat/server/chat/agent_write_test.py
Normal file
@@ -0,0 +1,253 @@
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from server.agent.agent import Agent
|
||||
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
|
||||
from server.agent.tools_select import tools, tool_names, search_tool_names
|
||||
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
|
||||
from langchain.agents import LLMSingleActionAgent, AgentExecutor
|
||||
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
|
||||
from fastapi import Body
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
|
||||
from server.chat import utils
|
||||
from server.chat.knowledge_base_name import KnowledgeBase
|
||||
from server.utils import replace_variables, wrap_done, get_ChatOpenAI, get_prompt_template
|
||||
from langchain.chains import LLMChain
|
||||
from typing import AsyncIterable, Optional
|
||||
from server.agent.tools import rag_search
|
||||
import asyncio
|
||||
from typing import List
|
||||
from server.chat.utils import History, remove_after_and_including, remove_after_and_includings
|
||||
import json
|
||||
from server.agent import model_container
|
||||
from server.knowledge_base.kb_service.base import get_kb_details
|
||||
import ast
|
||||
import re
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from server.chat.policy_fun_iast import get_llm_model_response, get_llm_model_response_async, get_llm_model_response_stream_openai
|
||||
from configs import kb_config
|
||||
from configs.basic_config import *
|
||||
|
||||
|
||||
|
||||
_executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=8)
|
||||
|
||||
async def run_sync(func, /, *args, **kwargs):
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(_executor, lambda: func(*args, **kwargs))
|
||||
|
||||
|
||||
async def agent_write_test(
|
||||
user_prompt_name: Optional[str] = Body(None, description="用户输入", examples=[""]),
|
||||
style: Optional[str] = Body("默认风格", description="语言风格", examples=[""]),
|
||||
query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
think_content: str = Body(..., description="思考过程", examples=[""]),
|
||||
uuid: str=Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "请使用天气查询工具查询今天北京天气"},
|
||||
{"role": "assistant",
|
||||
"content": "今天是2024年3月22日,受冷空气影响,白天有3、4级偏北风,阵风6、7"
|
||||
"级,西部山区阵风相对明显,局地伴有扬沙。白天晴,局地有扬沙,偏北风,1级转3、4级,阵风6、7级,"
|
||||
"最高气温22℃。夜间晴间多云,偏北风,1、2级,最低气温6℃。"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.1, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
prompt_name: str = Body("default",
|
||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
) -> AsyncIterable[str]:
|
||||
agent = Agent()
|
||||
agent.step = "初始步骤"
|
||||
finish_tools = "已经调用过的工具名称:"
|
||||
if user_prompt_name == "complete_outline":
|
||||
user_prompt = get_prompt_template("llm_chat",user_prompt_name+"_with_history")
|
||||
user_prompt = replace_variables(user_prompt, replace_content=style, replace_param="{style}")
|
||||
else:
|
||||
user_prompt = get_prompt_template("llm_chat",user_prompt_name)
|
||||
user_prompt = user_prompt.replace("{{time}}",datetime.now().strftime("%Y%m%d")).replace("{{input}}",query)
|
||||
# 假设 Tool 有 .name 和 .func 属性
|
||||
tool_map = {tool.name: tool.func for tool in tools}
|
||||
if history:
|
||||
history_detail = str(history)
|
||||
else:
|
||||
history_detail = ""
|
||||
all_tool = "\n".join([f'【工具名称】{tool.name}: 【工具描述】{tool.description}' for tool in tools])
|
||||
rewrite = 0
|
||||
wrong_num = 0
|
||||
while not agent.step == "answer":
|
||||
if rewrite == 0:
|
||||
step_res = await run_sync(
|
||||
get_llm_model_response,
|
||||
strategy_name="query rewrite",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="get_next_write_tip",
|
||||
prompt_param_dict={
|
||||
"time": datetime.now().strftime("%Y%m%d"),
|
||||
"step": agent.step,
|
||||
"question": query,
|
||||
"user_prompt": user_prompt,
|
||||
"history": history_detail,
|
||||
"tools": all_tool,
|
||||
"res": agent.res,
|
||||
"finish_tools": finish_tools,
|
||||
"think_content": think_content,
|
||||
},
|
||||
temperature=0.01,
|
||||
max_tokens=512,
|
||||
)
|
||||
next_step_pattern = r'<step>(.*?)</step>'
|
||||
next_step_content = re.search(next_step_pattern, step_res, re.DOTALL)
|
||||
agent.step = next_step_content.group(1).strip() if next_step_content else ""
|
||||
agent.res += (
|
||||
step_res.replace("<step>", "")
|
||||
.replace("</step>", "")
|
||||
.replace(agent.step, "")
|
||||
)
|
||||
match agent.step:
|
||||
case "thinking":
|
||||
thinking = await run_sync(
|
||||
get_llm_model_response,
|
||||
strategy_name="query rewrite",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="agent_write_think",
|
||||
prompt_param_dict={
|
||||
"time": datetime.now().strftime("%Y%m%d"),
|
||||
"user_prompt": user_prompt,
|
||||
"think_content": think_content,
|
||||
"input": query,
|
||||
"history": history_detail,
|
||||
"res": agent.res,
|
||||
"tools": all_tool,
|
||||
"finish_tools": finish_tools,
|
||||
},
|
||||
temperature=0.01,
|
||||
max_tokens=512,
|
||||
)
|
||||
agent.res = agent.res+"<step>thinking</step>"+thinking
|
||||
print("thinking")
|
||||
case "select_tool":
|
||||
tool_desc = await run_sync(
|
||||
get_llm_model_response,
|
||||
strategy_name="query rewrite",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="tool_write_select",
|
||||
prompt_param_dict={
|
||||
"time": datetime.now().strftime("%Y%m%d"),
|
||||
"user_prompt": user_prompt,
|
||||
"think_content": think_content,
|
||||
"input": query,
|
||||
"history": history_detail,
|
||||
"res": agent.res,
|
||||
"tools": tools,
|
||||
"finish_tools": finish_tools
|
||||
},
|
||||
temperature=0.01,
|
||||
max_tokens=512,
|
||||
)
|
||||
agent.res += "<step>select_tool</step>"
|
||||
try:
|
||||
try:
|
||||
tool_name_pattern = r'<tool>(.*?)</tool>'
|
||||
toolname = re.search(tool_name_pattern, tool_desc.replace("\n","")).group(1).strip()
|
||||
|
||||
tool_input_pattern = r'<tool_input>(.*?)</tool_input>'
|
||||
toolinput=re.search(tool_input_pattern, tool_desc.replace("\n","")).group(1).strip()
|
||||
except Exception as e:
|
||||
if wrong_num > 1:
|
||||
tool_desc += f"请重新修正。以修正失败{wrong_num}次"
|
||||
tool_desc = await run_sync(
|
||||
get_llm_model_response,
|
||||
strategy_name="query rewrite",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="agent_rewrite",
|
||||
prompt_param_dict={"input": tool_desc, "format": all_tool},
|
||||
temperature=0.01,
|
||||
max_tokens=512,
|
||||
)
|
||||
try:
|
||||
tool_name_pattern = r'<tool>(.*?)</tool>'
|
||||
toolname = re.search(tool_name_pattern, tool_desc.replace("\n","")).group(1).strip()
|
||||
|
||||
tool_input_pattern = r'<tool_input>(.*?)</tool_input>'
|
||||
toolinput=re.search(tool_input_pattern, tool_desc.replace("\n","")).group(1).strip()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
wrong_num += 1
|
||||
rewrite = 1
|
||||
if wrong_num > 3:
|
||||
rewrite = 0
|
||||
finish_tools += tool_desc
|
||||
agent.res+=f"使用{tool_desc}工具出现了异常,请使用其他工具或用自身能力回答,禁止虚拟链接"
|
||||
finish_tools = finish_tools+ toolname +","
|
||||
rewrite = 0
|
||||
agent.res+="使用"+toolname+"工具的输入:【"+toolinput+"】"
|
||||
toolinput+="{\"uuid\":\""+uuid+"\"}"
|
||||
tool_func = tool_map[toolname]
|
||||
# ---------- 调用工具(阻塞 → 线程池) ----------
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
result = await tool_func(toolinput)
|
||||
else:
|
||||
result = await run_sync(tool_func, toolinput)
|
||||
if result == None and toolname == "统计数据查询":
|
||||
agent.res+="使用"+toolname+"工具没有搜索到结果,推荐用联网思索替换统计数据查询"
|
||||
else:
|
||||
agent.res+="使用"+toolname+"工具的结果:【"+result+"】"
|
||||
rewrite = 0
|
||||
if toolname == "知识库联想":
|
||||
source = utils.get_shared_variable(uuid)
|
||||
source_docs = source["source_docs"]
|
||||
try:
|
||||
docs_string = "\n"+"\n".join(f"{str(doc)}\n" for doc in source_docs)
|
||||
except:
|
||||
print("知识库联想存在异常")
|
||||
pass
|
||||
yield json.dumps({"docs": docs_string}, ensure_ascii=False)
|
||||
if toolname == "联网思索":
|
||||
source = utils.get_shared_variable(uuid)
|
||||
source_docs = source["source_docs"]
|
||||
try:
|
||||
docs_string = "\n"+"\n".join(f"{str(doc)}\n" for doc in source_docs)
|
||||
except:
|
||||
print("联网思索存在异常")
|
||||
pass
|
||||
yield json.dumps({"docs": docs_string}, ensure_ascii=False)
|
||||
# if toolname == "天气工具":
|
||||
# source = utils.get_shared_variable(uuid)
|
||||
# source_docs = source["source_docs"]
|
||||
# try:
|
||||
# docs_string = "\n"+"\n".join(f"{str(doc)}\n" for doc in source_docs)
|
||||
# except:
|
||||
# print("知识库联想存在异常")
|
||||
# pass
|
||||
# yield json.dumps({"docs": docs_string}, ensure_ascii=False)
|
||||
print("结果:", result)
|
||||
print("select_tool")
|
||||
rewrite = 0
|
||||
|
||||
except Exception as e:
|
||||
rewrite = 1
|
||||
wrong_num +=0.5
|
||||
if wrong_num > 3:
|
||||
rewrite = 0
|
||||
finish_tools += tool_desc
|
||||
agent.res+=f"使用{tool_desc}工具出现了异常,请使用其他工具或用自身能力回答,禁止虚拟链接"
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async for chunk in get_llm_model_response_stream_openai(
|
||||
type=1,
|
||||
strategy_name="query rewrite",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="agent_write_answer",
|
||||
prompt_param_dict={"time": datetime.now().strftime("%Y%m%d"),"user_prompt":user_prompt,"think_content":think_content,"input":query,"history":history_detail,"res":agent.res,"tools":all_tool,"finish_tools":finish_tools},
|
||||
temperature=0.01,
|
||||
max_tokens=20000
|
||||
):
|
||||
yield json.dumps(
|
||||
{"answer": chunk},
|
||||
ensure_ascii=False)
|
||||
Reference in New Issue
Block a user