import asyncio import json from typing import AsyncIterable, List, Optional from urllib.parse import urlencode from fastapi import Body, Request from fastapi.concurrency import run_in_threadpool from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.chains import LLMChain from langchain.prompts import PromptTemplate from langchain.prompts.chat import ChatPromptTemplate from sse_starlette.sse import EventSourceResponse from configs import (TEMPERATURE, USE_RERANKER, RERANKER_MODEL, RERANKER_MAX_LENGTH, MODEL_PATH, MAX_TOKENS, MAX_CUT_TOKENS, LLM_MODELS) from server.chat.utils import History from server.knowledge_base.kb_service.base import KBServiceFactory from server.reranker.reranker import LangchainReranker from server.utils import BaseResponse, get_prompt_template from server.utils import embedding_device from server.utils import wrap_done, get_ChatOpenAI from collections import defaultdict from server.custom.custom_fun import chi_translation,eng_translation async def paper_translation(query: str = Body("为我总结这些内容", description="用户输入", examples=["你好"]), knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body( MAX_TOKENS, description="限制LLM生成Token数量,默认None代表模型最大值" ), prompt_name: str = Body( "eng_chi", description="使用的prompt模板名称(在configs/prompt_config.py中配置)" ), source_name_list: List[str] = Body([], description="资源列表"), request: Request = None, ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") async def translation_iterator( query: str, model_name: str = model_name, prompt_name: str = prompt_name, ) -> AsyncIterable[str]: nonlocal max_tokens callback = AsyncIteratorCallbackHandler() if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None docs = [] docs = await run_in_threadpool(kb.get_doc_by_sources_name, source_name_list=source_name_list) if len(docs) == 0: # 如果没有找到相关文档,使用empty模板 prompt_template = get_prompt_template("knowledge_base_chat", "empty") else: prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) # input_msg = History(role="user", content=prompt_template).to_msg_template(False) # chat_prompt = ChatPromptTemplate.from_messages([input_msg]) # chain = LLMChain(prompt=chat_prompt, llm=model) if prompt_name == "chi_eng": async for chunk in chi_translation(docs, prompt_template, model_name, temperature, max_tokens): yield chunk else: async for chunk in eng_translation(docs, prompt_template, model_name, temperature, max_tokens): yield chunk return EventSourceResponse(translation_iterator(query, model_name=model_name, prompt_name=prompt_name))