Files
gangyan/langchain-chat/server/custom/paper_translation.py

78 lines
3.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
from typing import AsyncIterable, List, Optional
from urllib.parse import urlencode
from fastapi import Body, Request
from fastapi.concurrency import run_in_threadpool
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from sse_starlette.sse import EventSourceResponse
from configs import (TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
MODEL_PATH,
MAX_TOKENS,
MAX_CUT_TOKENS, LLM_MODELS)
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.reranker.reranker import LangchainReranker
from server.utils import BaseResponse, get_prompt_template
from server.utils import embedding_device
from server.utils import wrap_done, get_ChatOpenAI
from collections import defaultdict
from server.custom.custom_fun import chi_translation,eng_translation
async def paper_translation(query: str = Body("为我总结这些内容", description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(
MAX_TOKENS,
description="限制LLM生成Token数量默认None代表模型最大值"
),
prompt_name: str = Body(
"eng_chi",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
),
source_name_list: List[str] = Body([], description="资源列表"),
request: Request = None,
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
async def translation_iterator(
query: str,
model_name: str = model_name,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
docs = []
docs = await run_in_threadpool(kb.get_doc_by_sources_name,
source_name_list=source_name_list)
if len(docs) == 0: # 如果没有找到相关文档使用empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
# input_msg = History(role="user", content=prompt_template).to_msg_template(False)
# chat_prompt = ChatPromptTemplate.from_messages([input_msg])
# chain = LLMChain(prompt=chat_prompt, llm=model)
if prompt_name == "chi_eng":
async for chunk in chi_translation(docs, prompt_template, model_name, temperature, max_tokens):
yield chunk
else:
async for chunk in eng_translation(docs, prompt_template, model_name, temperature, max_tokens):
yield chunk
return EventSourceResponse(translation_iterator(query, model_name=model_name, prompt_name=prompt_name))