Files
gangyan/langchain-chat/server/custom/chapter_overview.py

113 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
from typing import AsyncIterable, List, Optional
from urllib.parse import urlencode
from fastapi import Body, Request
from fastapi.concurrency import run_in_threadpool
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from sse_starlette.sse import EventSourceResponse
from configs import (TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
MODEL_PATH,
MAX_TOKENS,
MAX_CUT_TOKENS, LLM_MODELS)
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.reranker.reranker import LangchainReranker
from server.utils import BaseResponse, get_prompt_template
from server.utils import embedding_device
from server.utils import wrap_done, get_ChatOpenAI
from collections import defaultdict
from server.custom.custom_fun import chapter_overview_summary
async def task(param):
contextk = param["contextk"]
i = param["i"]
model_name = param["model_name"]
temperature = param["temperature"]
max_tokens = param["max_tokens"]
chat_prompt = param["chat_prompt"]
print(f"i:{i}len_context:{len(contextk)}\n")
callback_temp = AsyncIteratorCallbackHandler()
model_temp = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback_temp],
)
chain_temp = LLMChain(prompt=chat_prompt, llm=model_temp)
task_temp = wrap_done(chain_temp.acall({"context": contextk,
"question": "对该部分内容进行总结"}),
callback_temp.done)
await task_temp
async for token in callback_temp.aiter():
yield token
# 使用多线程执行任务
async def run_tasks_concurrently(params):
result = []
async for data in asyncio.as_completed([task(param) async for param in params]):
result.append(''.join([token async for token in data]))
return result
async def chapter_overview(query: str = Body("为我总结这些内容", description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(
MAX_TOKENS,
description="限制LLM生成Token数量默认None代表模型最大值"
),
prompt_name: str = Body(
"Chapter Overview",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
),
source_name_list: List[str] = Body([], description="资源列表"),
request: Request = None,
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
async def chapter_overview_iterator(
model_name: str = model_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
docs = await run_in_threadpool(kb.get_doc_by_sources_name,source_name_list=source_name_list)
chapter_summaries, global_summary = await chapter_overview_summary(docs, model_name, temperature, max_tokens)
if stream:
for h1, summaries in chapter_summaries.items():
for summary in summaries:
yield json.dumps({"chapter_title": h1, "summary": summary}, ensure_ascii=False)
yield json.dumps({"global_summary": global_summary}, ensure_ascii=False)
# yield json.dumps({"docs": source_documents}, ensure_ascii=False)
else:
result = {
"chapter_summaries": chapter_summaries,
"global_summary": global_summary,
}
yield json.dumps(result, ensure_ascii=False)
return EventSourceResponse(chapter_overview_iterator(model_name=model_name))