[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
228
langchain-chat/server/agent/custom_agent/ChatGLM3Agent.py
Normal file
228
langchain-chat/server/agent/custom_agent/ChatGLM3Agent.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, List, Sequence, Tuple, Optional, Union
|
||||
from pydantic.schema import model_schema
|
||||
|
||||
|
||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
||||
from langchain.memory import ConversationBufferWindowMemory
|
||||
from langchain.agents.agent import Agent
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
|
||||
from langchain.agents.agent import AgentOutputParser
|
||||
from langchain.output_parsers import OutputFixingParser
|
||||
from langchain.pydantic_v1 import Field
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
||||
"""Output parser with retries for the structured chat agent."""
|
||||
|
||||
base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser)
|
||||
"""The base parser to use."""
|
||||
output_fixing_parser: Optional[OutputFixingParser] = None
|
||||
"""The output fixing parser to use."""
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
special_tokens = ["Action:", "<|observation|>"]
|
||||
first_index = min([text.find(token) if token in text else len(text) for token in special_tokens])
|
||||
text = text[:first_index]
|
||||
if "tool_call" in text:
|
||||
action_end = text.find("```")
|
||||
action = text[:action_end].strip()
|
||||
params_str_start = text.find("(") + 1
|
||||
params_str_end = text.rfind(")")
|
||||
params_str = text[params_str_start:params_str_end]
|
||||
|
||||
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
|
||||
params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs}
|
||||
|
||||
action_json = {
|
||||
"action": action,
|
||||
"action_input": params
|
||||
}
|
||||
else:
|
||||
action_json = {
|
||||
"action": "Final Answer",
|
||||
"action_input": text
|
||||
}
|
||||
action_str = f"""
|
||||
Action:
|
||||
```
|
||||
{json.dumps(action_json, ensure_ascii=False)}
|
||||
```"""
|
||||
try:
|
||||
if self.output_fixing_parser is not None:
|
||||
parsed_obj: Union[
|
||||
AgentAction, AgentFinish
|
||||
] = self.output_fixing_parser.parse(action_str)
|
||||
else:
|
||||
parsed_obj = self.base_parser.parse(action_str)
|
||||
return parsed_obj
|
||||
except Exception as e:
|
||||
raise OutputParserException(f"Could not parse LLM output: {text}") from e
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "structured_chat_ChatGLM3_6b_with_retries"
|
||||
|
||||
|
||||
class StructuredGLM3ChatAgent(Agent):
|
||||
"""Structured Chat Agent."""
|
||||
|
||||
output_parser: AgentOutputParser = Field(
|
||||
default_factory=StructuredChatOutputParserWithRetries
|
||||
)
|
||||
"""Output parser for the agent."""
|
||||
|
||||
@property
|
||||
def observation_prefix(self) -> str:
|
||||
"""Prefix to append the ChatGLM3-6B observation with."""
|
||||
return "Observation:"
|
||||
|
||||
@property
|
||||
def llm_prefix(self) -> str:
|
||||
"""Prefix to append the llm call with."""
|
||||
return "Thought:"
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
|
||||
if not isinstance(agent_scratchpad, str):
|
||||
raise ValueError("agent_scratchpad should be of type string.")
|
||||
if agent_scratchpad:
|
||||
return (
|
||||
f"This was your previous work "
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
f"you return as final answer):\n{agent_scratchpad}"
|
||||
)
|
||||
else:
|
||||
return agent_scratchpad
|
||||
|
||||
@classmethod
|
||||
def _get_default_output_parser(
|
||||
cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
|
||||
) -> AgentOutputParser:
|
||||
return StructuredChatOutputParserWithRetries(llm=llm)
|
||||
|
||||
@property
|
||||
def _stop(self) -> List[str]:
|
||||
return ["<|observation|>"]
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prompt: str = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tools_json = []
|
||||
tool_names = []
|
||||
for tool in tools:
|
||||
tool_schema = model_schema(tool.args_schema) if tool.args_schema else {}
|
||||
simplified_config_langchain = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool_schema.get("properties", {})
|
||||
}
|
||||
tools_json.append(simplified_config_langchain)
|
||||
tool_names.append(tool.name)
|
||||
formatted_tools = "\n".join([
|
||||
f"{tool['name']}: {tool['description']}, args: {tool['parameters']}"
|
||||
for tool in tools_json
|
||||
])
|
||||
formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}")
|
||||
template = prompt.format(tool_names=tool_names,
|
||||
tools=formatted_tools,
|
||||
history="None",
|
||||
input="{input}",
|
||||
agent_scratchpad="{agent_scratchpad}")
|
||||
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
_memory_prompts = memory_prompts or []
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(template),
|
||||
*_memory_prompts,
|
||||
]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
prompt: str = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prompt=prompt,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser or cls._get_default_output_parser(llm=llm)
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
output_parser=_output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def _agent_type(self) -> str:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def initialize_glm3_agent(
|
||||
tools: Sequence[BaseTool],
|
||||
llm: BaseLanguageModel,
|
||||
prompt: str = None,
|
||||
memory: Optional[ConversationBufferWindowMemory] = None,
|
||||
agent_kwargs: Optional[dict] = None,
|
||||
*,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
tags_ = list(tags) if tags else []
|
||||
agent_kwargs = agent_kwargs or {}
|
||||
agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
prompt=prompt,
|
||||
**agent_kwargs
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent_obj,
|
||||
tools=tools,
|
||||
memory=memory,
|
||||
tags=tags_,
|
||||
**kwargs,
|
||||
)
|
||||
Reference in New Issue
Block a user