from fastapi import Body from configs import LLM_MODELS, TEMPERATURE, MAX_TOKENS from server.chat.policy_fun_iast import get_llm_model_response from typing import Optional from langchain.chains import LLMChain from langchain.prompts import ChatPromptTemplate from server.chat.utils import History from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template from langchain.callbacks import AsyncIteratorCallbackHandler import asyncio from server.knowledge_base.kb_service.base import TextRank from configs.basic_config import * async def abb_rewrite( # context: Optional[str] = Body(..., description="当前已撰写的全文", examples=[""]), query: str = Body(..., description="用户框选的段落", examples=[""]), # previous_text: Optional[str] = Body(..., description="用户框选段落的前文", examples=[""]), # following_text: Optional[str] = Body(..., description="用户框选段落的后文", examples=[""]), # con_direction: Optional[str] = Body("", description="用户输入的缩写指令", examples=[""]), # stream: bool = Body(False, description="是否流式输出", examples=[False]), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0), max_tokens: Optional[int] = Body(512, description="限制LLM生成Token数量,默认None代表模型最大值"), # prompt_name: Optional[str] = Body("abb_rewrite"), ): logger.info(f"开始缩写...") # 定义生成摘要的函数 # def generate_summary(text: str) -> str: # """使用 TextRank 生成文本摘要""" # if len(text) <= 20000: # summary = TextRank(text, num_sentences=60) # 生成60句话的摘要 # else: # summary = TextRank(text, num_sentences=80) # 生成80句话的摘要 # return summary # # 根据上下文长度决定是否生成摘要 # if len(context) >= 15000: # context_summary = generate_summary(context) # logger.info(f"生成撰写文稿的摘要: %s", context_summary) # else: # context_summary = context # 直接使用原文 # logger.info(f"撰写文稿小于15000字符,使用原文") # 调用模型生成缩写内容 try: abb_rewrite_content = get_llm_model_response( strategy_name="abb_rewrite", llm_model_name=LLM_MODELS[0], template_prompt_name="abb_rewrite", prompt_param_dict={ # "context": context_summary, # 使用摘要或原文 "paragraph_content": query, # "con_direction": con_direction, # "previous_text": previous_text, # "following_text": following_text }, temperature=TEMPERATURE, max_tokens=MAX_TOKENS ) # logger.info("生成的缩写内容: %s", abb_rewrite_content) except Exception as e: logger.error("生成缩写内容时出错: %s", e) return (f"出错了。。请重试。。") # # 如果 previous_text 和 following_text 存在空值,直接返回 abb_rewrite_content # if not previous_text or not following_text: # logger.info("上文或下文为空,直接返回生成的内容。。") # final_content = abb_rewrite_content # else: # # 定义内容检查函数 # def abb_rewrite_check(abb_rewrite_content: str) -> str: # logger.info("对文章缩写内容进行行文检查。。") # try: # abb_rewrite_check_content = get_llm_model_response( # strategy_name="abb_rewrite_check", # llm_model_name=LLM_MODELS[0], # template_prompt_name="abb_rewrite_check", # prompt_param_dict={ # "previous_text": previous_text, # "following_text": following_text, # "abb_rewrite_content": abb_rewrite_content, # 使用生成的 rewrite_content # }, # temperature=temperature, # max_tokens=max_tokens # ) # logger.info("检查后的缩写内容: %s", abb_rewrite_check_content) # return abb_rewrite_check_content # except Exception as e: # logger.error("检查缩写内容时出错: %s", e) # return (f"出错了。。请重试。。") # # 调用内容检查函数 # final_content = abb_rewrite_check(abb_rewrite_content) # 返回最终生成的字符串 return abb_rewrite_content