Files
gangyan/langchain-chat/server/chat/policy_fun_iast.py

187 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import time
from langchain.chains import LLMChain
from langchain_core.prompts import ChatPromptTemplate
from configs import KB_PROMPT, LLM_PROMPT, logger
from configs.prompt_config import AGENT_PROMPT, AGENT_WRITE_PROMPT, COMPARISON
from server.chat.utils import History
from server.utils import get_prompt_template, get_strategy_prompt_template, get_ChatOpenAI
import openai
from typing import Any, AsyncGenerator
from langchain.schema import HumanMessage
MAX_RETRIES = 2
RETRY_DELAY = 1
MAX_MAX_TOKENS = 8192 # 阿里云 DashScope API 限制
def get_llm_model_response(
strategy_name: str,
llm_model_name: str,
template_prompt_name: str,
prompt_param_dict: dict,
temperature: float,
max_tokens: int,
**kwargs: Any,
) -> str:
'''调用大模型,实现不同策略'''
# 校验 max_tokens 不超过 API 限制
if max_tokens is not None and max_tokens > MAX_MAX_TOKENS:
logger.warning(f"max_tokens({max_tokens}) 超过 API 限制,已调整为 {MAX_MAX_TOKENS}")
max_tokens = MAX_MAX_TOKENS
# 读取指定的大模型这里不能加入callback否则会把这部分模型响应加入最终的回答
# 同步调用关闭 streaming避免流式传输错误
model = get_ChatOpenAI(
model_name=llm_model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[],
# streaming=False,
**kwargs
)
# 获取prompt
if template_prompt_name in KB_PROMPT:
prompt_template = get_prompt_template("knowledge_base_chat", template_prompt_name)
elif template_prompt_name in LLM_PROMPT:
prompt_template = get_prompt_template("llm_chat", template_prompt_name)
elif template_prompt_name in COMPARISON:
prompt_template = get_prompt_template("comparison_chat", template_prompt_name)
# 此处仅对全新agent流程的模板提示词奏效如果添加其他请注意是否冲突
elif template_prompt_name in AGENT_PROMPT:
if not template_prompt_name == "Think Test Bak" and not template_prompt_name == "get_next_tip":
prompt_template1 = get_prompt_template("agent_chat", "Think Test Bak")
prompt_template2 = get_prompt_template("agent_chat", template_prompt_name)
prompt_template = f"{prompt_template1}{prompt_template2}"
else:
prompt_template = get_prompt_template("agent_chat", template_prompt_name)
elif template_prompt_name in AGENT_WRITE_PROMPT:
if not template_prompt_name == "Write Test Bak" and not template_prompt_name == "get_next_write_tip":
prompt_template1 = get_prompt_template("agent_chat", "Write Test Bak")
prompt_template2 = get_prompt_template("agent_chat", template_prompt_name)
prompt_template = f"{prompt_template1}{prompt_template2}"
else:
prompt_template = get_prompt_template("agent_chat", template_prompt_name)
else:
prompt_template = get_strategy_prompt_template("knowledge_base_chat", template_prompt_name)
input_msg = History(role="system", content=prompt_template).to_msg_template(False)
prompt = ChatPromptTemplate.from_messages([input_msg])
# print("strategy_prompt_name: ",template_prompt_name, "\n","strategy_prompt:",prompt_template)
# 获取模型响应,带重试机制
retry_count = 0
last_error = None
while retry_count <= MAX_RETRIES:
try:
llm_chain = LLMChain(prompt=prompt, llm=model, verbose=True)
model_response = llm_chain.run(prompt_param_dict)
# print(f'---------after {strategy_name}------------------')
# print(model_response)
return model_response
except Exception as e:
last_error = e
retry_count += 1
if retry_count > MAX_RETRIES:
logger.error(f"LLM调用失败已达到最大重试次数 {MAX_RETRIES}: {e}")
raise
logger.warning(f"LLM调用第 {retry_count} 次失败,{RETRY_DELAY}秒后重试: {e}")
time.sleep(RETRY_DELAY)
# 重新创建 model关闭 streaming
model = get_ChatOpenAI(
model_name=llm_model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[],
streaming=False,
**kwargs
)
async def get_llm_model_response_async(
strategy_name: str,
llm_model_name: str,
template_prompt_name: str,
prompt_param_dict: dict,
temperature: float,
max_tokens: int,
) -> str:
'''异步调用大模型,实现不同策略'''
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
get_llm_model_response,
strategy_name,
llm_model_name,
template_prompt_name,
prompt_param_dict,
temperature,
max_tokens
)
async def get_llm_model_response_stream_openai(
type: int,
strategy_name: str,
llm_model_name: str,
template_prompt_name: str,
prompt_param_dict: dict,
temperature: float,
max_tokens: int,
) -> AsyncGenerator[str, None]:
# 校验 max_tokens
if max_tokens is not None and max_tokens > MAX_MAX_TOKENS:
max_tokens = MAX_MAX_TOKENS
retry_count = 0
while retry_count <= MAX_RETRIES:
try:
if type == 0 or type == 2:
kwargs = {}
kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": True}}
model = get_ChatOpenAI(
model_name=llm_model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[],
**kwargs
)
else:
model = get_ChatOpenAI(
model_name=llm_model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[],
)
# 调用流式接口
if type == 0:
prompt_template1 = get_prompt_template("agent_chat", "Think Test Bak")
if type == 1:
prompt_template1 = get_prompt_template("agent_chat", "Write Test Bak")
if type == 2:
prompt_template = get_prompt_template("llm_chat", template_prompt_name)
else:
prompt_template2 = get_prompt_template("agent_chat", template_prompt_name)
prompt_template = f"{prompt_template1}{prompt_template2}"
for key in prompt_param_dict:
prompt_template = prompt_template.replace(f"{{{{{key}}}}}", prompt_param_dict[key])
messages = [HumanMessage(content=prompt_template)]
if type == 0 or type == 2:
start_flag = False
async for chunk in model.astream(messages):
if start_flag:
yield chunk.content
if "<think>" in chunk.content:
start_flag = True
else:
async for chunk in model.astream(messages):
yield chunk.content
return # 成功完成,退出函数
except Exception as e:
retry_count += 1
if retry_count > MAX_RETRIES:
logger.error(f"流式LLM调用失败已达到最大重试次数 {MAX_RETRIES}: {e}")
raise
logger.warning(f"流式LLM调用第 {retry_count} 次失败,{RETRY_DELAY}秒后重试: {e}")
await asyncio.sleep(RETRY_DELAY)