import uuid from fastapi import Body from langchain.memory import ( CombinedMemory, ConversationBufferMemory, ConversationSummaryMemory, ConversationBufferWindowMemory ) from configs.model_config import MAX_TOKENS from configs.outline_config import outlines from langchain.chains.question_answering import load_qa_chain from typing import Any from sse_starlette.sse import EventSourceResponse from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN from configs.kb_config import CH_BASE_NAME from server.agent.tools.search_tool import search_tool from server.chat import utils from server.chat.agent_chat_test import agent_chat_test from server.chat.policy_fun_iast import get_llm_model_response from server.chat.solve_problem import solve_problem from server.knowledge_base.kb_service.base import TextRank from server.knowledge_base.model.kb_document_model import DocumentWithVSId from server.utils import wrap_done, get_ChatOpenAI from langchain.chains import LLMChain, ConversationChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable import asyncio import json from langchain.prompts.chat import ChatPromptTemplate from typing import List, Optional, Union from server.chat.utils import History, compute_lps, remove_after_and_including,remove_before_and_including from langchain.prompts import PromptTemplate from server.utils import get_prompt_template, get_format_template from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory from server.db.repository import add_message_to_db from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler from datetime import datetime from langchain_core.messages import SystemMessage import time as t from configs.basic_config import * async def process_task(task): results = [] async for result in task: results.append(result) return results async def knowledge_chat_test( uid: Optional[Any] = Body(None, description="用户ID"), rag: bool = Body(False, description="增强标识, 当knowledge_base_list有值时,需要传True"), knowledge_base_list: Optional[List[str]] = Body(None, description="知识库名称列表"), query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), conversation_id: str = Body("", description="对话框ID"), model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0), max_tokens: Optional[int] = Body( MAX_TOKENS, description="限制LLM生成Token数量,默认None代表模型最大值" ), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature 同时设置", gt=0.0, lt=1.0), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): async def chat_iterator() -> AsyncIterable[str]: nonlocal rag,max_tokens, uid, model_name,knowledge_base_list # rag = True max_tokens = MAX_TOKENS history_summary = "" memory = None message_id = str(uuid.uuid1())+"q" think_type = {"text": "", "message_id": message_id} user_prompt_name = prompt_name outline_num = "" res = "" callback = AsyncIteratorCallbackHandler() callbacks = [callback] conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id, chat_type="llm_chat", query=query) callbacks.append(conversation_callback) time_based_uuid = str(uuid.uuid1())+"q" tip ={} tip["END"]="" tip["source_docs"]=[] tip["num"] = 0 tip["title"]=[] utils.set_shared_variable(time_based_uuid,tip) # 先判断需不需要使用知识库或者联网检索同时需要获取大纲模板 task1 = solve_problem(user_prompt_name= user_prompt_name,query=query, conversation_id="", history=[], model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="think_route",stream=True) task2 = solve_problem(user_prompt_name= user_prompt_name,query=query, conversation_id="", history=[], model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="outlines_route",stream=True) results = await asyncio.gather(process_task(task1), process_task(task2)) for result0 in results[0]: think_type["text"] += json.loads(result0)["text"].strip() break for result1 in results[1]: outline_num += json.loads(result1)["text"] break if(str(think_type["text"]) == "2") or rag == True: yield json.dumps({"text": "思考中..."}, ensure_ascii=False) model = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, callbacks=callbacks, ) search_query = get_llm_model_response( strategy_name="query rewrite", llm_model_name=LLM_MODELS[0], template_prompt_name="query_rewrite", prompt_param_dict={"query": query, "history": [], "time": datetime.now().strftime("%Y%m%d")}, temperature=0.01, max_tokens=512 ) # 第一个 JSON 对象 keys = json.loads(search_query).keys() keys_list = list(keys) if rag: knowledge_base_list = list(set(CH_BASE_NAME).union(knowledge_base_list)) logger.info(f"大纲撰写参考的个人知识库名称:{knowledge_base_list}") first_json = { "query": json.loads(search_query)[keys_list[0]], "knowledge_name": knowledge_base_list, "keywords": [] } else: first_json = { "query": json.loads(search_query)[keys_list[0]], "knowledge_name": CH_BASE_NAME, "keywords": [] } # 第二个 JSON 对象 second_json = { "uuid": time_based_uuid } origin_query=json.dumps(first_json) + json.dumps(second_json) # 截断检索资料,避免超token res += search_tool(origin_query) if len(res) >= 7000: res = res[:7000] else: model = get_ChatOpenAI( model_name=LLM_MODELS[0], temperature=temperature, max_tokens=max_tokens, callbacks=callbacks, ) prompt_template = get_prompt_template("knowledge_base_chat", "default_outlines") input_prompt = History(role="system", content=prompt_template).to_msg_template(False) # input_msg = History(role="user", content=query).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages([input_prompt]) chain = LLMChain(prompt=chat_prompt, llm=model) input_prompt = History(role="system", content=prompt_template).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages([input_prompt]) format_template = get_format_template("knowledge_base_chat", "default_outlines") chain = load_qa_chain( model, chain_type="stuff", memory=memory, prompt=chat_prompt, verbose=True ) # docs = list(itertools.chain(policydocs, reportdocs, journaldocs, personaldocs)) if not outline_num or outline_num == "" or "无" in outline_num: outline_num = len(outlines)-1 if res =="": res+="根据你的自身能力回答" results1=[] results1.append(DocumentWithVSId(page_content=res)) logger.info(f"outline_num:{outline_num}") task = asyncio.create_task(wrap_done( chain.acall({ # "context": res, "input_documents": results1, "self_knowledge":outlines[int(outline_num)], "history": [], "question": query, "format_template": format_template, "time": datetime.now().strftime("%Y%m%d") }), callback.done), ) # if stream: async for token in callback.aiter(): # Use server-sent-events to stream the response history_summary += token yield json.dumps( {"text": token, "message_id": message_id}, ensure_ascii=False) # try: # source = utils.get_shared_variable(time_based_uuid) # source_docs = source["source_docs"] # docs_string = "\n"+"\n".join(f"{str(doc)}\n" for doc in source_docs) # yield json.dumps({"docs": docs_string}, ensure_ascii=False) # except Exception as e: # logging.error(f"获取知识库联想信息失败:{e}") # print("知识库联想存在异常或没有进行知识库联想") # pass #summary = TextRank(history_summary, 80) #yield json.dumps({"summary":summary}, ensure_ascii=False) # del question_history await task callback = AsyncIteratorCallbackHandler() return EventSourceResponse(chat_iterator())