import json import logging import os import re import uuid import requests from matplotlib import pyplot as plt from pydantic import BaseModel, Field from configs.model_config import LLM_MODELS from server.chat import utils from server.chat.policy_fun_iast import get_llm_model_response from configs.kb_config import GENERATED_IMAGES_BASE_PATH, realistic_url,ink_url from matplotlib import font_manager my_font = font_manager.FontProperties(fname="/usr/share/fonts/MicroSoft-YaHei/MSYH.TTC") def create_and_save_plot(query:str) -> str: try: query = query.replace(" ","").replace("'","\"") json_str ='{\n"data": {"XXX": XX, "XXX": XX, "XXX": X, "XXX": X},"title": "X","xlabel": "X","ylabel": "X","plot_type": "X"}' datas = {} try: match = re.search(r'(.*?)(\{.*\})', query.replace("\n", "")) if match: query = match.group(1).strip() datas = json.loads(query) else: print(f"Invalid JSON format in query:\n{query}") return"暂时无法画图" except: query = get_llm_model_response( strategy_name="query rewrite", llm_model_name=LLM_MODELS[0], template_prompt_name="check_plot", prompt_param_dict={"user_input": query,"json":json_str }, temperature=0.01, max_tokens=512 ) re.search(r"```json\n(.*?)\n```", query.replace("\n", ""), re.DOTALL) query = match.group(1).strip() datas = json.loads(query) data = datas["data"] xlabel = datas["xlabel"] ylabel = datas["ylabel"] title = datas["title"] plot_type = datas["plot_type"] # 分析和汇总数据 categories = list(data.keys()) values = list(data.values()) # 创建图表 plt.figure(figsize=(10, 6)) if plot_type == 'bar': plt.bar(categories, values, color='skyblue') elif plot_type == 'pie': plt.pie(values, labels=categories, autopct='%1.1f%%', startangle=140,textprops={'fontproperties': my_font}) elif plot_type == 'line': plt.plot(categories, values, marker='o', linestyle='-') else: raise ValueError("Unsupported plot type. Choose from 'bar', 'pie', or 'line'.") # 添加标题和标签 plt.title(title,fontproperties=my_font) if plot_type != 'pie': # 饼状图不需要轴标签 plt.xlabel(xlabel,fontproperties=my_font) plt.ylabel(ylabel,fontproperties=my_font) plt.xticks(fontproperties=my_font,rotation=45) plt.yticks(fontproperties=my_font) namesid = uuid.uuid1() # 保存图表为图片文件 file_path = f'{GENERATED_IMAGES_BASE_PATH}/plot{namesid}.png' absolute_path = os.path.abspath(file_path) # sources = utils.get_shared_variable(uuids) # sources["url"] = f"plot{namesid}.png" # sources["END"]="ok" # utils.set_shared_variable(uuids, sources) # plt.figure(figsize=(50, 60)) # plt.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.15) plt.tight_layout() plt.savefig(absolute_path) plt.close() image_name = f"plot{namesid}.png" return f"图片如下:![图片{title}](http://127.0.0.1:8099/chat_web_backend/get-image?file_name={image_name})" except Exception as e: logging.error(f"An error occurred: {e}") return "不要再调用该工具了,无法提供该功能,你只能按工具描述提供有的入参,其他不支持" class drawPlotInput(BaseModel): query: str = Field(...,description="输入要画图的内容") def draw_realistic_pic(query:str) -> str: try: matches = re.findall(r'\{.*?\}', query) if len(matches)>=2: uuid = json.loads(matches[1])["uuid"] query = query.replace(matches[1],"") else: return"暂时无法画图" except: logging.error("Invalid JSON format in query.") # 请求体数据 data = { 'prompt': json.loads(query)["query"] } try: # 发起 POST 请求 response = requests.post(realistic_url, json=data) # 检查响应状态码 if response.status_code == 200: # 解析 JSON 响应 result = response.json() print("Image path:", result.get('image_path')) file_path = result.get('image_path') sources = utils.get_shared_variable(uuid) sources["url"] = file_path # sources["END"]="ok" utils.set_shared_variable(uuid, sources) return f"<关键指令>你必须以markdown格式的图片路径,![图片](http://127.0.0.1:8099/chat_web_backend/get-image?file_name={file_path}),之后你也需要按照要求给出更多的markdown格式的数据表格" else: print("Failed to get response:", response.status_code) return "暂时无法画图" except requests.exceptions.RequestException as e: print("An error occurred:", e) return "暂时无法画图" class drawRealisticInput(BaseModel): query: str = Field(...,description="输入要画图的内容") def draw_ink_pic(query:str) -> str: get_llm_model_response() try: matches = re.findall(r'\{.*?\}', query) if len(matches)>=2: uuid = json.loads(matches[1])["uuid"] query = query.replace(matches[1],"") else: return"暂时无法画图" except: logging.error("Invalid JSON format in query.") # 请求体数据 data = { 'prompt': json.loads(query)["query"] } try: # 发起 POST 请求 response = requests.post(ink_url, json=data) # 检查响应状态码 if response.status_code == 200: # 解析 JSON 响应 result = response.json() print("Image path:", result.get('image_path')) file_path = result.get('image_path') sources = utils.get_shared_variable(uuid) sources["url"] = file_path # sources["END"]="ok" utils.set_shared_variable(uuid, sources) return f"<关键指令>你必须以markdown格式的图片路径,![图片](http://127.0.0.1:8099/chat_web_backend/get-image?file_name={file_path}),之后你也需要按照要求给出更多的markdown格式的数据表格" else: print("Failed to get response:", response.status_code) return "暂时无法画图" except requests.exceptions.RequestException as e: print("An error occurred:", e) return "暂时无法画图" class drawInkInput(BaseModel): query: str = Field(...,description="输入要画图的内容")