[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
208
langchain-chat/server/chat/knowledge_chat_test.py
Normal file
208
langchain-chat/server/chat/knowledge_chat_test.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import uuid
|
||||
from fastapi import Body
|
||||
from langchain.memory import (
|
||||
CombinedMemory,
|
||||
ConversationBufferMemory,
|
||||
ConversationSummaryMemory,
|
||||
ConversationBufferWindowMemory
|
||||
)
|
||||
from configs.model_config import MAX_TOKENS
|
||||
from configs.outline_config import outlines
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from typing import Any
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN
|
||||
from configs.kb_config import CH_BASE_NAME
|
||||
from server.agent.tools.search_tool import search_tool
|
||||
from server.chat import utils
|
||||
from server.chat.agent_chat_test import agent_chat_test
|
||||
from server.chat.policy_fun_iast import get_llm_model_response
|
||||
from server.chat.solve_problem import solve_problem
|
||||
from server.knowledge_base.kb_service.base import TextRank
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from langchain.chains import LLMChain, ConversationChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
import json
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List, Optional, Union
|
||||
from server.chat.utils import History, compute_lps, remove_after_and_including,remove_before_and_including
|
||||
from langchain.prompts import PromptTemplate
|
||||
from server.utils import get_prompt_template, get_format_template
|
||||
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||||
from server.db.repository import add_message_to_db
|
||||
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
|
||||
from datetime import datetime
|
||||
from langchain_core.messages import SystemMessage
|
||||
import time as t
|
||||
from configs.basic_config import *
|
||||
async def process_task(task):
|
||||
results = []
|
||||
async for result in task:
|
||||
results.append(result)
|
||||
return results
|
||||
async def knowledge_chat_test(
|
||||
uid: Optional[Any] = Body(None, description="用户ID"),
|
||||
rag: bool = Body(False, description="增强标识, 当knowledge_base_list有值时,需要传True"),
|
||||
knowledge_base_list: Optional[List[str]] = Body(None, description="知识库名称列表"),
|
||||
query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
conversation_id: str = Body("", description="对话框ID"),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
|
||||
max_tokens: Optional[int] = Body(
|
||||
MAX_TOKENS,
|
||||
description="限制LLM生成Token数量,默认None代表模型最大值"
|
||||
),
|
||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature 同时设置", gt=0.0, lt=1.0),
|
||||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
):
|
||||
async def chat_iterator() -> AsyncIterable[str]:
|
||||
nonlocal rag,max_tokens, uid, model_name,knowledge_base_list
|
||||
# rag = True
|
||||
max_tokens = MAX_TOKENS
|
||||
history_summary = ""
|
||||
memory = None
|
||||
message_id = str(uuid.uuid1())+"q"
|
||||
think_type = {"text": "", "message_id": message_id}
|
||||
user_prompt_name = prompt_name
|
||||
outline_num = ""
|
||||
res = ""
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
callbacks = [callback]
|
||||
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
|
||||
chat_type="llm_chat",
|
||||
query=query)
|
||||
callbacks.append(conversation_callback)
|
||||
time_based_uuid = str(uuid.uuid1())+"q"
|
||||
tip ={}
|
||||
tip["END"]=""
|
||||
tip["source_docs"]=[]
|
||||
tip["num"] = 0
|
||||
tip["title"]=[]
|
||||
utils.set_shared_variable(time_based_uuid,tip)
|
||||
# 先判断需不需要使用知识库或者联网检索同时需要获取大纲模板
|
||||
task1 = solve_problem(user_prompt_name= user_prompt_name,query=query, conversation_id="", history=[], model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="think_route",stream=True)
|
||||
task2 = solve_problem(user_prompt_name= user_prompt_name,query=query, conversation_id="", history=[], model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="outlines_route",stream=True)
|
||||
results = await asyncio.gather(process_task(task1), process_task(task2))
|
||||
for result0 in results[0]:
|
||||
think_type["text"] += json.loads(result0)["text"].strip()
|
||||
break
|
||||
for result1 in results[1]:
|
||||
outline_num += json.loads(result1)["text"]
|
||||
break
|
||||
|
||||
if(str(think_type["text"]) == "2") or rag == True:
|
||||
yield json.dumps({"text": "思考中..."}, ensure_ascii=False)
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
search_query = get_llm_model_response(
|
||||
strategy_name="query rewrite",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="query_rewrite",
|
||||
prompt_param_dict={"query": query, "history": [], "time": datetime.now().strftime("%Y%m%d")},
|
||||
temperature=0.01,
|
||||
max_tokens=512
|
||||
)
|
||||
# 第一个 JSON 对象
|
||||
keys = json.loads(search_query).keys()
|
||||
keys_list = list(keys)
|
||||
if rag:
|
||||
knowledge_base_list = list(set(CH_BASE_NAME).union(knowledge_base_list))
|
||||
logger.info(f"大纲撰写参考的个人知识库名称:{knowledge_base_list}")
|
||||
first_json = {
|
||||
"query": json.loads(search_query)[keys_list[0]],
|
||||
"knowledge_name": knowledge_base_list,
|
||||
"keywords": []
|
||||
}
|
||||
else:
|
||||
first_json = {
|
||||
"query": json.loads(search_query)[keys_list[0]],
|
||||
"knowledge_name": CH_BASE_NAME,
|
||||
"keywords": []
|
||||
}
|
||||
|
||||
# 第二个 JSON 对象
|
||||
second_json = {
|
||||
"uuid": time_based_uuid
|
||||
}
|
||||
|
||||
origin_query=json.dumps(first_json) + json.dumps(second_json)
|
||||
# 截断检索资料,避免超token
|
||||
res += search_tool(origin_query)
|
||||
if len(res) >= 7000:
|
||||
res = res[:7000]
|
||||
else:
|
||||
model = get_ChatOpenAI(
|
||||
model_name=LLM_MODELS[0],
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", "default_outlines")
|
||||
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
|
||||
# input_msg = History(role="user", content=query).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
|
||||
|
||||
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
|
||||
|
||||
format_template = get_format_template("knowledge_base_chat", "default_outlines")
|
||||
|
||||
chain = load_qa_chain(
|
||||
model, chain_type="stuff", memory=memory, prompt=chat_prompt, verbose=True
|
||||
)
|
||||
# docs = list(itertools.chain(policydocs, reportdocs, journaldocs, personaldocs))
|
||||
if not outline_num or outline_num == "" or "无" in outline_num:
|
||||
outline_num = len(outlines)-1
|
||||
if res =="":
|
||||
res+="根据你的自身能力回答"
|
||||
results1=[]
|
||||
results1.append(DocumentWithVSId(page_content=res))
|
||||
logger.info(f"outline_num:{outline_num}")
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({
|
||||
# "context": res,
|
||||
"input_documents": results1,
|
||||
"self_knowledge":outlines[int(outline_num)],
|
||||
"history": [],
|
||||
"question": query,
|
||||
"format_template": format_template,
|
||||
"time": datetime.now().strftime("%Y%m%d")
|
||||
}),
|
||||
callback.done),
|
||||
)
|
||||
|
||||
# if stream:
|
||||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
history_summary += token
|
||||
yield json.dumps(
|
||||
{"text": token, "message_id": message_id},
|
||||
ensure_ascii=False)
|
||||
# try:
|
||||
# source = utils.get_shared_variable(time_based_uuid)
|
||||
# source_docs = source["source_docs"]
|
||||
# docs_string = "\n"+"\n".join(f"{str(doc)}\n" for doc in source_docs)
|
||||
# yield json.dumps({"docs": docs_string}, ensure_ascii=False)
|
||||
# except Exception as e:
|
||||
# logging.error(f"获取知识库联想信息失败:{e}")
|
||||
# print("知识库联想存在异常或没有进行知识库联想")
|
||||
# pass
|
||||
#summary = TextRank(history_summary, 80)
|
||||
#yield json.dumps({"summary":summary}, ensure_ascii=False)
|
||||
# del question_history
|
||||
await task
|
||||
|
||||
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
|
||||
return EventSourceResponse(chat_iterator())
|
||||
Reference in New Issue
Block a user