Files
gangyan/langchain-chat/server/agent/tools/get_statistical_data.py

58 lines
2.2 KiB
Python

from datetime import datetime
import json
import logging
import re
import requests
from configs.model_config import LLM_MODELS
from server.chat import utils
from server.chat.policy_fun_iast import get_llm_model_response
def mysql_statistic(query:str) -> str:
try:
logging.info(f"\n🔍 统计工具查询query: \n{query}\n")
matches = re.search(r'<param>(.*?)</param>(\{.*\})', query.replace("\n", ""))
if matches:
uuid = json.loads(matches.group(2))["uuid"]
query = matches.group(1).strip()
else:
res1 = utils.get_shared_variable(uuid)
# res1["END"]="ok"
utils.set_shared_variable(uuid,res1)
return"暂时无法查询"
except:
res = utils.get_shared_variable(uuid)
# res["END"]="ok"
utils.set_shared_variable(uuid,res)
logging.error("Invalid JSON format in query.")
return f"暂时无法查询"
question = json.loads(query)["query"]
# question = get_llm_model_response(
# strategy_name="default_code",
# llm_model_name=LLM_MODELS[2],
# template_prompt_name="sql_query_rewrite",
# prompt_param_dict={"query": question,"time": datetime.now().strftime("%Y%m%d")},
# temperature=0.01,
# max_tokens=512
# )
logging.info(f"\n🔍 NL2SQL检索question: \n{question}\n")
res = requests.post(
url=f"http://127.0.0.1:6008/query",
json={"question": question},
headers={"Content-Type": "application/json"}
)
if res:
data = res.json()["result"]
if "'data': []" in data:
return f"统计库未检索到数据,使用“联网思索”工具检索该请求:{question}\n"
# temp = utils.get_shared_variable(uuid)
# temp["END"]="ok"
# utils.set_shared_variable(uuid,temp)
else:
return f"判断得到的数据是否准确,如果不准确,则使用“联网思索”工具检索。如果准确,则根据数据表格并使用图表绘制工具制图。\n 数据如下所示: \n{data}\n"