[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
175
langchain-chat/server/agent/tools/draw_plot.py
Normal file
175
langchain-chat/server/agent/tools/draw_plot.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import requests
|
||||
from matplotlib import pyplot as plt
|
||||
from pydantic import BaseModel, Field
|
||||
from configs.model_config import LLM_MODELS
|
||||
from server.chat import utils
|
||||
from server.chat.policy_fun_iast import get_llm_model_response
|
||||
from configs.kb_config import GENERATED_IMAGES_BASE_PATH, realistic_url,ink_url
|
||||
from matplotlib import font_manager
|
||||
|
||||
my_font = font_manager.FontProperties(fname="/usr/share/fonts/MicroSoft-YaHei/MSYH.TTC")
|
||||
def create_and_save_plot(query:str) -> str:
|
||||
try:
|
||||
query = query.replace(" ","").replace("'","\"")
|
||||
json_str ='{\n"data": {"XXX": XX, "XXX": XX, "XXX": X, "XXX": X},"title": "X","xlabel": "X","ylabel": "X","plot_type": "X"}'
|
||||
datas = {}
|
||||
try:
|
||||
match = re.search(r'<param>(.*?)</param>(\{.*\})', query.replace("\n", ""))
|
||||
if match:
|
||||
query = match.group(1).strip()
|
||||
datas = json.loads(query)
|
||||
else:
|
||||
print(f"Invalid JSON format in query:\n{query}")
|
||||
return"暂时无法画图"
|
||||
except:
|
||||
query = get_llm_model_response(
|
||||
strategy_name="query rewrite",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="check_plot",
|
||||
prompt_param_dict={"user_input": query,"json":json_str },
|
||||
temperature=0.01,
|
||||
max_tokens=512
|
||||
)
|
||||
re.search(r"```json\n(.*?)\n```", query.replace("\n", ""), re.DOTALL)
|
||||
query = match.group(1).strip()
|
||||
datas = json.loads(query)
|
||||
data = datas["data"]
|
||||
xlabel = datas["xlabel"]
|
||||
ylabel = datas["ylabel"]
|
||||
title = datas["title"]
|
||||
plot_type = datas["plot_type"]
|
||||
# 分析和汇总数据
|
||||
categories = list(data.keys())
|
||||
values = list(data.values())
|
||||
|
||||
# 创建图表
|
||||
plt.figure(figsize=(10, 6))
|
||||
|
||||
if plot_type == 'bar':
|
||||
plt.bar(categories, values, color='skyblue')
|
||||
elif plot_type == 'pie':
|
||||
plt.pie(values, labels=categories, autopct='%1.1f%%', startangle=140,textprops={'fontproperties': my_font})
|
||||
elif plot_type == 'line':
|
||||
plt.plot(categories, values, marker='o', linestyle='-')
|
||||
else:
|
||||
raise ValueError("Unsupported plot type. Choose from 'bar', 'pie', or 'line'.")
|
||||
|
||||
# 添加标题和标签
|
||||
plt.title(title,fontproperties=my_font)
|
||||
if plot_type != 'pie': # 饼状图不需要轴标签
|
||||
plt.xlabel(xlabel,fontproperties=my_font)
|
||||
plt.ylabel(ylabel,fontproperties=my_font)
|
||||
plt.xticks(fontproperties=my_font,rotation=45)
|
||||
plt.yticks(fontproperties=my_font)
|
||||
|
||||
namesid = uuid.uuid1()
|
||||
# 保存图表为图片文件
|
||||
file_path = f'{GENERATED_IMAGES_BASE_PATH}/plot{namesid}.png'
|
||||
absolute_path = os.path.abspath(file_path)
|
||||
# sources = utils.get_shared_variable(uuids)
|
||||
# sources["url"] = f"plot{namesid}.png"
|
||||
# sources["END"]="ok"
|
||||
# utils.set_shared_variable(uuids, sources)
|
||||
# plt.figure(figsize=(50, 60))
|
||||
# plt.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.15)
|
||||
plt.tight_layout()
|
||||
plt.savefig(absolute_path)
|
||||
plt.close()
|
||||
image_name = f"plot{namesid}.png"
|
||||
|
||||
return f"图片如下:"
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
return "不要再调用该工具了,无法提供该功能,你只能按工具描述提供有的入参,其他不支持"
|
||||
|
||||
|
||||
class drawPlotInput(BaseModel):
|
||||
query: str = Field(...,description="输入要画图的内容")
|
||||
|
||||
def draw_realistic_pic(query:str) -> str:
|
||||
try:
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
uuid = json.loads(matches[1])["uuid"]
|
||||
query = query.replace(matches[1],"")
|
||||
else:
|
||||
return"暂时无法画图"
|
||||
except:
|
||||
logging.error("Invalid JSON format in query.")
|
||||
|
||||
# 请求体数据
|
||||
data = {
|
||||
'prompt': json.loads(query)["query"]
|
||||
}
|
||||
|
||||
try:
|
||||
# 发起 POST 请求
|
||||
response = requests.post(realistic_url, json=data)
|
||||
|
||||
# 检查响应状态码
|
||||
if response.status_code == 200:
|
||||
# 解析 JSON 响应
|
||||
result = response.json()
|
||||
print("Image path:", result.get('image_path'))
|
||||
file_path = result.get('image_path')
|
||||
sources = utils.get_shared_variable(uuid)
|
||||
sources["url"] = file_path
|
||||
# sources["END"]="ok"
|
||||
utils.set_shared_variable(uuid, sources)
|
||||
return f"<关键指令>你必须以markdown格式的图片路径,,之后你也需要按照要求给出更多的markdown格式的数据表格</关键指令>"
|
||||
else:
|
||||
print("Failed to get response:", response.status_code)
|
||||
return "暂时无法画图"
|
||||
except requests.exceptions.RequestException as e:
|
||||
print("An error occurred:", e)
|
||||
return "暂时无法画图"
|
||||
|
||||
class drawRealisticInput(BaseModel):
|
||||
query: str = Field(...,description="输入要画图的内容")
|
||||
|
||||
def draw_ink_pic(query:str) -> str:
|
||||
get_llm_model_response()
|
||||
try:
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
uuid = json.loads(matches[1])["uuid"]
|
||||
query = query.replace(matches[1],"")
|
||||
else:
|
||||
return"暂时无法画图"
|
||||
except:
|
||||
logging.error("Invalid JSON format in query.")
|
||||
|
||||
# 请求体数据
|
||||
data = {
|
||||
'prompt': json.loads(query)["query"]
|
||||
}
|
||||
|
||||
try:
|
||||
# 发起 POST 请求
|
||||
response = requests.post(ink_url, json=data)
|
||||
|
||||
# 检查响应状态码
|
||||
if response.status_code == 200:
|
||||
# 解析 JSON 响应
|
||||
result = response.json()
|
||||
print("Image path:", result.get('image_path'))
|
||||
file_path = result.get('image_path')
|
||||
sources = utils.get_shared_variable(uuid)
|
||||
sources["url"] = file_path
|
||||
# sources["END"]="ok"
|
||||
utils.set_shared_variable(uuid, sources)
|
||||
return f"<关键指令>你必须以markdown格式的图片路径,,之后你也需要按照要求给出更多的markdown格式的数据表格</关键指令>"
|
||||
else:
|
||||
print("Failed to get response:", response.status_code)
|
||||
return "暂时无法画图"
|
||||
except requests.exceptions.RequestException as e:
|
||||
print("An error occurred:", e)
|
||||
return "暂时无法画图"
|
||||
|
||||
class drawInkInput(BaseModel):
|
||||
query: str = Field(...,description="输入要画图的内容")
|
||||
Reference in New Issue
Block a user