Files
gangyan/langchain-chat/server/chat/exp_rewrite.py

96 lines
4.4 KiB
Python
Raw Normal View History

from fastapi import Body
from configs import LLM_MODELS, TEMPERATURE, MAX_TOKENS
from server.chat.policy_fun_iast import get_llm_model_response
from typing import Optional
from langchain.chains import LLMChain
from langchain.prompts import ChatPromptTemplate
from server.chat.utils import History
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from langchain.callbacks import AsyncIteratorCallbackHandler
import asyncio
from server.knowledge_base.kb_service.base import TextRank
from configs.basic_config import *
async def exp_rewrite(
context: str = Body(..., description="当前已撰写的全文", examples=[""]),
paragraph_content: str = Body(..., description="用户框选的段落", examples=[""]),
previous_text: str = Body(..., description="用户框选段落的前文", examples=[""]),
following_text: str = Body(..., description="用户框选段落的后文", examples=[""]),
con_direction: Optional[str] = Body("", description="用户输入的扩写指令", examples=[""]),
stream: bool = Body(False, description="是否流式输出", examples=[False]),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
max_tokens: Optional[int] = Body(MAX_TOKENS, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("exp_rewrite"),
):
logger.info(f"开始扩写...")
# 定义生成摘要的函数
# 定义生成摘要的函数
def generate_summary(text: str) -> str:
"""使用 TextRank 生成文本摘要"""
summary = TextRank(text, num_sentences=80) # 生成80句话的摘要
return summary
# 根据上下文长度决定是否生成摘要
if len(context) >= 30000:
context_summary = generate_summary(context)
logger.info(f"生成撰写文稿的摘要: %s", context_summary)
else:
context_summary = context # 直接使用原文
logger.info(f"撰写文稿小于30000字符使用原文")
# 调用模型生成扩写内容
try:
exp_rewrite_content = get_llm_model_response(
strategy_name="exp_rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="exp_rewrite",
prompt_param_dict={
"context": context_summary, # 使用摘要或原文
"paragraph_content": paragraph_content,
"con_direction": con_direction,
"previous_text": previous_text,
"following_text": following_text
},
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS
)
# logger.info("生成的扩写内容: %s", exp_rewrite_content)
except Exception as e:
logger.error("生成扩写内容时出错: %s", e)
return (f"出错了。。请重试。。")
# 如果 previous_text 和 following_text 存在空值,直接返回 exp_rewrite_content
if not previous_text or not following_text:
logger.info("上文或下文为空,直接返回生成的内容。。")
final_content = exp_rewrite_content
else:
# 定义内容检查函数
def exp_rewrite_check(exp_rewrite_content: str) -> str:
logger.info("对文章扩写内容进行行文检查。。")
logger.info("检查前的扩写内容: %s", exp_rewrite_content)
try:
exp_rewrite_check_content = get_llm_model_response(
strategy_name="exp_rewrite_check",
llm_model_name=LLM_MODELS[0],
template_prompt_name="exp_rewrite_check",
prompt_param_dict={
"previous_text": previous_text,
"following_text": following_text,
"exp_rewrite_content": exp_rewrite_content, # 使用生成的 rewrite_content
},
temperature=temperature,
max_tokens=max_tokens
)
logger.info("检查后的扩写内容: %s", exp_rewrite_check_content)
return exp_rewrite_check_content
except Exception as e:
logger.error("检查扩写内容时出错: %s", e)
return (f"出错了。。请重试。。")
# 调用内容检查函数
final_content = exp_rewrite_check(exp_rewrite_content)
# 返回最终生成的字符串
return final_content