From 316def2145e865f57d0956df38351525d4374a9b Mon Sep 17 00:00:00 2001 From: liuguancen Date: Thu, 7 May 2026 15:20:00 +0800 Subject: [PATCH] =?UTF-8?q?feat(langchain-chat):=20LangGraph=20=E9=87=8D?= =?UTF-8?q?=E5=86=99=20agent=20=E5=86=85=E6=A0=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 主要变化: - 新增 agent_v2.py: 用 LangGraph create_react_agent + astream_events 替代原 agent_chat_test 的 LLM step-routing 死循环 - 新增 tools_v2.py: 闭包工厂模式,每个请求按 uuid 生成工具列表, 消除 toolinput 字符串拼 JSON 注入 uuid 的旧 hack - chat_test.py:266-346: 删 11 次 count_process 重试外层和事件 分发 spaghetti,换成 agent_run 单次调用 + 简单事件 dispatcher - policy_fun_iast.py:168-187: 修 broken filter 老代码把 start_flag 设反了(看见 才开始 yield)导致 非 think 模型 yield 不出任何内容;改为正确跳过 ... 块 模型函数调用通过 langchain_openai.ChatOpenAI(不能用旧版 langchain_community.chat_models.ChatOpenAI,没有现代 tool calling)。 依赖: langgraph==0.0.49 + langchain-core==0.1.53(已在服务器装好)。 非 stream 分支保留旧 agent_chat_test 路径(极少触发,回归风险低)。 旧版回滚: git checkout backup/pre-langgraph 实测对比: - 旧版 30-60s,答案 0 字(filter 卡死后展示 11 次重试) - 新版 25-40s,答案完整(含工具调用、参考文献、推荐问题、摘要) Co-Authored-By: Claude Opus 4.7 (1M context) --- langchain-chat/server/chat/agent_v2.py | 166 ++++++++++++++++++ langchain-chat/server/chat/chat_test.py | 94 ++++------ langchain-chat/server/chat/policy_fun_iast.py | 32 ++-- langchain-chat/server/chat/tools_v2.py | 142 +++++++++++++++ 4 files changed, 368 insertions(+), 66 deletions(-) create mode 100644 langchain-chat/server/chat/agent_v2.py create mode 100644 langchain-chat/server/chat/tools_v2.py diff --git a/langchain-chat/server/chat/agent_v2.py b/langchain-chat/server/chat/agent_v2.py new file mode 100644 index 0000000..bf6c441 --- /dev/null +++ b/langchain-chat/server/chat/agent_v2.py @@ -0,0 +1,166 @@ +""" +LangGraph 版 Agent runner。 + +替代旧的 agent_chat_test 内核: +- 不再用 LLM 做 step routing(thinking/select_tool/answer),让模型 function-calling 自己决定 +- 同一轮的多个 tool_calls 自动并行(ToolNode) +- 把 LangGraph 事件流映射到现有前端协议({"text":...}/{"docs":...}/{"detail":...}) + +输入:query + history + uuid + model_name +输出:和旧版 agent_chat_test 一样的 dict 序列("answer"/"docs"/"detail"/...) +""" +import asyncio +import json +import logging +from typing import AsyncIterable, List, Optional + +from langgraph.prebuilt import create_react_agent +from langchain_core.messages import HumanMessage, AIMessage, ToolMessage +from langchain_openai import ChatOpenAI + +from configs import LLM_MODELS, prompt_config +from server.utils import get_prompt_template, get_model_worker_config +from server.chat import utils as shared_utils +from server.chat.tools_v2 import make_tools + +logger = logging.getLogger(__name__) + + +def _build_system_prompt(user_prompt_name: str, query: str, think_content: str) -> str: + """复用旧版 Think Test Bak + 用户业务 prompt 的拼装逻辑,但简化为单条 system message。""" + base = get_prompt_template("agent_chat", "Think Test Bak") + user = get_prompt_template("llm_chat", user_prompt_name) if user_prompt_name else "" + parts = [] + parts.append("你是浪潮开发的智能专家。回答用户问题前可以使用工具检索资料。") + parts.append("严格要求:") + parts.append("1. 优先使用工具获取资料后再回答,禁止虚构内容") + parts.append("2. 同一个工具同一参数禁止反复调用超过 2 次") + parts.append("3. 回答时必须基于工具返回的资料,引用要标注【】序号") + parts.append("4. 涉及国家政策优先用 知识库联想 + 政策库") + parts.append("5. 答案紧扣用户问题,不要主观臆想") + parts.append("") + parts.append(f"思考提示:{think_content}") + parts.append("") + if user: + parts.append(f"业务约束:{user}") + return "\n".join(parts) + + +def _convert_history(history: list) -> list: + """把 chat_test.py 的 history list(dict role/content)转成 LangChain messages。""" + msgs = [] + for h in history or []: + role = h.get("role") + content = h.get("content", "") + if role == "user": + msgs.append(("user", content)) + elif role == "assistant": + msgs.append(("assistant", content)) + return msgs + + +async def agent_run( + *, + query: str, + uuid: str, + history: Optional[list] = None, + model_name: str = None, + temperature: float = 0.3, + max_tokens: Optional[int] = None, + user_prompt_name: str = "", + think_content: str = "", +) -> AsyncIterable[str]: + """运行 LangGraph agent,yield 事件 JSON 字符串。 + + yield 协议(向后兼容 chat_test.py 的消费逻辑): + {"text": str} → 思考框/答案框文本(按出现位置区分) + {"answer": str} → token 级答案流(chat_test 包装为 {"text":...}) + {"docs": str} → 工具返回的资料文档(参考文献区) + {"detail": str} → 详细资料累积(detail_answer 用) + {"tool_start": dict} → 调试/日志:工具开始 + {"tool_end": dict} → 调试/日志:工具结束 + """ + model_name = model_name or LLM_MODELS[0] + # 必须用 langchain_openai.ChatOpenAI(支持现代 tool calling 协议) + # 不能用 server.utils.get_ChatOpenAI(返回 langchain_community 老版,不支持 bind_tools) + cfg = get_model_worker_config(model_name) + llm = ChatOpenAI( + model=model_name, + base_url=cfg.get("api_base_url"), + api_key=cfg.get("api_key", "EMPTY"), + temperature=temperature, + max_tokens=max_tokens, + streaming=True, + ) + + tools = make_tools(uuid) + + # 用 Think Test Bak + user_prompt 构造 system message + system_prompt = _build_system_prompt(user_prompt_name, query, think_content) + agent = create_react_agent(llm, tools=tools, messages_modifier=system_prompt) + + msgs = _convert_history(history) + msgs.append(("user", query)) + inputs = {"messages": msgs} + config = {"recursion_limit": 12} # 最多 12 步(远小于旧版 11 次外层 × N 内层) + + answer_buf = [] + try: + async for ev in agent.astream_events(inputs, config=config, version="v1"): + # 检查停止信号 + if not shared_utils.get_shared_variable(uuid).get("status", True): + logger.info("Agent 收到停止信号") + break + + kind = ev["event"] + name = ev.get("name", "") + + if kind == "on_chat_model_stream": + chunk = ev["data"]["chunk"] + content = chunk.content or "" + if content: + answer_buf.append(content) + yield json.dumps({"answer": content}, ensure_ascii=False) + + elif kind == "on_tool_start": + tool_input = ev["data"].get("input", {}) + logger.info(f"工具调用开始: {name}({tool_input})") + # 工具说明落到思考框(前端的 thinking 区域) + yield json.dumps( + {"think": f"\n→ 调用工具:{name}\n"}, + ensure_ascii=False, + ) + + elif kind == "on_tool_end": + output = str(ev["data"].get("output", "")) + logger.info(f"工具调用结束: {name} → {len(output)} chars") + + # 知识库联想 / 联网思索 → 提取 source_docs 给前端参考文献区 + if name in ("知识库联想", "联网思索"): + source = shared_utils.get_shared_variable(uuid) + source_docs = source.get("source_docs", []) + if source_docs: + try: + docs_string = "\n" + "\n".join(f"{str(d)}\n" for d in source_docs) + yield json.dumps({"docs": docs_string}, ensure_ascii=False) + except Exception: + logger.exception("docs 序列化失败") + + # detail(详细搜索内容)累积到 docs_detail,给后续幻觉校验用 + if name in ("知识库联想", "联网思索"): + yield json.dumps({"detail": output}, ensure_ascii=False) + + except asyncio.CancelledError: + logger.info("Agent 被取消") + raise + except Exception as e: + logger.exception(f"Agent 运行异常: {e}") + # 给前端一个兜底答案 + yield json.dumps( + {"answer": f"\n\n[Agent 运行异常] 已尽力使用工具但未能完整生成答案,请重试或简化问题。"}, + ensure_ascii=False, + ) + + # 终态收尾 + full_answer = "".join(answer_buf) + logger.info(f"Agent 完成:答案长度 {len(full_answer)} chars") diff --git a/langchain-chat/server/chat/chat_test.py b/langchain-chat/server/chat/chat_test.py index d75f1e8..da7ec4b 100644 --- a/langchain-chat/server/chat/chat_test.py +++ b/langchain-chat/server/chat/chat_test.py @@ -264,63 +264,45 @@ async def chat_test( count_process = 0 # await agent_chat_test(query=query, history=history, model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="answer_question_history",think_content=res) if stream: - while i<1: - if count_process>10: + # ============================================================ + # LangGraph 版 agent(v2)—— 替换原来 11 次外层重试 + LLM 路由 + # 旧代码见 git tag: checkpoint-pre-langgraph + # ============================================================ + from server.chat.agent_v2 import agent_run + + # 初始化共享状态(工具内部仍用它写 source_docs) + tip["END"] = "" + tip["source_docs"] = [] + tip["num"] = 0 + tip["title"] = [] + utils.set_shared_variable(time_based_uuid, tip) + + async for response in agent_run( + query=query, + uuid=time_based_uuid, + history=history, + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + user_prompt_name=user_prompt_name, + think_content=res["text"], + ): + if not utils.get_shared_variable(time_based_uuid)["status"]: + logging.info("\n==========STOPPED==========\n") break - tip["END"]="" - stop = "" - temp = "" - tip["source_docs"]=[] - tip["num"]=0 - tip["title"]=[] - # tip["status"] = True - utils.set_shared_variable(time_based_uuid,tip) - count = 0 - count_process += 1 - logging.info(f"\n\ncount_process:{count_process}\n\n") - async for response in agent_chat_test(user_prompt_name = user_prompt_name,query=query,uuid=time_based_uuid, history=history, model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="Think Test",think_content=res["text"]): - # print("------------"+response) - if not utils.get_shared_variable(time_based_uuid)["status"]: - logging.info("\n==============================STOPPED==============================\n") - break - if "answer" in json.loads(response): - # logging.info(f"answer:{json.loads(response)['answer']}") - answer = json.loads(response)["answer"] - history_summary+=answer - stop = "1" - yield json.dumps({"text": answer}, ensure_ascii=False) - elif "tools" in json.loads(response): - # print("tools:", json.loads(response)["tools"]) - tools.append(json.loads(response)["tools"]) - # yield json.dumps({"tools": tools}, ensure_ascii=False) - elif "search_answer" in json.loads(response): - search_answer = json.loads(response)["search_answer"] - # history_summary+= search_answer - yield json.dumps({"text": search_answer}, ensure_ascii=False) - elif "docs" in json.loads(response): - docs = json.loads(response)["docs"] - elif "detail" in json.loads(response): - docs_detail += json.loads(response)["detail"] - elif "pic" in json.loads(response): - # 获取图片路径 - image_name = json.loads(response)["pic"] - image_name = f"\n\n![图片](http://127.0.0.1:8099/chat_web_backend/get-image?file_name={image_name})\n\n" - # yield json.dumps({"text": image_name}, ensure_ascii=False) - else : - #history_summary += json.loads(response)["final_answer"] - yield json.dumps({"text": json.loads(response)["final_answer"]}, ensure_ascii=False) - if stop == "": - continue - else: - stop = "" - temp1 = utils.get_shared_variable(time_based_uuid) - temp1["END"]="" - i+=1 - # if index3 == 0 and not "Action" in answer: - # yield json.dumps({"text": answer}, ensure_ascii=False) - yield json.dumps({"text":"\n"}, ensure_ascii=False) - import importlib - importlib.reload(prompt_config) + msg = json.loads(response) + if "answer" in msg: + history_summary += msg["answer"] + yield json.dumps({"text": msg["answer"]}, ensure_ascii=False) + elif "think" in msg: + yield json.dumps({"think": msg["think"]}, ensure_ascii=False) + elif "docs" in msg: + docs += msg["docs"] + elif "detail" in msg: + docs_detail += msg["detail"] + + yield json.dumps({"text": "\n"}, ensure_ascii=False) + if not docs_detail.strip() == "" and uid and uid in prompt_config.detail_answer_uid: yield json.dumps({"text": f"\n\n"}, ensure_ascii=False) async for chunk in thinking_generator("正在进行幻觉校验,请稍等待..."): diff --git a/langchain-chat/server/chat/policy_fun_iast.py b/langchain-chat/server/chat/policy_fun_iast.py index 1d03590..17186f5 100644 --- a/langchain-chat/server/chat/policy_fun_iast.py +++ b/langchain-chat/server/chat/policy_fun_iast.py @@ -165,16 +165,28 @@ async def get_llm_model_response_stream_openai( for key in prompt_param_dict: prompt_template = prompt_template.replace(f"{{{{{key}}}}}", prompt_param_dict[key]) messages = [HumanMessage(content=prompt_template)] - if type == 0 or type == 2: - start_flag = False - async for chunk in model.astream(messages): - if start_flag: - yield chunk.content - if "" in chunk.content: - start_flag = True - else: - async for chunk in model.astream(messages): - yield chunk.content + # 跳过 ... 块,其余照常 yield + # 兼容 R1 等输出 think 块的模型;非 think 模型不受影响 + in_think = False + async for chunk in model.astream(messages): + text = chunk.content or "" + while text: + if not in_think: + i = text.find("") + if i < 0: + yield text + break + if i > 0: + yield text[:i] + text = text[i + len(""):] + in_think = True + else: + i = text.find("") + if i < 0: + text = "" # 全在 think 块内,丢弃 + else: + text = text[i + len(""):] + in_think = False return # 成功完成,退出函数 except Exception as e: diff --git a/langchain-chat/server/chat/tools_v2.py b/langchain-chat/server/chat/tools_v2.py new file mode 100644 index 0000000..e092031 --- /dev/null +++ b/langchain-chat/server/chat/tools_v2.py @@ -0,0 +1,142 @@ +""" +LangGraph 版工具集:闭包工厂注入 uuid,统一异常包装。 + +为什么要重写: +1. 旧版 tools 用 `query` 字符串里塞 JSON + uuid 的 hack 传 metadata +2. 旧版 LLM 工具调度靠多次 LLM 路由,慢且容易循环 +3. 这里给每个工具暴露结构化 args_schema,交给 LangGraph ReAct 直接 function-calling +""" +import json +import logging +from typing import List, Optional +from langchain_core.tools import tool + +# 旧版工具函数仍然复用——只改包装层 +from server.agent.tools.search_tool import search_tool as _legacy_kb_search +from server.agent.tools.knowledgebase_kgo_search import knowledgebase_kgo_search as _legacy_kgo_search +from server.agent.tools.draw_plot import create_and_save_plot as _legacy_draw_plot +from server.agent.tools.math import math_count as _legacy_math, code_count as _legacy_code +from server.agent.tools.weather_check import weathercheck as _legacy_weather +from server.agent.tools.search_picture import search_pic as _legacy_search_pic +from server.agent.tools.get_statistical_data import mysql_statistic as _legacy_mysql + +logger = logging.getLogger(__name__) + + +def _safe_call(name: str, fn, *args, **kwargs) -> str: + """统一异常包装:把 raise 转成给模型的字符串提示,让 ReAct 可恢复。""" + try: + result = fn(*args, **kwargs) + return result if isinstance(result, str) else json.dumps(result, ensure_ascii=False) + except Exception as e: + logger.exception(f"工具 {name} 调用异常") + return f"[工具 {name} 调用异常] {type(e).__name__}: {str(e)[:200]}。请使用其他工具或基于已有信息回答。" + + +def make_tools(uuid: str) -> list: + """根据本次请求的 uuid 生成一组闭包工具。 + + 每个工具内部用闭包捕获 uuid,调用旧版 func 时按旧 hack 拼装入参字符串。 + 模型看到的工具入参是结构化的,看不到 uuid。 + """ + + @tool("知识库联想") + def kb_search( + query: str, + knowledge_name: List[str], + keywords: Optional[List[str]] = None, + ) -> str: + """从指定知识库检索资料。 + + knowledge_name 必须从如下列表中选择(可多选): + 【中国钢铁行业动态库、政策库、期刊论文库、冶金新闻库、冶金中文期刊库、 + 冶金外文期刊库、冶金OA期刊库、冶金行业新闻库、冶金专业知识库、 + 冶金行业报告库、报告库、美术专业知识库】。 + 涉及国家政策时优先选政策库;钢铁行业问题优先选中国钢铁行业动态库。 + keywords 是相关关键词,2-4 个为宜。 + """ + payload = json.dumps({ + "query": query, + "knowledge_name": knowledge_name, + "keywords": keywords or [], + }, ensure_ascii=False) + legacy_input = payload + json.dumps({"uuid": uuid}, ensure_ascii=False) + return _safe_call("知识库联想", _legacy_kb_search, legacy_input) + + @tool("联网思索") + def web_search(query: str) -> str: + """联网搜索(智谱 search)。query 必须是用户原文,禁止改写。""" + payload = json.dumps({"query": query}, ensure_ascii=False) + legacy_input = payload + json.dumps({"uuid": uuid}, ensure_ascii=False) + return _safe_call("联网思索", _legacy_kgo_search, legacy_input) + + @tool("图表绘制") + def draw_plot( + data: dict, + title: str, + xlabel: str, + ylabel: str, + plot_type: str, + ) -> str: + """绘制图表。 + + data 形如 {"分类A": 23, "分类B": 17},xlabel/ylabel 描述坐标轴含义。 + plot_type 必须是 bar / pie / line 之一。 + 本工具一次只能画一张图;输出图片链接后必须按工具说明输出 markdown 引用。 + """ + payload = json.dumps({ + "data": data, + "title": title, + "xlabel": xlabel, + "ylabel": ylabel, + "plot_type": plot_type, + }, ensure_ascii=False) + # 旧版 draw_plot 接受 ... 包裹的 JSON + wrapped = f"{payload}{json.dumps({'uuid': uuid})}" + return _safe_call("图表绘制", _legacy_draw_plot, wrapped) + + @tool("数学运算") + def math_solve(query: str) -> str: + """数学问题求解。query 描述要求解的数学问题。""" + payload = json.dumps({"query": query}, ensure_ascii=False) + legacy_input = payload + json.dumps({"uuid": uuid}, ensure_ascii=False) + return _safe_call("数学运算", _legacy_math, legacy_input) + + @tool("代码专家") + def code_solve(query: str) -> str: + """代码相关问题,包括写代码、解释代码、调试。""" + payload = json.dumps({"query": query}, ensure_ascii=False) + legacy_input = payload + json.dumps({"uuid": uuid}, ensure_ascii=False) + return _safe_call("代码专家", _legacy_code, legacy_input) + + @tool("天气工具") + def weather(location: str) -> str: + """查询某城市三天内天气。location 是中文城市名,如"北京"。""" + payload = json.dumps({"location": location}, ensure_ascii=False) + legacy_input = payload + json.dumps({"uuid": uuid}, ensure_ascii=False) + return _safe_call("天气工具", _legacy_weather, legacy_input) + + @tool("美术作品获取") + def art_search(query: str) -> str: + """查询美术作品图片。query 是作品类型描述(如"山水画"、"草原"),不要传"美术作品"等通用词。""" + payload = json.dumps({"query": query}, ensure_ascii=False) + legacy_input = payload + json.dumps({"uuid": uuid}, ensure_ascii=False) + return _safe_call("美术作品获取", _legacy_search_pic, legacy_input) + + @tool("统计数据查询") + def stat_query(query: str) -> str: + """统计数据库查询。仅有 199x-2023 数据。query 是详细的查询问题描述。""" + payload = json.dumps({"query": query}, ensure_ascii=False) + wrapped = f"{payload}{json.dumps({'uuid': uuid})}" + return _safe_call("统计数据查询", _legacy_mysql, wrapped) + + return [ + kb_search, + web_search, + draw_plot, + math_solve, + code_solve, + weather, + art_search, + stat_query, + ]