主要变化: - 新增 agent_v2.py: 用 LangGraph create_react_agent + astream_events 替代原 agent_chat_test 的 LLM step-routing 死循环 - 新增 tools_v2.py: 闭包工厂模式,每个请求按 uuid 生成工具列表, 消除 toolinput 字符串拼 JSON 注入 uuid 的旧 hack - chat_test.py:266-346: 删 11 次 count_process 重试外层和事件 分发 spaghetti,换成 agent_run 单次调用 + 简单事件 dispatcher - policy_fun_iast.py:168-187: 修 broken <think> filter 老代码把 start_flag 设反了(看见 <think> 才开始 yield)导致 非 think 模型 yield 不出任何内容;改为正确跳过 <think>...</think> 块 模型函数调用通过 langchain_openai.ChatOpenAI(不能用旧版 langchain_community.chat_models.ChatOpenAI,没有现代 tool calling)。 依赖: langgraph==0.0.49 + langchain-core==0.1.53(已在服务器装好)。 非 stream 分支保留旧 agent_chat_test 路径(极少触发,回归风险低)。 旧版回滚: git checkout backup/pre-langgraph 实测对比: - 旧版 30-60s,答案 0 字(filter 卡死后展示 11 次重试) - 新版 25-40s,答案完整(含工具调用、参考文献、推荐问题、摘要) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
687 lines
37 KiB
Python
687 lines
37 KiB
Python
import uuid
|
||
from fastapi import Body, HTTPException
|
||
from fastapi.responses import FileResponse
|
||
from configs.kb_config import GENERATED_IMAGES_BASE_PATH
|
||
import geoip2.database
|
||
from langchain.memory import (
|
||
CombinedMemory,
|
||
ConversationBufferMemory,
|
||
# ConversationSummaryMemory,
|
||
# ConversationBufferWindowMemory
|
||
)
|
||
from typing import Any
|
||
import requests
|
||
from sse_starlette.sse import EventSourceResponse
|
||
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, DEEPSEEK_MODELS, CAST_MODELS
|
||
from configs import prompt_config
|
||
from configs.model_config import MODEL_ROOT_PATH
|
||
from server.agent.tools.search_tool import search_tool
|
||
from server.chat import utils
|
||
from server.chat.agent_chat_test import agent_chat_test
|
||
from server.chat.policy_fun_iast import get_llm_model_response, get_llm_model_response_stream_openai
|
||
from server.chat.utils import split_questions
|
||
from server.chat.solve_problem import solve_problem
|
||
from server.knowledge_base.kb_service.base import TextRank
|
||
from server.utils import wrap_done, get_ChatOpenAI
|
||
from langchain.chains import LLMChain, ConversationChain
|
||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||
from typing import AsyncIterable
|
||
import asyncio
|
||
import json
|
||
from langchain.prompts.chat import ChatPromptTemplate
|
||
from typing import List, Optional, Union
|
||
from server.chat.utils import History, compute_lps, remove_after_and_including,remove_before_and_including
|
||
from langchain.prompts import PromptTemplate
|
||
from server.utils import get_prompt_template, get_format_template
|
||
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||
from server.db.repository import add_message_to_db
|
||
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
|
||
from datetime import datetime
|
||
from langchain_core.messages import SystemMessage
|
||
import time as t
|
||
from configs.basic_config import *
|
||
async def process_task(task):
|
||
results = []
|
||
async for result in task:
|
||
results.append(result)
|
||
return results
|
||
|
||
async def get_image(file_name: str):
|
||
# 检查文件是否存在
|
||
if not os.path.exists(f"{GENERATED_IMAGES_BASE_PATH}/{file_name}") or not file_name.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"))or "*" in file_name:
|
||
raise HTTPException(status_code=404, detail="File not found")
|
||
# 返回文件响应
|
||
return FileResponse(f"{GENERATED_IMAGES_BASE_PATH}/{file_name}")
|
||
|
||
|
||
async def thinking_generator(content: str) -> AsyncIterable[str]:
|
||
"""思考过程的异步生成器(打字机效果,整段约 0.3s 完成)"""
|
||
yield json.dumps({'think': '\n'}, ensure_ascii=False)
|
||
for i in content:
|
||
yield json.dumps({'think': i}, ensure_ascii=False)
|
||
await asyncio.sleep(0.02)
|
||
yield json.dumps({'think': '\n'}, ensure_ascii=False)
|
||
|
||
|
||
async def chat_test(
|
||
messageId: Optional[Any] = Body(None, description="消息ID", examples=[""]),
|
||
ip: Optional[str] = Body(None, description="用户IP"),
|
||
uid: Optional[Any] = Body(None, description="用户ID"),
|
||
query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||
conversation_id: str = Body("", description="对话框ID"),
|
||
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
|
||
history: Union[int, List] = Body([],description="历史对话,设为一个整数可以从数据库中读取历史消息",
|
||
examples=[[
|
||
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||
{"role": "assistant", "content": "虎头虎脑","summary":"8989uj9"}
|
||
]]
|
||
),
|
||
stream: bool = Body(False, description="流式输出"),
|
||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
|
||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature 同时设置", gt=0.0, lt=1.0),
|
||
prompt_name: str = Body("default_new", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||
):
|
||
async def chat_iterator() -> AsyncIterable[str]:
|
||
nonlocal history, max_tokens, uid, model_name, prompt_name, messageId,query,ip, temperature
|
||
if prompt_name == "default":
|
||
prompt_name = "default_new"
|
||
docs_detail = ""
|
||
use_ip = get_llm_model_response(
|
||
strategy_name="query rewrite",
|
||
llm_model_name=LLM_MODELS[0],
|
||
template_prompt_name="use_ip",
|
||
prompt_param_dict={"query": query},
|
||
temperature=0.01,
|
||
max_tokens=512
|
||
)
|
||
print("\n\nuser_id:", uid)
|
||
if "1"in use_ip:
|
||
# ip = "58.56.96.28"
|
||
if ip:
|
||
try:
|
||
reader = geoip2.database.Reader(f'{MODEL_ROOT_PATH}/GeoLite2-City.mmdb')
|
||
city = reader.city(ip).city.names["zh-CN"]
|
||
except Exception as e:
|
||
url = f"https://ip.taobao.com/outGetIpInfo?ip={ip}&accessKey=alibaba-inc"
|
||
response = requests.get(url)
|
||
if response.status_code == 200:
|
||
if 'data' in response.json():
|
||
pass
|
||
else:
|
||
t.sleep(1)
|
||
response = requests.get(url)
|
||
city = response.json()["data"]["city"]
|
||
else:
|
||
city = "未知地区"
|
||
query = f"{query}(我所在的地区是{city})"
|
||
time_based_uuid = messageId if messageId else str(uuid.uuid1())+"q"
|
||
tip = {}
|
||
tip["status"] = True
|
||
utils.set_shared_variable(time_based_uuid,tip)
|
||
query = query if len(query)<20000 else TextRank(query,num_sentences=70)
|
||
query = query if len(query)<20000 else TextRank(query,num_sentences=10)
|
||
if model_name == "R1-70B":
|
||
model_name = DEEPSEEK_MODELS[0]
|
||
elif model_name in ["QIANWEN", "Qwen1.5-32B-Chat"]:
|
||
model_name = LLM_MODELS[0]
|
||
if prompt_name == "customer_service":
|
||
model_name = CAST_MODELS[0]
|
||
temperature = 0.01
|
||
history_temp = []
|
||
#history=[ {"role": "user","content": "我们来玩成语接龙,我先来,生龙活虎"}, {"role": "assistant", "content": "虎头虎脑","summary":"虎头虎脑"}]
|
||
callback = AsyncIteratorCallbackHandler()
|
||
callbacks = [callback]
|
||
memory = None
|
||
time = datetime.now().strftime("%Y年%m月%d日")
|
||
total_length = 0
|
||
total_length = sum(len(item["content"]) for item in history if "content" in item)
|
||
logger.info(f"历史对话长度:{total_length}")
|
||
# 负责保存llm response到message db
|
||
message_id = str(uuid.uuid1())+"q"
|
||
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
|
||
chat_type="llm_chat",
|
||
query=query)
|
||
callbacks.append(conversation_callback)
|
||
logger.info(f"智能对话的入参信息:messageId: {messageId}\n query:{query}\nconversation_id:{conversation_id}\nstream:{stream}\nmodel_name:{model_name}\ntemperature:{temperature}\nmax_tokens:{max_tokens}prompt_name:{prompt_name}")
|
||
|
||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
||
max_tokens = None
|
||
|
||
if prompt_name == "Search Summary":
|
||
model = get_ChatOpenAI(
|
||
model_name=LLM_MODELS[0],
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
callbacks=callbacks,
|
||
)
|
||
# print ("model info >>>", LLM_MODELS[0])
|
||
else:
|
||
model = get_ChatOpenAI(
|
||
model_name=model_name,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
callbacks=callbacks,
|
||
)
|
||
logger.info(f"当前使用模型:{model_name}")
|
||
think_type = {"text": "", "message_id": message_id}
|
||
user_prompt_name = ""
|
||
history_list_str = ""
|
||
history_summary = ""
|
||
docs = ""
|
||
summary_only_history = []
|
||
summary_group_history = {}
|
||
history_index = []
|
||
max_length = 90000
|
||
match prompt_name:
|
||
# case "default":
|
||
# print("未指定模板名称")
|
||
|
||
case _:
|
||
user_prompt_name = prompt_name
|
||
print(f"使用prompt模板:{prompt_name}")
|
||
# 用来记录需要的历史列表
|
||
if total_length>max_length and model_name not in DEEPSEEK_MODELS and model_name not in CAST_MODELS:
|
||
summary_only_history = [{"role": item["role"], "content": item.get("summary",item["content"])} for item in history]
|
||
summary_group_history = {i:[summary_only_history[i], summary_only_history[i + 1]] for i in range(0, len(summary_only_history) - 1, 2)}
|
||
task1 = solve_problem(user_prompt_name= user_prompt_name,query=query, conversation_id="", history=summary_only_history[len(history) - 6:], model_name=LLM_MODELS[0],temperature=temperature,max_tokens=max_tokens,prompt_name="think_route",stream=True)
|
||
task2 = solve_problem(user_prompt_name= user_prompt_name,query=query, conversation_id="", history=summary_group_history, model_name=LLM_MODELS[0],temperature=temperature,max_tokens=max_tokens,prompt_name="history_route",stream=True)
|
||
results = await asyncio.gather(process_task(task1), process_task(task2))
|
||
for result0 in results[0]:
|
||
think_type["text"] += json.loads(result0)["text"].strip()
|
||
break
|
||
for result1 in results[1]:
|
||
history_list_str += json.loads(result1)["text"]
|
||
|
||
if "无" in history_list_str:
|
||
last_five_indices = list(range(max(0, len(history) - 6), len(history)))
|
||
type(last_five_indices)
|
||
# history_index.append(indice for indice in last_five_indices)
|
||
history = [{"role": item["role"], "content": item["content"]} for item in history]
|
||
history_temp = []
|
||
for index in last_five_indices:
|
||
history_temp.append(history[index])
|
||
else:
|
||
history_index = history_list_str.split(",")
|
||
last_five_indices = list(range(max(0, len(history) - 6), len(history)))
|
||
# history_index1 = []
|
||
# history_index1.append(indice for indice in last_five_indices)
|
||
history_temp = []
|
||
for index in history_index:
|
||
if index not in last_five_indices:
|
||
history_temp.append(history[int(index)])
|
||
history_temp.append(history[int(index)+1])
|
||
for index in last_five_indices:
|
||
history_temp.append(history[index])
|
||
history = [{"role": item["role"], "content": item["content"]} for item in history_temp]
|
||
# elif history and total_length<=max_length:
|
||
# history = [{"role": item["role"], "content": item["content"]} for item in history]
|
||
# async for response in solve_problem(user_prompt_name= user_prompt_name,query=query, conversation_id="", history=history, model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="think_route",stream=True):
|
||
# think_type["text"] += json.loads(response)["text"]
|
||
# break
|
||
elif model_name not in DEEPSEEK_MODELS and model_name not in CAST_MODELS:
|
||
history = [{"role": item["role"], "content": item["content"]} for item in history]
|
||
async for response in solve_problem(user_prompt_name= user_prompt_name,query=query, conversation_id="", history=history, model_name=LLM_MODELS[0],temperature=temperature,max_tokens=max_tokens,prompt_name="think_route",stream=True):
|
||
think_type["text"] += json.loads(response)["text"]
|
||
break
|
||
elif model_name in DEEPSEEK_MODELS:
|
||
think_type["text"] = 1
|
||
prompt_name = "deepseek_default"
|
||
else:
|
||
think_type["text"] = 1
|
||
res = {}
|
||
if(str(think_type["text"]) == "2"):
|
||
if prompt_name == "default_new":
|
||
prompt_name = "default"
|
||
logging.info(f"💡 think_type == 2")
|
||
res["text"]=""
|
||
async for response in solve_problem(user_prompt_name=prompt_name,query=query, conversation_id="", history=history, model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="solve_problem",stream=True):
|
||
res["text"]+=json.loads(response)["text"]
|
||
yield json.dumps({"think":json.loads(response)["text"]}, ensure_ascii=False)
|
||
if not utils.get_shared_variable(time_based_uuid)["status"]:
|
||
logging.info("\n==============================STOPPED==============================\n")
|
||
break
|
||
# 使用
|
||
async for chunk in thinking_generator("正在整合各个信息,请稍等待..."):
|
||
text = json.loads(chunk)["think"]
|
||
yield json.dumps({"think":text}, ensure_ascii=False)
|
||
#res = await solve_problem(query=query, conversation_id="", history=history, model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="solve_problem",stream=True)
|
||
answer = ""
|
||
tools = []
|
||
search_answer = ""
|
||
current_str1 = ""
|
||
# 用于标志是否为调用工具状态
|
||
index = 0
|
||
#零表示没有结束调用工具
|
||
index1 = 0
|
||
index2 = 0
|
||
index3 = 0
|
||
i =0
|
||
prefixs=["¥","¥我","¥我将","¥我将会","¥我将会使","¥我将会使用","¥我将会使用工"]
|
||
# time_based_uuid = messageId if messageId else str(uuid.uuid1())+"q"
|
||
# tip ={}
|
||
count_process = 0
|
||
# await agent_chat_test(query=query, history=history, model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="answer_question_history",think_content=res)
|
||
if stream:
|
||
# ============================================================
|
||
# LangGraph 版 agent(v2)—— 替换原来 11 次外层重试 + LLM 路由
|
||
# 旧代码见 git tag: checkpoint-pre-langgraph
|
||
# ============================================================
|
||
from server.chat.agent_v2 import agent_run
|
||
|
||
# 初始化共享状态(工具内部仍用它写 source_docs)
|
||
tip["END"] = ""
|
||
tip["source_docs"] = []
|
||
tip["num"] = 0
|
||
tip["title"] = []
|
||
utils.set_shared_variable(time_based_uuid, tip)
|
||
|
||
async for response in agent_run(
|
||
query=query,
|
||
uuid=time_based_uuid,
|
||
history=history,
|
||
model_name=model_name,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
user_prompt_name=user_prompt_name,
|
||
think_content=res["text"],
|
||
):
|
||
if not utils.get_shared_variable(time_based_uuid)["status"]:
|
||
logging.info("\n==========STOPPED==========\n")
|
||
break
|
||
msg = json.loads(response)
|
||
if "answer" in msg:
|
||
history_summary += msg["answer"]
|
||
yield json.dumps({"text": msg["answer"]}, ensure_ascii=False)
|
||
elif "think" in msg:
|
||
yield json.dumps({"think": msg["think"]}, ensure_ascii=False)
|
||
elif "docs" in msg:
|
||
docs += msg["docs"]
|
||
elif "detail" in msg:
|
||
docs_detail += msg["detail"]
|
||
|
||
yield json.dumps({"text": "\n"}, ensure_ascii=False)
|
||
|
||
if not docs_detail.strip() == "" and uid and uid in prompt_config.detail_answer_uid:
|
||
yield json.dumps({"text": f"\n\n"}, ensure_ascii=False)
|
||
async for chunk in thinking_generator("正在进行幻觉校验,请稍等待..."):
|
||
text = json.loads(chunk)["think"]
|
||
yield json.dumps({"text":text}, ensure_ascii=False)
|
||
async for chunk in get_llm_model_response_stream_openai(
|
||
type=2,
|
||
strategy_name="query rewrite",
|
||
llm_model_name=LLM_MODELS[0],
|
||
template_prompt_name="detail_answer",
|
||
prompt_param_dict={
|
||
"input": query,
|
||
"history_summary": history_summary,
|
||
"docs_detail": docs_detail,
|
||
},
|
||
temperature=0.7,
|
||
max_tokens=None,
|
||
):
|
||
yield json.dumps({"text": chunk}, ensure_ascii=False)
|
||
if not docs.strip() == "":
|
||
yield json.dumps({"docs": docs}, ensure_ascii=False)
|
||
history_summary += docs
|
||
del docs
|
||
|
||
else:
|
||
async for response in agent_chat_test(user_prompt_name = user_prompt_name,query=query,uuid=time_based_uuid, history=history, model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="answer_question_history",think_content=res["text"]):
|
||
print("------------"+response)
|
||
if "answer" in json.loads(response):
|
||
answer += json.loads(response)["answer"]
|
||
if any(answer.endswith(prefix)for prefix in prefixs) and index == 0:
|
||
index1 = 1
|
||
current_str1+=json.loads(response)["answer"]
|
||
elif (answer.endswith("Action")) or "Action" in answer:
|
||
index1 = 1
|
||
current_str1+=json.loads(response)["answer"]
|
||
if index == 0:
|
||
#yield json.dumps({"text": remove_after_and_including(current_str1,"Action")+"正在使用工具中...", "message_id": message_id}, ensure_ascii=False)
|
||
history_summary += remove_after_and_including(current_str1,"Action")
|
||
index = 1
|
||
if utils.get_shared_variable(time_based_uuid)["END"]=="ok":
|
||
# if "【调用结束】" in answer:
|
||
# yield json.dumps({"texts": remove_before_and_including(current_str1,"【调用结束】"), "message_id": message_id}, ensure_ascii=False)
|
||
# else:
|
||
# yield json.dumps({"texts": remove_before_and_including(current_str1,"调用结束"), "message_id": message_id}, ensure_ascii=False)
|
||
current_str1 = ""
|
||
index2 = 1
|
||
index = 0
|
||
index1 = 0
|
||
answer=""
|
||
ok = utils.get_shared_variable(time_based_uuid)
|
||
ok["END"]=""
|
||
utils.set_shared_variable(time_based_uuid,ok)
|
||
elif index3==0:
|
||
print("等待中...")
|
||
pass
|
||
else:
|
||
if index1 == 0:
|
||
current_str1+=json.loads(response)["answer"]
|
||
if index2 == 1:
|
||
if len(current_str1) > 5:
|
||
index2 = 0
|
||
else:
|
||
history_summary += current_str1
|
||
#yield json.dumps({"text": current_str1, "message_id": message_id}, ensure_ascii=False)
|
||
current_str1 = ""
|
||
index = 0
|
||
index1 = 0
|
||
#yield json.dumps({"text": remove_before_and_including(answer,"【调用结束】"), "message_id": message_id}, ensure_ascii=False)
|
||
elif "tools" in json.loads(response):
|
||
# print("tools:", json.loads(response)["tools"])
|
||
tools.append(json.loads(response)["tools"])
|
||
# yield json.dumps({"tools": tools}, ensure_ascii=False)
|
||
elif "search_answer" in json.loads(response):
|
||
search_answer = json.loads(response)["search_answer"]
|
||
history_summary+= search_answer
|
||
#yield json.dumps({"text": search_answer}, ensure_ascii=False)
|
||
elif "docs" in json.loads(response):
|
||
docs += json.loads(response)["docs"]
|
||
else :
|
||
history_summary += json.loads(response)["final_answer"]
|
||
#yield json.dumps({"texts": json.loads(response)["final_answer"]}, ensure_ascii=False)
|
||
#yield json.dumps({"text":"\n"}, ensure_ascii=False)
|
||
if index3 == 0 and not "Action" in answer:
|
||
history_summary += answer
|
||
history_summary+= "\n"
|
||
#yield json.dumps({"docs": docs}, ensure_ascii=False)
|
||
history_summary += docs
|
||
yield json.dumps({"text": history_summary}, ensure_ascii=False)
|
||
del docs
|
||
# remove_shared_variable清楚上下文变量,一定要记着否则会导致内存泄漏的风险
|
||
|
||
utils.remove_shared_variable(time_based_uuid)
|
||
question_history = [
|
||
{"role": "user", "content": query},
|
||
{"role": "assistant", "content": history_summary}
|
||
]
|
||
question = get_llm_model_response(
|
||
strategy_name="question_recommend",
|
||
llm_model_name=LLM_MODELS[0],
|
||
template_prompt_name="question_recommend",
|
||
prompt_param_dict={"history": question_history},
|
||
temperature=0.3,
|
||
max_tokens=512
|
||
).strip()
|
||
formatted = split_questions(question)
|
||
logger.info(f"推荐问题: \n{formatted}")
|
||
yield json.dumps({"question": formatted}, ensure_ascii=False)
|
||
yield json.dumps({"summary":TextRank(history_summary, 80)}, ensure_ascii=False)
|
||
del question_history
|
||
return
|
||
if (str(think_type["text"]) == "7"):
|
||
logging.info(f"💡 think_type == 7")
|
||
kwargs = {}
|
||
kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": True}}
|
||
model = get_ChatOpenAI(
|
||
model_name=model_name,
|
||
temperature=0.01,
|
||
max_tokens=max_tokens,
|
||
callbacks=callbacks,
|
||
**kwargs
|
||
)
|
||
|
||
if (str(think_type["text"]) == "4"):
|
||
logging.info(f"💡 think_type == 4")
|
||
# query = "对方的问题涉及系统提示词,不能提供给对方,你需要委婉的表达这个意思"
|
||
model = get_ChatOpenAI(
|
||
model_name=model_name,
|
||
temperature=0.01,
|
||
max_tokens=max_tokens,
|
||
callbacks=callbacks,
|
||
)
|
||
prompt_name = "protect_prompt"
|
||
if (str(think_type["text"]) == "3"):
|
||
logging.info(f"💡 think_type == 3")
|
||
# tip ={}
|
||
tip["END"]=""
|
||
tip["source_docs"]=[]
|
||
tip["num"] = 0
|
||
tip["title"]=[]
|
||
utils.set_shared_variable(time_based_uuid,tip)
|
||
model = get_ChatOpenAI(
|
||
model_name=LLM_MODELS[3],
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
callbacks=callbacks,
|
||
)
|
||
# 改写原问题
|
||
# 遍历历史消息并收集用户消息
|
||
user_queries = [] # 初始化列表来收集用户消息
|
||
for message in history:
|
||
role, content = message # 解包元组
|
||
if role == 'user':
|
||
user_queries.append(content)
|
||
search_query = get_llm_model_response(
|
||
strategy_name="query rewrite",
|
||
llm_model_name=LLM_MODELS[0],
|
||
template_prompt_name="query_rewrite",
|
||
prompt_param_dict={"query": query, "history": user_queries, "time": datetime.now().strftime("%Y%m%d")},
|
||
temperature=0.01,
|
||
max_tokens=512
|
||
)
|
||
keys = json.loads(search_query).keys()
|
||
keys_list = list(keys)
|
||
first_json = {
|
||
"query": json.loads(search_query)[keys_list[0]],
|
||
"knowledge_name": [],
|
||
"keywords": []
|
||
}
|
||
second_json = {
|
||
"uuid": time_based_uuid
|
||
}
|
||
math_doc = search_tool(json.dumps(first_json) + json.dumps(second_json))
|
||
prompt_name = "default_math"
|
||
if (str(think_type["text"]) == "5"):
|
||
logging.info(f"💡 think_type == 5")
|
||
# tip ={}
|
||
tip["END"]=""
|
||
tip["source_docs"]=[]
|
||
tip["num"] = 0
|
||
tip["title"]=[]
|
||
utils.set_shared_variable(time_based_uuid,tip)
|
||
model = get_ChatOpenAI(
|
||
model_name=LLM_MODELS[2],
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
callbacks=callbacks,
|
||
)
|
||
search_query = get_llm_model_response(
|
||
strategy_name="query rewrite",
|
||
llm_model_name=LLM_MODELS[2],
|
||
template_prompt_name="query_rewrite",
|
||
prompt_param_dict={"query": query, "history": [], "time": datetime.now().strftime("%Y%m%d")},
|
||
temperature=0.01,
|
||
max_tokens=512
|
||
)
|
||
keys = json.loads(search_query).keys()
|
||
keys_list = list(keys)
|
||
first_json = {
|
||
"query": json.loads(search_query)[keys_list[0]],
|
||
"knowledge_name": ["coding"],
|
||
"keywords": []
|
||
}
|
||
second_json = {
|
||
"uuid": time_based_uuid
|
||
}
|
||
code_doc = search_tool(json.dumps(first_json) + json.dumps(second_json))
|
||
prompt_name = "default_code"
|
||
if history and prompt_name not in ["Search Summary", "get_policy_time", "customer_service"]:
|
||
history = [History.from_data(h) for h in history]
|
||
prompt_name = prompt_name + "_with_history"
|
||
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
||
if (str(think_type["text"]) == "3"):
|
||
prompt_template = prompt_template.replace("{{{math_doc}}}", math_doc)
|
||
if (str(think_type["text"]) == "5"):
|
||
prompt_template = prompt_template.replace("{{{code_doc}}}", code_doc)
|
||
# input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
|
||
# chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
|
||
chat_prompt = PromptTemplate.from_template(prompt_template)
|
||
# 把history转成memory
|
||
buff_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="history", input_key="input")
|
||
for message in history:
|
||
# 检查消息的角色
|
||
if message.role == 'user':
|
||
# 添加用户消息
|
||
buff_memory.chat_memory.add_user_message(message.content)
|
||
elif message.role == 'assistant':
|
||
# 添加AI消息
|
||
buff_memory.chat_memory.add_ai_message(message.content)
|
||
background_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="time", input_key="input")
|
||
message = SystemMessage(content = f'当前的时间是:{time}')
|
||
message.type = ""
|
||
background_memory.chat_memory.add_message(message)
|
||
# if(str(think_type["text"]) == "2"):
|
||
# solve_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="solve", input_key="input")
|
||
# message1 = SystemMessage(content = f'应对逻辑如下'+res["text"])
|
||
# solve_memory.chat_memory.add_message(message1)
|
||
# memory = CombinedMemory(memories=[background_memory,solve_memory, buff_memory])
|
||
# chain = ConversationChain(llm=model, verbose=True, memory=memory, prompt=chat_prompt)
|
||
# else:
|
||
memory = CombinedMemory(memories=[background_memory, buff_memory])
|
||
chain = ConversationChain(llm=model, verbose=True, memory=memory, prompt=chat_prompt)
|
||
# elif conversation_id and history_len > 0: # 前端要求从数据库取历史消息
|
||
# # 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量
|
||
# prompt = get_prompt_template("llm_chat", "with_history")
|
||
# chat_prompt = PromptTemplate.from_template(prompt)
|
||
# # 根据conversation_id 获取message 列表进而拼凑 memory
|
||
# memory = ConversationBufferDBMemory(conversation_id=conversation_id,
|
||
# llm=model,
|
||
# message_limit=history_len)
|
||
else:
|
||
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
||
if (str(think_type["text"]) == "3"):
|
||
prompt_template = prompt_template.replace("{{math_doc}}",math_doc)
|
||
if (str(think_type["text"]) == "5"):
|
||
prompt_template = prompt_template.replace("{{code_doc}}",code_doc)
|
||
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
|
||
# input_msg = History(role="user", content=query).to_msg_template(False)
|
||
chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
|
||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||
logger.info(f"当前提示词名称为:{prompt_name}")
|
||
# print(f'智能对话的chain>>>\n{chain}\n')
|
||
queue = asyncio.Queue()
|
||
if(prompt_name == "answer_question"):
|
||
task = asyncio.create_task(wrap_done(
|
||
chain.acall({"input": query, "time": time, "solve":res["text"]}),
|
||
callback.done),
|
||
)
|
||
elif(prompt_name == "customer_service"):
|
||
task = asyncio.create_task(wrap_done(
|
||
chain.acall({"input": query}),
|
||
callback.done),
|
||
)
|
||
else:
|
||
task = asyncio.create_task(wrap_done(
|
||
chain.acall({"input": query, "time": time}),
|
||
callback.done,
|
||
queue=queue
|
||
),
|
||
)
|
||
#用于标记是否开始生成正文
|
||
start = 0
|
||
if think_type["text"] == "3":
|
||
yield json.dumps({"text": "思考中..."}, ensure_ascii=False)
|
||
logging.info(f"过滤思考过程中。。。")
|
||
if stream:
|
||
menu = 0 #0处于deepseek思考过程中的状态1处于生成正文状态
|
||
include_think = False #是否包含思考(源码修改的手动拼接的思考标签)
|
||
async for token in callback.aiter():
|
||
if not utils.get_shared_variable(time_based_uuid)["status"]:
|
||
logging.info("\n==============================STOPPED==============================\n")
|
||
break
|
||
# Use server-sent-events to stream the response
|
||
history_summary += token
|
||
if think_type["text"] == "3":
|
||
# if start == 1:
|
||
# yield json.dumps(
|
||
# {"text": token, "message_id": message_id},
|
||
# ensure_ascii=False)
|
||
# if "</think>" in history_summary:
|
||
# start = 1
|
||
if "<think>" in token:
|
||
include_think = True
|
||
token = token.replace("<think>","")
|
||
#yield json.dumps({"text": token}, ensure_ascii=False)
|
||
else:
|
||
if menu == 1:
|
||
yield json.dumps({"text": token}, ensure_ascii=False)
|
||
if menu == 0 and include_think:
|
||
yield json.dumps({"text": token}, ensure_ascii=False)
|
||
menu = 1
|
||
if not include_think:
|
||
yield json.dumps({"text": token}, ensure_ascii=False)
|
||
elif (str(think_type["text"]) == "7"):
|
||
if "<think>" in token:
|
||
include_think1 = True
|
||
token = token.replace("<think>","")
|
||
logger.info(f"think:{token}")
|
||
yield json.dumps({"think": token}, ensure_ascii=False)
|
||
elif "</think>" in token:
|
||
include_think1 = False
|
||
token = token.replace("</think>","")
|
||
logger.info(f"think:{token}")
|
||
yield json.dumps({"think": token}, ensure_ascii=False)
|
||
else:
|
||
if include_think1:
|
||
yield json.dumps({"think": token}, ensure_ascii=False)
|
||
else:
|
||
yield json.dumps({"text": token}, ensure_ascii=False)
|
||
else:
|
||
if model_name in DEEPSEEK_MODELS:
|
||
if "<think>" in token:
|
||
include_think = True
|
||
token = token.replace("<think>","")
|
||
logger.info(f"think:{token}")
|
||
yield json.dumps({"think": token}, ensure_ascii=False)
|
||
else:
|
||
if menu == 1:
|
||
yield json.dumps({"text": token}, ensure_ascii=False)
|
||
if menu == 0 and include_think:
|
||
yield json.dumps({"text": token}, ensure_ascii=False)
|
||
menu = 1
|
||
if not include_think:
|
||
yield json.dumps({"text": token}, ensure_ascii=False)
|
||
else:
|
||
yield json.dumps(
|
||
{"text": token, "message_id": message_id},
|
||
ensure_ascii=False)
|
||
|
||
else:
|
||
answer = ""
|
||
async for token in callback.aiter():
|
||
answer += token
|
||
yield json.dumps(
|
||
{"text": answer, "message_id": message_id},
|
||
ensure_ascii=False)
|
||
history_summary += answer
|
||
question_history = [
|
||
{"role": "user", "content": query},
|
||
{"role": "assistant", "content": history_summary}
|
||
]
|
||
question = get_llm_model_response(
|
||
strategy_name="question_recommend",
|
||
llm_model_name=LLM_MODELS[0],
|
||
template_prompt_name="question_recommend",
|
||
prompt_param_dict={"history": question_history},
|
||
temperature=0.3,
|
||
max_tokens=512
|
||
).strip()
|
||
formatted = split_questions(question)
|
||
logger.info(f"推荐问题: \n{formatted}")
|
||
yield json.dumps({"question": formatted}, ensure_ascii=False)
|
||
summary = TextRank(history_summary, 80)
|
||
yield json.dumps({"summary":summary}, ensure_ascii=False)
|
||
del question_history
|
||
await task
|
||
await task
|
||
if not queue.empty():
|
||
yield json.dumps({"text": "\n"}, ensure_ascii=False)
|
||
yield json.dumps({"text": "<span style='color:red'>检测到当前内容涉及敏感信息,请换个问题再次尝试。</span>"}, ensure_ascii=False)
|
||
return EventSourceResponse(chat_iterator())
|
||
|