175 lines
6.9 KiB
Python
175 lines
6.9 KiB
Python
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
import uuid
|
||
import requests
|
||
from matplotlib import pyplot as plt
|
||
from pydantic import BaseModel, Field
|
||
from configs.model_config import LLM_MODELS
|
||
from server.chat import utils
|
||
from server.chat.policy_fun_iast import get_llm_model_response
|
||
from configs.kb_config import GENERATED_IMAGES_BASE_PATH, realistic_url,ink_url
|
||
from matplotlib import font_manager
|
||
|
||
my_font = font_manager.FontProperties(fname="/usr/share/fonts/MicroSoft-YaHei/MSYH.TTC")
|
||
def create_and_save_plot(query:str) -> str:
|
||
try:
|
||
query = query.replace(" ","").replace("'","\"")
|
||
json_str ='{\n"data": {"XXX": XX, "XXX": XX, "XXX": X, "XXX": X},"title": "X","xlabel": "X","ylabel": "X","plot_type": "X"}'
|
||
datas = {}
|
||
try:
|
||
match = re.search(r'<param>(.*?)</param>(\{.*\})', query.replace("\n", ""))
|
||
if match:
|
||
query = match.group(1).strip()
|
||
datas = json.loads(query)
|
||
else:
|
||
print(f"Invalid JSON format in query:\n{query}")
|
||
return"暂时无法画图"
|
||
except:
|
||
query = get_llm_model_response(
|
||
strategy_name="query rewrite",
|
||
llm_model_name=LLM_MODELS[0],
|
||
template_prompt_name="check_plot",
|
||
prompt_param_dict={"user_input": query,"json":json_str },
|
||
temperature=0.01,
|
||
max_tokens=512
|
||
)
|
||
re.search(r"```json\n(.*?)\n```", query.replace("\n", ""), re.DOTALL)
|
||
query = match.group(1).strip()
|
||
datas = json.loads(query)
|
||
data = datas["data"]
|
||
xlabel = datas["xlabel"]
|
||
ylabel = datas["ylabel"]
|
||
title = datas["title"]
|
||
plot_type = datas["plot_type"]
|
||
# 分析和汇总数据
|
||
categories = list(data.keys())
|
||
values = list(data.values())
|
||
|
||
# 创建图表
|
||
plt.figure(figsize=(10, 6))
|
||
|
||
if plot_type == 'bar':
|
||
plt.bar(categories, values, color='skyblue')
|
||
elif plot_type == 'pie':
|
||
plt.pie(values, labels=categories, autopct='%1.1f%%', startangle=140,textprops={'fontproperties': my_font})
|
||
elif plot_type == 'line':
|
||
plt.plot(categories, values, marker='o', linestyle='-')
|
||
else:
|
||
raise ValueError("Unsupported plot type. Choose from 'bar', 'pie', or 'line'.")
|
||
|
||
# 添加标题和标签
|
||
plt.title(title,fontproperties=my_font)
|
||
if plot_type != 'pie': # 饼状图不需要轴标签
|
||
plt.xlabel(xlabel,fontproperties=my_font)
|
||
plt.ylabel(ylabel,fontproperties=my_font)
|
||
plt.xticks(fontproperties=my_font,rotation=45)
|
||
plt.yticks(fontproperties=my_font)
|
||
|
||
namesid = uuid.uuid1()
|
||
# 保存图表为图片文件
|
||
file_path = f'{GENERATED_IMAGES_BASE_PATH}/plot{namesid}.png'
|
||
absolute_path = os.path.abspath(file_path)
|
||
# sources = utils.get_shared_variable(uuids)
|
||
# sources["url"] = f"plot{namesid}.png"
|
||
# sources["END"]="ok"
|
||
# utils.set_shared_variable(uuids, sources)
|
||
# plt.figure(figsize=(50, 60))
|
||
# plt.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.15)
|
||
plt.tight_layout()
|
||
plt.savefig(absolute_path)
|
||
plt.close()
|
||
image_name = f"plot{namesid}.png"
|
||
|
||
return f"图片如下:"
|
||
except Exception as e:
|
||
logging.error(f"An error occurred: {e}")
|
||
return "不要再调用该工具了,无法提供该功能,你只能按工具描述提供有的入参,其他不支持"
|
||
|
||
|
||
class drawPlotInput(BaseModel):
|
||
query: str = Field(...,description="输入要画图的内容")
|
||
|
||
def draw_realistic_pic(query:str) -> str:
|
||
try:
|
||
matches = re.findall(r'\{.*?\}', query)
|
||
if len(matches)>=2:
|
||
uuid = json.loads(matches[1])["uuid"]
|
||
query = query.replace(matches[1],"")
|
||
else:
|
||
return"暂时无法画图"
|
||
except:
|
||
logging.error("Invalid JSON format in query.")
|
||
|
||
# 请求体数据
|
||
data = {
|
||
'prompt': json.loads(query)["query"]
|
||
}
|
||
|
||
try:
|
||
# 发起 POST 请求
|
||
response = requests.post(realistic_url, json=data)
|
||
|
||
# 检查响应状态码
|
||
if response.status_code == 200:
|
||
# 解析 JSON 响应
|
||
result = response.json()
|
||
print("Image path:", result.get('image_path'))
|
||
file_path = result.get('image_path')
|
||
sources = utils.get_shared_variable(uuid)
|
||
sources["url"] = file_path
|
||
# sources["END"]="ok"
|
||
utils.set_shared_variable(uuid, sources)
|
||
return f"<关键指令>你必须以markdown格式的图片路径,,之后你也需要按照要求给出更多的markdown格式的数据表格</关键指令>"
|
||
else:
|
||
print("Failed to get response:", response.status_code)
|
||
return "暂时无法画图"
|
||
except requests.exceptions.RequestException as e:
|
||
print("An error occurred:", e)
|
||
return "暂时无法画图"
|
||
|
||
class drawRealisticInput(BaseModel):
|
||
query: str = Field(...,description="输入要画图的内容")
|
||
|
||
def draw_ink_pic(query:str) -> str:
|
||
get_llm_model_response()
|
||
try:
|
||
matches = re.findall(r'\{.*?\}', query)
|
||
if len(matches)>=2:
|
||
uuid = json.loads(matches[1])["uuid"]
|
||
query = query.replace(matches[1],"")
|
||
else:
|
||
return"暂时无法画图"
|
||
except:
|
||
logging.error("Invalid JSON format in query.")
|
||
|
||
# 请求体数据
|
||
data = {
|
||
'prompt': json.loads(query)["query"]
|
||
}
|
||
|
||
try:
|
||
# 发起 POST 请求
|
||
response = requests.post(ink_url, json=data)
|
||
|
||
# 检查响应状态码
|
||
if response.status_code == 200:
|
||
# 解析 JSON 响应
|
||
result = response.json()
|
||
print("Image path:", result.get('image_path'))
|
||
file_path = result.get('image_path')
|
||
sources = utils.get_shared_variable(uuid)
|
||
sources["url"] = file_path
|
||
# sources["END"]="ok"
|
||
utils.set_shared_variable(uuid, sources)
|
||
return f"<关键指令>你必须以markdown格式的图片路径,,之后你也需要按照要求给出更多的markdown格式的数据表格</关键指令>"
|
||
else:
|
||
print("Failed to get response:", response.status_code)
|
||
return "暂时无法画图"
|
||
except requests.exceptions.RequestException as e:
|
||
print("An error occurred:", e)
|
||
return "暂时无法画图"
|
||
|
||
class drawInkInput(BaseModel):
|
||
query: str = Field(...,description="输入要画图的内容") |