Files
gangyan/langchain-chat/server/chat/chat_comparison.py

134 lines
6.7 KiB
Python
Raw Normal View History

from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
import json
from uuid import uuid4
from fastapi import Body, HTTPException
from fastapi.responses import FileResponse
from typing import Any, Dict, List, Optional
from sse_starlette.sse import EventSourceResponse
from langchain.chains import LLMChain
from configs.comparison_config import comparison
from configs import LLM_MODELS
from configs.model_config import MAX_TOKENS
from server.agent.tools.draw_plot import create_and_save_plot
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
from server.chat import utils
from server.chat.policy_fun_iast import get_llm_model_response
from typing import AsyncIterable
from configs.basic_config import *
from server.chat.utils import History
from server.custom.AsyncIteratorCallbackHandlerNew import AsyncIteratorCallbackHandler
from server.utils import get_ChatOpenAI, get_prompt_template, wrap_done
from langchain.prompts.chat import ChatPromptTemplate
async def chat_comparison_test(
uid: Optional[str] = Body(None, description="时间"),
content_list: Optional[List[str]] = Body(None, description="对比文献内容,需要是个包含所有文件内容的字符串列表"),
):
async def chat_iterator() -> AsyncIterable[str]:
nonlocal content_list
# 按照对比模板提取的对比维度
context = comparison[0]["summary"]
# 用于处理提炼资料的进度条
total_tasks = len(content_list)
completed_tasks = 0
# 用于存储从原文资料里提炼出来的所有有用的信息
useful_content = []
# 使用线程池执行异步任务对文献所有内容进行提取
with ThreadPoolExecutor() as executor:
# 创建一个任务列表,每个任务处理 content_list 中的一个元素
futures = [
executor.submit(
get_llm_model_response,
strategy_name="query rewrite",
llm_model_name=LLM_MODELS[1],
template_prompt_name="extract_key_points",
prompt_param_dict={"time": datetime.now().strftime("%Y%m%d"), "context": context, "content": content},
temperature=0.01,
max_tokens=512
)
for content in content_list
]
# 收集所有任务的结果
for future in as_completed(futures):
useful_content.append(future.result())
completed_tasks += 1
progress = (completed_tasks / total_tasks) * 50
print(f"Progress: {progress:.2f}%")
yield json.dumps(
{"progress": f"{progress:.2f}%"},
ensure_ascii=False)
# 对所有结果进行处理
useful_res = "\n".join(f"{i+1}篇文章:\n{content}\n"for i,content in enumerate(useful_content) )
# 根据要点数据先制作出对比图
pic_param=get_llm_model_response(
strategy_name="query rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="make_comparison_pic",
prompt_param_dict={"time": datetime.now().strftime("%Y%m%d"), "context": useful_res, "content": "【参数说明】:参数格式以JSON格式提供 必需参数:\n 1. data图表数据格式如下{\"XXX\": XX, \"XXX\": XX, \"XXX\": X, \"XXX\": X}其中XXX代表分类名称xx代表分类数据量。\n2.title图表标题3.xlabel横轴标题你按照分的几类属于哪个大类必须有\n4.ylabel纵轴标题你的数值数据是什么必须有\n 5.plot_type:图表类型必须从以下几个当中选一个作为输入【bar,pie,line】其中bar代表柱状图pie代表饼图line代表折线图你只能选一个作为输入\n【使用指南】:\n当需要生成图表时,请使用此工具。\n输入必须包含以下参数::{\n\"data\": {\"XXX\": XX, \"XXX\": XX, \"XXX\": X, \"XXX\": X},\"title\": \"X\",\"xlabel\": \"X\",\"ylabel\": \"X\",\"plot_type\": \"X\"}\n示例:{'\"data\": {\"Category A\": 23, \"Category B\": 17, \"Category C\": 35, \"Category D\": 29},\"title\": \"My Chart\",\"xlabel\": \"Category\",\"ylabel\": \"Value\",\"plot_type\": \"pie\"}\n请务必以json格式输入方便使用。\n" },
temperature=0.01,
max_tokens=MAX_TOKENS
).replace("```json","").replace("```","").strip()
pic_params = json.loads(pic_param)
pic_content={}
uuid = str(uuid4())
utils.set_shared_variable(uuid, pic_content)
temp = {}
temp["uuid"] = uuid
pic_res = ""
for param in pic_params:
pic_res+=create_and_save_plot(f"{param}{temp}")+"\n"
#pic_res = create_and_save_plot(f"{pic_param}{temp}")
# 生成对比文件
callback = AsyncIteratorCallbackHandler()
callbacks = [callback]
conversation_callback = ConversationCallbackHandler(conversation_id="", message_id="",
chat_type="llm_chat",
query=useful_res)
callbacks.append(conversation_callback)
model = get_ChatOpenAI(
model_name=LLM_MODELS[0],
temperature=0.01,
max_tokens=MAX_TOKENS,
callbacks=callbacks,
)
memory = None
prompt_template = get_prompt_template("comparison_chat", "write_report")
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
chain = LLMChain(prompt=chat_prompt, llm=model)
task = asyncio.create_task(wrap_done(
chain.acall({
# "context": res,
"context": comparison[0]["content"],
"time": datetime.now().strftime("%Y%m%d"),
"content": useful_res,
"pic": pic_res,
}),
callback.done),
)
history_summary = ""
# if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
history_summary += token
# yield json.dumps(
# {"text": token},
# ensure_ascii=False)
yield json.dumps(
{"text": history_summary},
ensure_ascii=False)
yield json.dumps(
{"progress": "100%"},
ensure_ascii=False)
await task
return EventSourceResponse(chat_iterator())