from datetime import datetime import uuid from langchain.memory import ConversationBufferMemory from server.agent.agent import Agent from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent from server.agent.tools_select import tools, tool_names, search_tool_names from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status from langchain.agents import LLMSingleActionAgent, AgentExecutor from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate from fastapi import Body from sse_starlette.sse import EventSourceResponse from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL from server.chat import utils from server.chat.knowledge_base_name import KnowledgeBase from server.utils import replace_variables, wrap_done, get_ChatOpenAI, get_prompt_template from langchain.chains import LLMChain from typing import AsyncIterable, Optional from server.agent.tools import rag_search import asyncio from typing import List from server.chat.utils import History, remove_after_and_including, remove_after_and_includings import json from server.agent import model_container from server.knowledge_base.kb_service.base import get_kb_details import ast import re import asyncio from concurrent.futures import ThreadPoolExecutor from server.chat.policy_fun_iast import get_llm_model_response, get_llm_model_response_async, get_llm_model_response_stream_openai from configs import kb_config from configs.basic_config import * _executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=8) async def run_sync(func, /, *args, **kwargs): loop = asyncio.get_running_loop() return await loop.run_in_executor(_executor, lambda: func(*args, **kwargs)) async def agent_write_test( user_prompt_name: Optional[str] = Body(None, description="用户输入", examples=[""]), style: Optional[str] = Body("默认风格", description="语言风格", examples=[""]), query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), think_content: str = Body(..., description="思考过程", examples=[""]), uuid: str=Body(..., description="用户输入", examples=["恼羞成怒"]), history: List[History] = Body([], description="历史对话", examples=[[ {"role": "user", "content": "请使用天气查询工具查询今天北京天气"}, {"role": "assistant", "content": "今天是2024年3月22日,受冷空气影响,白天有3、4级偏北风,阵风6、7" "级,西部山区阵风相对明显,局地伴有扬沙。白天晴,局地有扬沙,偏北风,1级转3、4级,阵风6、7级," "最高气温22℃。夜间晴间多云,偏北风,1、2级,最低气温6℃。"}]] ), stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), temperature: float = Body(0.1, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ) -> AsyncIterable[str]: agent = Agent() agent.step = "初始步骤" finish_tools = "已经调用过的工具名称:" if user_prompt_name == "complete_outline": user_prompt = get_prompt_template("llm_chat",user_prompt_name+"_with_history") user_prompt = replace_variables(user_prompt, replace_content=style, replace_param="{style}") else: user_prompt = get_prompt_template("llm_chat",user_prompt_name) user_prompt = user_prompt.replace("{{time}}",datetime.now().strftime("%Y%m%d")).replace("{{input}}",query) # 假设 Tool 有 .name 和 .func 属性 tool_map = {tool.name: tool.func for tool in tools} if history: history_detail = str(history) else: history_detail = "" all_tool = "\n".join([f'【工具名称】{tool.name}: 【工具描述】{tool.description}' for tool in tools]) rewrite = 0 wrong_num = 0 while not agent.step == "answer": if rewrite == 0: step_res = await run_sync( get_llm_model_response, strategy_name="query rewrite", llm_model_name=LLM_MODELS[0], template_prompt_name="get_next_write_tip", prompt_param_dict={ "time": datetime.now().strftime("%Y%m%d"), "step": agent.step, "question": query, "user_prompt": user_prompt, "history": history_detail, "tools": all_tool, "res": agent.res, "finish_tools": finish_tools, "think_content": think_content, }, temperature=0.01, max_tokens=512, ) next_step_pattern = r'(.*?)' next_step_content = re.search(next_step_pattern, step_res, re.DOTALL) agent.step = next_step_content.group(1).strip() if next_step_content else "" agent.res += ( step_res.replace("", "") .replace("", "") .replace(agent.step, "") ) match agent.step: case "thinking": thinking = await run_sync( get_llm_model_response, strategy_name="query rewrite", llm_model_name=LLM_MODELS[0], template_prompt_name="agent_write_think", prompt_param_dict={ "time": datetime.now().strftime("%Y%m%d"), "user_prompt": user_prompt, "think_content": think_content, "input": query, "history": history_detail, "res": agent.res, "tools": all_tool, "finish_tools": finish_tools, }, temperature=0.01, max_tokens=512, ) agent.res = agent.res+"thinking"+thinking print("thinking") case "select_tool": tool_desc = await run_sync( get_llm_model_response, strategy_name="query rewrite", llm_model_name=LLM_MODELS[0], template_prompt_name="tool_write_select", prompt_param_dict={ "time": datetime.now().strftime("%Y%m%d"), "user_prompt": user_prompt, "think_content": think_content, "input": query, "history": history_detail, "res": agent.res, "tools": tools, "finish_tools": finish_tools }, temperature=0.01, max_tokens=512, ) agent.res += "select_tool" try: try: tool_name_pattern = r'(.*?)' toolname = re.search(tool_name_pattern, tool_desc.replace("\n","")).group(1).strip() tool_input_pattern = r'(.*?)' toolinput=re.search(tool_input_pattern, tool_desc.replace("\n","")).group(1).strip() except Exception as e: if wrong_num > 1: tool_desc += f"请重新修正。以修正失败{wrong_num}次" tool_desc = await run_sync( get_llm_model_response, strategy_name="query rewrite", llm_model_name=LLM_MODELS[0], template_prompt_name="agent_rewrite", prompt_param_dict={"input": tool_desc, "format": all_tool}, temperature=0.01, max_tokens=512, ) try: tool_name_pattern = r'(.*?)' toolname = re.search(tool_name_pattern, tool_desc.replace("\n","")).group(1).strip() tool_input_pattern = r'(.*?)' toolinput=re.search(tool_input_pattern, tool_desc.replace("\n","")).group(1).strip() except Exception as e: print(e) wrong_num += 1 rewrite = 1 if wrong_num > 3: rewrite = 0 finish_tools += tool_desc agent.res+=f"使用{tool_desc}工具出现了异常,请使用其他工具或用自身能力回答,禁止虚拟链接" finish_tools = finish_tools+ toolname +"," rewrite = 0 agent.res+="使用"+toolname+"工具的输入:【"+toolinput+"】" toolinput+="{\"uuid\":\""+uuid+"\"}" tool_func = tool_map[toolname] # ---------- 调用工具(阻塞 → 线程池) ---------- if asyncio.iscoroutinefunction(tool_func): result = await tool_func(toolinput) else: result = await run_sync(tool_func, toolinput) if result == None and toolname == "统计数据查询": agent.res+="使用"+toolname+"工具没有搜索到结果,推荐用联网思索替换统计数据查询" else: agent.res+="使用"+toolname+"工具的结果:【"+result+"】" rewrite = 0 if toolname == "知识库联想": source = utils.get_shared_variable(uuid) source_docs = source["source_docs"] try: docs_string = "\n"+"\n".join(f"{str(doc)}\n" for doc in source_docs) except: print("知识库联想存在异常") pass yield json.dumps({"docs": docs_string}, ensure_ascii=False) if toolname == "联网思索": source = utils.get_shared_variable(uuid) source_docs = source["source_docs"] try: docs_string = "\n"+"\n".join(f"{str(doc)}\n" for doc in source_docs) except: print("联网思索存在异常") pass yield json.dumps({"docs": docs_string}, ensure_ascii=False) # if toolname == "天气工具": # source = utils.get_shared_variable(uuid) # source_docs = source["source_docs"] # try: # docs_string = "\n"+"\n".join(f"{str(doc)}\n" for doc in source_docs) # except: # print("知识库联想存在异常") # pass # yield json.dumps({"docs": docs_string}, ensure_ascii=False) print("结果:", result) print("select_tool") rewrite = 0 except Exception as e: rewrite = 1 wrong_num +=0.5 if wrong_num > 3: rewrite = 0 finish_tools += tool_desc agent.res+=f"使用{tool_desc}工具出现了异常,请使用其他工具或用自身能力回答,禁止虚拟链接" await asyncio.sleep(0) async for chunk in get_llm_model_response_stream_openai( type=1, strategy_name="query rewrite", llm_model_name=LLM_MODELS[0], template_prompt_name="agent_write_answer", prompt_param_dict={"time": datetime.now().strftime("%Y%m%d"),"user_prompt":user_prompt,"think_content":think_content,"input":query,"history":history_detail,"res":agent.res,"tools":all_tool,"finish_tools":finish_tools}, temperature=0.01, max_tokens=20000 ): yield json.dumps( {"answer": chunk}, ensure_ascii=False)