Files
gangyan/langchain-chat/server/chat/chat_comparison.py

134 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
import json
from uuid import uuid4
from fastapi import Body, HTTPException
from fastapi.responses import FileResponse
from typing import Any, Dict, List, Optional
from sse_starlette.sse import EventSourceResponse
from langchain.chains import LLMChain
from configs.comparison_config import comparison
from configs import LLM_MODELS
from configs.model_config import MAX_TOKENS
from server.agent.tools.draw_plot import create_and_save_plot
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
from server.chat import utils
from server.chat.policy_fun_iast import get_llm_model_response
from typing import AsyncIterable
from configs.basic_config import *
from server.chat.utils import History
from server.custom.AsyncIteratorCallbackHandlerNew import AsyncIteratorCallbackHandler
from server.utils import get_ChatOpenAI, get_prompt_template, wrap_done
from langchain.prompts.chat import ChatPromptTemplate
async def chat_comparison_test(
uid: Optional[str] = Body(None, description="时间"),
content_list: Optional[List[str]] = Body(None, description="对比文献内容,需要是个包含所有文件内容的字符串列表"),
):
async def chat_iterator() -> AsyncIterable[str]:
nonlocal content_list
# 按照对比模板提取的对比维度
context = comparison[0]["summary"]
# 用于处理提炼资料的进度条
total_tasks = len(content_list)
completed_tasks = 0
# 用于存储从原文资料里提炼出来的所有有用的信息
useful_content = []
# 使用线程池执行异步任务对文献所有内容进行提取
with ThreadPoolExecutor() as executor:
# 创建一个任务列表,每个任务处理 content_list 中的一个元素
futures = [
executor.submit(
get_llm_model_response,
strategy_name="query rewrite",
llm_model_name=LLM_MODELS[1],
template_prompt_name="extract_key_points",
prompt_param_dict={"time": datetime.now().strftime("%Y%m%d"), "context": context, "content": content},
temperature=0.01,
max_tokens=512
)
for content in content_list
]
# 收集所有任务的结果
for future in as_completed(futures):
useful_content.append(future.result())
completed_tasks += 1
progress = (completed_tasks / total_tasks) * 50
print(f"Progress: {progress:.2f}%")
yield json.dumps(
{"progress": f"{progress:.2f}%"},
ensure_ascii=False)
# 对所有结果进行处理
useful_res = "\n".join(f"{i+1}篇文章:\n{content}\n"for i,content in enumerate(useful_content) )
# 根据要点数据先制作出对比图
pic_param=get_llm_model_response(
strategy_name="query rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="make_comparison_pic",
prompt_param_dict={"time": datetime.now().strftime("%Y%m%d"), "context": useful_res, "content": "【参数说明】:参数格式以JSON格式提供 必需参数:\n 1. data图表数据格式如下{\"XXX\": XX, \"XXX\": XX, \"XXX\": X, \"XXX\": X}其中XXX代表分类名称xx代表分类数据量。\n2.title图表标题3.xlabel横轴标题你按照分的几类属于哪个大类必须有\n4.ylabel纵轴标题你的数值数据是什么必须有\n 5.plot_type:图表类型必须从以下几个当中选一个作为输入【bar,pie,line】其中bar代表柱状图pie代表饼图line代表折线图你只能选一个作为输入\n【使用指南】:\n当需要生成图表时,请使用此工具。\n输入必须包含以下参数::{\n\"data\": {\"XXX\": XX, \"XXX\": XX, \"XXX\": X, \"XXX\": X},\"title\": \"X\",\"xlabel\": \"X\",\"ylabel\": \"X\",\"plot_type\": \"X\"}\n示例:{'\"data\": {\"Category A\": 23, \"Category B\": 17, \"Category C\": 35, \"Category D\": 29},\"title\": \"My Chart\",\"xlabel\": \"Category\",\"ylabel\": \"Value\",\"plot_type\": \"pie\"}\n请务必以json格式输入方便使用。\n" },
temperature=0.01,
max_tokens=MAX_TOKENS
).replace("```json","").replace("```","").strip()
pic_params = json.loads(pic_param)
pic_content={}
uuid = str(uuid4())
utils.set_shared_variable(uuid, pic_content)
temp = {}
temp["uuid"] = uuid
pic_res = ""
for param in pic_params:
pic_res+=create_and_save_plot(f"{param}{temp}")+"\n"
#pic_res = create_and_save_plot(f"{pic_param}{temp}")
# 生成对比文件
callback = AsyncIteratorCallbackHandler()
callbacks = [callback]
conversation_callback = ConversationCallbackHandler(conversation_id="", message_id="",
chat_type="llm_chat",
query=useful_res)
callbacks.append(conversation_callback)
model = get_ChatOpenAI(
model_name=LLM_MODELS[0],
temperature=0.01,
max_tokens=MAX_TOKENS,
callbacks=callbacks,
)
memory = None
prompt_template = get_prompt_template("comparison_chat", "write_report")
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
chain = LLMChain(prompt=chat_prompt, llm=model)
task = asyncio.create_task(wrap_done(
chain.acall({
# "context": res,
"context": comparison[0]["content"],
"time": datetime.now().strftime("%Y%m%d"),
"content": useful_res,
"pic": pic_res,
}),
callback.done),
)
history_summary = ""
# if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
history_summary += token
# yield json.dumps(
# {"text": token},
# ensure_ascii=False)
yield json.dumps(
{"text": history_summary},
ensure_ascii=False)
yield json.dumps(
{"progress": "100%"},
ensure_ascii=False)
await task
return EventSourceResponse(chat_iterator())