2026-04-02 11:36:05 +08:00
|
|
|
|
import asyncio
|
|
|
|
|
|
import time
|
|
|
|
|
|
from langchain.chains import LLMChain
|
|
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
|
|
from configs import KB_PROMPT, LLM_PROMPT, logger
|
|
|
|
|
|
|
|
|
|
|
|
from configs.prompt_config import AGENT_PROMPT, AGENT_WRITE_PROMPT, COMPARISON
|
|
|
|
|
|
from server.chat.utils import History
|
|
|
|
|
|
from server.utils import get_prompt_template, get_strategy_prompt_template, get_ChatOpenAI
|
|
|
|
|
|
import openai
|
|
|
|
|
|
from typing import Any, AsyncGenerator
|
|
|
|
|
|
from langchain.schema import HumanMessage
|
|
|
|
|
|
|
|
|
|
|
|
MAX_RETRIES = 2
|
|
|
|
|
|
RETRY_DELAY = 1
|
|
|
|
|
|
MAX_MAX_TOKENS = 8192 # 阿里云 DashScope API 限制
|
|
|
|
|
|
|
|
|
|
|
|
def get_llm_model_response(
|
|
|
|
|
|
strategy_name: str,
|
|
|
|
|
|
llm_model_name: str,
|
|
|
|
|
|
template_prompt_name: str,
|
|
|
|
|
|
prompt_param_dict: dict,
|
|
|
|
|
|
temperature: float,
|
|
|
|
|
|
max_tokens: int,
|
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
'''调用大模型,实现不同策略'''
|
|
|
|
|
|
# 校验 max_tokens 不超过 API 限制
|
|
|
|
|
|
if max_tokens is not None and max_tokens > MAX_MAX_TOKENS:
|
|
|
|
|
|
logger.warning(f"max_tokens({max_tokens}) 超过 API 限制,已调整为 {MAX_MAX_TOKENS}")
|
|
|
|
|
|
max_tokens = MAX_MAX_TOKENS
|
|
|
|
|
|
|
|
|
|
|
|
# 读取指定的大模型,这里不能加入callback,否则会把这部分模型响应加入最终的回答
|
|
|
|
|
|
# 同步调用关闭 streaming,避免流式传输错误
|
|
|
|
|
|
model = get_ChatOpenAI(
|
|
|
|
|
|
model_name=llm_model_name,
|
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
|
|
callbacks=[],
|
|
|
|
|
|
# streaming=False,
|
|
|
|
|
|
**kwargs
|
|
|
|
|
|
)
|
|
|
|
|
|
# 获取prompt
|
|
|
|
|
|
if template_prompt_name in KB_PROMPT:
|
|
|
|
|
|
prompt_template = get_prompt_template("knowledge_base_chat", template_prompt_name)
|
|
|
|
|
|
elif template_prompt_name in LLM_PROMPT:
|
|
|
|
|
|
prompt_template = get_prompt_template("llm_chat", template_prompt_name)
|
|
|
|
|
|
elif template_prompt_name in COMPARISON:
|
|
|
|
|
|
prompt_template = get_prompt_template("comparison_chat", template_prompt_name)
|
|
|
|
|
|
# 此处仅对全新agent流程的模板提示词奏效如果添加其他请注意是否冲突
|
|
|
|
|
|
elif template_prompt_name in AGENT_PROMPT:
|
|
|
|
|
|
if not template_prompt_name == "Think Test Bak" and not template_prompt_name == "get_next_tip":
|
|
|
|
|
|
prompt_template1 = get_prompt_template("agent_chat", "Think Test Bak")
|
|
|
|
|
|
prompt_template2 = get_prompt_template("agent_chat", template_prompt_name)
|
|
|
|
|
|
prompt_template = f"{prompt_template1}{prompt_template2}"
|
|
|
|
|
|
else:
|
|
|
|
|
|
prompt_template = get_prompt_template("agent_chat", template_prompt_name)
|
|
|
|
|
|
elif template_prompt_name in AGENT_WRITE_PROMPT:
|
|
|
|
|
|
if not template_prompt_name == "Write Test Bak" and not template_prompt_name == "get_next_write_tip":
|
|
|
|
|
|
prompt_template1 = get_prompt_template("agent_chat", "Write Test Bak")
|
|
|
|
|
|
prompt_template2 = get_prompt_template("agent_chat", template_prompt_name)
|
|
|
|
|
|
prompt_template = f"{prompt_template1}{prompt_template2}"
|
|
|
|
|
|
else:
|
|
|
|
|
|
prompt_template = get_prompt_template("agent_chat", template_prompt_name)
|
|
|
|
|
|
else:
|
|
|
|
|
|
prompt_template = get_strategy_prompt_template("knowledge_base_chat", template_prompt_name)
|
|
|
|
|
|
input_msg = History(role="system", content=prompt_template).to_msg_template(False)
|
|
|
|
|
|
prompt = ChatPromptTemplate.from_messages([input_msg])
|
|
|
|
|
|
# print("strategy_prompt_name: ",template_prompt_name, "\n","strategy_prompt:",prompt_template)
|
|
|
|
|
|
|
|
|
|
|
|
# 获取模型响应,带重试机制
|
|
|
|
|
|
retry_count = 0
|
|
|
|
|
|
last_error = None
|
|
|
|
|
|
while retry_count <= MAX_RETRIES:
|
|
|
|
|
|
try:
|
|
|
|
|
|
llm_chain = LLMChain(prompt=prompt, llm=model, verbose=True)
|
|
|
|
|
|
model_response = llm_chain.run(prompt_param_dict)
|
|
|
|
|
|
# print(f'---------after {strategy_name}------------------')
|
|
|
|
|
|
# print(model_response)
|
|
|
|
|
|
return model_response
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
last_error = e
|
|
|
|
|
|
retry_count += 1
|
|
|
|
|
|
if retry_count > MAX_RETRIES:
|
|
|
|
|
|
logger.error(f"LLM调用失败,已达到最大重试次数 {MAX_RETRIES}: {e}")
|
|
|
|
|
|
raise
|
|
|
|
|
|
logger.warning(f"LLM调用第 {retry_count} 次失败,{RETRY_DELAY}秒后重试: {e}")
|
|
|
|
|
|
time.sleep(RETRY_DELAY)
|
|
|
|
|
|
# 重新创建 model,关闭 streaming
|
|
|
|
|
|
model = get_ChatOpenAI(
|
|
|
|
|
|
model_name=llm_model_name,
|
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
|
|
callbacks=[],
|
|
|
|
|
|
streaming=False,
|
|
|
|
|
|
**kwargs
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async def get_llm_model_response_async(
|
|
|
|
|
|
strategy_name: str,
|
|
|
|
|
|
llm_model_name: str,
|
|
|
|
|
|
template_prompt_name: str,
|
|
|
|
|
|
prompt_param_dict: dict,
|
|
|
|
|
|
temperature: float,
|
|
|
|
|
|
max_tokens: int,
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
'''异步调用大模型,实现不同策略'''
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
|
return await loop.run_in_executor(
|
|
|
|
|
|
None,
|
|
|
|
|
|
get_llm_model_response,
|
|
|
|
|
|
strategy_name,
|
|
|
|
|
|
llm_model_name,
|
|
|
|
|
|
template_prompt_name,
|
|
|
|
|
|
prompt_param_dict,
|
|
|
|
|
|
temperature,
|
|
|
|
|
|
max_tokens
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_llm_model_response_stream_openai(
|
|
|
|
|
|
type: int,
|
|
|
|
|
|
strategy_name: str,
|
|
|
|
|
|
llm_model_name: str,
|
|
|
|
|
|
template_prompt_name: str,
|
|
|
|
|
|
prompt_param_dict: dict,
|
|
|
|
|
|
temperature: float,
|
|
|
|
|
|
max_tokens: int,
|
|
|
|
|
|
) -> AsyncGenerator[str, None]:
|
|
|
|
|
|
# 校验 max_tokens
|
|
|
|
|
|
if max_tokens is not None and max_tokens > MAX_MAX_TOKENS:
|
|
|
|
|
|
max_tokens = MAX_MAX_TOKENS
|
|
|
|
|
|
|
|
|
|
|
|
retry_count = 0
|
|
|
|
|
|
|
|
|
|
|
|
while retry_count <= MAX_RETRIES:
|
|
|
|
|
|
try:
|
|
|
|
|
|
if type == 0 or type == 2:
|
|
|
|
|
|
kwargs = {}
|
|
|
|
|
|
kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": True}}
|
|
|
|
|
|
model = get_ChatOpenAI(
|
|
|
|
|
|
model_name=llm_model_name,
|
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
|
|
callbacks=[],
|
|
|
|
|
|
**kwargs
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
model = get_ChatOpenAI(
|
|
|
|
|
|
model_name=llm_model_name,
|
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
|
|
callbacks=[],
|
|
|
|
|
|
)
|
|
|
|
|
|
# 调用流式接口
|
|
|
|
|
|
if type == 0:
|
|
|
|
|
|
prompt_template1 = get_prompt_template("agent_chat", "Think Test Bak")
|
|
|
|
|
|
if type == 1:
|
|
|
|
|
|
prompt_template1 = get_prompt_template("agent_chat", "Write Test Bak")
|
|
|
|
|
|
if type == 2:
|
|
|
|
|
|
prompt_template = get_prompt_template("llm_chat", template_prompt_name)
|
|
|
|
|
|
else:
|
|
|
|
|
|
prompt_template2 = get_prompt_template("agent_chat", template_prompt_name)
|
|
|
|
|
|
prompt_template = f"{prompt_template1}{prompt_template2}"
|
|
|
|
|
|
for key in prompt_param_dict:
|
|
|
|
|
|
prompt_template = prompt_template.replace(f"{{{{{key}}}}}", prompt_param_dict[key])
|
|
|
|
|
|
messages = [HumanMessage(content=prompt_template)]
|
2026-05-07 15:20:00 +08:00
|
|
|
|
# 跳过 <think>...</think> 块,其余照常 yield
|
|
|
|
|
|
# 兼容 R1 等输出 think 块的模型;非 think 模型不受影响
|
|
|
|
|
|
in_think = False
|
|
|
|
|
|
async for chunk in model.astream(messages):
|
|
|
|
|
|
text = chunk.content or ""
|
|
|
|
|
|
while text:
|
|
|
|
|
|
if not in_think:
|
|
|
|
|
|
i = text.find("<think>")
|
|
|
|
|
|
if i < 0:
|
|
|
|
|
|
yield text
|
|
|
|
|
|
break
|
|
|
|
|
|
if i > 0:
|
|
|
|
|
|
yield text[:i]
|
|
|
|
|
|
text = text[i + len("<think>"):]
|
|
|
|
|
|
in_think = True
|
|
|
|
|
|
else:
|
|
|
|
|
|
i = text.find("</think>")
|
|
|
|
|
|
if i < 0:
|
|
|
|
|
|
text = "" # 全在 think 块内,丢弃
|
|
|
|
|
|
else:
|
|
|
|
|
|
text = text[i + len("</think>"):]
|
|
|
|
|
|
in_think = False
|
2026-04-02 11:36:05 +08:00
|
|
|
|
return # 成功完成,退出函数
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
retry_count += 1
|
|
|
|
|
|
if retry_count > MAX_RETRIES:
|
|
|
|
|
|
logger.error(f"流式LLM调用失败,已达到最大重试次数 {MAX_RETRIES}: {e}")
|
|
|
|
|
|
raise
|
|
|
|
|
|
logger.warning(f"流式LLM调用第 {retry_count} 次失败,{RETRY_DELAY}秒后重试: {e}")
|
|
|
|
|
|
await asyncio.sleep(RETRY_DELAY)
|