Files
gangyan/langchain-chat/server/chat/ZhipuSearchAPI.py

122 lines
4.3 KiB
Python
Raw Normal View History

from datetime import datetime
import logging
import requests
# import uuid
from configs.model_config import LLM_MODELS
from server.chat.policy_fun_iast import get_llm_model_response
api_key = "2b35424a76188ea96558f9631890ecd3.Wl9stbNr8TJ9L5PJ"
class ZhipuSearchAPIWrapper:
# def zhipu_search(self, origin_query):
# search_query = get_llm_model_response(
# strategy_name="zhipu_search_rewrite",
# llm_model_name=LLM_MODELS[0],
# template_prompt_name="zhipu_search_rewrite",
# prompt_param_dict={"input": origin_query, "year": datetime.now().strftime("%Y")},
# temperature=0.3,
# max_tokens=512
# )
# logging.info(f"Zhipu检索内容:{search_query}")
# msg = [
# {
# "role": "user",
# "content": search_query
# }
# ]
# tool = "web-search-pro"
# url = "https://open.bigmodel.cn/api/paas/v4/tools"
# request_id = str(uuid.uuid4())
# data = {
# "request_id": request_id,
# "tool": tool,
# "stream": False,
# "messages": msg
# }
# try:
# resp = requests.post(
# url,
# json=data,
# headers={'Authorization': api_key},
# timeout=300
# )
# resp.raise_for_status() # 检查请求是否成功
# resp_json = resp.json()
# except requests.exceptions.RequestException as e:
# print(f"请求错误: {e}")
# return []
# # 解析响应以提取 search_result
# search_results = []
# choices = resp_json.get('choices', [])
# for choice in choices:
# message = choice.get('message', {})
# tool_calls = message.get('tool_calls', [])
# for tool_call in tool_calls:
# if 'search_result' in tool_call:
# search_results.extend(tool_call['search_result'])
# return search_results[:7]
def zhipu_search(self, origin_query):
search_query = get_llm_model_response(
strategy_name="zhipu_search_rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="zhipu_search_rewrite",
prompt_param_dict={"input": origin_query, "year": datetime.now().strftime("%Y")},
temperature=0.3,
max_tokens=512
)
logging.info(f"Zhipu检索内容:{search_query}")
url = "http://ywk3hvt4d:01Jp2V1tR9PdTsYSz919779Rb9_@134.122.191.214/search"
if "天气" in search_query:
engines = "google"
else:
engines = "baidu"
data = {
"format":"json",
"q":search_query,
"engines":engines,
"limit":10
}
try:
# resp = requests.post(
# url,
# json=data,
# timeout=5000
# )
resp = requests.get(
url,
params=data,
timeout=5000
)
resp.raise_for_status() # 检查请求是否成功
resp_json = resp.json()
except requests.exceptions.RequestException as e:
print(f"请求错误: {e}")
return []
# 解析响应以提取 search_result
search_results = []
choices = resp_json.get('results', [])
for choice in choices:
if 'publishedDate' in choice:
choice["content"] +=f"发布于{choice['publishedDate']}"
search_results.append(choice)
else:
search_results.append(choice)
# search_results.extend(choices)
return search_results[:7]
if __name__ == "__main__":
zhipu_search = ZhipuSearchAPIWrapper()
query = "粉末冶金产业技术创新战略联盟 粉末冶金领域关键词"
results = zhipu_search.zhipu_search(query)
if results:
print(results)
print(f"成功获取到 {len(results)} 个搜索结果")
else:
logging.info("没有找到任何搜索结果")