from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime import json from uuid import uuid4 from fastapi import Body, HTTPException from fastapi.responses import FileResponse from typing import Any, Dict, List, Optional from sse_starlette.sse import EventSourceResponse from langchain.chains import LLMChain from configs.comparison_config import comparison from configs import LLM_MODELS from configs.model_config import MAX_TOKENS from server.agent.tools.draw_plot import create_and_save_plot from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler from server.chat import utils from server.chat.policy_fun_iast import get_llm_model_response from typing import AsyncIterable from configs.basic_config import * from server.chat.utils import History from server.custom.AsyncIteratorCallbackHandlerNew import AsyncIteratorCallbackHandler from server.utils import get_ChatOpenAI, get_prompt_template, wrap_done from langchain.prompts.chat import ChatPromptTemplate async def chat_comparison_test( uid: Optional[str] = Body(None, description="时间"), content_list: Optional[List[str]] = Body(None, description="对比文献内容,需要是个包含所有文件内容的字符串列表"), ): async def chat_iterator() -> AsyncIterable[str]: nonlocal content_list # 按照对比模板提取的对比维度 context = comparison[0]["summary"] # 用于处理提炼资料的进度条 total_tasks = len(content_list) completed_tasks = 0 # 用于存储从原文资料里提炼出来的所有有用的信息 useful_content = [] # 使用线程池执行异步任务对文献所有内容进行提取 with ThreadPoolExecutor() as executor: # 创建一个任务列表,每个任务处理 content_list 中的一个元素 futures = [ executor.submit( get_llm_model_response, strategy_name="query rewrite", llm_model_name=LLM_MODELS[0], template_prompt_name="extract_key_points", prompt_param_dict={"time": datetime.now().strftime("%Y%m%d"), "context": context, "content": content}, temperature=0.01, max_tokens=512 ) for content in content_list ] # 收集所有任务的结果 for future in as_completed(futures): useful_content.append(future.result()) completed_tasks += 1 progress = (completed_tasks / total_tasks) * 50 print(f"Progress: {progress:.2f}%") yield json.dumps( {"progress": f"{progress:.2f}%"}, ensure_ascii=False) # 对所有结果进行处理 useful_res = "\n".join(f"第{i+1}篇文章:\n{content}\n"for i,content in enumerate(useful_content) ) # 根据要点数据先制作出对比图 pic_param=get_llm_model_response( strategy_name="query rewrite", llm_model_name=LLM_MODELS[0], template_prompt_name="make_comparison_pic", prompt_param_dict={"time": datetime.now().strftime("%Y%m%d"), "context": useful_res, "content": "【参数说明】:参数格式:以JSON格式提供, 必需参数:\n 1. data:图表数据格式如下{\"XXX\": XX, \"XXX\": XX, \"XXX\": X, \"XXX\": X}其中XXX代表分类名称,xx代表分类数据量。\n2.title:图表标题3.xlabel:横轴标题(你按照分的几类属于哪个大类)必须有\n4.ylabel:纵轴标题(你的数值数据是什么)必须有\n 5.plot_type:图表类型必须从以下几个当中选一个作为输入【bar,pie,line】其中bar代表柱状图pie代表饼图line代表折线图你只能选一个作为输入,\n【使用指南】:\n当需要生成图表时,请使用此工具。\n输入必须包含以下参数::{\n\"data\": {\"XXX\": XX, \"XXX\": XX, \"XXX\": X, \"XXX\": X},\"title\": \"X\",\"xlabel\": \"X\",\"ylabel\": \"X\",\"plot_type\": \"X\"}\n示例:{'\"data\": {\"Category A\": 23, \"Category B\": 17, \"Category C\": 35, \"Category D\": 29},\"title\": \"My Chart\",\"xlabel\": \"Category\",\"ylabel\": \"Value\",\"plot_type\": \"pie\"}\n请务必以json格式输入方便使用。\n" }, temperature=0.01, max_tokens=MAX_TOKENS ).replace("```json","").replace("```","").strip() pic_params = json.loads(pic_param) pic_content={} uuid = str(uuid4()) utils.set_shared_variable(uuid, pic_content) temp = {} temp["uuid"] = uuid pic_res = "" for param in pic_params: pic_res+=create_and_save_plot(f"{param}{temp}")+"\n" #pic_res = create_and_save_plot(f"{pic_param}{temp}") # 生成对比文件 callback = AsyncIteratorCallbackHandler() callbacks = [callback] conversation_callback = ConversationCallbackHandler(conversation_id="", message_id="", chat_type="llm_chat", query=useful_res) callbacks.append(conversation_callback) model = get_ChatOpenAI( model_name=LLM_MODELS[0], temperature=0.01, max_tokens=MAX_TOKENS, callbacks=callbacks, ) memory = None prompt_template = get_prompt_template("comparison_chat", "write_report") input_prompt = History(role="system", content=prompt_template).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages([input_prompt]) chain = LLMChain(prompt=chat_prompt, llm=model) task = asyncio.create_task(wrap_done( chain.acall({ # "context": res, "context": comparison[0]["content"], "time": datetime.now().strftime("%Y%m%d"), "content": useful_res, "pic": pic_res, }), callback.done), ) history_summary = "" # if stream: async for token in callback.aiter(): # Use server-sent-events to stream the response history_summary += token # yield json.dumps( # {"text": token}, # ensure_ascii=False) yield json.dumps( {"text": history_summary}, ensure_ascii=False) yield json.dumps( {"progress": "100%"}, ensure_ascii=False) await task return EventSourceResponse(chat_iterator())