[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
186
langchain-chat/server/chat/policy_fun_iast.py
Normal file
186
langchain-chat/server/chat/policy_fun_iast.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import asyncio
|
||||
import time
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from configs import KB_PROMPT, LLM_PROMPT, logger
|
||||
|
||||
from configs.prompt_config import AGENT_PROMPT, AGENT_WRITE_PROMPT, COMPARISON
|
||||
from server.chat.utils import History
|
||||
from server.utils import get_prompt_template, get_strategy_prompt_template, get_ChatOpenAI
|
||||
import openai
|
||||
from typing import Any, AsyncGenerator
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
MAX_RETRIES = 2
|
||||
RETRY_DELAY = 1
|
||||
MAX_MAX_TOKENS = 8192 # 阿里云 DashScope API 限制
|
||||
|
||||
def get_llm_model_response(
|
||||
strategy_name: str,
|
||||
llm_model_name: str,
|
||||
template_prompt_name: str,
|
||||
prompt_param_dict: dict,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
'''调用大模型,实现不同策略'''
|
||||
# 校验 max_tokens 不超过 API 限制
|
||||
if max_tokens is not None and max_tokens > MAX_MAX_TOKENS:
|
||||
logger.warning(f"max_tokens({max_tokens}) 超过 API 限制,已调整为 {MAX_MAX_TOKENS}")
|
||||
max_tokens = MAX_MAX_TOKENS
|
||||
|
||||
# 读取指定的大模型,这里不能加入callback,否则会把这部分模型响应加入最终的回答
|
||||
# 同步调用关闭 streaming,避免流式传输错误
|
||||
model = get_ChatOpenAI(
|
||||
model_name=llm_model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[],
|
||||
# streaming=False,
|
||||
**kwargs
|
||||
)
|
||||
# 获取prompt
|
||||
if template_prompt_name in KB_PROMPT:
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", template_prompt_name)
|
||||
elif template_prompt_name in LLM_PROMPT:
|
||||
prompt_template = get_prompt_template("llm_chat", template_prompt_name)
|
||||
elif template_prompt_name in COMPARISON:
|
||||
prompt_template = get_prompt_template("comparison_chat", template_prompt_name)
|
||||
# 此处仅对全新agent流程的模板提示词奏效如果添加其他请注意是否冲突
|
||||
elif template_prompt_name in AGENT_PROMPT:
|
||||
if not template_prompt_name == "Think Test Bak" and not template_prompt_name == "get_next_tip":
|
||||
prompt_template1 = get_prompt_template("agent_chat", "Think Test Bak")
|
||||
prompt_template2 = get_prompt_template("agent_chat", template_prompt_name)
|
||||
prompt_template = f"{prompt_template1}{prompt_template2}"
|
||||
else:
|
||||
prompt_template = get_prompt_template("agent_chat", template_prompt_name)
|
||||
elif template_prompt_name in AGENT_WRITE_PROMPT:
|
||||
if not template_prompt_name == "Write Test Bak" and not template_prompt_name == "get_next_write_tip":
|
||||
prompt_template1 = get_prompt_template("agent_chat", "Write Test Bak")
|
||||
prompt_template2 = get_prompt_template("agent_chat", template_prompt_name)
|
||||
prompt_template = f"{prompt_template1}{prompt_template2}"
|
||||
else:
|
||||
prompt_template = get_prompt_template("agent_chat", template_prompt_name)
|
||||
else:
|
||||
prompt_template = get_strategy_prompt_template("knowledge_base_chat", template_prompt_name)
|
||||
input_msg = History(role="system", content=prompt_template).to_msg_template(False)
|
||||
prompt = ChatPromptTemplate.from_messages([input_msg])
|
||||
# print("strategy_prompt_name: ",template_prompt_name, "\n","strategy_prompt:",prompt_template)
|
||||
|
||||
# 获取模型响应,带重试机制
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
while retry_count <= MAX_RETRIES:
|
||||
try:
|
||||
llm_chain = LLMChain(prompt=prompt, llm=model, verbose=True)
|
||||
model_response = llm_chain.run(prompt_param_dict)
|
||||
# print(f'---------after {strategy_name}------------------')
|
||||
# print(model_response)
|
||||
return model_response
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
retry_count += 1
|
||||
if retry_count > MAX_RETRIES:
|
||||
logger.error(f"LLM调用失败,已达到最大重试次数 {MAX_RETRIES}: {e}")
|
||||
raise
|
||||
logger.warning(f"LLM调用第 {retry_count} 次失败,{RETRY_DELAY}秒后重试: {e}")
|
||||
time.sleep(RETRY_DELAY)
|
||||
# 重新创建 model,关闭 streaming
|
||||
model = get_ChatOpenAI(
|
||||
model_name=llm_model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[],
|
||||
streaming=False,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def get_llm_model_response_async(
|
||||
strategy_name: str,
|
||||
llm_model_name: str,
|
||||
template_prompt_name: str,
|
||||
prompt_param_dict: dict,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> str:
|
||||
'''异步调用大模型,实现不同策略'''
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
get_llm_model_response,
|
||||
strategy_name,
|
||||
llm_model_name,
|
||||
template_prompt_name,
|
||||
prompt_param_dict,
|
||||
temperature,
|
||||
max_tokens
|
||||
)
|
||||
|
||||
|
||||
async def get_llm_model_response_stream_openai(
|
||||
type: int,
|
||||
strategy_name: str,
|
||||
llm_model_name: str,
|
||||
template_prompt_name: str,
|
||||
prompt_param_dict: dict,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
# 校验 max_tokens
|
||||
if max_tokens is not None and max_tokens > MAX_MAX_TOKENS:
|
||||
max_tokens = MAX_MAX_TOKENS
|
||||
|
||||
retry_count = 0
|
||||
|
||||
while retry_count <= MAX_RETRIES:
|
||||
try:
|
||||
if type == 0 or type == 2:
|
||||
kwargs = {}
|
||||
kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": True}}
|
||||
model = get_ChatOpenAI(
|
||||
model_name=llm_model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[],
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
model = get_ChatOpenAI(
|
||||
model_name=llm_model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[],
|
||||
)
|
||||
# 调用流式接口
|
||||
if type == 0:
|
||||
prompt_template1 = get_prompt_template("agent_chat", "Think Test Bak")
|
||||
if type == 1:
|
||||
prompt_template1 = get_prompt_template("agent_chat", "Write Test Bak")
|
||||
if type == 2:
|
||||
prompt_template = get_prompt_template("llm_chat", template_prompt_name)
|
||||
else:
|
||||
prompt_template2 = get_prompt_template("agent_chat", template_prompt_name)
|
||||
prompt_template = f"{prompt_template1}{prompt_template2}"
|
||||
for key in prompt_param_dict:
|
||||
prompt_template = prompt_template.replace(f"{{{{{key}}}}}", prompt_param_dict[key])
|
||||
messages = [HumanMessage(content=prompt_template)]
|
||||
if type == 0 or type == 2:
|
||||
start_flag = False
|
||||
async for chunk in model.astream(messages):
|
||||
if start_flag:
|
||||
yield chunk.content
|
||||
if "<think>" in chunk.content:
|
||||
start_flag = True
|
||||
else:
|
||||
async for chunk in model.astream(messages):
|
||||
yield chunk.content
|
||||
return # 成功完成,退出函数
|
||||
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count > MAX_RETRIES:
|
||||
logger.error(f"流式LLM调用失败,已达到最大重试次数 {MAX_RETRIES}: {e}")
|
||||
raise
|
||||
logger.warning(f"流式LLM调用第 {retry_count} 次失败,{RETRY_DELAY}秒后重试: {e}")
|
||||
await asyncio.sleep(RETRY_DELAY)
|
||||
Reference in New Issue
Block a user