[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
161
langchain-chat/server/agent/callbacks.py
Normal file
161
langchain-chat/server/agent/callbacks.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from __future__ import annotations
|
||||
from uuid import UUID
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.schema import AgentFinish, AgentAction
|
||||
from langchain.schema.output import LLMResult
|
||||
|
||||
|
||||
def dumps(obj: Dict) -> str:
|
||||
return json.dumps(obj, ensure_ascii=False)
|
||||
|
||||
|
||||
class Status:
|
||||
start: int = 1
|
||||
running: int = 2
|
||||
complete: int = 3
|
||||
agent_action: int = 4
|
||||
agent_finish: int = 5
|
||||
error: int = 6
|
||||
tool_finish: int = 7
|
||||
|
||||
|
||||
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.queue = asyncio.Queue()
|
||||
self.done = asyncio.Event()
|
||||
self.cur_tool = {}
|
||||
self.out = True
|
||||
|
||||
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
|
||||
parent_run_id: UUID | None = None, tags: List[str] | None = None,
|
||||
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
|
||||
|
||||
# 对于截断不能自理的大模型,我来帮他截断
|
||||
stop_words = ["Observation:", "Thought","\"","(", "\n","\t"]
|
||||
for stop_word in stop_words:
|
||||
index = input_str.find(stop_word)
|
||||
if index != -1:
|
||||
input_str = input_str[:index]
|
||||
break
|
||||
|
||||
self.cur_tool = {
|
||||
"tool_name": serialized["name"],
|
||||
"input_str": input_str,
|
||||
"output_str": "",
|
||||
"status": Status.agent_action,
|
||||
"run_id": run_id.hex,
|
||||
"llm_token": "",
|
||||
"final_answer": "",
|
||||
"error": "",
|
||||
}
|
||||
# print("\nInput Str:",self.cur_tool["input_str"])
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
|
||||
tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||
self.out = True ## 重置输出
|
||||
self.cur_tool.update(
|
||||
status=Status.tool_finish,
|
||||
output_str=output.replace("Answer:", ""),
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
|
||||
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.error,
|
||||
error=str(error),
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
# async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
# if "Action" in token: ## 减少重复输出
|
||||
# before_action = token.split("Action")[0]
|
||||
# self.cur_tool.update(
|
||||
# status=Status.running,
|
||||
# llm_token=before_action + "\n",
|
||||
# )
|
||||
# self.queue.put_nowait(dumps(self.cur_tool))
|
||||
#
|
||||
# self.out = False
|
||||
#
|
||||
# if token and self.out:
|
||||
# self.cur_tool.update(
|
||||
# status=Status.running,
|
||||
# llm_token=token,
|
||||
# )
|
||||
# self.queue.put_nowait(dumps(self.cur_tool))
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
special_tokens = ["Action", "<|observation|>"]
|
||||
for stoken in special_tokens:
|
||||
if stoken in token:
|
||||
before_action = token.split(stoken)[0]
|
||||
self.cur_tool.update(
|
||||
status=Status.running,
|
||||
llm_token=before_action + "\n",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
self.out = False
|
||||
break
|
||||
|
||||
if token and self.out:
|
||||
self.cur_tool.update(
|
||||
status=Status.running,
|
||||
llm_token=token,
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.start,
|
||||
llm_token="",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.start,
|
||||
llm_token="",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.complete,
|
||||
llm_token="\n",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.error,
|
||||
error=str(error),
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_agent_finish(
|
||||
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# 返回最终答案
|
||||
self.cur_tool.update(
|
||||
status=Status.agent_finish,
|
||||
final_answer=finish.return_values["output"],
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
self.cur_tool = {}
|
||||
Reference in New Issue
Block a user