209 lines
9.9 KiB
Python
209 lines
9.9 KiB
Python
|
|
import uuid
|
|||
|
|
from fastapi import Body
|
|||
|
|
from langchain.memory import (
|
|||
|
|
CombinedMemory,
|
|||
|
|
ConversationBufferMemory,
|
|||
|
|
ConversationSummaryMemory,
|
|||
|
|
ConversationBufferWindowMemory
|
|||
|
|
)
|
|||
|
|
from configs.model_config import MAX_TOKENS
|
|||
|
|
from configs.outline_config import outlines
|
|||
|
|
from langchain.chains.question_answering import load_qa_chain
|
|||
|
|
from typing import Any
|
|||
|
|
from sse_starlette.sse import EventSourceResponse
|
|||
|
|
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN
|
|||
|
|
from configs.kb_config import CH_BASE_NAME
|
|||
|
|
from server.agent.tools.search_tool import search_tool
|
|||
|
|
from server.chat import utils
|
|||
|
|
from server.chat.agent_chat_test import agent_chat_test
|
|||
|
|
from server.chat.policy_fun_iast import get_llm_model_response
|
|||
|
|
from server.chat.solve_problem import solve_problem
|
|||
|
|
from server.knowledge_base.kb_service.base import TextRank
|
|||
|
|
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
|||
|
|
from server.utils import wrap_done, get_ChatOpenAI
|
|||
|
|
from langchain.chains import LLMChain, ConversationChain
|
|||
|
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
|||
|
|
from typing import AsyncIterable
|
|||
|
|
import asyncio
|
|||
|
|
import json
|
|||
|
|
from langchain.prompts.chat import ChatPromptTemplate
|
|||
|
|
from typing import List, Optional, Union
|
|||
|
|
from server.chat.utils import History, compute_lps, remove_after_and_including,remove_before_and_including
|
|||
|
|
from langchain.prompts import PromptTemplate
|
|||
|
|
from server.utils import get_prompt_template, get_format_template
|
|||
|
|
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
|||
|
|
from server.db.repository import add_message_to_db
|
|||
|
|
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
|
|||
|
|
from datetime import datetime
|
|||
|
|
from langchain_core.messages import SystemMessage
|
|||
|
|
import time as t
|
|||
|
|
from configs.basic_config import *
|
|||
|
|
async def process_task(task):
|
|||
|
|
results = []
|
|||
|
|
async for result in task:
|
|||
|
|
results.append(result)
|
|||
|
|
return results
|
|||
|
|
async def knowledge_chat_test(
|
|||
|
|
uid: Optional[Any] = Body(None, description="用户ID"),
|
|||
|
|
rag: bool = Body(False, description="增强标识, 当knowledge_base_list有值时,需要传True"),
|
|||
|
|
knowledge_base_list: Optional[List[str]] = Body(None, description="知识库名称列表"),
|
|||
|
|
query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
|||
|
|
conversation_id: str = Body("", description="对话框ID"),
|
|||
|
|
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
|||
|
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
|
|||
|
|
max_tokens: Optional[int] = Body(
|
|||
|
|
MAX_TOKENS,
|
|||
|
|
description="限制LLM生成Token数量,默认None代表模型最大值"
|
|||
|
|
),
|
|||
|
|
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature 同时设置", gt=0.0, lt=1.0),
|
|||
|
|
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
|||
|
|
):
|
|||
|
|
async def chat_iterator() -> AsyncIterable[str]:
|
|||
|
|
nonlocal rag,max_tokens, uid, model_name,knowledge_base_list
|
|||
|
|
# rag = True
|
|||
|
|
max_tokens = MAX_TOKENS
|
|||
|
|
history_summary = ""
|
|||
|
|
memory = None
|
|||
|
|
message_id = str(uuid.uuid1())+"q"
|
|||
|
|
think_type = {"text": "", "message_id": message_id}
|
|||
|
|
user_prompt_name = prompt_name
|
|||
|
|
outline_num = ""
|
|||
|
|
res = ""
|
|||
|
|
callback = AsyncIteratorCallbackHandler()
|
|||
|
|
callbacks = [callback]
|
|||
|
|
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
|
|||
|
|
chat_type="llm_chat",
|
|||
|
|
query=query)
|
|||
|
|
callbacks.append(conversation_callback)
|
|||
|
|
time_based_uuid = str(uuid.uuid1())+"q"
|
|||
|
|
tip ={}
|
|||
|
|
tip["END"]=""
|
|||
|
|
tip["source_docs"]=[]
|
|||
|
|
tip["num"] = 0
|
|||
|
|
tip["title"]=[]
|
|||
|
|
utils.set_shared_variable(time_based_uuid,tip)
|
|||
|
|
# 先判断需不需要使用知识库或者联网检索同时需要获取大纲模板
|
|||
|
|
task1 = solve_problem(user_prompt_name= user_prompt_name,query=query, conversation_id="", history=[], model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="think_route",stream=True)
|
|||
|
|
task2 = solve_problem(user_prompt_name= user_prompt_name,query=query, conversation_id="", history=[], model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="outlines_route",stream=True)
|
|||
|
|
results = await asyncio.gather(process_task(task1), process_task(task2))
|
|||
|
|
for result0 in results[0]:
|
|||
|
|
think_type["text"] += json.loads(result0)["text"].strip()
|
|||
|
|
break
|
|||
|
|
for result1 in results[1]:
|
|||
|
|
outline_num += json.loads(result1)["text"]
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if(str(think_type["text"]) == "2") or rag == True:
|
|||
|
|
yield json.dumps({"text": "思考中..."}, ensure_ascii=False)
|
|||
|
|
model = get_ChatOpenAI(
|
|||
|
|
model_name=model_name,
|
|||
|
|
temperature=temperature,
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
callbacks=callbacks,
|
|||
|
|
)
|
|||
|
|
search_query = get_llm_model_response(
|
|||
|
|
strategy_name="query rewrite",
|
|||
|
|
llm_model_name=LLM_MODELS[0],
|
|||
|
|
template_prompt_name="query_rewrite",
|
|||
|
|
prompt_param_dict={"query": query, "history": [], "time": datetime.now().strftime("%Y%m%d")},
|
|||
|
|
temperature=0.01,
|
|||
|
|
max_tokens=512
|
|||
|
|
)
|
|||
|
|
# 第一个 JSON 对象
|
|||
|
|
keys = json.loads(search_query).keys()
|
|||
|
|
keys_list = list(keys)
|
|||
|
|
if rag:
|
|||
|
|
knowledge_base_list = list(set(CH_BASE_NAME).union(knowledge_base_list))
|
|||
|
|
logger.info(f"大纲撰写参考的个人知识库名称:{knowledge_base_list}")
|
|||
|
|
first_json = {
|
|||
|
|
"query": json.loads(search_query)[keys_list[0]],
|
|||
|
|
"knowledge_name": knowledge_base_list,
|
|||
|
|
"keywords": []
|
|||
|
|
}
|
|||
|
|
else:
|
|||
|
|
first_json = {
|
|||
|
|
"query": json.loads(search_query)[keys_list[0]],
|
|||
|
|
"knowledge_name": CH_BASE_NAME,
|
|||
|
|
"keywords": []
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 第二个 JSON 对象
|
|||
|
|
second_json = {
|
|||
|
|
"uuid": time_based_uuid
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
origin_query=json.dumps(first_json) + json.dumps(second_json)
|
|||
|
|
# 截断检索资料,避免超token
|
|||
|
|
res += search_tool(origin_query)
|
|||
|
|
if len(res) >= 7000:
|
|||
|
|
res = res[:7000]
|
|||
|
|
else:
|
|||
|
|
model = get_ChatOpenAI(
|
|||
|
|
model_name=LLM_MODELS[0],
|
|||
|
|
temperature=temperature,
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
callbacks=callbacks,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
prompt_template = get_prompt_template("knowledge_base_chat", "default_outlines")
|
|||
|
|
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
|
|||
|
|
# input_msg = History(role="user", content=query).to_msg_template(False)
|
|||
|
|
chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
|
|||
|
|
chain = LLMChain(prompt=chat_prompt, llm=model)
|
|||
|
|
|
|||
|
|
|
|||
|
|
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
|
|||
|
|
chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
|
|||
|
|
|
|||
|
|
format_template = get_format_template("knowledge_base_chat", "default_outlines")
|
|||
|
|
|
|||
|
|
chain = load_qa_chain(
|
|||
|
|
model, chain_type="stuff", memory=memory, prompt=chat_prompt, verbose=True
|
|||
|
|
)
|
|||
|
|
# docs = list(itertools.chain(policydocs, reportdocs, journaldocs, personaldocs))
|
|||
|
|
if not outline_num or outline_num == "" or "无" in outline_num:
|
|||
|
|
outline_num = len(outlines)-1
|
|||
|
|
if res =="":
|
|||
|
|
res+="根据你的自身能力回答"
|
|||
|
|
results1=[]
|
|||
|
|
results1.append(DocumentWithVSId(page_content=res))
|
|||
|
|
logger.info(f"outline_num:{outline_num}")
|
|||
|
|
task = asyncio.create_task(wrap_done(
|
|||
|
|
chain.acall({
|
|||
|
|
# "context": res,
|
|||
|
|
"input_documents": results1,
|
|||
|
|
"self_knowledge":outlines[int(outline_num)],
|
|||
|
|
"history": [],
|
|||
|
|
"question": query,
|
|||
|
|
"format_template": format_template,
|
|||
|
|
"time": datetime.now().strftime("%Y%m%d")
|
|||
|
|
}),
|
|||
|
|
callback.done),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# if stream:
|
|||
|
|
async for token in callback.aiter():
|
|||
|
|
# Use server-sent-events to stream the response
|
|||
|
|
history_summary += token
|
|||
|
|
yield json.dumps(
|
|||
|
|
{"text": token, "message_id": message_id},
|
|||
|
|
ensure_ascii=False)
|
|||
|
|
# try:
|
|||
|
|
# source = utils.get_shared_variable(time_based_uuid)
|
|||
|
|
# source_docs = source["source_docs"]
|
|||
|
|
# docs_string = "\n"+"\n".join(f"{str(doc)}\n" for doc in source_docs)
|
|||
|
|
# yield json.dumps({"docs": docs_string}, ensure_ascii=False)
|
|||
|
|
# except Exception as e:
|
|||
|
|
# logging.error(f"获取知识库联想信息失败:{e}")
|
|||
|
|
# print("知识库联想存在异常或没有进行知识库联想")
|
|||
|
|
# pass
|
|||
|
|
#summary = TextRank(history_summary, 80)
|
|||
|
|
#yield json.dumps({"summary":summary}, ensure_ascii=False)
|
|||
|
|
# del question_history
|
|||
|
|
await task
|
|||
|
|
|
|||
|
|
|
|||
|
|
callback = AsyncIteratorCallbackHandler()
|
|||
|
|
|
|||
|
|
return EventSourceResponse(chat_iterator())
|