Files
gangyan/langchain-chat/server/chat/abb_rewrite.py

97 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import Body
from configs import LLM_MODELS, TEMPERATURE, MAX_TOKENS
from server.chat.policy_fun_iast import get_llm_model_response
from typing import Optional
from langchain.chains import LLMChain
from langchain.prompts import ChatPromptTemplate
from server.chat.utils import History
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from langchain.callbacks import AsyncIteratorCallbackHandler
import asyncio
from server.knowledge_base.kb_service.base import TextRank
from configs.basic_config import *
async def abb_rewrite(
# context: Optional[str] = Body(..., description="当前已撰写的全文", examples=[""]),
query: str = Body(..., description="用户框选的段落", examples=[""]),
# previous_text: Optional[str] = Body(..., description="用户框选段落的前文", examples=[""]),
# following_text: Optional[str] = Body(..., description="用户框选段落的后文", examples=[""]),
# con_direction: Optional[str] = Body("", description="用户输入的缩写指令", examples=[""]),
# stream: bool = Body(False, description="是否流式输出", examples=[False]),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
max_tokens: Optional[int] = Body(512, description="限制LLM生成Token数量默认None代表模型最大值"),
# prompt_name: Optional[str] = Body("abb_rewrite"),
):
logger.info(f"开始缩写...")
# 定义生成摘要的函数
# def generate_summary(text: str) -> str:
# """使用 TextRank 生成文本摘要"""
# if len(text) <= 20000:
# summary = TextRank(text, num_sentences=60) # 生成60句话的摘要
# else:
# summary = TextRank(text, num_sentences=80) # 生成80句话的摘要
# return summary
# # 根据上下文长度决定是否生成摘要
# if len(context) >= 15000:
# context_summary = generate_summary(context)
# logger.info(f"生成撰写文稿的摘要: %s", context_summary)
# else:
# context_summary = context # 直接使用原文
# logger.info(f"撰写文稿小于15000字符使用原文")
# 调用模型生成缩写内容
try:
abb_rewrite_content = get_llm_model_response(
strategy_name="abb_rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="abb_rewrite",
prompt_param_dict={
# "context": context_summary, # 使用摘要或原文
"paragraph_content": query,
# "con_direction": con_direction,
# "previous_text": previous_text,
# "following_text": following_text
},
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS
)
# logger.info("生成的缩写内容: %s", abb_rewrite_content)
except Exception as e:
logger.error("生成缩写内容时出错: %s", e)
return (f"出错了。。请重试。。")
# # 如果 previous_text 和 following_text 存在空值,直接返回 abb_rewrite_content
# if not previous_text or not following_text:
# logger.info("上文或下文为空,直接返回生成的内容。。")
# final_content = abb_rewrite_content
# else:
# # 定义内容检查函数
# def abb_rewrite_check(abb_rewrite_content: str) -> str:
# logger.info("对文章缩写内容进行行文检查。。")
# try:
# abb_rewrite_check_content = get_llm_model_response(
# strategy_name="abb_rewrite_check",
# llm_model_name=LLM_MODELS[0],
# template_prompt_name="abb_rewrite_check",
# prompt_param_dict={
# "previous_text": previous_text,
# "following_text": following_text,
# "abb_rewrite_content": abb_rewrite_content, # 使用生成的 rewrite_content
# },
# temperature=temperature,
# max_tokens=max_tokens
# )
# logger.info("检查后的缩写内容: %s", abb_rewrite_check_content)
# return abb_rewrite_check_content
# except Exception as e:
# logger.error("检查缩写内容时出错: %s", e)
# return (f"出错了。。请重试。。")
# # 调用内容检查函数
# final_content = abb_rewrite_check(abb_rewrite_content)
# 返回最终生成的字符串
return abb_rewrite_content