Files
gangyan/langchain-chat/server/chat/agent_v2.py

167 lines
6.9 KiB
Python
Raw Normal View History

"""
LangGraph Agent runner
替代旧的 agent_chat_test 内核
- 不再用 LLM step routingthinking/select_tool/answer让模型 function-calling 自己决定
- 同一轮的多个 tool_calls 自动并行ToolNode
- LangGraph 事件流映射到现有前端协议{"text":...}/{"docs":...}/{"detail":...}
输入query + history + uuid + model_name
输出和旧版 agent_chat_test 一样的 dict 序列"answer"/"docs"/"detail"/...
"""
import asyncio
import json
import logging
from typing import AsyncIterable, List, Optional
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langchain_openai import ChatOpenAI
from configs import LLM_MODELS, prompt_config
from server.utils import get_prompt_template, get_model_worker_config
from server.chat import utils as shared_utils
from server.chat.tools_v2 import make_tools
logger = logging.getLogger(__name__)
def _build_system_prompt(user_prompt_name: str, query: str, think_content: str) -> str:
"""复用旧版 Think Test Bak + 用户业务 prompt 的拼装逻辑,但简化为单条 system message。"""
base = get_prompt_template("agent_chat", "Think Test Bak")
user = get_prompt_template("llm_chat", user_prompt_name) if user_prompt_name else ""
parts = []
parts.append("你是浪潮开发的智能专家。回答用户问题前可以使用工具检索资料。")
parts.append("严格要求:")
parts.append("1. 优先使用工具获取资料后再回答,禁止虚构内容")
parts.append("2. 同一个工具同一参数禁止反复调用超过 2 次")
parts.append("3. 回答时必须基于工具返回的资料,引用要标注【】序号")
parts.append("4. 涉及国家政策优先用 知识库联想 + 政策库")
parts.append("5. 答案紧扣用户问题,不要主观臆想")
parts.append("")
parts.append(f"思考提示:{think_content}")
parts.append("")
if user:
parts.append(f"业务约束:{user}")
return "\n".join(parts)
def _convert_history(history: list) -> list:
"""把 chat_test.py 的 history listdict role/content转成 LangChain messages。"""
msgs = []
for h in history or []:
role = h.get("role")
content = h.get("content", "")
if role == "user":
msgs.append(("user", content))
elif role == "assistant":
msgs.append(("assistant", content))
return msgs
async def agent_run(
*,
query: str,
uuid: str,
history: Optional[list] = None,
model_name: str = None,
temperature: float = 0.3,
max_tokens: Optional[int] = None,
user_prompt_name: str = "",
think_content: str = "",
) -> AsyncIterable[str]:
"""运行 LangGraph agentyield 事件 JSON 字符串。
yield 协议向后兼容 chat_test.py 的消费逻辑
{"text": str} 思考框/答案框文本按出现位置区分
{"answer": str} token 级答案流chat_test 包装为 {"text":...}
{"docs": str} 工具返回的资料文档参考文献区
{"detail": str} 详细资料累积detail_answer
{"tool_start": dict} 调试/日志工具开始
{"tool_end": dict} 调试/日志工具结束
"""
model_name = model_name or LLM_MODELS[0]
# 必须用 langchain_openai.ChatOpenAI支持现代 tool calling 协议)
# 不能用 server.utils.get_ChatOpenAI返回 langchain_community 老版,不支持 bind_tools
cfg = get_model_worker_config(model_name)
llm = ChatOpenAI(
model=model_name,
base_url=cfg.get("api_base_url"),
api_key=cfg.get("api_key", "EMPTY"),
temperature=temperature,
max_tokens=max_tokens,
streaming=True,
)
tools = make_tools(uuid)
# 用 Think Test Bak + user_prompt 构造 system message
system_prompt = _build_system_prompt(user_prompt_name, query, think_content)
agent = create_react_agent(llm, tools=tools, messages_modifier=system_prompt)
msgs = _convert_history(history)
msgs.append(("user", query))
inputs = {"messages": msgs}
config = {"recursion_limit": 12} # 最多 12 步(远小于旧版 11 次外层 × N 内层)
answer_buf = []
try:
async for ev in agent.astream_events(inputs, config=config, version="v1"):
# 检查停止信号
if not shared_utils.get_shared_variable(uuid).get("status", True):
logger.info("Agent 收到停止信号")
break
kind = ev["event"]
name = ev.get("name", "")
if kind == "on_chat_model_stream":
chunk = ev["data"]["chunk"]
content = chunk.content or ""
if content:
answer_buf.append(content)
yield json.dumps({"answer": content}, ensure_ascii=False)
elif kind == "on_tool_start":
tool_input = ev["data"].get("input", {})
logger.info(f"工具调用开始: {name}({tool_input})")
# 工具说明落到思考框(前端的 thinking 区域)
yield json.dumps(
{"think": f"\n→ 调用工具:{name}\n"},
ensure_ascii=False,
)
elif kind == "on_tool_end":
output = str(ev["data"].get("output", ""))
logger.info(f"工具调用结束: {name}{len(output)} chars")
# 知识库联想 / 联网思索 → 提取 source_docs 给前端参考文献区
if name in ("知识库联想", "联网思索"):
source = shared_utils.get_shared_variable(uuid)
source_docs = source.get("source_docs", [])
if source_docs:
try:
docs_string = "\n" + "\n".join(f"{str(d)}\n" for d in source_docs)
yield json.dumps({"docs": docs_string}, ensure_ascii=False)
except Exception:
logger.exception("docs 序列化失败")
# detail详细搜索内容累积到 docs_detail给后续幻觉校验用
if name in ("知识库联想", "联网思索"):
yield json.dumps({"detail": output}, ensure_ascii=False)
except asyncio.CancelledError:
logger.info("Agent 被取消")
raise
except Exception as e:
logger.exception(f"Agent 运行异常: {e}")
# 给前端一个兜底答案
yield json.dumps(
{"answer": f"\n\n[Agent 运行异常] 已尽力使用工具但未能完整生成答案,请重试或简化问题。"},
ensure_ascii=False,
)
# 终态收尾
full_answer = "".join(answer_buf)
logger.info(f"Agent 完成:答案长度 {len(full_answer)} chars")