from __future__ import annotations import json import re from langchain.agents import Tool, AgentOutputParser from langchain.prompts import StringPromptTemplate from typing import List from langchain.schema import AgentAction, AgentFinish from configs.basic_config import * from configs import SUPPORT_AGENT_MODEL from server.agent import model_container from server.chat import utils from server.chat.knowledge_base_name import KnowledgeBase from collections import defaultdict class CustomPromptTemplate(StringPromptTemplate): template: str tools: List[Tool] # def format(self, **kwargs) -> str: # intermediate_steps = kwargs.pop("intermediate_steps") # thoughts = "" # for action, observation in intermediate_steps: # thoughts += action.log # thoughts += f"\nObservation: {observation}\n" # kwargs["agent_scratchpad"] = thoughts # kwargs["tools"] = "\n".join([f"【工具名称】{tool.name}: 【工具描述】{tool.description}" for tool in self.tools]) # kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools]) # return self.template.format(**kwargs) def format(self, **kwargs) -> str: # 确保 intermediate_steps 存在 intermediate_steps = kwargs.pop("intermediate_steps", []) thoughts = "" for action, observation in intermediate_steps: thoughts += action.log thoughts += f"\nObservation: {observation}\n" # 设置默认值以防止 KeyError kwargs["agent_scratchpad"] = thoughts kwargs["tools"] = "\n".join([f"【工具名称】{tool.name}: 【工具描述】{tool.description}" for tool in getattr(self, 'tools', [])]) kwargs["tool_names"] = ", ".join([tool.name for tool in getattr(self, 'tools', [])]) # 使用正则表达式替换位置字段为命名字段 def replace_positional(match): index = match.group(1) return f"{{arg{index}}}" self.template = re.sub(r'\{(\d+)\}', replace_positional, self.template) # 为所有占位符提供默认值 placeholders = re.findall(r'\{(\w+)\}', self.template) for placeholder in placeholders: if placeholder not in kwargs: kwargs[placeholder] = f"" # 确保所有占位符都有对应的值 try: return self.template.format_map(kwargs) except KeyError as e: raise ValueError(f"Missing key in template formatting: {e}") except ValueError as e: raise ValueError(f"Format string contains positional fields: {e}") class CustomOutputParser(AgentOutputParser): begin: bool = False knowledge_base_name: KnowledgeBase time_based_uuid:str def __init__(self, knowledge_base_name: KnowledgeBase,time_based_uuid:str): super().__init__(knowledge_base_name = knowledge_base_name,time_based_uuid=time_based_uuid) self.begin = True self.knowledge_base_name = knowledge_base_name self.time_based_uuid = time_based_uuid def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction: if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin: self.begin = False stop_words = ["Observation:"] min_index = len(llm_output) for stop_word in stop_words: index = llm_output.find(stop_word) if index != -1 and index < min_index: min_index = index llm_output = llm_output[:min_index] if "Final Answer:" in llm_output: self.begin = True return AgentFinish( return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()}, log=llm_output, ) # print("llm_output>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",llm_output) # logger.info(f"llm_output: {llm_output}") parts = llm_output.split("Action:") # print("parts:", parts) if len(parts) < 2: return AgentFinish( # return_values={"output": f"以下内容如果不符合您的预期,可以试着换个方式问问我:\n\n {llm_output} "}, return_values={"output": f" {llm_output} "}, log=llm_output, ) # action = parts[1].split("Action Input:")[0].strip() match = re.search(r'Action:\s*(.+)', llm_output) try: action = match.group(1).strip().strip('\'').replace("【工具名称】","") # print("parts[1]>>>>",parts[1]) # if "\nAction Input" not in parts: action_input = parts[1].split("Action Input:")[1].strip() action_input = action_input.replace("【调用结束】","") action_input = action_input.replace("调用结束","") except: action_input = "无" # END = utils.get_shared_variable(self.time_based_uuid) # END["END"] = "ok" false_tokens = ["无","None"] # use_tools = ["联网思索","知识库联想","图表绘制","实景绘制","水墨画绘制"] use_tools = ["联网思索","知识库联想","图表绘制","数学运算","代码专家","天气工具","美术作品获取","统计数据查询"] # utils.set_shared_variable(self.time_based_uuid,END) print("action_input: ",action_input) try: if "knowledge_name" in json.loads(action_input): match = re.search(r'\{.*?\}', action_input,re.DOTALL) query = match.group(0) self.knowledge_base_name.name = json.loads(query)["knowledge_name"] except Exception as e: pass try: if not (any(false_token == action.strip("").replace("【","").replace("】","") for false_token in false_tokens)) and any(tool_temp in action for tool_temp in use_tools): ans = AgentAction( tool=action.replace("【","").replace("】",""), tool_input=action_input.strip(" ").strip('"')+"{\"uuid\":\""+self.time_based_uuid+"\"}", log=llm_output ) return ans elif action.strip("").replace("【","").replace("】","") == "环节跳转": ans = AgentAction( tool="环节跳转", tool_input=self.time_based_uuid, log=llm_output ) return ans else: ans = AgentAction( tool="无需调用工具", tool_input=self.time_based_uuid, log=llm_output ) # END = utils.get_shared_variable(self.time_based_uuid) # END["END"] = "ok" # utils.set_shared_variable(self.time_based_uuid,END) return ans except: # END = utils.get_shared_variable(self.time_based_uuid) # END["END"] = "ok" return AgentFinish( return_values={"output": f"调用agent失败: `{llm_output}`"}, log=llm_output, )