Files
gangyan/langchain-chat/server/agent/custom_template.py

163 lines
7.1 KiB
Python

from __future__ import annotations
import json
import re
from langchain.agents import Tool, AgentOutputParser
from langchain.prompts import StringPromptTemplate
from typing import List
from langchain.schema import AgentAction, AgentFinish
from configs.basic_config import *
from configs import SUPPORT_AGENT_MODEL
from server.agent import model_container
from server.chat import utils
from server.chat.knowledge_base_name import KnowledgeBase
from collections import defaultdict
class CustomPromptTemplate(StringPromptTemplate):
template: str
tools: List[Tool]
# def format(self, **kwargs) -> str:
# intermediate_steps = kwargs.pop("intermediate_steps")
# thoughts = ""
# for action, observation in intermediate_steps:
# thoughts += action.log
# thoughts += f"\nObservation: {observation}\n"
# kwargs["agent_scratchpad"] = thoughts
# kwargs["tools"] = "\n".join([f"【工具名称】{tool.name}: 【工具描述】{tool.description}" for tool in self.tools])
# kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
# return self.template.format(**kwargs)
def format(self, **kwargs) -> str:
# 确保 intermediate_steps 存在
intermediate_steps = kwargs.pop("intermediate_steps", [])
thoughts = ""
for action, observation in intermediate_steps:
thoughts += action.log
thoughts += f"\nObservation: {observation}\n"
# 设置默认值以防止 KeyError
kwargs["agent_scratchpad"] = thoughts
kwargs["tools"] = "\n".join([f"【工具名称】{tool.name}: 【工具描述】{tool.description}" for tool in getattr(self, 'tools', [])])
kwargs["tool_names"] = ", ".join([tool.name for tool in getattr(self, 'tools', [])])
# 使用正则表达式替换位置字段为命名字段
def replace_positional(match):
index = match.group(1)
return f"{{arg{index}}}"
self.template = re.sub(r'\{(\d+)\}', replace_positional, self.template)
# 为所有占位符提供默认值
placeholders = re.findall(r'\{(\w+)\}', self.template)
for placeholder in placeholders:
if placeholder not in kwargs:
kwargs[placeholder] = f"<missing {placeholder}>"
# 确保所有占位符都有对应的值
try:
return self.template.format_map(kwargs)
except KeyError as e:
raise ValueError(f"Missing key in template formatting: {e}")
except ValueError as e:
raise ValueError(f"Format string contains positional fields: {e}")
class CustomOutputParser(AgentOutputParser):
begin: bool = False
knowledge_base_name: KnowledgeBase
time_based_uuid:str
def __init__(self, knowledge_base_name: KnowledgeBase,time_based_uuid:str):
super().__init__(knowledge_base_name = knowledge_base_name,time_based_uuid=time_based_uuid)
self.begin = True
self.knowledge_base_name = knowledge_base_name
self.time_based_uuid = time_based_uuid
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin:
self.begin = False
stop_words = ["Observation:"]
min_index = len(llm_output)
for stop_word in stop_words:
index = llm_output.find(stop_word)
if index != -1 and index < min_index:
min_index = index
llm_output = llm_output[:min_index]
if "Final Answer:" in llm_output:
self.begin = True
return AgentFinish(
return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()},
log=llm_output,
)
# print("llm_output>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",llm_output)
# logger.info(f"llm_output: {llm_output}")
parts = llm_output.split("Action:")
# print("parts:", parts)
if len(parts) < 2:
return AgentFinish(
# return_values={"output": f"以下内容如果不符合您的预期,可以试着换个方式问问我:\n\n {llm_output} "},
return_values={"output": f" {llm_output} "},
log=llm_output,
)
# action = parts[1].split("Action Input:")[0].strip()
match = re.search(r'Action:\s*(.+)', llm_output)
try:
action = match.group(1).strip().strip('\'').replace("【工具名称】","")
# print("parts[1]>>>>",parts[1])
# if "\nAction Input" not in parts:
action_input = parts[1].split("Action Input:")[1].strip()
action_input = action_input.replace("【调用结束】","")
action_input = action_input.replace("调用结束","")
except:
action_input = ""
# END = utils.get_shared_variable(self.time_based_uuid)
# END["END"] = "ok"
false_tokens = ["","None"]
# use_tools = ["联网思索","知识库联想","图表绘制","实景绘制","水墨画绘制"]
use_tools = ["联网思索","知识库联想","图表绘制","数学运算","代码专家","天气工具","美术作品获取","统计数据查询"]
# utils.set_shared_variable(self.time_based_uuid,END)
print("action_input: ",action_input)
try:
if "knowledge_name" in json.loads(action_input):
match = re.search(r'\{.*?\}', action_input,re.DOTALL)
query = match.group(0)
self.knowledge_base_name.name = json.loads(query)["knowledge_name"]
except Exception as e:
pass
try:
if not (any(false_token == action.strip("").replace("","").replace("","") for false_token in false_tokens)) and any(tool_temp in action for tool_temp in use_tools):
ans = AgentAction(
tool=action.replace("","").replace("",""),
tool_input=action_input.strip(" ").strip('"')+"{\"uuid\":\""+self.time_based_uuid+"\"}",
log=llm_output
)
return ans
elif action.strip("").replace("","").replace("","") == "环节跳转":
ans = AgentAction(
tool="环节跳转",
tool_input=self.time_based_uuid,
log=llm_output
)
return ans
else:
ans = AgentAction(
tool="无需调用工具",
tool_input=self.time_based_uuid,
log=llm_output
)
# END = utils.get_shared_variable(self.time_based_uuid)
# END["END"] = "ok"
# utils.set_shared_variable(self.time_based_uuid,END)
return ans
except:
# END = utils.get_shared_variable(self.time_based_uuid)
# END["END"] = "ok"
return AgentFinish(
return_values={"output": f"调用agent失败: `{llm_output}`"},
log=llm_output,
)