209 lines
9.9 KiB
Python
209 lines
9.9 KiB
Python
import uuid
|
||
from fastapi import Body
|
||
from langchain.memory import (
|
||
CombinedMemory,
|
||
ConversationBufferMemory,
|
||
ConversationSummaryMemory,
|
||
ConversationBufferWindowMemory
|
||
)
|
||
from configs.model_config import MAX_TOKENS
|
||
from configs.outline_config import outlines
|
||
from langchain.chains.question_answering import load_qa_chain
|
||
from typing import Any
|
||
from sse_starlette.sse import EventSourceResponse
|
||
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN
|
||
from configs.kb_config import CH_BASE_NAME
|
||
from server.agent.tools.search_tool import search_tool
|
||
from server.chat import utils
|
||
from server.chat.agent_chat_test import agent_chat_test
|
||
from server.chat.policy_fun_iast import get_llm_model_response
|
||
from server.chat.solve_problem import solve_problem
|
||
from server.knowledge_base.kb_service.base import TextRank
|
||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||
from server.utils import wrap_done, get_ChatOpenAI
|
||
from langchain.chains import LLMChain, ConversationChain
|
||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||
from typing import AsyncIterable
|
||
import asyncio
|
||
import json
|
||
from langchain.prompts.chat import ChatPromptTemplate
|
||
from typing import List, Optional, Union
|
||
from server.chat.utils import History, compute_lps, remove_after_and_including,remove_before_and_including
|
||
from langchain.prompts import PromptTemplate
|
||
from server.utils import get_prompt_template, get_format_template
|
||
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||
from server.db.repository import add_message_to_db
|
||
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
|
||
from datetime import datetime
|
||
from langchain_core.messages import SystemMessage
|
||
import time as t
|
||
from configs.basic_config import *
|
||
async def process_task(task):
|
||
results = []
|
||
async for result in task:
|
||
results.append(result)
|
||
return results
|
||
async def knowledge_chat_test(
|
||
uid: Optional[Any] = Body(None, description="用户ID"),
|
||
rag: bool = Body(False, description="增强标识, 当knowledge_base_list有值时,需要传True"),
|
||
knowledge_base_list: Optional[List[str]] = Body(None, description="知识库名称列表"),
|
||
query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||
conversation_id: str = Body("", description="对话框ID"),
|
||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
|
||
max_tokens: Optional[int] = Body(
|
||
MAX_TOKENS,
|
||
description="限制LLM生成Token数量,默认None代表模型最大值"
|
||
),
|
||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature 同时设置", gt=0.0, lt=1.0),
|
||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||
):
|
||
async def chat_iterator() -> AsyncIterable[str]:
|
||
nonlocal rag,max_tokens, uid, model_name,knowledge_base_list
|
||
# rag = True
|
||
max_tokens = MAX_TOKENS
|
||
history_summary = ""
|
||
memory = None
|
||
message_id = str(uuid.uuid1())+"q"
|
||
think_type = {"text": "", "message_id": message_id}
|
||
user_prompt_name = prompt_name
|
||
outline_num = ""
|
||
res = ""
|
||
callback = AsyncIteratorCallbackHandler()
|
||
callbacks = [callback]
|
||
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
|
||
chat_type="llm_chat",
|
||
query=query)
|
||
callbacks.append(conversation_callback)
|
||
time_based_uuid = str(uuid.uuid1())+"q"
|
||
tip ={}
|
||
tip["END"]=""
|
||
tip["source_docs"]=[]
|
||
tip["num"] = 0
|
||
tip["title"]=[]
|
||
utils.set_shared_variable(time_based_uuid,tip)
|
||
# 先判断需不需要使用知识库或者联网检索同时需要获取大纲模板
|
||
task1 = solve_problem(user_prompt_name= user_prompt_name,query=query, conversation_id="", history=[], model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="think_route",stream=True)
|
||
task2 = solve_problem(user_prompt_name= user_prompt_name,query=query, conversation_id="", history=[], model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="outlines_route",stream=True)
|
||
results = await asyncio.gather(process_task(task1), process_task(task2))
|
||
for result0 in results[0]:
|
||
think_type["text"] += json.loads(result0)["text"].strip()
|
||
break
|
||
for result1 in results[1]:
|
||
outline_num += json.loads(result1)["text"]
|
||
break
|
||
|
||
if(str(think_type["text"]) == "2") or rag == True:
|
||
yield json.dumps({"text": "思考中..."}, ensure_ascii=False)
|
||
model = get_ChatOpenAI(
|
||
model_name=model_name,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
callbacks=callbacks,
|
||
)
|
||
search_query = get_llm_model_response(
|
||
strategy_name="query rewrite",
|
||
llm_model_name=LLM_MODELS[0],
|
||
template_prompt_name="query_rewrite",
|
||
prompt_param_dict={"query": query, "history": [], "time": datetime.now().strftime("%Y%m%d")},
|
||
temperature=0.01,
|
||
max_tokens=512
|
||
)
|
||
# 第一个 JSON 对象
|
||
keys = json.loads(search_query).keys()
|
||
keys_list = list(keys)
|
||
if rag:
|
||
knowledge_base_list = list(set(CH_BASE_NAME).union(knowledge_base_list))
|
||
logger.info(f"大纲撰写参考的个人知识库名称:{knowledge_base_list}")
|
||
first_json = {
|
||
"query": json.loads(search_query)[keys_list[0]],
|
||
"knowledge_name": knowledge_base_list,
|
||
"keywords": []
|
||
}
|
||
else:
|
||
first_json = {
|
||
"query": json.loads(search_query)[keys_list[0]],
|
||
"knowledge_name": CH_BASE_NAME,
|
||
"keywords": []
|
||
}
|
||
|
||
# 第二个 JSON 对象
|
||
second_json = {
|
||
"uuid": time_based_uuid
|
||
}
|
||
|
||
origin_query=json.dumps(first_json) + json.dumps(second_json)
|
||
# 截断检索资料,避免超token
|
||
res += search_tool(origin_query)
|
||
if len(res) >= 7000:
|
||
res = res[:7000]
|
||
else:
|
||
model = get_ChatOpenAI(
|
||
model_name=LLM_MODELS[0],
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
callbacks=callbacks,
|
||
)
|
||
|
||
prompt_template = get_prompt_template("knowledge_base_chat", "default_outlines")
|
||
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
|
||
# input_msg = History(role="user", content=query).to_msg_template(False)
|
||
chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
|
||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||
|
||
|
||
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
|
||
chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
|
||
|
||
format_template = get_format_template("knowledge_base_chat", "default_outlines")
|
||
|
||
chain = load_qa_chain(
|
||
model, chain_type="stuff", memory=memory, prompt=chat_prompt, verbose=True
|
||
)
|
||
# docs = list(itertools.chain(policydocs, reportdocs, journaldocs, personaldocs))
|
||
if not outline_num or outline_num == "" or "无" in outline_num:
|
||
outline_num = len(outlines)-1
|
||
if res =="":
|
||
res+="根据你的自身能力回答"
|
||
results1=[]
|
||
results1.append(DocumentWithVSId(page_content=res))
|
||
logger.info(f"outline_num:{outline_num}")
|
||
task = asyncio.create_task(wrap_done(
|
||
chain.acall({
|
||
# "context": res,
|
||
"input_documents": results1,
|
||
"self_knowledge":outlines[int(outline_num)],
|
||
"history": [],
|
||
"question": query,
|
||
"format_template": format_template,
|
||
"time": datetime.now().strftime("%Y%m%d")
|
||
}),
|
||
callback.done),
|
||
)
|
||
|
||
# if stream:
|
||
async for token in callback.aiter():
|
||
# Use server-sent-events to stream the response
|
||
history_summary += token
|
||
yield json.dumps(
|
||
{"text": token, "message_id": message_id},
|
||
ensure_ascii=False)
|
||
# try:
|
||
# source = utils.get_shared_variable(time_based_uuid)
|
||
# source_docs = source["source_docs"]
|
||
# docs_string = "\n"+"\n".join(f"{str(doc)}\n" for doc in source_docs)
|
||
# yield json.dumps({"docs": docs_string}, ensure_ascii=False)
|
||
# except Exception as e:
|
||
# logging.error(f"获取知识库联想信息失败:{e}")
|
||
# print("知识库联想存在异常或没有进行知识库联想")
|
||
# pass
|
||
#summary = TextRank(history_summary, 80)
|
||
#yield json.dumps({"summary":summary}, ensure_ascii=False)
|
||
# del question_history
|
||
await task
|
||
|
||
|
||
callback = AsyncIteratorCallbackHandler()
|
||
|
||
return EventSourceResponse(chat_iterator())
|