Files
gangyan/langchain-chat/server/agent/tools/draw_plot.py

175 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
import os
import re
import uuid
import requests
from matplotlib import pyplot as plt
from pydantic import BaseModel, Field
from configs.model_config import LLM_MODELS
from server.chat import utils
from server.chat.policy_fun_iast import get_llm_model_response
from configs.kb_config import GENERATED_IMAGES_BASE_PATH, realistic_url,ink_url
from matplotlib import font_manager
my_font = font_manager.FontProperties(fname="/usr/share/fonts/MicroSoft-YaHei/MSYH.TTC")
def create_and_save_plot(query:str) -> str:
try:
query = query.replace(" ","").replace("'","\"")
json_str ='{\n"data": {"XXX": XX, "XXX": XX, "XXX": X, "XXX": X},"title": "X","xlabel": "X","ylabel": "X","plot_type": "X"}'
datas = {}
try:
match = re.search(r'<param>(.*?)</param>(\{.*\})', query.replace("\n", ""))
if match:
query = match.group(1).strip()
datas = json.loads(query)
else:
print(f"Invalid JSON format in query:\n{query}")
return"暂时无法画图"
except:
query = get_llm_model_response(
strategy_name="query rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="check_plot",
prompt_param_dict={"user_input": query,"json":json_str },
temperature=0.01,
max_tokens=512
)
re.search(r"```json\n(.*?)\n```", query.replace("\n", ""), re.DOTALL)
query = match.group(1).strip()
datas = json.loads(query)
data = datas["data"]
xlabel = datas["xlabel"]
ylabel = datas["ylabel"]
title = datas["title"]
plot_type = datas["plot_type"]
# 分析和汇总数据
categories = list(data.keys())
values = list(data.values())
# 创建图表
plt.figure(figsize=(10, 6))
if plot_type == 'bar':
plt.bar(categories, values, color='skyblue')
elif plot_type == 'pie':
plt.pie(values, labels=categories, autopct='%1.1f%%', startangle=140,textprops={'fontproperties': my_font})
elif plot_type == 'line':
plt.plot(categories, values, marker='o', linestyle='-')
else:
raise ValueError("Unsupported plot type. Choose from 'bar', 'pie', or 'line'.")
# 添加标题和标签
plt.title(title,fontproperties=my_font)
if plot_type != 'pie': # 饼状图不需要轴标签
plt.xlabel(xlabel,fontproperties=my_font)
plt.ylabel(ylabel,fontproperties=my_font)
plt.xticks(fontproperties=my_font,rotation=45)
plt.yticks(fontproperties=my_font)
namesid = uuid.uuid1()
# 保存图表为图片文件
file_path = f'{GENERATED_IMAGES_BASE_PATH}/plot{namesid}.png'
absolute_path = os.path.abspath(file_path)
# sources = utils.get_shared_variable(uuids)
# sources["url"] = f"plot{namesid}.png"
# sources["END"]="ok"
# utils.set_shared_variable(uuids, sources)
# plt.figure(figsize=(50, 60))
# plt.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.15)
plt.tight_layout()
plt.savefig(absolute_path)
plt.close()
image_name = f"plot{namesid}.png"
return f"图片如下:![图片{title}](http://127.0.0.1:8099/chat_web_backend/get-image?file_name={image_name})"
except Exception as e:
logging.error(f"An error occurred: {e}")
return "不要再调用该工具了,无法提供该功能,你只能按工具描述提供有的入参,其他不支持"
class drawPlotInput(BaseModel):
query: str = Field(...,description="输入要画图的内容")
def draw_realistic_pic(query:str) -> str:
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
uuid = json.loads(matches[1])["uuid"]
query = query.replace(matches[1],"")
else:
return"暂时无法画图"
except:
logging.error("Invalid JSON format in query.")
# 请求体数据
data = {
'prompt': json.loads(query)["query"]
}
try:
# 发起 POST 请求
response = requests.post(realistic_url, json=data)
# 检查响应状态码
if response.status_code == 200:
# 解析 JSON 响应
result = response.json()
print("Image path:", result.get('image_path'))
file_path = result.get('image_path')
sources = utils.get_shared_variable(uuid)
sources["url"] = file_path
# sources["END"]="ok"
utils.set_shared_variable(uuid, sources)
return f"<关键指令>你必须以markdown格式的图片路径![图片](http://127.0.0.1:8099/chat_web_backend/get-image?file_name={file_path}),之后你也需要按照要求给出更多的markdown格式的数据表格</关键指令>"
else:
print("Failed to get response:", response.status_code)
return "暂时无法画图"
except requests.exceptions.RequestException as e:
print("An error occurred:", e)
return "暂时无法画图"
class drawRealisticInput(BaseModel):
query: str = Field(...,description="输入要画图的内容")
def draw_ink_pic(query:str) -> str:
get_llm_model_response()
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
uuid = json.loads(matches[1])["uuid"]
query = query.replace(matches[1],"")
else:
return"暂时无法画图"
except:
logging.error("Invalid JSON format in query.")
# 请求体数据
data = {
'prompt': json.loads(query)["query"]
}
try:
# 发起 POST 请求
response = requests.post(ink_url, json=data)
# 检查响应状态码
if response.status_code == 200:
# 解析 JSON 响应
result = response.json()
print("Image path:", result.get('image_path'))
file_path = result.get('image_path')
sources = utils.get_shared_variable(uuid)
sources["url"] = file_path
# sources["END"]="ok"
utils.set_shared_variable(uuid, sources)
return f"<关键指令>你必须以markdown格式的图片路径![图片](http://127.0.0.1:8099/chat_web_backend/get-image?file_name={file_path}),之后你也需要按照要求给出更多的markdown格式的数据表格</关键指令>"
else:
print("Failed to get response:", response.status_code)
return "暂时无法画图"
except requests.exceptions.RequestException as e:
print("An error occurred:", e)
return "暂时无法画图"
class drawInkInput(BaseModel):
query: str = Field(...,description="输入要画图的内容")