134 lines
6.7 KiB
Python
134 lines
6.7 KiB
Python
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from datetime import datetime
|
||
import json
|
||
from uuid import uuid4
|
||
from fastapi import Body, HTTPException
|
||
from fastapi.responses import FileResponse
|
||
from typing import Any, Dict, List, Optional
|
||
from sse_starlette.sse import EventSourceResponse
|
||
from langchain.chains import LLMChain
|
||
from configs.comparison_config import comparison
|
||
from configs import LLM_MODELS
|
||
from configs.model_config import MAX_TOKENS
|
||
from server.agent.tools.draw_plot import create_and_save_plot
|
||
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
|
||
from server.chat import utils
|
||
from server.chat.policy_fun_iast import get_llm_model_response
|
||
from typing import AsyncIterable
|
||
from configs.basic_config import *
|
||
from server.chat.utils import History
|
||
from server.custom.AsyncIteratorCallbackHandlerNew import AsyncIteratorCallbackHandler
|
||
from server.utils import get_ChatOpenAI, get_prompt_template, wrap_done
|
||
from langchain.prompts.chat import ChatPromptTemplate
|
||
|
||
async def chat_comparison_test(
|
||
uid: Optional[str] = Body(None, description="时间"),
|
||
content_list: Optional[List[str]] = Body(None, description="对比文献内容,需要是个包含所有文件内容的字符串列表"),
|
||
):
|
||
async def chat_iterator() -> AsyncIterable[str]:
|
||
nonlocal content_list
|
||
|
||
# 按照对比模板提取的对比维度
|
||
context = comparison[0]["summary"]
|
||
|
||
# 用于处理提炼资料的进度条
|
||
total_tasks = len(content_list)
|
||
completed_tasks = 0
|
||
|
||
|
||
# 用于存储从原文资料里提炼出来的所有有用的信息
|
||
useful_content = []
|
||
|
||
# 使用线程池执行异步任务对文献所有内容进行提取
|
||
with ThreadPoolExecutor() as executor:
|
||
# 创建一个任务列表,每个任务处理 content_list 中的一个元素
|
||
futures = [
|
||
executor.submit(
|
||
get_llm_model_response,
|
||
strategy_name="query rewrite",
|
||
llm_model_name=LLM_MODELS[1],
|
||
template_prompt_name="extract_key_points",
|
||
prompt_param_dict={"time": datetime.now().strftime("%Y%m%d"), "context": context, "content": content},
|
||
temperature=0.01,
|
||
max_tokens=512
|
||
)
|
||
for content in content_list
|
||
]
|
||
|
||
# 收集所有任务的结果
|
||
for future in as_completed(futures):
|
||
useful_content.append(future.result())
|
||
completed_tasks += 1
|
||
progress = (completed_tasks / total_tasks) * 50
|
||
print(f"Progress: {progress:.2f}%")
|
||
yield json.dumps(
|
||
{"progress": f"{progress:.2f}%"},
|
||
ensure_ascii=False)
|
||
|
||
# 对所有结果进行处理
|
||
useful_res = "\n".join(f"第{i+1}篇文章:\n{content}\n"for i,content in enumerate(useful_content) )
|
||
|
||
# 根据要点数据先制作出对比图
|
||
pic_param=get_llm_model_response(
|
||
strategy_name="query rewrite",
|
||
llm_model_name=LLM_MODELS[0],
|
||
template_prompt_name="make_comparison_pic",
|
||
prompt_param_dict={"time": datetime.now().strftime("%Y%m%d"), "context": useful_res, "content": "【参数说明】:参数格式:以JSON格式提供, 必需参数:\n 1. data:图表数据格式如下{\"XXX\": XX, \"XXX\": XX, \"XXX\": X, \"XXX\": X}其中XXX代表分类名称,xx代表分类数据量。\n2.title:图表标题3.xlabel:横轴标题(你按照分的几类属于哪个大类)必须有\n4.ylabel:纵轴标题(你的数值数据是什么)必须有\n 5.plot_type:图表类型必须从以下几个当中选一个作为输入【bar,pie,line】其中bar代表柱状图pie代表饼图line代表折线图你只能选一个作为输入,\n【使用指南】:\n当需要生成图表时,请使用此工具。\n输入必须包含以下参数::{\n\"data\": {\"XXX\": XX, \"XXX\": XX, \"XXX\": X, \"XXX\": X},\"title\": \"X\",\"xlabel\": \"X\",\"ylabel\": \"X\",\"plot_type\": \"X\"}\n示例:{'\"data\": {\"Category A\": 23, \"Category B\": 17, \"Category C\": 35, \"Category D\": 29},\"title\": \"My Chart\",\"xlabel\": \"Category\",\"ylabel\": \"Value\",\"plot_type\": \"pie\"}\n请务必以json格式输入方便使用。\n" },
|
||
temperature=0.01,
|
||
max_tokens=MAX_TOKENS
|
||
).replace("```json","").replace("```","").strip()
|
||
pic_params = json.loads(pic_param)
|
||
pic_content={}
|
||
uuid = str(uuid4())
|
||
utils.set_shared_variable(uuid, pic_content)
|
||
temp = {}
|
||
temp["uuid"] = uuid
|
||
pic_res = ""
|
||
for param in pic_params:
|
||
pic_res+=create_and_save_plot(f"{param}{temp}")+"\n"
|
||
#pic_res = create_and_save_plot(f"{pic_param}{temp}")
|
||
# 生成对比文件
|
||
callback = AsyncIteratorCallbackHandler()
|
||
callbacks = [callback]
|
||
conversation_callback = ConversationCallbackHandler(conversation_id="", message_id="",
|
||
chat_type="llm_chat",
|
||
query=useful_res)
|
||
callbacks.append(conversation_callback)
|
||
model = get_ChatOpenAI(
|
||
model_name=LLM_MODELS[0],
|
||
temperature=0.01,
|
||
max_tokens=MAX_TOKENS,
|
||
callbacks=callbacks,
|
||
)
|
||
memory = None
|
||
prompt_template = get_prompt_template("comparison_chat", "write_report")
|
||
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
|
||
chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
|
||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||
task = asyncio.create_task(wrap_done(
|
||
chain.acall({
|
||
# "context": res,
|
||
"context": comparison[0]["content"],
|
||
"time": datetime.now().strftime("%Y%m%d"),
|
||
"content": useful_res,
|
||
"pic": pic_res,
|
||
}),
|
||
callback.done),
|
||
)
|
||
history_summary = ""
|
||
# if stream:
|
||
async for token in callback.aiter():
|
||
# Use server-sent-events to stream the response
|
||
history_summary += token
|
||
# yield json.dumps(
|
||
# {"text": token},
|
||
# ensure_ascii=False)
|
||
yield json.dumps(
|
||
{"text": history_summary},
|
||
ensure_ascii=False)
|
||
yield json.dumps(
|
||
{"progress": "100%"},
|
||
ensure_ascii=False)
|
||
await task
|
||
|
||
return EventSourceResponse(chat_iterator()) |