Files
gangyan/langchain-chat/server/chat/sentence_reference.py

76 lines
3.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import Body
from configs import LLM_MODELS, TEMPERATURE, MAX_TOKENS
from server.chat.policy_fun_iast import get_llm_model_response
from typing import Optional
from langchain.chains import LLMChain
from langchain.prompts import ChatPromptTemplate
from server.chat.utils import History
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from langchain.callbacks import AsyncIteratorCallbackHandler
import asyncio
from server.knowledge_base.kb_service.base import TextRank
from configs.basic_config import *
async def sentence_reference(
# context: str = Body(..., description="上文全文", examples=[""]),
paragraph_content: str = Body(..., description="用户框选的内容,<=2句", examples=[""]),
temperature: float = Body(0.9, description="LLM 采样温度", ge=0.0, le=2.0),
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量默认None代表模型最大值"),
):
logger.info(f"开始提示句子...")
# 定义生成摘要的函数
# def generate_summary(text: str) -> str:
# """使用 TextRank 生成文本摘要"""
# if len(text) <= 20000:
# summary = TextRank(text, num_sentences=60) # 生成60句话的摘要
# else:
# summary = TextRank(text, num_sentences=80) # 生成80句话的摘要
# return summary
# # 根据上下文长度决定是否生成摘要
# if len(context) >= 15000:
# context_summary = generate_summary(context)
# logger.info(f"生成撰写文稿的摘要: %s", context_summary)
# else:
# context_summary = context # 直接使用原文
# logger.info(f"撰写文稿小于15000字符使用原文")
# 定义一个函数来调用 get_llm_model_response并异步封装它
async def get_sentence_reference():
try:
# 使用 asyncio.to_thread 封装同步函数
result = await asyncio.to_thread(
get_llm_model_response,
strategy_name="sentence_reference",
llm_model_name=LLM_MODELS[0],
template_prompt_name="sentence_reference",
prompt_param_dict={
# "context": context_summary, # 使用摘要或原文
"paragraph_content": paragraph_content
},
temperature=temperature,
max_tokens=max_tokens
)
return result
except Exception as e:
logger.error("生成提示句子内容时出错: %s", e)
return "出错了。。请重试。。"
# 并行调用三次 get_llm_model_response
try:
responses = await asyncio.gather(
get_sentence_reference(),
get_sentence_reference(),
get_sentence_reference()
)
except Exception as e:
logger.error("并行调用 LLM 模型时出错: %s", e)
return "出错了。。请重试。。"
# 拼接结果
final_output = "\n\n".join([f"句子{i + 1}{response}" for i, response in enumerate(responses)])
logger.info("生成的最终拼接内容: %s", final_output)
return final_output