Files
gangyan/langchain-chat/server/agent/tools/duckduckgo_search.py

187 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import re
import aiohttp
import json
import logging
from pydantic import BaseModel, Field
from server.chat import utils
# 配置日志记录器
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)
async def duckduckgo_search_iter(query: str, uuid: str = "",time: str = "", resource_type: str = None, limit: int = 3):
# 定义三个API的URL
text_url = 'http://43.251.225.121/inspur/search_text'
video_url = 'http://43.251.225.121/inspur/search_video'
news_url = 'http://43.251.225.121/inspur/search_new'
payload = {
"query": query,
"time": time
}
async def fetch(session, url, json_payload,limit):
logger.info(f"{url} 获取数据,请求参数: {json_payload}")
try:
json_payload["limit"] = limit
async with session.post(url, json=json_payload) as response:
if response.status != 200:
logger.error(f"{url} 请求失败,状态码 {response.status}")
data = await response.json()
logger.info(f"{url} 获取的资料数: {len(data) if isinstance(data, list) else '未知'}")
return data
except Exception as e:
logger.error(f"获取 {url} 数据时发生错误: {e}")
return []
# 根据 resource_type 确定要请求的 API
# 默认并发请求三个API
# 视频只请求 video_url
# 新闻只请求 news_url
# 其他类型只请求 text_url
async with aiohttp.ClientSession() as session:
logger.info("发起请求duckduckgo...")
n = limit % 3
limit1 = 0
limit2 = 0
limit3 = 0
match n:
case 0:
limit1 = limit//3
limit2 = limit1
limit3 = limit1
case 1:
limit1 = limit//3 +1
limit2 = limit//3
limit3 = limit2
case 2:
limit1 = limit//3 +1
limit2 = limit1
limit2 = limit
if resource_type is None or not resource_type == 'video':
text_task = asyncio.create_task(fetch(session, text_url, payload,limit1))
video_task = asyncio.create_task(fetch(session, video_url, payload, limit3))
news_task = asyncio.create_task(fetch(session, news_url, payload, limit2))
text_result, video_result, news_result = await asyncio.gather(text_task, video_task, news_task)
logger.info("合并结果...")
logger.info("合并结果完成")
combined_result = {
"text": text_result,
"video": video_result,
"news": news_result
}
else:
video_result = await fetch(session, video_url, payload, limit)
combined_result = {
"video": video_result
}
del limit1,limit2,limit3
# elif resource_type == 'news':
# news_result = await fetch(session, news_url, payload)
# combined_result = {
# "news": news_result
# }
# else: # 其他类型
# text_result = await fetch(session, text_url, payload)
# combined_result = {
# "text": text_result
# }
logger.info("请求已完成")
res = []
source = []
info = utils.get_shared_variable(uuid)
index = info["num"]
if "text" in combined_result:
for item in combined_result["text"]:
index += 1
res.append(f'资料[{index}] 资料标题{item["title"]}({item["href"]}) 资料内容为: {item["body"]}')
source.append(f'资料[{index}] [{item["title"]}]({item["href"]})')
if "video" in combined_result:
for item in combined_result["video"]:
index += 1
res.append(f'资料[{index}] 视频标题[{item["title"]}]({item["content"]}) 视频内容为: {item["description"]}')
source.append(f'视频资料[{index}] [{item["title"]}]({item["content"]})')
if "news" in combined_result:
for item in combined_result["news"]:
index += 1
res.append(f'资料[{index}] 新闻标题[{item["title"]}]({item["url"]}) 新闻内容为: {item["body"]}')
source.append(f'资料[{index}] [{item["title"]}]({item["url"]})')
info["source_docs"].extend(source)
utils.set_shared_variable(uuid, info)
return res,source
def duckduckgo_search(query: str, time: str = "", resource_type: str = None):
logger.info(f"模型输入: {query}")
# 对传入的 query 字段进行解析
# 判断 query 是否包含 "}{"
# if "}{" in query:
# # 将 query 分割为两个JSON字符串
# split_index = query.find("}{")
# json_part1 = query[:split_index+1]
# json_part2 = query[split_index+1:]
# try:
# obj1 = json.loads(json_part1)
# obj2 = json.loads(json_part2)
# # 提取 query, resource_type, time, uuid
# parsed_query = obj1.get("query", "")
# parsed_resource_type = obj1.get("resource_type", None)
# parsed_time = obj1.get("time", time) # 如obj1未包含time则使用传入的默认值
# parsed_uuid = obj2.get("uuid", "")
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
else:
return "<关键指令>不需要再调用该工具了</关键指令>"
try:
obj1= json.loads(query)
parsed_query = obj1.get("query", "")
parsed_limit = obj1.get("limit", 3)
parsed_resource_type = obj1.get("resource_type", None)
parsed_time = obj1.get("time", time) # 如obj1未包含time则使用传入的默认值
parsed_uuid = json.loads(matches[1])["uuid"]
# 将解析到的值覆盖原有的参数
query = parsed_query if parsed_query else query
resource_type = parsed_resource_type if parsed_resource_type else resource_type
time = parsed_time if parsed_time else time
logger.info(f"解析完成query: {query}, uuid: {parsed_uuid}, time: {time}, resource_type: {resource_type}, parsed_limit: {parsed_limit}")
except json.JSONDecodeError as e:
logger.error(f"解析JSON出错: {e}")
# 在同步环境中运行异步函数
combined_result = asyncio.run(duckduckgo_search_iter(query, parsed_uuid, time, resource_type, parsed_limit))
# 以标准json格式输出
logger.info("返回JSON格式的结果给到模型...")
return combined_result
class DuckduckgoInput(BaseModel):
location: str = Field(description="网络搜索查询")
if __name__ == "__main__":
# 测试调用
# 1. 默认请求三个API
# result_default = duckduckgo_search("粉末冶金", "m", "default")
# print("duckduckgo输出(默认):\n", result_default)
# # 2. 只请求视频
# result_video = duckduckgo_search("粉末冶金", "m", "video")
# print("duckduckgo输出(视频):\n", result_video)
# # 3. 只请求新闻
# result_news = duckduckgo_search("粉末冶金", "m", "news")
# print("duckduckgo输出(新闻):\n", result_news)
# 4. 其它类型只请求文本
result_other = duckduckgo_search("粉末冶金", "m", "other")
print("duckduckgo输出(其他):\n", result_other)