Files
gangyan/langchain-chat/server/chat/knowledge_base_chat_old.py

138 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import Body, Request
from langchain.chains.question_answering import load_qa_chain
from langchain.memory import ConversationBufferMemory
from langchain_core.prompts import PromptTemplate
from sse_starlette.sse import EventSourceResponse
from fastapi.concurrency import run_in_threadpool
from configs import (LLM_MODELS,
VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
MODEL_PATH,
MAX_TOKENS,
MAX_CUT_TOKENS, HISTORY_LEN)
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template, get_format_template
from server.utils import get_strategy_prompt_template
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, List, Optional
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory
import json
from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
from server.reranker.reranker import LangchainReranker
from server.utils import embedding_device
from server.chat.policy_fun import add_summary_retrieved_results
from server.chat.policy_fun_iast import get_llm_model_response
import json
from configs.basic_config import *
async def knowledge_base_chat_old(query: str = Body(..., description="用户输入", examples=["你好"]),
fileName: List = Body([], description="文件名称", examples=[["123.txt"]]),
knowledge_base_name: str = Body(..., description="知识库名称",
examples=["t_policy_total_bge_v1"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(
SCORE_THRESHOLD,
description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右",
ge=0,
le=2
),
history: List[History] = Body(
[],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(
MAX_TOKENS,
description="限制LLM生成Token数量默认None代表模型最大值"
),
prompt_name: str = Body(
"default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
),
request: Request = None,
use_summary=False,
chunk_size: int = 20000,
min_chunk_size: int = 2000,
summary_model_name=LLM_MODELS[0],
query_rewrite_model_name=LLM_MODELS[0]
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
logger.info(f'当前知识库:{knowledge_base_name}')
history = [History.from_data(h) for h in history]
async def knowledge_base_chat_old_iterator(
query: str,
top_k: int,
history: Optional[List[History]],
model_name: str = model_name,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
memory = None
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
docs = await run_in_threadpool(search_docs,
fileName=fileName,
query=query,
knowledge_base_name=knowledge_base_name,
top_k=top_k,
score_threshold=score_threshold)
context = "\n".join([doc.page_content for doc in docs])
context = "\n".join([doc.page_content for doc in docs])
prompt_template = get_prompt_template("knowledge_base_chat", "Question Assistant")
input_msg = History(role="system", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model, verbose=True)
task = asyncio.create_task(wrap_done(
chain.acall({"context": context,
"question": query,
}),
callback.done),
)
source_documents = []
if stream:
answer = ""
async for token in callback.aiter():
answer += token
# Use server-sent-events to stream the response
yield json.dumps({"answer": token}, ensure_ascii=False)
logger.info(f"推荐问题:\n{answer}")
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": answer})
logger.info(f"推荐问题:\n{answer}")
await task
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
return EventSourceResponse(knowledge_base_chat_old_iterator(query, top_k, history, model_name, prompt_name))