[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
134
langchain-chat/server/chat/chat_comparison.py
Normal file
134
langchain-chat/server/chat/chat_comparison.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from datetime import datetime
|
||||
import json
|
||||
from uuid import uuid4
|
||||
from fastapi import Body, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from typing import Any, Dict, List, Optional
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from langchain.chains import LLMChain
|
||||
from configs.comparison_config import comparison
|
||||
from configs import LLM_MODELS
|
||||
from configs.model_config import MAX_TOKENS
|
||||
from server.agent.tools.draw_plot import create_and_save_plot
|
||||
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
|
||||
from server.chat import utils
|
||||
from server.chat.policy_fun_iast import get_llm_model_response
|
||||
from typing import AsyncIterable
|
||||
from configs.basic_config import *
|
||||
from server.chat.utils import History
|
||||
from server.custom.AsyncIteratorCallbackHandlerNew import AsyncIteratorCallbackHandler
|
||||
from server.utils import get_ChatOpenAI, get_prompt_template, wrap_done
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
async def chat_comparison_test(
|
||||
uid: Optional[str] = Body(None, description="时间"),
|
||||
content_list: Optional[List[str]] = Body(None, description="对比文献内容,需要是个包含所有文件内容的字符串列表"),
|
||||
):
|
||||
async def chat_iterator() -> AsyncIterable[str]:
|
||||
nonlocal content_list
|
||||
|
||||
# 按照对比模板提取的对比维度
|
||||
context = comparison[0]["summary"]
|
||||
|
||||
# 用于处理提炼资料的进度条
|
||||
total_tasks = len(content_list)
|
||||
completed_tasks = 0
|
||||
|
||||
|
||||
# 用于存储从原文资料里提炼出来的所有有用的信息
|
||||
useful_content = []
|
||||
|
||||
# 使用线程池执行异步任务对文献所有内容进行提取
|
||||
with ThreadPoolExecutor() as executor:
|
||||
# 创建一个任务列表,每个任务处理 content_list 中的一个元素
|
||||
futures = [
|
||||
executor.submit(
|
||||
get_llm_model_response,
|
||||
strategy_name="query rewrite",
|
||||
llm_model_name=LLM_MODELS[1],
|
||||
template_prompt_name="extract_key_points",
|
||||
prompt_param_dict={"time": datetime.now().strftime("%Y%m%d"), "context": context, "content": content},
|
||||
temperature=0.01,
|
||||
max_tokens=512
|
||||
)
|
||||
for content in content_list
|
||||
]
|
||||
|
||||
# 收集所有任务的结果
|
||||
for future in as_completed(futures):
|
||||
useful_content.append(future.result())
|
||||
completed_tasks += 1
|
||||
progress = (completed_tasks / total_tasks) * 50
|
||||
print(f"Progress: {progress:.2f}%")
|
||||
yield json.dumps(
|
||||
{"progress": f"{progress:.2f}%"},
|
||||
ensure_ascii=False)
|
||||
|
||||
# 对所有结果进行处理
|
||||
useful_res = "\n".join(f"第{i+1}篇文章:\n{content}\n"for i,content in enumerate(useful_content) )
|
||||
|
||||
# 根据要点数据先制作出对比图
|
||||
pic_param=get_llm_model_response(
|
||||
strategy_name="query rewrite",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="make_comparison_pic",
|
||||
prompt_param_dict={"time": datetime.now().strftime("%Y%m%d"), "context": useful_res, "content": "【参数说明】:参数格式:以JSON格式提供, 必需参数:\n 1. data:图表数据格式如下{\"XXX\": XX, \"XXX\": XX, \"XXX\": X, \"XXX\": X}其中XXX代表分类名称,xx代表分类数据量。\n2.title:图表标题3.xlabel:横轴标题(你按照分的几类属于哪个大类)必须有\n4.ylabel:纵轴标题(你的数值数据是什么)必须有\n 5.plot_type:图表类型必须从以下几个当中选一个作为输入【bar,pie,line】其中bar代表柱状图pie代表饼图line代表折线图你只能选一个作为输入,\n【使用指南】:\n当需要生成图表时,请使用此工具。\n输入必须包含以下参数::{\n\"data\": {\"XXX\": XX, \"XXX\": XX, \"XXX\": X, \"XXX\": X},\"title\": \"X\",\"xlabel\": \"X\",\"ylabel\": \"X\",\"plot_type\": \"X\"}\n示例:{'\"data\": {\"Category A\": 23, \"Category B\": 17, \"Category C\": 35, \"Category D\": 29},\"title\": \"My Chart\",\"xlabel\": \"Category\",\"ylabel\": \"Value\",\"plot_type\": \"pie\"}\n请务必以json格式输入方便使用。\n" },
|
||||
temperature=0.01,
|
||||
max_tokens=MAX_TOKENS
|
||||
).replace("```json","").replace("```","").strip()
|
||||
pic_params = json.loads(pic_param)
|
||||
pic_content={}
|
||||
uuid = str(uuid4())
|
||||
utils.set_shared_variable(uuid, pic_content)
|
||||
temp = {}
|
||||
temp["uuid"] = uuid
|
||||
pic_res = ""
|
||||
for param in pic_params:
|
||||
pic_res+=create_and_save_plot(f"{param}{temp}")+"\n"
|
||||
#pic_res = create_and_save_plot(f"{pic_param}{temp}")
|
||||
# 生成对比文件
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
callbacks = [callback]
|
||||
conversation_callback = ConversationCallbackHandler(conversation_id="", message_id="",
|
||||
chat_type="llm_chat",
|
||||
query=useful_res)
|
||||
callbacks.append(conversation_callback)
|
||||
model = get_ChatOpenAI(
|
||||
model_name=LLM_MODELS[0],
|
||||
temperature=0.01,
|
||||
max_tokens=MAX_TOKENS,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
memory = None
|
||||
prompt_template = get_prompt_template("comparison_chat", "write_report")
|
||||
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({
|
||||
# "context": res,
|
||||
"context": comparison[0]["content"],
|
||||
"time": datetime.now().strftime("%Y%m%d"),
|
||||
"content": useful_res,
|
||||
"pic": pic_res,
|
||||
}),
|
||||
callback.done),
|
||||
)
|
||||
history_summary = ""
|
||||
# if stream:
|
||||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
history_summary += token
|
||||
# yield json.dumps(
|
||||
# {"text": token},
|
||||
# ensure_ascii=False)
|
||||
yield json.dumps(
|
||||
{"text": history_summary},
|
||||
ensure_ascii=False)
|
||||
yield json.dumps(
|
||||
{"progress": "100%"},
|
||||
ensure_ascii=False)
|
||||
await task
|
||||
|
||||
return EventSourceResponse(chat_iterator())
|
||||
Reference in New Issue
Block a user