import asyncio import json from typing import AsyncIterable, List, Optional from urllib.parse import urlencode from fastapi import Body, Request from fastapi.concurrency import run_in_threadpool from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.chains import LLMChain from langchain.prompts import PromptTemplate from langchain.prompts.chat import ChatPromptTemplate from sse_starlette.sse import EventSourceResponse from configs import (TEMPERATURE, USE_RERANKER, RERANKER_MODEL, RERANKER_MAX_LENGTH, MODEL_PATH, MAX_TOKENS, MAX_CUT_TOKENS, LLM_MODELS) from server.chat.utils import History from server.knowledge_base.kb_service.base import KBServiceFactory from server.reranker.reranker import LangchainReranker from server.utils import BaseResponse, get_prompt_template from server.utils import embedding_device from server.utils import wrap_done, get_ChatOpenAI from collections import defaultdict from server.custom.custom_fun import chapter_overview_summary async def task(param): contextk = param["contextk"] i = param["i"] model_name = param["model_name"] temperature = param["temperature"] max_tokens = param["max_tokens"] chat_prompt = param["chat_prompt"] print(f"i:{i},len_context:{len(contextk)}\n") callback_temp = AsyncIteratorCallbackHandler() model_temp = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, callbacks=[callback_temp], ) chain_temp = LLMChain(prompt=chat_prompt, llm=model_temp) task_temp = wrap_done(chain_temp.acall({"context": contextk, "question": "对该部分内容进行总结"}), callback_temp.done) await task_temp async for token in callback_temp.aiter(): yield token # 使用多线程执行任务 async def run_tasks_concurrently(params): result = [] async for data in asyncio.as_completed([task(param) async for param in params]): result.append(''.join([token async for token in data])) return result async def chapter_overview(query: str = Body("为我总结这些内容", description="用户输入", examples=["你好"]), knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body( MAX_TOKENS, description="限制LLM生成Token数量,默认None代表模型最大值" ), prompt_name: str = Body( "Chapter Overview", description="使用的prompt模板名称(在configs/prompt_config.py中配置)" ), source_name_list: List[str] = Body([], description="资源列表"), request: Request = None, ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") async def chapter_overview_iterator( model_name: str = model_name, ) -> AsyncIterable[str]: nonlocal max_tokens if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None docs = await run_in_threadpool(kb.get_doc_by_sources_name,source_name_list=source_name_list) chapter_summaries, global_summary = await chapter_overview_summary(docs, model_name, temperature, max_tokens) if stream: for h1, summaries in chapter_summaries.items(): for summary in summaries: yield json.dumps({"chapter_title": h1, "summary": summary}, ensure_ascii=False) yield json.dumps({"global_summary": global_summary}, ensure_ascii=False) # yield json.dumps({"docs": source_documents}, ensure_ascii=False) else: result = { "chapter_summaries": chapter_summaries, "global_summary": global_summary, } yield json.dumps(result, ensure_ascii=False) return EventSourceResponse(chapter_overview_iterator(model_name=model_name))