297 lines
13 KiB
Python
297 lines
13 KiB
Python
from concurrent.futures import ThreadPoolExecutor
|
|
from datetime import datetime
|
|
import uuid
|
|
from langchain.memory import ConversationBufferMemory
|
|
from server.agent.agent import Agent
|
|
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
|
|
from server.agent.tools_select import tools, tool_names, search_tool_names
|
|
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
|
|
from langchain.agents import LLMSingleActionAgent, AgentExecutor
|
|
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
|
|
from fastapi import Body
|
|
from sse_starlette.sse import EventSourceResponse
|
|
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
|
|
from server.chat import utils
|
|
from server.chat.knowledge_base_name import KnowledgeBase
|
|
from server.utils import replace_variables, wrap_done, get_ChatOpenAI, get_prompt_template
|
|
from langchain.chains import LLMChain
|
|
from typing import AsyncIterable, Optional
|
|
from server.agent.tools import rag_search
|
|
import asyncio
|
|
from typing import List
|
|
from server.chat.utils import History, remove_after_and_including, remove_after_and_includings
|
|
import json
|
|
from server.agent import model_container
|
|
from server.knowledge_base.kb_service.base import get_kb_details
|
|
import ast
|
|
import re
|
|
from server.chat.policy_fun_iast import get_llm_model_response, get_llm_model_response_async, get_llm_model_response_stream_openai
|
|
from configs import kb_config
|
|
from configs.basic_config import *
|
|
|
|
|
|
|
|
from datetime import datetime
|
|
from typing import AsyncIterable, List, Optional, Dict, Any
|
|
|
|
from fastapi import Body
|
|
|
|
|
|
_executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=8)
|
|
|
|
|
|
async def run_sync(func, /, *args, **kwargs):
|
|
"""
|
|
将同步阻塞函数放线程池执行,避免事件循环阻塞。
|
|
用法: result = await run_sync(blocking_fn, *args, **kwargs)
|
|
"""
|
|
loop = asyncio.get_running_loop()
|
|
return await loop.run_in_executor(_executor, lambda: func(*args, **kwargs))
|
|
|
|
|
|
|
|
async def agent_chat_test(
|
|
user_prompt_name: Optional[str] = Body(None, description="用户输入"),
|
|
style: Optional[str] = Body(None, description="语言风格"),
|
|
query: str = Body(..., description="用户输入"),
|
|
think_content: str = Body(..., description="思考过程"),
|
|
uuid: str = Body(..., description="uuid"),
|
|
history: List["History"] = Body([], description="历史对话"),
|
|
stream: bool = Body(False, description="流式输出"),
|
|
model_name: str = Body(..., description="LLM 模型名称"),
|
|
temperature: float = Body(0.1, description="LLM 采样温度", ge=0.0, le=1.0),
|
|
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量"),
|
|
prompt_name: str = Body("default", description="Prompt 模板名称"),
|
|
) -> AsyncIterable[str]:
|
|
agent = Agent()
|
|
agent.step = "暂未执行步骤"
|
|
finish_tools = "已经调用过的工具名称:"
|
|
|
|
# ------------- prompt 预处理 -------------
|
|
user_prompt = get_prompt_template("llm_chat", user_prompt_name)
|
|
if user_prompt_name == "complete_outline":
|
|
user_prompt = replace_variables(
|
|
user_prompt, replace_content=style, replace_param="{style}"
|
|
)
|
|
user_prompt = (
|
|
user_prompt.replace("{{time}}", datetime.now().strftime("%Y%m%d"))
|
|
.replace("{{input}}", query)
|
|
)
|
|
|
|
# 工具映射
|
|
tool_map: Dict[str, Any] = {tool.name: tool.func for tool in tools}
|
|
|
|
history_detail = str(history) if history else ""
|
|
all_tool = "\n".join(
|
|
f"【工具名称】{tool.name}: 【工具描述】{tool.description}" for tool in tools
|
|
)
|
|
|
|
rewrite = 0
|
|
wrong_num = 0
|
|
|
|
# ====================== 主循环 ======================
|
|
while agent.step != "answer":
|
|
# ---------- ① 获取下一步提示 ----------
|
|
if rewrite == 0:
|
|
step_res = await run_sync(
|
|
get_llm_model_response,
|
|
strategy_name="query rewrite",
|
|
llm_model_name=LLM_MODELS[0],
|
|
template_prompt_name="get_next_tip",
|
|
prompt_param_dict={
|
|
"time": datetime.now().strftime("%Y%m%d"),
|
|
"step": agent.step,
|
|
"question": query,
|
|
"user_prompt": user_prompt,
|
|
"history": history_detail,
|
|
"tools": all_tool,
|
|
"res": agent.res,
|
|
"finish_tools": finish_tools,
|
|
},
|
|
temperature=0.01,
|
|
max_tokens=512,
|
|
)
|
|
next_step_pattern = r"<step>(.*?)</step>"
|
|
next_step_content = re.search(next_step_pattern, step_res, re.DOTALL)
|
|
agent.step = next_step_content.group(1).strip() if next_step_content else ""
|
|
# agent.res += (
|
|
# step_res.replace("<step>", "")
|
|
# .replace("</step>", "")
|
|
# .replace(agent.step, "")
|
|
# )
|
|
agent.res += step_res
|
|
current_tip = (
|
|
step_res.replace("<step>", "")
|
|
.replace("</step>", "")
|
|
.replace(agent.step, "")
|
|
)
|
|
# ---------- ② 按 step 分支 ----------
|
|
match agent.step:
|
|
# ========== thinking ==========
|
|
case "thinking":
|
|
thinking = await run_sync(
|
|
get_llm_model_response,
|
|
strategy_name="query rewrite",
|
|
llm_model_name=LLM_MODELS[0],
|
|
template_prompt_name="agent_think",
|
|
prompt_param_dict={
|
|
"time": datetime.now().strftime("%Y%m%d"),
|
|
"user_prompt": user_prompt,
|
|
"think_content": think_content,
|
|
"input": query,
|
|
"history": history_detail,
|
|
"res": agent.res,
|
|
"tools": all_tool,
|
|
"finish_tools": finish_tools,
|
|
},
|
|
temperature=0.01,
|
|
max_tokens=512,
|
|
)
|
|
agent.res += thinking
|
|
print("thinking")
|
|
|
|
# ========== select_tool ==========
|
|
case "select_tool":
|
|
tool_desc = await run_sync(
|
|
get_llm_model_response,
|
|
strategy_name="query rewrite",
|
|
llm_model_name=LLM_MODELS[0],
|
|
template_prompt_name="tool_select",
|
|
prompt_param_dict={
|
|
"time": datetime.now().strftime("%Y%m%d"),
|
|
"user_prompt": user_prompt,
|
|
"think_content": think_content,
|
|
"input": query,
|
|
"history": history_detail,
|
|
"res": agent.res,
|
|
"current_tip": current_tip,
|
|
"tools": tools,
|
|
"finish_tools": finish_tools,
|
|
},
|
|
temperature=0.01,
|
|
max_tokens=512,
|
|
)
|
|
try:
|
|
try:
|
|
# -------- 你的原始解析逻辑 --------
|
|
tool_name_pattern = r"<tool>(.*?)</tool>"
|
|
toolname = re.search(
|
|
tool_name_pattern, tool_desc.replace("\n", "")
|
|
).group(1).strip()
|
|
|
|
tool_input_pattern = r"<tool_input>(.*?)</tool_input>"
|
|
toolinput = re.search(
|
|
tool_input_pattern, tool_desc.replace("\n", "")
|
|
).group(1).strip()
|
|
except Exception as e:
|
|
print("开始修正:")
|
|
if wrong_num > 1:
|
|
tool_desc += f"请重新修正。以修正失败{wrong_num}次"
|
|
tool_desc = await run_sync(
|
|
get_llm_model_response,
|
|
strategy_name="query rewrite",
|
|
llm_model_name=LLM_MODELS[0],
|
|
template_prompt_name="agent_rewrite",
|
|
prompt_param_dict={"input": tool_desc, "format": all_tool},
|
|
temperature=0.01,
|
|
max_tokens=512,
|
|
)
|
|
try:
|
|
toolname = re.search(
|
|
r"<tool>(.*?)</tool>", tool_desc.replace("\n", "")
|
|
).group(1).strip()
|
|
|
|
toolinput = re.search(
|
|
r"<tool_input>(.*?)</tool_input>",
|
|
tool_desc.replace("\n", ""),
|
|
).group(1).strip()
|
|
except Exception as e:
|
|
print(e)
|
|
wrong_num += 1
|
|
rewrite = 1
|
|
if wrong_num > 3:
|
|
rewrite = 0
|
|
finish_tools += tool_desc
|
|
agent.res+=f"使用{tool_desc}工具出现了异常,请使用其他工具或用自身能力回答,禁止虚拟链接"
|
|
|
|
finish_tools = finish_tools + toolname + ","
|
|
rewrite = 0
|
|
agent.res+="使用"+toolname+"工具的输入:【"+toolinput+"】"
|
|
toolinput += '{"uuid":"' + uuid + '"}'
|
|
tool_func = tool_map[toolname]
|
|
|
|
# ---------- 调用工具(阻塞 → 线程池) ----------
|
|
if asyncio.iscoroutinefunction(tool_func):
|
|
result = await tool_func(toolinput)
|
|
else:
|
|
result = await run_sync(tool_func, toolinput)
|
|
|
|
# ---------- 原来的结果处理 ----------
|
|
if result is None and toolname == "统计数据查询":
|
|
agent.res += (
|
|
"使用" + toolname + "工具没有搜索到结果,推荐用联网思索替换统计数据查询"
|
|
)
|
|
else:
|
|
agent.res += "使用" + toolname + "工具的结果:【" + str(result) + "】"
|
|
if toolname == "知识库联想" or toolname == "联网思索":
|
|
yield json.dumps({"detail": str(result)}, ensure_ascii=False)
|
|
|
|
if toolname == "知识库联想":
|
|
source = utils.get_shared_variable(uuid)
|
|
source_docs = source["source_docs"]
|
|
try:
|
|
docs_string = "\n" + "\n".join(
|
|
f"{str(doc)}\n" for doc in source_docs
|
|
)
|
|
except Exception:
|
|
print("知识库联想存在异常")
|
|
docs_string = ""
|
|
yield json.dumps({"docs": docs_string}, ensure_ascii=False)
|
|
|
|
if toolname == "联网思索":
|
|
source = utils.get_shared_variable(uuid)
|
|
source_docs = source["source_docs"]
|
|
try:
|
|
docs_string = "\n" + "\n".join(
|
|
f"{str(doc)}\n" for doc in source_docs
|
|
)
|
|
except Exception:
|
|
print("联网思索存在异常")
|
|
docs_string = ""
|
|
yield json.dumps({"docs": docs_string}, ensure_ascii=False)
|
|
|
|
print("结果:", result)
|
|
print("select_tool")
|
|
rewrite = 0
|
|
except Exception as e:
|
|
rewrite = 1
|
|
wrong_num += 0.5
|
|
if wrong_num >3:
|
|
finish_tools += tool_desc
|
|
agent.res+=f"使用{tool_desc}工具出现了异常,请使用其他工具或用自身能力回答,禁止虚拟链接"
|
|
rewrite = 0
|
|
# 其它 step 分支可继续在这里扩展……
|
|
|
|
# ---------- ③ 每轮让出事件循环 ----------
|
|
await asyncio.sleep(0)
|
|
|
|
# ====================== 最终回答 ======================
|
|
async for chunk in get_llm_model_response_stream_openai(
|
|
type=0,
|
|
strategy_name="query rewrite",
|
|
llm_model_name=LLM_MODELS[0],
|
|
template_prompt_name="agent_answer",
|
|
prompt_param_dict={
|
|
"time": datetime.now().strftime("%Y%m%d"),
|
|
"user_prompt": user_prompt,
|
|
"think_content": '无',
|
|
"input": query,
|
|
"history": history_detail,
|
|
"res": agent.res,
|
|
"tools": all_tool,
|
|
"finish_tools": finish_tools,
|
|
},
|
|
temperature=0.7,
|
|
max_tokens=20000,
|
|
):
|
|
yield json.dumps({"answer": chunk}, ensure_ascii=False) |