187 lines
7.4 KiB
Python
187 lines
7.4 KiB
Python
import asyncio
|
||
import re
|
||
import aiohttp
|
||
import json
|
||
import logging
|
||
from pydantic import BaseModel, Field
|
||
|
||
from server.chat import utils
|
||
|
||
# 配置日志记录器
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
async def duckduckgo_search_iter(query: str, uuid: str = "",time: str = "", resource_type: str = None, limit: int = 3):
|
||
# 定义三个API的URL
|
||
text_url = 'http://43.251.225.121/inspur/search_text'
|
||
video_url = 'http://43.251.225.121/inspur/search_video'
|
||
news_url = 'http://43.251.225.121/inspur/search_new'
|
||
|
||
payload = {
|
||
"query": query,
|
||
"time": time
|
||
}
|
||
|
||
async def fetch(session, url, json_payload,limit):
|
||
logger.info(f"从 {url} 获取数据,请求参数: {json_payload}")
|
||
try:
|
||
json_payload["limit"] = limit
|
||
async with session.post(url, json=json_payload) as response:
|
||
if response.status != 200:
|
||
logger.error(f"向 {url} 请求失败,状态码 {response.status}")
|
||
data = await response.json()
|
||
logger.info(f"从 {url} 获取的资料数: {len(data) if isinstance(data, list) else '未知'}")
|
||
return data
|
||
except Exception as e:
|
||
logger.error(f"获取 {url} 数据时发生错误: {e}")
|
||
return []
|
||
|
||
# 根据 resource_type 确定要请求的 API
|
||
# 默认并发请求三个API
|
||
# 视频只请求 video_url
|
||
# 新闻只请求 news_url
|
||
# 其他类型只请求 text_url
|
||
async with aiohttp.ClientSession() as session:
|
||
logger.info("发起请求duckduckgo...")
|
||
|
||
n = limit % 3
|
||
limit1 = 0
|
||
limit2 = 0
|
||
limit3 = 0
|
||
match n:
|
||
case 0:
|
||
limit1 = limit//3
|
||
limit2 = limit1
|
||
limit3 = limit1
|
||
case 1:
|
||
limit1 = limit//3 +1
|
||
limit2 = limit//3
|
||
limit3 = limit2
|
||
case 2:
|
||
limit1 = limit//3 +1
|
||
limit2 = limit1
|
||
limit2 = limit
|
||
|
||
if resource_type is None or not resource_type == 'video':
|
||
text_task = asyncio.create_task(fetch(session, text_url, payload,limit1))
|
||
video_task = asyncio.create_task(fetch(session, video_url, payload, limit3))
|
||
news_task = asyncio.create_task(fetch(session, news_url, payload, limit2))
|
||
text_result, video_result, news_result = await asyncio.gather(text_task, video_task, news_task)
|
||
logger.info("合并结果...")
|
||
|
||
logger.info("合并结果完成")
|
||
combined_result = {
|
||
"text": text_result,
|
||
"video": video_result,
|
||
"news": news_result
|
||
}
|
||
|
||
else:
|
||
video_result = await fetch(session, video_url, payload, limit)
|
||
combined_result = {
|
||
"video": video_result
|
||
}
|
||
del limit1,limit2,limit3
|
||
# elif resource_type == 'news':
|
||
# news_result = await fetch(session, news_url, payload)
|
||
# combined_result = {
|
||
# "news": news_result
|
||
# }
|
||
|
||
# else: # 其他类型
|
||
# text_result = await fetch(session, text_url, payload)
|
||
# combined_result = {
|
||
# "text": text_result
|
||
# }
|
||
|
||
logger.info("请求已完成")
|
||
res = []
|
||
source = []
|
||
info = utils.get_shared_variable(uuid)
|
||
index = info["num"]
|
||
if "text" in combined_result:
|
||
for item in combined_result["text"]:
|
||
index += 1
|
||
res.append(f'资料[{index}] 资料标题{item["title"]}({item["href"]}) 资料内容为: {item["body"]}')
|
||
source.append(f'资料[{index}] [{item["title"]}]({item["href"]})')
|
||
if "video" in combined_result:
|
||
for item in combined_result["video"]:
|
||
index += 1
|
||
res.append(f'资料[{index}] 视频标题[{item["title"]}]({item["content"]}) 视频内容为: {item["description"]}')
|
||
source.append(f'视频资料[{index}] [{item["title"]}]({item["content"]})')
|
||
if "news" in combined_result:
|
||
for item in combined_result["news"]:
|
||
index += 1
|
||
res.append(f'资料[{index}] 新闻标题[{item["title"]}]({item["url"]}) 新闻内容为: {item["body"]}')
|
||
source.append(f'资料[{index}] [{item["title"]}]({item["url"]})')
|
||
info["source_docs"].extend(source)
|
||
utils.set_shared_variable(uuid, info)
|
||
return res,source
|
||
|
||
|
||
def duckduckgo_search(query: str, time: str = "", resource_type: str = None):
|
||
logger.info(f"模型输入: {query}")
|
||
# 对传入的 query 字段进行解析
|
||
# 判断 query 是否包含 "}{"
|
||
# if "}{" in query:
|
||
# # 将 query 分割为两个JSON字符串
|
||
# split_index = query.find("}{")
|
||
# json_part1 = query[:split_index+1]
|
||
# json_part2 = query[split_index+1:]
|
||
|
||
# try:
|
||
# obj1 = json.loads(json_part1)
|
||
# obj2 = json.loads(json_part2)
|
||
|
||
# # 提取 query, resource_type, time, uuid
|
||
# parsed_query = obj1.get("query", "")
|
||
# parsed_resource_type = obj1.get("resource_type", None)
|
||
# parsed_time = obj1.get("time", time) # 如obj1未包含time则使用传入的默认值
|
||
# parsed_uuid = obj2.get("uuid", "")
|
||
matches = re.findall(r'\{.*?\}', query)
|
||
if len(matches)>=2:
|
||
query = matches[0]
|
||
else:
|
||
return "<关键指令>不需要再调用该工具了</关键指令>"
|
||
try:
|
||
obj1= json.loads(query)
|
||
parsed_query = obj1.get("query", "")
|
||
parsed_limit = obj1.get("limit", 3)
|
||
parsed_resource_type = obj1.get("resource_type", None)
|
||
parsed_time = obj1.get("time", time) # 如obj1未包含time则使用传入的默认值
|
||
parsed_uuid = json.loads(matches[1])["uuid"]
|
||
# 将解析到的值覆盖原有的参数
|
||
query = parsed_query if parsed_query else query
|
||
resource_type = parsed_resource_type if parsed_resource_type else resource_type
|
||
time = parsed_time if parsed_time else time
|
||
|
||
logger.info(f"解析完成,query: {query}, uuid: {parsed_uuid}, time: {time}, resource_type: {resource_type}, parsed_limit: {parsed_limit}")
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"解析JSON出错: {e}")
|
||
|
||
# 在同步环境中运行异步函数
|
||
combined_result = asyncio.run(duckduckgo_search_iter(query, parsed_uuid, time, resource_type, parsed_limit))
|
||
# 以标准json格式输出
|
||
logger.info("返回JSON格式的结果给到模型...")
|
||
return combined_result
|
||
class DuckduckgoInput(BaseModel):
|
||
location: str = Field(description="网络搜索查询")
|
||
|
||
if __name__ == "__main__":
|
||
# 测试调用
|
||
# 1. 默认请求三个API
|
||
# result_default = duckduckgo_search("粉末冶金", "m", "default")
|
||
# print("duckduckgo输出(默认):\n", result_default)
|
||
|
||
# # 2. 只请求视频
|
||
# result_video = duckduckgo_search("粉末冶金", "m", "video")
|
||
# print("duckduckgo输出(视频):\n", result_video)
|
||
|
||
# # 3. 只请求新闻
|
||
# result_news = duckduckgo_search("粉末冶金", "m", "news")
|
||
# print("duckduckgo输出(新闻):\n", result_news)
|
||
|
||
# 4. 其它类型只请求文本
|
||
result_other = duckduckgo_search("粉末冶金", "m", "other")
|
||
print("duckduckgo输出(其他):\n", result_other)
|