76 lines
3.1 KiB
Python
76 lines
3.1 KiB
Python
|
|
from fastapi import Body
|
|||
|
|
from configs import LLM_MODELS, TEMPERATURE, MAX_TOKENS
|
|||
|
|
from server.chat.policy_fun_iast import get_llm_model_response
|
|||
|
|
from typing import Optional
|
|||
|
|
from langchain.chains import LLMChain
|
|||
|
|
from langchain.prompts import ChatPromptTemplate
|
|||
|
|
from server.chat.utils import History
|
|||
|
|
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
|||
|
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
|||
|
|
import asyncio
|
|||
|
|
from server.knowledge_base.kb_service.base import TextRank
|
|||
|
|
from configs.basic_config import *
|
|||
|
|
|
|||
|
|
async def sentence_reference(
|
|||
|
|
# context: str = Body(..., description="上文全文", examples=[""]),
|
|||
|
|
paragraph_content: str = Body(..., description="用户框选的内容,<=2句", examples=[""]),
|
|||
|
|
temperature: float = Body(0.9, description="LLM 采样温度", ge=0.0, le=2.0),
|
|||
|
|
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
|||
|
|
):
|
|||
|
|
logger.info(f"开始提示句子...")
|
|||
|
|
|
|||
|
|
# 定义生成摘要的函数
|
|||
|
|
# def generate_summary(text: str) -> str:
|
|||
|
|
# """使用 TextRank 生成文本摘要"""
|
|||
|
|
# if len(text) <= 20000:
|
|||
|
|
# summary = TextRank(text, num_sentences=60) # 生成60句话的摘要
|
|||
|
|
# else:
|
|||
|
|
# summary = TextRank(text, num_sentences=80) # 生成80句话的摘要
|
|||
|
|
# return summary
|
|||
|
|
|
|||
|
|
# # 根据上下文长度决定是否生成摘要
|
|||
|
|
# if len(context) >= 15000:
|
|||
|
|
# context_summary = generate_summary(context)
|
|||
|
|
# logger.info(f"生成撰写文稿的摘要: %s", context_summary)
|
|||
|
|
# else:
|
|||
|
|
# context_summary = context # 直接使用原文
|
|||
|
|
# logger.info(f"撰写文稿小于15000字符,使用原文")
|
|||
|
|
|
|||
|
|
# 定义一个函数来调用 get_llm_model_response,并异步封装它
|
|||
|
|
async def get_sentence_reference():
|
|||
|
|
try:
|
|||
|
|
# 使用 asyncio.to_thread 封装同步函数
|
|||
|
|
result = await asyncio.to_thread(
|
|||
|
|
get_llm_model_response,
|
|||
|
|
strategy_name="sentence_reference",
|
|||
|
|
llm_model_name=LLM_MODELS[0],
|
|||
|
|
template_prompt_name="sentence_reference",
|
|||
|
|
prompt_param_dict={
|
|||
|
|
# "context": context_summary, # 使用摘要或原文
|
|||
|
|
"paragraph_content": paragraph_content
|
|||
|
|
},
|
|||
|
|
temperature=temperature,
|
|||
|
|
max_tokens=max_tokens
|
|||
|
|
)
|
|||
|
|
return result
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error("生成提示句子内容时出错: %s", e)
|
|||
|
|
return "出错了。。请重试。。"
|
|||
|
|
|
|||
|
|
# 并行调用三次 get_llm_model_response
|
|||
|
|
try:
|
|||
|
|
responses = await asyncio.gather(
|
|||
|
|
get_sentence_reference(),
|
|||
|
|
get_sentence_reference(),
|
|||
|
|
get_sentence_reference()
|
|||
|
|
)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error("并行调用 LLM 模型时出错: %s", e)
|
|||
|
|
return "出错了。。请重试。。"
|
|||
|
|
|
|||
|
|
# 拼接结果
|
|||
|
|
final_output = "\n\n".join([f"句子{i + 1}:{response}" for i, response in enumerate(responses)])
|
|||
|
|
logger.info("生成的最终拼接内容: %s", final_output)
|
|||
|
|
|
|||
|
|
return final_output
|