163 lines
7.1 KiB
Python
163 lines
7.1 KiB
Python
from __future__ import annotations
|
|
import json
|
|
import re
|
|
from langchain.agents import Tool, AgentOutputParser
|
|
from langchain.prompts import StringPromptTemplate
|
|
from typing import List
|
|
from langchain.schema import AgentAction, AgentFinish
|
|
from configs.basic_config import *
|
|
from configs import SUPPORT_AGENT_MODEL
|
|
from server.agent import model_container
|
|
from server.chat import utils
|
|
from server.chat.knowledge_base_name import KnowledgeBase
|
|
from collections import defaultdict
|
|
class CustomPromptTemplate(StringPromptTemplate):
|
|
template: str
|
|
tools: List[Tool]
|
|
|
|
# def format(self, **kwargs) -> str:
|
|
# intermediate_steps = kwargs.pop("intermediate_steps")
|
|
# thoughts = ""
|
|
# for action, observation in intermediate_steps:
|
|
# thoughts += action.log
|
|
# thoughts += f"\nObservation: {observation}\n"
|
|
# kwargs["agent_scratchpad"] = thoughts
|
|
# kwargs["tools"] = "\n".join([f"【工具名称】{tool.name}: 【工具描述】{tool.description}" for tool in self.tools])
|
|
# kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
|
|
# return self.template.format(**kwargs)
|
|
|
|
|
|
|
|
def format(self, **kwargs) -> str:
|
|
# 确保 intermediate_steps 存在
|
|
intermediate_steps = kwargs.pop("intermediate_steps", [])
|
|
thoughts = ""
|
|
for action, observation in intermediate_steps:
|
|
thoughts += action.log
|
|
thoughts += f"\nObservation: {observation}\n"
|
|
|
|
# 设置默认值以防止 KeyError
|
|
kwargs["agent_scratchpad"] = thoughts
|
|
kwargs["tools"] = "\n".join([f"【工具名称】{tool.name}: 【工具描述】{tool.description}" for tool in getattr(self, 'tools', [])])
|
|
kwargs["tool_names"] = ", ".join([tool.name for tool in getattr(self, 'tools', [])])
|
|
|
|
# 使用正则表达式替换位置字段为命名字段
|
|
def replace_positional(match):
|
|
index = match.group(1)
|
|
return f"{{arg{index}}}"
|
|
|
|
self.template = re.sub(r'\{(\d+)\}', replace_positional, self.template)
|
|
|
|
# 为所有占位符提供默认值
|
|
placeholders = re.findall(r'\{(\w+)\}', self.template)
|
|
for placeholder in placeholders:
|
|
if placeholder not in kwargs:
|
|
kwargs[placeholder] = f"<missing {placeholder}>"
|
|
|
|
# 确保所有占位符都有对应的值
|
|
try:
|
|
return self.template.format_map(kwargs)
|
|
except KeyError as e:
|
|
raise ValueError(f"Missing key in template formatting: {e}")
|
|
except ValueError as e:
|
|
raise ValueError(f"Format string contains positional fields: {e}")
|
|
|
|
|
|
|
|
class CustomOutputParser(AgentOutputParser):
|
|
begin: bool = False
|
|
knowledge_base_name: KnowledgeBase
|
|
time_based_uuid:str
|
|
def __init__(self, knowledge_base_name: KnowledgeBase,time_based_uuid:str):
|
|
super().__init__(knowledge_base_name = knowledge_base_name,time_based_uuid=time_based_uuid)
|
|
self.begin = True
|
|
self.knowledge_base_name = knowledge_base_name
|
|
self.time_based_uuid = time_based_uuid
|
|
|
|
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
|
|
if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin:
|
|
self.begin = False
|
|
stop_words = ["Observation:"]
|
|
min_index = len(llm_output)
|
|
for stop_word in stop_words:
|
|
index = llm_output.find(stop_word)
|
|
if index != -1 and index < min_index:
|
|
min_index = index
|
|
llm_output = llm_output[:min_index]
|
|
|
|
if "Final Answer:" in llm_output:
|
|
self.begin = True
|
|
return AgentFinish(
|
|
return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()},
|
|
log=llm_output,
|
|
)
|
|
# print("llm_output>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",llm_output)
|
|
# logger.info(f"llm_output: {llm_output}")
|
|
parts = llm_output.split("Action:")
|
|
# print("parts:", parts)
|
|
if len(parts) < 2:
|
|
return AgentFinish(
|
|
# return_values={"output": f"以下内容如果不符合您的预期,可以试着换个方式问问我:\n\n {llm_output} "},
|
|
return_values={"output": f" {llm_output} "},
|
|
log=llm_output,
|
|
)
|
|
|
|
# action = parts[1].split("Action Input:")[0].strip()
|
|
match = re.search(r'Action:\s*(.+)', llm_output)
|
|
try:
|
|
action = match.group(1).strip().strip('\'').replace("【工具名称】","")
|
|
# print("parts[1]>>>>",parts[1])
|
|
# if "\nAction Input" not in parts:
|
|
action_input = parts[1].split("Action Input:")[1].strip()
|
|
action_input = action_input.replace("【调用结束】","")
|
|
action_input = action_input.replace("调用结束","")
|
|
except:
|
|
action_input = "无"
|
|
# END = utils.get_shared_variable(self.time_based_uuid)
|
|
# END["END"] = "ok"
|
|
false_tokens = ["无","None"]
|
|
# use_tools = ["联网思索","知识库联想","图表绘制","实景绘制","水墨画绘制"]
|
|
use_tools = ["联网思索","知识库联想","图表绘制","数学运算","代码专家","天气工具","美术作品获取","统计数据查询"]
|
|
|
|
# utils.set_shared_variable(self.time_based_uuid,END)
|
|
print("action_input: ",action_input)
|
|
try:
|
|
if "knowledge_name" in json.loads(action_input):
|
|
match = re.search(r'\{.*?\}', action_input,re.DOTALL)
|
|
query = match.group(0)
|
|
self.knowledge_base_name.name = json.loads(query)["knowledge_name"]
|
|
except Exception as e:
|
|
pass
|
|
try:
|
|
if not (any(false_token == action.strip("").replace("【","").replace("】","") for false_token in false_tokens)) and any(tool_temp in action for tool_temp in use_tools):
|
|
ans = AgentAction(
|
|
tool=action.replace("【","").replace("】",""),
|
|
tool_input=action_input.strip(" ").strip('"')+"{\"uuid\":\""+self.time_based_uuid+"\"}",
|
|
log=llm_output
|
|
)
|
|
return ans
|
|
elif action.strip("").replace("【","").replace("】","") == "环节跳转":
|
|
ans = AgentAction(
|
|
tool="环节跳转",
|
|
tool_input=self.time_based_uuid,
|
|
log=llm_output
|
|
)
|
|
return ans
|
|
else:
|
|
ans = AgentAction(
|
|
tool="无需调用工具",
|
|
tool_input=self.time_based_uuid,
|
|
log=llm_output
|
|
)
|
|
# END = utils.get_shared_variable(self.time_based_uuid)
|
|
# END["END"] = "ok"
|
|
# utils.set_shared_variable(self.time_based_uuid,END)
|
|
return ans
|
|
except:
|
|
# END = utils.get_shared_variable(self.time_based_uuid)
|
|
# END["END"] = "ok"
|
|
return AgentFinish(
|
|
return_values={"output": f"调用agent失败: `{llm_output}`"},
|
|
log=llm_output,
|
|
)
|