[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
162
langchain-chat/server/agent/custom_template.py
Normal file
162
langchain-chat/server/agent/custom_template.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import re
|
||||
from langchain.agents import Tool, AgentOutputParser
|
||||
from langchain.prompts import StringPromptTemplate
|
||||
from typing import List
|
||||
from langchain.schema import AgentAction, AgentFinish
|
||||
from configs.basic_config import *
|
||||
from configs import SUPPORT_AGENT_MODEL
|
||||
from server.agent import model_container
|
||||
from server.chat import utils
|
||||
from server.chat.knowledge_base_name import KnowledgeBase
|
||||
from collections import defaultdict
|
||||
class CustomPromptTemplate(StringPromptTemplate):
|
||||
template: str
|
||||
tools: List[Tool]
|
||||
|
||||
# def format(self, **kwargs) -> str:
|
||||
# intermediate_steps = kwargs.pop("intermediate_steps")
|
||||
# thoughts = ""
|
||||
# for action, observation in intermediate_steps:
|
||||
# thoughts += action.log
|
||||
# thoughts += f"\nObservation: {observation}\n"
|
||||
# kwargs["agent_scratchpad"] = thoughts
|
||||
# kwargs["tools"] = "\n".join([f"【工具名称】{tool.name}: 【工具描述】{tool.description}" for tool in self.tools])
|
||||
# kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
|
||||
# return self.template.format(**kwargs)
|
||||
|
||||
|
||||
|
||||
def format(self, **kwargs) -> str:
|
||||
# 确保 intermediate_steps 存在
|
||||
intermediate_steps = kwargs.pop("intermediate_steps", [])
|
||||
thoughts = ""
|
||||
for action, observation in intermediate_steps:
|
||||
thoughts += action.log
|
||||
thoughts += f"\nObservation: {observation}\n"
|
||||
|
||||
# 设置默认值以防止 KeyError
|
||||
kwargs["agent_scratchpad"] = thoughts
|
||||
kwargs["tools"] = "\n".join([f"【工具名称】{tool.name}: 【工具描述】{tool.description}" for tool in getattr(self, 'tools', [])])
|
||||
kwargs["tool_names"] = ", ".join([tool.name for tool in getattr(self, 'tools', [])])
|
||||
|
||||
# 使用正则表达式替换位置字段为命名字段
|
||||
def replace_positional(match):
|
||||
index = match.group(1)
|
||||
return f"{{arg{index}}}"
|
||||
|
||||
self.template = re.sub(r'\{(\d+)\}', replace_positional, self.template)
|
||||
|
||||
# 为所有占位符提供默认值
|
||||
placeholders = re.findall(r'\{(\w+)\}', self.template)
|
||||
for placeholder in placeholders:
|
||||
if placeholder not in kwargs:
|
||||
kwargs[placeholder] = f"<missing {placeholder}>"
|
||||
|
||||
# 确保所有占位符都有对应的值
|
||||
try:
|
||||
return self.template.format_map(kwargs)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing key in template formatting: {e}")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Format string contains positional fields: {e}")
|
||||
|
||||
|
||||
|
||||
class CustomOutputParser(AgentOutputParser):
|
||||
begin: bool = False
|
||||
knowledge_base_name: KnowledgeBase
|
||||
time_based_uuid:str
|
||||
def __init__(self, knowledge_base_name: KnowledgeBase,time_based_uuid:str):
|
||||
super().__init__(knowledge_base_name = knowledge_base_name,time_based_uuid=time_based_uuid)
|
||||
self.begin = True
|
||||
self.knowledge_base_name = knowledge_base_name
|
||||
self.time_based_uuid = time_based_uuid
|
||||
|
||||
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
|
||||
if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin:
|
||||
self.begin = False
|
||||
stop_words = ["Observation:"]
|
||||
min_index = len(llm_output)
|
||||
for stop_word in stop_words:
|
||||
index = llm_output.find(stop_word)
|
||||
if index != -1 and index < min_index:
|
||||
min_index = index
|
||||
llm_output = llm_output[:min_index]
|
||||
|
||||
if "Final Answer:" in llm_output:
|
||||
self.begin = True
|
||||
return AgentFinish(
|
||||
return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()},
|
||||
log=llm_output,
|
||||
)
|
||||
# print("llm_output>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",llm_output)
|
||||
# logger.info(f"llm_output: {llm_output}")
|
||||
parts = llm_output.split("Action:")
|
||||
# print("parts:", parts)
|
||||
if len(parts) < 2:
|
||||
return AgentFinish(
|
||||
# return_values={"output": f"以下内容如果不符合您的预期,可以试着换个方式问问我:\n\n {llm_output} "},
|
||||
return_values={"output": f" {llm_output} "},
|
||||
log=llm_output,
|
||||
)
|
||||
|
||||
# action = parts[1].split("Action Input:")[0].strip()
|
||||
match = re.search(r'Action:\s*(.+)', llm_output)
|
||||
try:
|
||||
action = match.group(1).strip().strip('\'').replace("【工具名称】","")
|
||||
# print("parts[1]>>>>",parts[1])
|
||||
# if "\nAction Input" not in parts:
|
||||
action_input = parts[1].split("Action Input:")[1].strip()
|
||||
action_input = action_input.replace("【调用结束】","")
|
||||
action_input = action_input.replace("调用结束","")
|
||||
except:
|
||||
action_input = "无"
|
||||
# END = utils.get_shared_variable(self.time_based_uuid)
|
||||
# END["END"] = "ok"
|
||||
false_tokens = ["无","None"]
|
||||
# use_tools = ["联网思索","知识库联想","图表绘制","实景绘制","水墨画绘制"]
|
||||
use_tools = ["联网思索","知识库联想","图表绘制","数学运算","代码专家","天气工具","美术作品获取","统计数据查询"]
|
||||
|
||||
# utils.set_shared_variable(self.time_based_uuid,END)
|
||||
print("action_input: ",action_input)
|
||||
try:
|
||||
if "knowledge_name" in json.loads(action_input):
|
||||
match = re.search(r'\{.*?\}', action_input,re.DOTALL)
|
||||
query = match.group(0)
|
||||
self.knowledge_base_name.name = json.loads(query)["knowledge_name"]
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
if not (any(false_token == action.strip("").replace("【","").replace("】","") for false_token in false_tokens)) and any(tool_temp in action for tool_temp in use_tools):
|
||||
ans = AgentAction(
|
||||
tool=action.replace("【","").replace("】",""),
|
||||
tool_input=action_input.strip(" ").strip('"')+"{\"uuid\":\""+self.time_based_uuid+"\"}",
|
||||
log=llm_output
|
||||
)
|
||||
return ans
|
||||
elif action.strip("").replace("【","").replace("】","") == "环节跳转":
|
||||
ans = AgentAction(
|
||||
tool="环节跳转",
|
||||
tool_input=self.time_based_uuid,
|
||||
log=llm_output
|
||||
)
|
||||
return ans
|
||||
else:
|
||||
ans = AgentAction(
|
||||
tool="无需调用工具",
|
||||
tool_input=self.time_based_uuid,
|
||||
log=llm_output
|
||||
)
|
||||
# END = utils.get_shared_variable(self.time_based_uuid)
|
||||
# END["END"] = "ok"
|
||||
# utils.set_shared_variable(self.time_based_uuid,END)
|
||||
return ans
|
||||
except:
|
||||
# END = utils.get_shared_variable(self.time_based_uuid)
|
||||
# END["END"] = "ok"
|
||||
return AgentFinish(
|
||||
return_values={"output": f"调用agent失败: `{llm_output}`"},
|
||||
log=llm_output,
|
||||
)
|
||||
Reference in New Issue
Block a user