Files
gangyan/langchain-chat/server/chat/knowledge_chat_test.py

209 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import uuid
from fastapi import Body
from langchain.memory import (
CombinedMemory,
ConversationBufferMemory,
ConversationSummaryMemory,
ConversationBufferWindowMemory
)
from configs.model_config import MAX_TOKENS
from configs.outline_config import outlines
from langchain.chains.question_answering import load_qa_chain
from typing import Any
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN
from configs.kb_config import CH_BASE_NAME
from server.agent.tools.search_tool import search_tool
from server.chat import utils
from server.chat.agent_chat_test import agent_chat_test
from server.chat.policy_fun_iast import get_llm_model_response
from server.chat.solve_problem import solve_problem
from server.knowledge_base.kb_service.base import TextRank
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from server.utils import wrap_done, get_ChatOpenAI
from langchain.chains import LLMChain, ConversationChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
import json
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional, Union
from server.chat.utils import History, compute_lps, remove_after_and_including,remove_before_and_including
from langchain.prompts import PromptTemplate
from server.utils import get_prompt_template, get_format_template
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
from datetime import datetime
from langchain_core.messages import SystemMessage
import time as t
from configs.basic_config import *
async def process_task(task):
results = []
async for result in task:
results.append(result)
return results
async def knowledge_chat_test(
uid: Optional[Any] = Body(None, description="用户ID"),
rag: bool = Body(False, description="增强标识, 当knowledge_base_list有值时需要传True"),
knowledge_base_list: Optional[List[str]] = Body(None, description="知识库名称列表"),
query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
conversation_id: str = Body("", description="对话框ID"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
max_tokens: Optional[int] = Body(
MAX_TOKENS,
description="限制LLM生成Token数量默认None代表模型最大值"
),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature 同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
async def chat_iterator() -> AsyncIterable[str]:
nonlocal rag,max_tokens, uid, model_name,knowledge_base_list
# rag = True
max_tokens = MAX_TOKENS
history_summary = ""
memory = None
message_id = str(uuid.uuid1())+"q"
think_type = {"text": "", "message_id": message_id}
user_prompt_name = prompt_name
outline_num = ""
res = ""
callback = AsyncIteratorCallbackHandler()
callbacks = [callback]
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
chat_type="llm_chat",
query=query)
callbacks.append(conversation_callback)
time_based_uuid = str(uuid.uuid1())+"q"
tip ={}
tip["END"]=""
tip["source_docs"]=[]
tip["num"] = 0
tip["title"]=[]
utils.set_shared_variable(time_based_uuid,tip)
# 先判断需不需要使用知识库或者联网检索同时需要获取大纲模板
task1 = solve_problem(user_prompt_name= user_prompt_name,query=query, conversation_id="", history=[], model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="think_route",stream=True)
task2 = solve_problem(user_prompt_name= user_prompt_name,query=query, conversation_id="", history=[], model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="outlines_route",stream=True)
results = await asyncio.gather(process_task(task1), process_task(task2))
for result0 in results[0]:
think_type["text"] += json.loads(result0)["text"].strip()
break
for result1 in results[1]:
outline_num += json.loads(result1)["text"]
break
if(str(think_type["text"]) == "2") or rag == True:
yield json.dumps({"text": "思考中..."}, ensure_ascii=False)
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=callbacks,
)
search_query = get_llm_model_response(
strategy_name="query rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="query_rewrite",
prompt_param_dict={"query": query, "history": [], "time": datetime.now().strftime("%Y%m%d")},
temperature=0.01,
max_tokens=512
)
# 第一个 JSON 对象
keys = json.loads(search_query).keys()
keys_list = list(keys)
if rag:
knowledge_base_list = list(set(CH_BASE_NAME).union(knowledge_base_list))
logger.info(f"大纲撰写参考的个人知识库名称:{knowledge_base_list}")
first_json = {
"query": json.loads(search_query)[keys_list[0]],
"knowledge_name": knowledge_base_list,
"keywords": []
}
else:
first_json = {
"query": json.loads(search_query)[keys_list[0]],
"knowledge_name": CH_BASE_NAME,
"keywords": []
}
# 第二个 JSON 对象
second_json = {
"uuid": time_based_uuid
}
origin_query=json.dumps(first_json) + json.dumps(second_json)
# 截断检索资料避免超token
res += search_tool(origin_query)
if len(res) >= 7000:
res = res[:7000]
else:
model = get_ChatOpenAI(
model_name=LLM_MODELS[0],
temperature=temperature,
max_tokens=max_tokens,
callbacks=callbacks,
)
prompt_template = get_prompt_template("knowledge_base_chat", "default_outlines")
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
# input_msg = History(role="user", content=query).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
chain = LLMChain(prompt=chat_prompt, llm=model)
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
format_template = get_format_template("knowledge_base_chat", "default_outlines")
chain = load_qa_chain(
model, chain_type="stuff", memory=memory, prompt=chat_prompt, verbose=True
)
# docs = list(itertools.chain(policydocs, reportdocs, journaldocs, personaldocs))
if not outline_num or outline_num == "" or "" in outline_num:
outline_num = len(outlines)-1
if res =="":
res+="根据你的自身能力回答"
results1=[]
results1.append(DocumentWithVSId(page_content=res))
logger.info(f"outline_num:{outline_num}")
task = asyncio.create_task(wrap_done(
chain.acall({
# "context": res,
"input_documents": results1,
"self_knowledge":outlines[int(outline_num)],
"history": [],
"question": query,
"format_template": format_template,
"time": datetime.now().strftime("%Y%m%d")
}),
callback.done),
)
# if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
history_summary += token
yield json.dumps(
{"text": token, "message_id": message_id},
ensure_ascii=False)
# try:
# source = utils.get_shared_variable(time_based_uuid)
# source_docs = source["source_docs"]
# docs_string = "\n"+"\n".join(f"{str(doc)}\n" for doc in source_docs)
# yield json.dumps({"docs": docs_string}, ensure_ascii=False)
# except Exception as e:
# logging.error(f"获取知识库联想信息失败:{e}")
# print("知识库联想存在异常或没有进行知识库联想")
# pass
#summary = TextRank(history_summary, 80)
#yield json.dumps({"summary":summary}, ensure_ascii=False)
# del question_history
await task
callback = AsyncIteratorCallbackHandler()
return EventSourceResponse(chat_iterator())