[全量] 初始化项目代码、配置、文档及Agent协同harness

This commit is contained in:
2026-04-02 11:36:05 +08:00
parent 0553309cdf
commit 87e571d9ec
1133 changed files with 221948 additions and 0 deletions

View File

@@ -0,0 +1,16 @@
## 导入所有的工具类
# from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput
# from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput
# from .chat_with_Yi34B import chat_with_Yi34B, ChatWithYi34BInput
# from .search_youtube import search_youtube, YoutubeInput
from .calculate import calculate, CalculatorInput
from .weather_check import weathercheck, WeatherInput
from .shell import shell, ShellInput
from .search_internet import search_internet, SearchInternetInput
from .wolfram import wolfram, WolframInput
from .arxiv import arxiv, ArxivInput
from .knowledgebase_kgo_search import knowledgebase_kgo_search, KnowledgeKgoInput
from .policy_knowledgebase_search import policy_knowledgebase_search, PolicyKnowledgeInput
from .report_knowledgebase_search import report_knowledgebase_search, ReportKnowledgeInput
from .rag_search import rag_search1, RagSearchInput
from .duckduckgo_search import duckduckgo_search, DuckduckgoInput

View File

@@ -0,0 +1,9 @@
# LangChain 的 ArxivQueryRun 工具
from pydantic import BaseModel, Field
from langchain.tools.arxiv.tool import ArxivQueryRun
def arxiv(query: str):
tool = ArxivQueryRun()
return tool.run(tool_input=query)
class ArxivInput(BaseModel):
query: str = Field(description="The search query title")

View File

@@ -0,0 +1,10 @@
name: arxiv
description: A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.
parameters:
type: object
properties:
query:
type: string
description: The search query title
required:
- query

View File

@@ -0,0 +1,76 @@
from langchain.prompts import PromptTemplate
from langchain.chains import LLMMathChain
from server.agent import model_container
from pydantic import BaseModel, Field
_PROMPT_TEMPLATE = """
将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。
问题: ${{包含数学问题的问题。}}
```text
${{解决问题的单行数学表达式}}
```
...numexpr.evaluate(query)...
```output
${{运行代码的输出}}
```
答案: ${{答案}}
这是两个例子:
问题: 37593 * 67是多少
```text
37593 * 67
```
...numexpr.evaluate("37593 * 67")...
```output
2518731
答案: 2518731
问题: 37593的五次方根是多少
```text
37593**(1/5)
```
...numexpr.evaluate("37593**(1/5)")...
```output
8.222831614237718
答案: 8.222831614237718
问题: 2的平方是多少
```text
2 ** 2
```
...numexpr.evaluate("2 ** 2")...
```output
4
答案: 4
现在,这是我的问题:
问题: {question}
"""
PROMPT = PromptTemplate(
input_variables=["question"],
template=_PROMPT_TEMPLATE,
)
class CalculatorInput(BaseModel):
query: str = Field()
def calculate(query: str):
model = model_container.MODEL
llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_math.run(query)
return ans
if __name__ == "__main__":
result = calculate("2的三次方")
print("答案:",result)

View File

@@ -0,0 +1,10 @@
name: calculate
description: Useful for when you need to answer questions about simple calculations
parameters:
type: object
properties:
query:
type: string
description: The formula to be calculated
required:
- query

View File

@@ -0,0 +1,43 @@
import asyncio
import json
from typing import List, Union
from pydantic import BaseModel, Field
from server.chat.chat import chat
from server.chat.utils import History
async def chat_with_Yi34B_iter(query: str,
stream=False,
model_name="qianfan-api",
history: Union[int, List[History]] = None,
conversation_id='',
temperature=0.7,
max_tokens=None,
history_len=3,
prompt_name="default"
):
response = await chat(query=query, history=history,
history_len=history_len,
conversation_id=conversation_id,
stream=stream, model_name=model_name, temperature=temperature,
max_tokens=max_tokens, prompt_name=prompt_name)
contents = ""
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
contents += data["text"]
return contents
def chat_with_Yi34B(query: str, model_name: str = "qianfan-api", conversation_id: str = '',
history: Union[int, List[History]] = None):
# 格式化查询字符串
return asyncio.run(chat_with_Yi34B_iter(query, model_name=model_name, conversation_id=conversation_id,
history=history))
class ChatWithYi34BInput(BaseModel):
location: str = Field(description="Query for any kind of chats and questions")

View File

@@ -0,0 +1,18 @@
name: chat_with_Yi34B
description: Use this tool to chat with human
parameters:
type: object
properties:
query:
type: string
description: Query for any kind of chat and questions
model_name:
type: string
description:
conversation_id:
type: string
description:
required:
- query
- model_name
- conversation_id

View File

@@ -0,0 +1,34 @@
import json
import re
import concurrent
from fastapi.concurrency import run_in_threadpool
from langchain.tools import YouTubeSearchTool
from pydantic import BaseModel, Field
from server.chat import utils
from server.chat.utils import doc_to_list, solve_knowledge_map,solve_mental_data,shared_variable
from server.knowledge_base.kb_doc_api import search_docs
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from configs import kb_config
def do_nothing(query: str):
"""
什么都不做不调用工具设置
"""
# 以下三行逻辑控制反问如果不需要反问注释掉即可但可能会带来的问题包括该agent将始终会使用工具
return f"\n不需要调用工具了"
def get_next_tip(query: str):
"""
什么都不做不调用工具设置
"""
# 以下三行逻辑控制反问如果不需要反问注释掉即可但可能会带来的问题包括该agent将始终会使用工具
res = utils.get_shared_variable(query)
res["END"] = "ok"
utils.set_shared_variable(query,res)
return f"\n提示:你已经使用过环节跳转了,可以开始输出正文了"
class doNothingInput(BaseModel):
query: str = Field(...,description="查询对象")

View File

@@ -0,0 +1,175 @@
import json
import logging
import os
import re
import uuid
import requests
from matplotlib import pyplot as plt
from pydantic import BaseModel, Field
from configs.model_config import LLM_MODELS
from server.chat import utils
from server.chat.policy_fun_iast import get_llm_model_response
from configs.kb_config import GENERATED_IMAGES_BASE_PATH, realistic_url,ink_url
from matplotlib import font_manager
my_font = font_manager.FontProperties(fname="/usr/share/fonts/MicroSoft-YaHei/MSYH.TTC")
def create_and_save_plot(query:str) -> str:
try:
query = query.replace(" ","").replace("'","\"")
json_str ='{\n"data": {"XXX": XX, "XXX": XX, "XXX": X, "XXX": X},"title": "X","xlabel": "X","ylabel": "X","plot_type": "X"}'
datas = {}
try:
match = re.search(r'<param>(.*?)</param>(\{.*\})', query.replace("\n", ""))
if match:
query = match.group(1).strip()
datas = json.loads(query)
else:
print(f"Invalid JSON format in query:\n{query}")
return"暂时无法画图"
except:
query = get_llm_model_response(
strategy_name="query rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="check_plot",
prompt_param_dict={"user_input": query,"json":json_str },
temperature=0.01,
max_tokens=512
)
re.search(r"```json\n(.*?)\n```", query.replace("\n", ""), re.DOTALL)
query = match.group(1).strip()
datas = json.loads(query)
data = datas["data"]
xlabel = datas["xlabel"]
ylabel = datas["ylabel"]
title = datas["title"]
plot_type = datas["plot_type"]
# 分析和汇总数据
categories = list(data.keys())
values = list(data.values())
# 创建图表
plt.figure(figsize=(10, 6))
if plot_type == 'bar':
plt.bar(categories, values, color='skyblue')
elif plot_type == 'pie':
plt.pie(values, labels=categories, autopct='%1.1f%%', startangle=140,textprops={'fontproperties': my_font})
elif plot_type == 'line':
plt.plot(categories, values, marker='o', linestyle='-')
else:
raise ValueError("Unsupported plot type. Choose from 'bar', 'pie', or 'line'.")
# 添加标题和标签
plt.title(title,fontproperties=my_font)
if plot_type != 'pie': # 饼状图不需要轴标签
plt.xlabel(xlabel,fontproperties=my_font)
plt.ylabel(ylabel,fontproperties=my_font)
plt.xticks(fontproperties=my_font,rotation=45)
plt.yticks(fontproperties=my_font)
namesid = uuid.uuid1()
# 保存图表为图片文件
file_path = f'{GENERATED_IMAGES_BASE_PATH}/plot{namesid}.png'
absolute_path = os.path.abspath(file_path)
# sources = utils.get_shared_variable(uuids)
# sources["url"] = f"plot{namesid}.png"
# sources["END"]="ok"
# utils.set_shared_variable(uuids, sources)
# plt.figure(figsize=(50, 60))
# plt.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.15)
plt.tight_layout()
plt.savefig(absolute_path)
plt.close()
image_name = f"plot{namesid}.png"
return f"图片如下:![图片{title}](http://127.0.0.1:8099/chat_web_backend/get-image?file_name={image_name})"
except Exception as e:
logging.error(f"An error occurred: {e}")
return "不要再调用该工具了,无法提供该功能,你只能按工具描述提供有的入参,其他不支持"
class drawPlotInput(BaseModel):
query: str = Field(...,description="输入要画图的内容")
def draw_realistic_pic(query:str) -> str:
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
uuid = json.loads(matches[1])["uuid"]
query = query.replace(matches[1],"")
else:
return"暂时无法画图"
except:
logging.error("Invalid JSON format in query.")
# 请求体数据
data = {
'prompt': json.loads(query)["query"]
}
try:
# 发起 POST 请求
response = requests.post(realistic_url, json=data)
# 检查响应状态码
if response.status_code == 200:
# 解析 JSON 响应
result = response.json()
print("Image path:", result.get('image_path'))
file_path = result.get('image_path')
sources = utils.get_shared_variable(uuid)
sources["url"] = file_path
# sources["END"]="ok"
utils.set_shared_variable(uuid, sources)
return f"<关键指令>你必须以markdown格式的图片路径![图片](http://127.0.0.1:8099/chat_web_backend/get-image?file_name={file_path}),之后你也需要按照要求给出更多的markdown格式的数据表格</关键指令>"
else:
print("Failed to get response:", response.status_code)
return "暂时无法画图"
except requests.exceptions.RequestException as e:
print("An error occurred:", e)
return "暂时无法画图"
class drawRealisticInput(BaseModel):
query: str = Field(...,description="输入要画图的内容")
def draw_ink_pic(query:str) -> str:
get_llm_model_response()
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
uuid = json.loads(matches[1])["uuid"]
query = query.replace(matches[1],"")
else:
return"暂时无法画图"
except:
logging.error("Invalid JSON format in query.")
# 请求体数据
data = {
'prompt': json.loads(query)["query"]
}
try:
# 发起 POST 请求
response = requests.post(ink_url, json=data)
# 检查响应状态码
if response.status_code == 200:
# 解析 JSON 响应
result = response.json()
print("Image path:", result.get('image_path'))
file_path = result.get('image_path')
sources = utils.get_shared_variable(uuid)
sources["url"] = file_path
# sources["END"]="ok"
utils.set_shared_variable(uuid, sources)
return f"<关键指令>你必须以markdown格式的图片路径![图片](http://127.0.0.1:8099/chat_web_backend/get-image?file_name={file_path}),之后你也需要按照要求给出更多的markdown格式的数据表格</关键指令>"
else:
print("Failed to get response:", response.status_code)
return "暂时无法画图"
except requests.exceptions.RequestException as e:
print("An error occurred:", e)
return "暂时无法画图"
class drawInkInput(BaseModel):
query: str = Field(...,description="输入要画图的内容")

View File

@@ -0,0 +1,186 @@
import asyncio
import re
import aiohttp
import json
import logging
from pydantic import BaseModel, Field
from server.chat import utils
# 配置日志记录器
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)
async def duckduckgo_search_iter(query: str, uuid: str = "",time: str = "", resource_type: str = None, limit: int = 3):
# 定义三个API的URL
text_url = 'http://43.251.225.121/inspur/search_text'
video_url = 'http://43.251.225.121/inspur/search_video'
news_url = 'http://43.251.225.121/inspur/search_new'
payload = {
"query": query,
"time": time
}
async def fetch(session, url, json_payload,limit):
logger.info(f"{url} 获取数据,请求参数: {json_payload}")
try:
json_payload["limit"] = limit
async with session.post(url, json=json_payload) as response:
if response.status != 200:
logger.error(f"{url} 请求失败,状态码 {response.status}")
data = await response.json()
logger.info(f"{url} 获取的资料数: {len(data) if isinstance(data, list) else '未知'}")
return data
except Exception as e:
logger.error(f"获取 {url} 数据时发生错误: {e}")
return []
# 根据 resource_type 确定要请求的 API
# 默认并发请求三个API
# 视频只请求 video_url
# 新闻只请求 news_url
# 其他类型只请求 text_url
async with aiohttp.ClientSession() as session:
logger.info("发起请求duckduckgo...")
n = limit % 3
limit1 = 0
limit2 = 0
limit3 = 0
match n:
case 0:
limit1 = limit//3
limit2 = limit1
limit3 = limit1
case 1:
limit1 = limit//3 +1
limit2 = limit//3
limit3 = limit2
case 2:
limit1 = limit//3 +1
limit2 = limit1
limit2 = limit
if resource_type is None or not resource_type == 'video':
text_task = asyncio.create_task(fetch(session, text_url, payload,limit1))
video_task = asyncio.create_task(fetch(session, video_url, payload, limit3))
news_task = asyncio.create_task(fetch(session, news_url, payload, limit2))
text_result, video_result, news_result = await asyncio.gather(text_task, video_task, news_task)
logger.info("合并结果...")
logger.info("合并结果完成")
combined_result = {
"text": text_result,
"video": video_result,
"news": news_result
}
else:
video_result = await fetch(session, video_url, payload, limit)
combined_result = {
"video": video_result
}
del limit1,limit2,limit3
# elif resource_type == 'news':
# news_result = await fetch(session, news_url, payload)
# combined_result = {
# "news": news_result
# }
# else: # 其他类型
# text_result = await fetch(session, text_url, payload)
# combined_result = {
# "text": text_result
# }
logger.info("请求已完成")
res = []
source = []
info = utils.get_shared_variable(uuid)
index = info["num"]
if "text" in combined_result:
for item in combined_result["text"]:
index += 1
res.append(f'资料[{index}] 资料标题{item["title"]}({item["href"]}) 资料内容为: {item["body"]}')
source.append(f'资料[{index}] [{item["title"]}]({item["href"]})')
if "video" in combined_result:
for item in combined_result["video"]:
index += 1
res.append(f'资料[{index}] 视频标题[{item["title"]}]({item["content"]}) 视频内容为: {item["description"]}')
source.append(f'视频资料[{index}] [{item["title"]}]({item["content"]})')
if "news" in combined_result:
for item in combined_result["news"]:
index += 1
res.append(f'资料[{index}] 新闻标题[{item["title"]}]({item["url"]}) 新闻内容为: {item["body"]}')
source.append(f'资料[{index}] [{item["title"]}]({item["url"]})')
info["source_docs"].extend(source)
utils.set_shared_variable(uuid, info)
return res,source
def duckduckgo_search(query: str, time: str = "", resource_type: str = None):
logger.info(f"模型输入: {query}")
# 对传入的 query 字段进行解析
# 判断 query 是否包含 "}{"
# if "}{" in query:
# # 将 query 分割为两个JSON字符串
# split_index = query.find("}{")
# json_part1 = query[:split_index+1]
# json_part2 = query[split_index+1:]
# try:
# obj1 = json.loads(json_part1)
# obj2 = json.loads(json_part2)
# # 提取 query, resource_type, time, uuid
# parsed_query = obj1.get("query", "")
# parsed_resource_type = obj1.get("resource_type", None)
# parsed_time = obj1.get("time", time) # 如obj1未包含time则使用传入的默认值
# parsed_uuid = obj2.get("uuid", "")
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
else:
return "<关键指令>不需要再调用该工具了</关键指令>"
try:
obj1= json.loads(query)
parsed_query = obj1.get("query", "")
parsed_limit = obj1.get("limit", 3)
parsed_resource_type = obj1.get("resource_type", None)
parsed_time = obj1.get("time", time) # 如obj1未包含time则使用传入的默认值
parsed_uuid = json.loads(matches[1])["uuid"]
# 将解析到的值覆盖原有的参数
query = parsed_query if parsed_query else query
resource_type = parsed_resource_type if parsed_resource_type else resource_type
time = parsed_time if parsed_time else time
logger.info(f"解析完成query: {query}, uuid: {parsed_uuid}, time: {time}, resource_type: {resource_type}, parsed_limit: {parsed_limit}")
except json.JSONDecodeError as e:
logger.error(f"解析JSON出错: {e}")
# 在同步环境中运行异步函数
combined_result = asyncio.run(duckduckgo_search_iter(query, parsed_uuid, time, resource_type, parsed_limit))
# 以标准json格式输出
logger.info("返回JSON格式的结果给到模型...")
return combined_result
class DuckduckgoInput(BaseModel):
location: str = Field(description="网络搜索查询")
if __name__ == "__main__":
# 测试调用
# 1. 默认请求三个API
# result_default = duckduckgo_search("粉末冶金", "m", "default")
# print("duckduckgo输出(默认):\n", result_default)
# # 2. 只请求视频
# result_video = duckduckgo_search("粉末冶金", "m", "video")
# print("duckduckgo输出(视频):\n", result_video)
# # 3. 只请求新闻
# result_news = duckduckgo_search("粉末冶金", "m", "news")
# print("duckduckgo输出(新闻):\n", result_news)
# 4. 其它类型只请求文本
result_other = duckduckgo_search("粉末冶金", "m", "other")
print("duckduckgo输出(其他):\n", result_other)

View File

@@ -0,0 +1,57 @@
from datetime import datetime
import json
import logging
import re
import requests
from configs.model_config import LLM_MODELS
from server.chat import utils
from server.chat.policy_fun_iast import get_llm_model_response
def mysql_statistic(query:str) -> str:
try:
logging.info(f"\n🔍 统计工具查询query: \n{query}\n")
matches = re.search(r'<param>(.*?)</param>(\{.*\})', query.replace("\n", ""))
if matches:
uuid = json.loads(matches.group(2))["uuid"]
query = matches.group(1).strip()
else:
res1 = utils.get_shared_variable(uuid)
# res1["END"]="ok"
utils.set_shared_variable(uuid,res1)
return"暂时无法查询"
except:
res = utils.get_shared_variable(uuid)
# res["END"]="ok"
utils.set_shared_variable(uuid,res)
logging.error("Invalid JSON format in query.")
return f"暂时无法查询"
question = json.loads(query)["query"]
# question = get_llm_model_response(
# strategy_name="default_code",
# llm_model_name=LLM_MODELS[2],
# template_prompt_name="sql_query_rewrite",
# prompt_param_dict={"query": question,"time": datetime.now().strftime("%Y%m%d")},
# temperature=0.01,
# max_tokens=512
# )
logging.info(f"\n🔍 NL2SQL检索question: \n{question}\n")
res = requests.post(
url=f"http://127.0.0.1:6008/query",
json={"question": question},
headers={"Content-Type": "application/json"}
)
if res:
data = res.json()["result"]
if "'data': []" in data:
return f"统计库未检索到数据,使用“联网思索”工具检索该请求:{question}\n"
# temp = utils.get_shared_variable(uuid)
# temp["END"]="ok"
# utils.set_shared_variable(uuid,temp)
else:
return f"判断得到的数据是否准确,如果不准确,则使用“联网思索”工具检索。如果准确,则根据数据表格并使用图表绘制工具制图。\n 数据如下所示: \n{data}\n"

View File

@@ -0,0 +1,170 @@
import json
import logging
import re
from typing import List, Any, Union
import concurrent
from pydantic import BaseModel, Field
from difflib import SequenceMatcher
from configs import (VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
DEFAULT_POLICY_BASE)
from server.agent.tools import search_internet
from server.chat import utils
from server.knowledge_base.kb_doc_api import search_docs
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.utils import BaseResponse
class KnowledgeKgoInput(BaseModel):
location: str = Field(description="Query for Internet search")
def preprocess_text(text: str) -> str:
# 去除空格和特殊符号
text = re.sub(r'[\s\W]', '', text.strip())
return text
def knowledge_temperature(a: str, b: str) -> float:
# 使用difflib中的SequenceMatcher计算相似度
return SequenceMatcher(None, a, b).ratio()
# def knowledgebase_kgo_iter(query: str,
# fileName: List = [],
# knowledge_base_name: str = DEFAULT_POLICY_BASE,
# top_k: int = VECTOR_SEARCH_TOP_K,
# score_threshold: float = SCORE_THRESHOLD) -> BaseResponse | list[str] | Any:
# kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
# if kb is None:
# return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
# query = query.strip()
# docs = search_docs(fileName=fileName,
# query=query,
# knowledge_base_name=knowledge_base_name,
# top_k=top_k,
# score_threshold=score_threshold)
# # 预处理查询文本
# processed_query = preprocess_text(query).replace("Observ","")
# print("processed_query:", processed_query)
# knowledge_docs = []
# knowledge_content = []
# # 知识库返回的文档与query的相似度
# if docs:
# for enum, doc in enumerate(docs):
# filename = doc.metadata.get("title")
# detail_url = "https://policy.ckcest.cn/detail/" + doc.metadata.get("primary_key") + ".html"
# if filename:
# text = f"""政策类资料[{enum + 1}]: [{filename}]({detail_url})\n"""
# else:
# text = f"""政策类资料[{enum + 1}]: [{"原文地址"}]({detail_url})\n"""
# knowledge_docs.append(text)
# # print("knowledge_docs:", knowledge_docs)
# knowledge_content = [doc.page_content for doc in docs]
# # print("knowledge_content:", knowledge_content)
# # 计算知识库返回的文档与query的相似度
# titles = [doc.metadata.get("title") for doc in docs]
# print("titles:", titles)
# def check_similarity_threshold(titles: List[str], query: str, knowledge_docs: List[str], knowledge_content: List[str]) -> Union[
# List[str], None]:
# # 用于记录是否存在相似度大于0.55的标题
# has_similar_title = False
# for title in titles:
# processed_title = preprocess_text(title)
# similarity = knowledge_temperature(processed_query, processed_title)
# print("processed_title:", processed_title)
# print("similarity:", similarity)
# if similarity >= 0.55:
# has_similar_title = True
# break
# # 如果存在相似度大于0.55的标题,则直接返回 knowledge_docs
# if has_similar_title:
# knowledge = knowledge_content + knowledge_docs
# return knowledge
# # 如果所有标题的相似度都不大于0.55,则返回 None
# return None
# # 在原函数中使用新的函数进行相似度阈值的判断
# similar_docs = check_similarity_threshold(titles, query, knowledge_docs, knowledge_content)
# if similar_docs is None:
# # 如果所有标题的相似度都不大于0.55,则执行搜索引擎查询
# kgo_docs = search_internet(processed_query)
# # print("kgo_docs", kgo_docs)
# return kgo_docs
# else:
# kgo_docs = search_internet(processed_query)
# # print("similar_docs", similar_docs)
# # print("kgo_docs", kgo_docs)
# similar_docs.extend(kgo_docs)
# return similar_docs
# else:
# # 执行搜索引擎查询
# kgo_docs = search_internet(query)
# return kgo_docs
def knowledgebase_kgo_iter(query: str, uid: str) -> BaseResponse | list[str] | Any:
kgo_docs = search_internet(query , uid)
return kgo_docs
def knowledgebase_kgo_search(query: str) -> List[str]:
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
parsed_query = json.loads(query)
# 继续使用解析后的查询进行后续操作
time_based_uuid = json.loads(matches[1])["uuid"]
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(knowledgebase_kgo_iter, parsed_query["query"], time_based_uuid)
res = future.result()
# res = knowledgebase_kgo_iter(parsed_query["query"],time_based_uuid)
try:
if type(res[0])==list and len(res[0])>0:
doc_content = "资料内容"
for doc in res[0]:
doc_content += doc
doc_content += "资料来源"
for source in res[1]:
doc_content += source
return f"{doc_content}"
elif type(res[1])==list and len(res[1])>0:
doc_content += "资料来源"
for source in res[1]:
doc_content += source
return "只有标题没有内容,标题为:{doc_content}"
else:
# return "<system>不要再调用工具了</system>"
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
except Exception as e:
logging.error(f"Error occurred while processing query: {e}")
# return "<system>不要再调用该工具了,根据已有资料或自身能力回答</system>"
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
else:
logging.error("Invalid JSON format in query.")
# return "<system>不要再调用该工具了</system>"
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
except json.JSONDecodeError:
# 如果JSON解码失败则返回错误消息
logging.error("Invalid JSON format in query.")
# return "<system>不要再调用该工具了</system>"
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
except KeyError:
# 如果解析的JSON对象中缺少必要的键则返回错误消息
# return "<system>不要再调用该工具了</system>"
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
except Exception as e:
# 捕获其他所有异常,并返回通用错误消息
logging.error(f"Error occurred while processing query: {e}")
# return f"<system>不要再调用该工具了</system>"
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
if __name__ == "__main__":
result = knowledgebase_kgo_iter("《区块链和分布式记账技术标准体系建设指南》")
print("检索结果:", result)

View File

@@ -0,0 +1,113 @@
from datetime import datetime
import json
import logging
import re
from pydantic import BaseModel, Field
from configs.model_config import LLM_MODELS
from server.agent.tools.search_tool import search_tool
from server.chat import utils
from server.chat.policy_fun_iast import get_llm_model_response
def math_count(query: str):
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
parsed_query = json.loads(query)["query"]
# 继续使用解析后的查询进行后续操作
time_based_uuid = json.loads(matches[1])["uuid"]
# temp = utils.get_shared_variable(time_based_uuid)
# temp["END"] ="ok"
# utils.set_shared_variable(time_based_uuid,temp)
tip = {}
# tip["END"]="ok"
tip["source_docs"]=[]
tip["num"]=0
tip["title"]=[]
utils.set_shared_variable(time_based_uuid+"q",tip)
first_json = {
"query": parsed_query,
"knowledge_name": [],
"keywords": []
}
second_json = {
"uuid": time_based_uuid+"q"
}
math_doc = search_tool(json.dumps(first_json) + json.dumps(second_json))
utils.remove_shared_variable(time_based_uuid+"q")
res = get_llm_model_response(
strategy_name="default_math",
llm_model_name=LLM_MODELS[3],
template_prompt_name="default_math",
prompt_param_dict={"input": parsed_query, "math_doc": f"{math_doc}", "time": datetime.now().strftime("%Y%m%d")},
temperature=0.01,
max_tokens=512
)
return f"{res}"
except Exception as e:
logging.error(f"Error occurred while processing math query: {e}")
return "<system>不要再调用该工具了</system>"
except Exception as e:
return "<system>不要再调用该工具了</system>"
def code_count(query: str):
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
parsed_query = json.loads(query)["query"]
# 继续使用解析后的查询进行后续操作
time_based_uuid = json.loads(matches[1])["uuid"]
# temp = utils.get_shared_variable(time_based_uuid)
# temp["END"] ="ok"
# utils.set_shared_variable(time_based_uuid,temp)
tip = {}
# tip["END"]="ok"
tip["source_docs"]=[]
tip["num"]=0
tip["title"]=[]
utils.set_shared_variable(time_based_uuid+"q",tip)
first_json = {
"query": parsed_query,
"knowledge_name": [],
"keywords": []
}
second_json = {
"uuid": time_based_uuid+"q"
}
code_doc = search_tool(json.dumps(first_json) + json.dumps(second_json))
utils.remove_shared_variable(time_based_uuid+"q")
res = get_llm_model_response(
strategy_name="default_code",
llm_model_name=LLM_MODELS[2],
template_prompt_name="default_code",
prompt_param_dict={"input": parsed_query, "code_doc": f"{code_doc}", "time": datetime.now().strftime("%Y%m%d")},
temperature=0.01,
max_tokens=512
)
res = res.replace("<think>","")
return f"{res}"
except Exception as e:
logging.error(f"Error occurred while processing math query: {e}")
return "<system>不要再调用该工具了</system>"
except Exception as e:
return "<system>不要再调用该工具了</system>"
class RagSearchInput(BaseModel):
query: str = Field(...,description="查询对象")

View File

@@ -0,0 +1,42 @@
from server.chat.knowledge_base_chat import knowledge_base_chat
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS, LLM_MODELS
import json
import asyncio
from typing import List, Tuple, Any
from pydantic import BaseModel, Field
class PolicyKnowledgeInput(BaseModel):
location: str = Field(description="The policy related query to be searched")
async def policy_knowledgebase_search_iter(query: str) -> tuple[str | Any, list[Any] | Any]:
response = await knowledge_base_chat(query=query,
fileName=None,
knowledge_base_name_list = ["t_policy_total_bge_new_v1"],
model_name=LLM_MODELS[0],
temperature=0.01,
history=[],
top_k=VECTOR_SEARCH_TOP_K,
max_tokens=MAX_TOKENS,
prompt_name="default",
score_threshold=SCORE_THRESHOLD,
stream=False)
contents = ""
docs = []
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
print("data>>>>>", data)
contents = data["answer"] if "answer" in data.keys() else contents
docs = data["docs"] if "docs" in data.keys() else docs
return contents, docs
def policy_knowledgebase_search(query: str) -> tuple[str | Any, list[Any] | Any]:
return asyncio.run(policy_knowledgebase_search_iter(query))
if __name__ == "__main__":
result = policy_knowledgebase_search("大数据男女比例")
print("答案:", result)

View File

@@ -0,0 +1,108 @@
import json
import re
import concurrent
from fastapi.concurrency import run_in_threadpool
from langchain.tools import YouTubeSearchTool
from pydantic import BaseModel, Field
from server.chat import utils
from server.chat.utils import doc_to_list, solve_knowledge_map,solve_mental_data,shared_variable
from server.knowledge_base.kb_doc_api import search_docs
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from configs import kb_config
def rag_search1(query: str):
"""
根据用户输入的query返回rag搜索结果
"""
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
else:
return "<关键指令>不需要再调用该工具了</关键指令>"
time_based_uuid = json.loads(matches[1])["uuid"]
search = json.loads(query)
search_query = search["query"]
search_keywords = []
search_text = f"{search_query}"
if type(search["keywords"]) == list:
search_keywords = search["keywords"]
for keyword in search_keywords:
search_text += f" {keyword}"
else:
search_keywords = search["keywords"].split(",")
for keyword in search_keywords:
search_text += f" {keyword}"
result = []
source_docs = {}
knownledge_name = []
if type(search["knowledge_name"]) == list:
knownledge_name=search["knowledge_name"]
else:
knownledge_name=search["knowledge_name"].split(",")
for knownledge in knownledge_name:
if not knownledge in kb_config.CH_BASE_NAME:
knownledge_name.remove(knownledge)
if len(knownledge_name)==0:
result.append(f"没有找到匹配的知识库,请必须更换联网思索搜索更多知识库内容")
return result
# knownledge_name=kb_config.CH_BASE_NAME
knownledge_name=solve_knowledge_map(knownledge_name)
num = 0
for knownledge in knownledge_name:
source_docs[knownledge] = []
seen_docs = set()
duplicate_indices = []
doc_list = search_docs(usr_query=search_text,fileName= [],top_k=5,score_threshold=0.9,query=search_text, knowledge_base_name=knownledge)
for inum,doc in enumerate(doc_list):
solve_mental_data(knownledge,doc_list,doc=doc,seen_docs=seen_docs,duplicate_indices=duplicate_indices,knowledge=result,inum=inum)
# 从policydocs中删除重复的文档从后往前删除以防止索引错位
for index in sorted(duplicate_indices, reverse=True):
del doc_list[index]
# 处理原文来源进入数组。使用开关语句明确各个条件分支
match knownledge:
# 属于政策库分支,入参为中文政策库名称
case kb_config.DEFAULT_POLICY_BASE:
doc_to_list(num,kb_config.DEFAULT_POLICY_BASE_NAME,doc_list,source_docs[knownledge])
# 属于期刊论文库分支,入参为期刊论文库的中文名称
case kb_config.DEFAULT_JOURNAL_BASE:
doc_to_list(num,kb_config.DEFAULT_JOURNAL_BASE_NAME,doc_list,source_docs[knownledge])
# 属于报告库分支,入参为报告库中文名称
case kb_config.DEFAULT_REPORT_BASE:
doc_to_list(num,kb_config.DEFAULT_REPORT_BASE_NAME,doc_list,source_docs[knownledge])
# 属于冶金行业新闻库分支,入参为冶金行业新闻库中文名称
case kb_config.GY_NEWS_BASE:
doc_to_list(num,kb_config.GY_NEWS_BASE_NAME,doc_list,source_docs[knownledge])
# 属于冶金行业报告库分支,入参为冶金行业报告库中文名称
case kb_config.GY_REPORT_BASE:
doc_to_list(num,kb_config.GY_REPORT_BASE_NAME,doc_list,source_docs[knownledge])
# 属于冶金专业知识库分支,入参为冶金专业知识库中文名称
case kb_config.GY_JOURNAL_BASE:
doc_to_list(num,kb_config.GY_JOURNAL_BASE_NAME,doc_list,source_docs[knownledge])
case _:
print(f"输入了没有的知识库名称")
return("输入了没有的知识库名称")
num += len(source_docs[knownledge])
# 构建缓存对象用于h'per_query_cache'用于传递给其他方法使用uuid获取本轮对话的文献来源
del num
source = utils.get_shared_variable(time_based_uuid)
print(utils.get_shared_variable(time_based_uuid))
source["source_docs"]=source_docs
utils.set_shared_variable(time_based_uuid,source)
if 0<len(result)<3:
return f"当前资料:{result}\n<关键指令>搜索结果较少,更换知识库或联网思索重新搜索!!!</关键指令>"
if len(result)==0:
return "注意:【指令:更换知识库或联网思索继续搜索!!!】"
except:
# return "入参格式需为{\"knowledge_name\":[\"XXX\",\"XXX\"],\"query\":\"XXX\",\"keywords\":[\"XXX\", \"XXX\", \"XXX\", \"XXX\"]}检查输入参数如果没有缺少必要值,当前工具异常请换其他工具"
return "当前工具异常!请换其他工具"
return f"当前资料:{result}\n<关键指令>总结此内容!!!</关键指令>"
class RagSearchInput(BaseModel):
query: str = Field(...,description="查询对象")

View File

@@ -0,0 +1,42 @@
from server.chat.report_chat import report_chat
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS, LLM_MODELS, DEFAULT_REPORT_BASE
import json
import asyncio
from typing import List, Tuple, Any
from pydantic import BaseModel, Field
class ReportKnowledgeInput(BaseModel):
location: str = Field(description="The report related query to be searched")
async def report_knowledgebase_search_iter(query: str) -> tuple[str | Any, list[Any] | Any]:
response = await report_chat(query=query,
fileName=None,
knowledge_base_name=DEFAULT_REPORT_BASE,
model_name=LLM_MODELS[0],
temperature=0.01,
history=[],
top_k=VECTOR_SEARCH_TOP_K,
max_tokens=MAX_TOKENS,
prompt_name="default",
score_threshold=SCORE_THRESHOLD,
stream=False)
contents = ""
docs = []
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
print("data>>>>>", data)
contents = data["answer"] if "answer" in data.keys() else contents
docs = data["docs"] if "docs" in data.keys() else docs
return contents, docs
def report_knowledgebase_search(query: str) -> tuple[str | Any, list[Any] | Any]:
return asyncio.run(report_knowledgebase_search_iter(query))
if __name__ == "__main__":
result = report_knowledgebase_search("大数据男女比例")
print("答案:", result)

View File

@@ -0,0 +1,89 @@
import json
import asyncio
import unicodedata
from server.chat.KgoSearchAPIWrapper import KgoSearchAPIWrapper
from server.chat.search_engine_chat import search_engine_chat
from configs import VECTOR_SEARCH_TOP_K, MAX_TOKENS
from server.agent import model_container
from pydantic import BaseModel, Field
from configs import LLM_MODELS, TEMPERATURE
from configs.basic_config import *
# def get_kgo_search_type(query: str = "全部"):
# # 过滤掉所有非汉字的字符
# query = ''.join(char for char in query if unicodedata.category(char) == 'Lo')
# search_map = KgoSearchAPIWrapper().search_map
# if "论文" in query:
# return "1001"
# elif "外文" in query or "英文" in query:
# return "1013"
# elif "期刊" in query or "研究进展" in query:
# return "1002"
# else:
# matched_types = [value for key, value in search_map.items() if key in query]
# if matched_types:
# return ','.join(matched_types)
# else:
# print("未找到匹配的搜索类型返回默认值1000")
# return "1000"
@timing_decorator
async def search_engine_iter(query: str , uid: str):
response = await search_engine_chat(uid = uid,
query=query,
search_engine_name="zhipu_search",
model_name=LLM_MODELS[1],
temperature=TEMPERATURE, # Agent搜索互联网的时候温度设为0.1
history=[],
top_k=VECTOR_SEARCH_TOP_K,
max_tokens=MAX_TOKENS,
prompt_name="search",
stream=False,
kgo_search_type="1000")
contents = ""
docs = []
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
contents = data.get("answer", [])
current_docs = data.get("docs", [])
if current_docs:
docs.extend(current_docs)
else:
logging.error("No docs found in the response")
return docs
# print("contents", contents)
# print("docs", docs)
# 回复搜索结果和搜索结果的总结
return contents,docs
def search_internet(query: str, uid: str):
# filter_words = {
# "统计数据", "视频", "数据集", "新闻", "专利", "期刊", "图书", "报告", "项目", "成果", "会议论文",
# "政策", "外文期刊论文", "学位论文", "期刊论文", "全部论文", "原文", "全文", "pdf", "资料", " ",
# "进展", "研究", "最新", "外文", "英文", "最新", "文件", "资料", "论文"
# }
# filtered_query = query
# for word in filter_words:
# if word in query:
# filtered_query = filtered_query.replace(word, "")
# print("filtered query:", filtered_query)
# kgo_search_type = get_kgo_search_type(query)
# print("kgo_search_type:", kgo_search_type)
# 使用过滤后的查询字符串进行搜索
return asyncio.run(search_engine_iter(query, uid))
class SearchInternetInput(BaseModel):
location: str = Field(description="Query for Internet search")
if __name__ == "__main__":
result = search_internet("人工智能领域的政策")
print("答案:", result)

View File

@@ -0,0 +1,15 @@
name: search_internet
description: Use this tool to surf internet and get information
parameters:
type: object
properties:
query:
type: string
description: Query for Internet search
kgo_search_type:
type: int
description: the return value 'kgo_search_type' of the 'get_kgo_search_type'
default: 1000
required:
- query
- kgo_search_type

View File

@@ -0,0 +1,294 @@
from __future__ import annotations
import json
import re
import warnings
from typing import Dict
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun
from langchain.chains.llm import LLMChain
from langchain.pydantic_v1 import Extra, root_validator
from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from typing import List, Any, Optional
from langchain.prompts import PromptTemplate
from server.chat.knowledge_base_chat import knowledge_base_chat
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
import asyncio
from server.agent import model_container
from pydantic import BaseModel, Field
async def search_knowledge_base_iter(database: str, query: str) -> str:
response = await knowledge_base_chat(query=query,
knowledge_base_name=database,
model_name=model_container.MODEL.model_name,
temperature=0.01,
history=[],
top_k=VECTOR_SEARCH_TOP_K,
max_tokens=MAX_TOKENS,
prompt_name="default",
score_threshold=SCORE_THRESHOLD,
stream=False)
contents = ""
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
contents += data["answer"]
docs = data["docs"]
return contents
async def search_knowledge_multiple(queries) -> List[str]:
# queries 应该是一个包含多个 (database, query) 元组的列表
tasks = [search_knowledge_base_iter(database, query) for database, query in queries]
results = await asyncio.gather(*tasks)
# 结合每个查询结果,并在每个查询结果前添加一个自定义的消息
combined_results = []
for (database, _), result in zip(queries, results):
message = f"\n查询到 {database} 知识库的相关信息:\n{result}"
combined_results.append(message)
return combined_results
def search_knowledge(queries) -> str:
responses = asyncio.run(search_knowledge_multiple(queries))
# 输出每个整合的查询结果
contents = ""
for response in responses:
contents += response + "\n\n"
return contents
_PROMPT_TEMPLATE = """
用户会提出一个需要你查询知识库的问题,你应该对问题进行理解和拆解,并在知识库中查询相关的内容。
对于每个知识库,你输出的内容应该是一个一行的字符串,这行字符串包含知识库名称和查询内容,中间用逗号隔开,不要有多余的文字和符号。你可以同时查询多个知识库,下面这个例子就是同时查询两个知识库的内容。
例子:
robotic,机器人男女比例是多少
bigdata,大数据的就业情况如何
这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能,你应该参考他们的功能来帮助你思考
{database_names}
你的回答格式应该按照下面的内容,请注意```text 等标记都必须输出,这是我用来提取答案的标记。
不要输出中文的逗号,不要输出引号。
Question: ${{用户的问题}}
```text
${{知识库名称,查询问题,不要带有任何除了,之外的符号,比如不要输出中文的逗号,不要输出引号}}
```output
数据库查询的结果
现在,我们开始作答
问题: {question}
"""
PROMPT = PromptTemplate(
input_variables=["question", "database_names"],
template=_PROMPT_TEMPLATE,
)
class LLMKnowledgeChain(LLMChain):
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
prompt: BasePromptTemplate = PROMPT
"""[Deprecated] Prompt to use to translate to python if necessary."""
database_names: Dict[str, str] = None
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the from_llm "
"class method."
)
if "llm_chain" not in values and values["llm"] is not None:
prompt = values.get("prompt", PROMPT)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _evaluate_expression(self, queries) -> str:
try:
output = search_knowledge(queries)
except Exception as e:
output = "输入的信息有误或不存在知识库,错误信息如下:\n"
return output + str(e)
return output
def _process_llm_result(
self,
llm_output: str,
run_manager: CallbackManagerForChainRun
) -> Dict[str, str]:
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
# text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1).strip()
cleaned_input_str = (expression.replace("\"", "").replace("", "").
replace("", "").replace("```", "").strip())
lines = cleaned_input_str.split("\n")
# 使用逗号分割每一行,然后形成一个(数据库,查询)元组的列表
try:
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
except:
queries = [(line.split("")[0].strip(), line.split("")[1].strip()) for line in lines]
run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", verbose=self.verbose)
output = self._evaluate_expression(queries)
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = llm_output.split("Answer:")[-1]
else:
return {self.output_key: f"输入的格式不对:\n {llm_output}"}
return {self.output_key: answer}
async def _aprocess_llm_result(
self,
llm_output: str,
run_manager: AsyncCallbackManagerForChainRun,
) -> Dict[str, str]:
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1).strip()
cleaned_input_str = (
expression.replace("\"", "").replace("", "").replace("", "").replace("```", "").strip())
lines = cleaned_input_str.split("\n")
try:
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
except:
queries = [(line.split("")[0].strip(), line.split("")[1].strip()) for line in lines]
await run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue",
verbose=self.verbose)
output = self._evaluate_expression(queries)
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
raise ValueError(f"unknown format from LLM: {llm_output}")
return {self.output_key: answer}
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key])
self.database_names = model_container.DATABASE
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
llm_output = self.llm_chain.predict(
database_names=data_formatted_str,
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return self._process_llm_result(llm_output, _run_manager)
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
await _run_manager.on_text(inputs[self.input_key])
self.database_names = model_container.DATABASE
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
llm_output = await self.llm_chain.apredict(
database_names=data_formatted_str,
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
@property
def _chain_type(self) -> str:
return "llm_knowledge_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = PROMPT,
**kwargs: Any,
) -> LLMKnowledgeChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)
def search_knowledgebase_complex(query: str):
model = model_container.MODEL
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_knowledge.run(query)
return ans
class KnowledgeSearchInput(BaseModel):
location: str = Field(description="The query to be searched")
class RagSearchInput(BaseModel):
query: str = Field(description="查询对象")
knowledge_name: str = Field(description="The name of the knowledge base to be searched,policy knowledge base name is t_policy_total_bge_new_v2, example: t_policy_total_bge_new_v2]")
keywords: str = Field(description="The keywords to be searched example: age,child]")
if __name__ == "__main__":
result = search_knowledgebase_complex("机器人和大数据在代码教学上有什么区别")
print(result)
# 这是一个正常的切割
# queries = [
# ("bigdata", "大数据专业的男女比例"),
# ("robotic", "机器人专业的优势")
# ]
# result = search_knowledge(queries)
# print(result)

View File

@@ -0,0 +1,10 @@
name: search_knowledgebase_complex
description: Use this tool to search local knowledgebase and get information
parameters:
type: object
properties:
query:
type: string
description: The query to be searched
required:
- query

View File

@@ -0,0 +1,234 @@
from __future__ import annotations
import re
import warnings
from typing import Dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.llm import LLMChain
from langchain.pydantic_v1 import Extra, root_validator
from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from typing import List, Any, Optional
from langchain.prompts import PromptTemplate
import sys
import os
import json
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from server.chat.knowledge_base_chat import knowledge_base_chat
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
import asyncio
from server.agent import model_container
from pydantic import BaseModel, Field
async def search_knowledge_base_iter(database: str, query: str):
response = await knowledge_base_chat(query=query,
knowledge_base_name=database,
model_name=model_container.MODEL.model_name,
temperature=0.01,
history=[],
top_k=VECTOR_SEARCH_TOP_K,
max_tokens=MAX_TOKENS,
prompt_name="knowledge_base_chat",
score_threshold=SCORE_THRESHOLD,
stream=False)
contents = ""
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
contents += data["answer"]
docs = data["docs"]
return contents
_PROMPT_TEMPLATE = """
用户会提出一个需要你查询知识库的问题,你应该按照我提供的思想进行思考
Question: ${{用户的问题}}
这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能:
{database_names}
你的回答格式应该按照下面的内容,请注意,格式内的```text 等标记都必须输出,这是我用来提取答案的标记。
```text
${{知识库的名称}}
```
```output
数据库查询的结果
```
答案: ${{答案}}
现在,这是我的问题:
问题: {question}
"""
PROMPT = PromptTemplate(
input_variables=["question", "database_names"],
template=_PROMPT_TEMPLATE,
)
class LLMKnowledgeChain(LLMChain):
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
prompt: BasePromptTemplate = PROMPT
"""[Deprecated] Prompt to use to translate to python if necessary."""
database_names: Dict[str, str] = model_container.DATABASE
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the from_llm "
"class method."
)
if "llm_chain" not in values and values["llm"] is not None:
prompt = values.get("prompt", PROMPT)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _evaluate_expression(self, dataset, query) -> str:
try:
output = asyncio.run(search_knowledge_base_iter(dataset, query))
except Exception as e:
output = "输入的信息有误或不存在知识库"
return output
return output
def _process_llm_result(
self,
llm_output: str,
llm_input: str,
run_manager: CallbackManagerForChainRun
) -> Dict[str, str]:
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
database = text_match.group(1).strip()
output = self._evaluate_expression(database, llm_input)
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
return {self.output_key: f"输入的格式不对: {llm_output}"}
return {self.output_key: answer}
async def _aprocess_llm_result(
self,
llm_output: str,
run_manager: AsyncCallbackManagerForChainRun,
) -> Dict[str, str]:
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
raise ValueError(f"unknown format from LLM: {llm_output}")
return {self.output_key: answer}
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key])
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
llm_output = self.llm_chain.predict(
database_names=data_formatted_str,
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return self._process_llm_result(llm_output, inputs[self.input_key], _run_manager)
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
await _run_manager.on_text(inputs[self.input_key])
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
llm_output = await self.llm_chain.apredict(
database_names=data_formatted_str,
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
@property
def _chain_type(self) -> str:
return "llm_knowledge_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = PROMPT,
**kwargs: Any,
) -> LLMKnowledgeChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)
def search_knowledgebase_once(query: str):
model = model_container.MODEL
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_knowledge.run(query)
return ans
class KnowledgeSearchInput(BaseModel):
location: str = Field(description="The query to be searched")
if __name__ == "__main__":
result = search_knowledgebase_once("大数据的男女比例")
print(result)

View File

@@ -0,0 +1,32 @@
from server.chat.knowledge_base_chat import knowledge_base_chat
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
import json
import asyncio
from server.agent import model_container
async def search_knowledge_base_iter(database: str, query: str) -> str:
response = await knowledge_base_chat(query=query,
knowledge_base_name=database,
model_name=model_container.MODEL.model_name,
temperature=0.01,
history=[],
top_k=VECTOR_SEARCH_TOP_K,
max_tokens=MAX_TOKENS,
prompt_name="knowledge_base_chat",
score_threshold=SCORE_THRESHOLD,
stream=False)
contents = ""
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
contents = data["answer"]
docs = data["docs"]
return contents
def search_knowledgebase_simple(query: str):
return asyncio.run(search_knowledge_base_iter(query))
if __name__ == "__main__":
result = search_knowledgebase_simple("大数据男女比例")
print("答案:",result)

View File

@@ -0,0 +1,56 @@
import json
import os
import re
from typing import List
from urllib.parse import quote
from server.agent.tools.search_tool import rag_search
from server.chat import utils
from server.knowledge_base.kb_doc_api import search_docs
def search_pic(query: str) -> List[str]:
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
else:
return "<关键指令>不需要再调用该工具了</关键指令>"
uuid = json.loads(matches[1])["uuid"]
tip = utils.get_shared_variable(uuid)
# tip["END"] ="ok"
temp = {}
temp = json.loads(query)
res = search_docs(usr_query=temp["query"],fileName= [],top_k=10,score_threshold=0.9,query=temp["query"], knowledge_base_name="p_meiyupic")
if len(res)==0 and len(tip["source_docs"])==0:
utils.set_shared_variable(uuid,tip)
return "工具没有找到结果"
# 遍历 res 中的每个元素
result = ""
for item in res:
# 获取 source 的目录部分
source_dir = os.path.splitext(item.metadata['source'])[0]
# 获取 page_content
page_content = item.page_content
# 拼接字符串
if item.metadata['source'] in tip["source_docs"]:
continue
else:
tip["source_docs"].append(item.metadata['source'])
page_content = quote(page_content.replace("http://127.0.0.1:8099/chat_web_backend", "http://127.0.0.1:8099/chat_web_backend"),safe='/:?=&#+')
result += f'![{source_dir}]({page_content})\n'
utils.set_shared_variable(uuid,tip)
if len(result)>0:
print(f"美术作品链接:{result}")
return f"注意:以下链接是图片不是参考文献,以下链接不要放到引文小标的格式输出而是以图片格式输出,禁止转义后面链接的编码,这个链接不能带中文。图片如下:{result}"
else:
return "<关键指令>不需要再调用该工具了</关键指令>"
except Exception as e:
return f"Failed to get picture.{e}"

View File

@@ -0,0 +1,331 @@
import asyncio
import concurrent
from datetime import datetime
import json
import logging
import re
from typing import List
from fastapi import logger
from configs import kb_config
from configs.model_config import LLM_MODELS
from server.agent.tools import duckduckgo_search
from server.agent.tools.duckduckgo_search import duckduckgo_search_iter
from server.agent.tools.knowledgebase_kgo_search import knowledgebase_kgo_iter
from server.agent.tools.rag_search import rag_search1
from server.chat import utils
from server.chat.policy_fun_iast import get_llm_model_response
from server.chat.utils import doc_to_list, get_similar_documents1, solve_knowledge_map, solve_mental_data
from server.knowledge_base.kb_doc_api import search_docs
def rag_search(query: str,uid):
"""
根据用户输入的query返回rag搜索结果
"""
source_docs = []
try:
search = json.loads(query)
logging.info(f'模型输入: {search["query"]}')
original_query = search["query"]
search_query = get_llm_model_response(
strategy_name="rag_search_rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="rag_search_rewrite",
prompt_param_dict={"input": search["query"], "year": datetime.now().strftime("%Y")},
temperature=0.3,
max_tokens=512
)
logging.info(f'模型改写: {search_query}')
search_keywords = []
search_text = f"{search_query}"
# if type(search["keywords"]) == list:
# search_keywords = search["keywords"]
# for keyword in search_keywords:
# search_text += f" {keyword}"
# else:
# search_keywords = search["keywords"].split(",")
# for keyword in search_keywords:
# search_text += f" {keyword}"
self_database = utils.get_shared_variable(uid)
result = []
knownledge_name = []
if type(search["knowledge_name"]) == list:
knownledge_name=search["knowledge_name"]
else:
knownledge_name=search["knowledge_name"].split(",")
if "美术专业知识库" in knownledge_name:
knownledge_name.remove("美术专业知识库")
if "database" in self_database:
self_database["database"]= self_database["database"].append("p_cafa0101011")
else:
self_database["database"] = ["p_cafa0101011"]
# 添加个人知识库
if "database" in self_database:
knownledge_name.extend(self_database["database"])
knownledge_name = [knownledge for knownledge in knownledge_name
if (knownledge in kb_config.CH_BASE_NAME
or knownledge in kb_config.EN_BASE_NAME
or knownledge in getattr(kb_config, "YJ_BASE_NAME", [])
or kb_config.SELF_KNOWLEDGE_BASE.match(knownledge)
or knownledge == "coding")]
if len(knownledge_name)==0:
#result.append(f"没有找到匹配的知识库,请必须更换联网思索搜索更多知识库内容")
return result,source_docs
# knownledge_name=kb_config.CH_BASE_NAME
knownledge_name=solve_knowledge_map(knownledge_name)
#knownledge_name = ["p_c88859a3d06e4265bd01d816ef2650d1"]
num = 0
temp=utils.get_shared_variable(uid)
for knownledge in knownledge_name:
seen_docs = set()
duplicate_indices = []
# 针对中国钢铁行业动态库增加日期范围过滤
expr_param = ""
if knownledge == kb_config.STEEL_KB:
time_today = datetime.now().strftime("%Y-%m-%d")
# 调用LLM生成日期表达式模板沿用 get_policy_time
try:
expr_candidate = get_llm_model_response(
strategy_name="get steel time",
llm_model_name=LLM_MODELS[0],
template_prompt_name="get_steel_time",
prompt_param_dict={"query": original_query, "time": time_today},
temperature=0.01,
max_tokens=512
).replace("None", "").strip()
expr_param = expr_candidate if expr_candidate else ""
except Exception as _:
expr_param = ""
doc_list = search_docs(
usr_query=original_query,
fileName=[],
top_k=20,
score_threshold=1.0,
query=search_text,
knowledge_base_name=knownledge,
expr=expr_param
)
if len(doc_list)==0:
return result,source_docs
titles = temp["title"]
doc_list,title = utils.remove_docs1(titles,doc_list)
titles.extend(title)
for inum,doc in enumerate(doc_list):
solve_mental_data(knownledge,doc_list,doc=doc,seen_docs=seen_docs,duplicate_indices=duplicate_indices,knowledge=result,inum=inum)
# 从policydocs中删除重复的文档从后往前删除以防止索引错位
for index in sorted(duplicate_indices, reverse=True):
del doc_list[index]
# 处理原文来源进入数组。使用开关语句明确各个条件分支
match knownledge:
# 属于政策库分支,入参为中文政策库名称
case kb_config.DEFAULT_POLICY_BASE:
doc_to_list(num,kb_config.DEFAULT_POLICY_BASE_NAME,doc_list,source_docs)
# 属于期刊论文库分支,入参为期刊论文库的中文名称
case kb_config.DEFAULT_JOURNAL_BASE:
doc_to_list(num,kb_config.DEFAULT_JOURNAL_BASE_NAME,doc_list,source_docs)
# 属于报告库分支,入参为报告库中文名称
case kb_config.DEFAULT_REPORT_BASE1:
doc_to_list(num,kb_config.DEFAULT_REPORT_BASE_NAME,doc_list,source_docs)
# 属于冶金行业新闻库分支,入参为冶金行业新闻库中文名称
case kb_config.GY_NEWS_BASE:
doc_to_list(num,kb_config.GY_NEWS_BASE_NAME,doc_list,source_docs)
# 属于冶金行业报告库分支,入参为冶金行业报告库中文名称
case kb_config.GY_REPORT_BASE:
doc_to_list(num,kb_config.GY_REPORT_BASE_NAME,doc_list,source_docs)
# 属于冶金专业知识库分支,入参为冶金专业知识库中文名称
case kb_config.GY_JOURNAL_BASE:
doc_to_list(num,kb_config.GY_JOURNAL_BASE_NAME,doc_list,source_docs)
# 新增冶金新闻库2024年以及之前
case kb_config.YJ_NEWS_BASE:
doc_to_list(num,kb_config.YJ_NEWS_BASE_NAME,doc_list,source_docs)
# 新增冶金中文期刊库
case kb_config.YJ_CH_JOURNAL_BASE:
doc_to_list(num,kb_config.YJ_CH_JOURNAL_BASE_NAME,doc_list,source_docs)
# 新增冶金外文期刊库
case kb_config.YJ_FOR_JOURNAL_BASE:
doc_to_list(num,kb_config.YJ_FOR_JOURNAL_BASE_NAME,doc_list,source_docs)
# 新增冶金OA期刊库
case kb_config.YJ_OA_JOURNAL_BASE:
doc_to_list(num,kb_config.YJ_OA_JOURNAL_BASE_NAME,doc_list,source_docs)
# 新增冶金政策库
case kb_config.YJ_POLICYS_BASE:
doc_to_list(num,kb_config.YJ_POLICYS_BASE_NAME,doc_list,source_docs)
# 新增中国钢铁行业动态库
case kb_config.STEEL_KB:
doc_to_list(num,kb_config.STEEL_KB_NAME,doc_list,source_docs)
# 属于个人知识库分支
case _ if kb_config.SELF_KNOWLEDGE_BASE.match(knownledge) or knownledge == "coding":
doc_to_list(num,knownledge,doc_list,source_docs)
case _:
print(f"输入了没有的知识库名称")
return "输入了没有的知识库名称",source_docs
# num += len(source_docs[knownledge])
# 构建缓存对象用于h'per_query_cache'用于传递给其他方法使用uuid获取本轮对话的文献来源
# del num
# source = utils.get_shared_variable(uid)
# print(utils.get_shared_variable(uid))
# source["source_docs"]=source_docs
# utils.set_shared_variable(uid,source)
except Exception as e:
logging.error(f"Error in rag_search: {e}")
# return "入参格式需为{\"knowledge_name\":[\"XXX\",\"XXX\"],\"query\":\"XXX\",\"keywords\":[\"XXX\", \"XXX\", \"XXX\", \"XXX\"]}检查输入参数如果没有缺少必要值,当前工具异常请换其他工具"
return "当前工具异常!请换其他工具",source_docs
return result,source_docs
def knowledgebase_kgo_search(query: str, uid) -> List[str]:
try:
res = knowledgebase_kgo_iter(query,uid)
try:
if type(res[0])==list and type(res[1])==list:
return res
elif type(res[1])==list and len(res[1])>0:
res[0]=[]
return res
else:
temp = []
temp[0]=[]
temp[1]=[]
return temp
except Exception as e:
temp = {}
logging.error(f"No docs: {e}")
temp[0]=[]
temp[1]=[]
return temp
except json.JSONDecodeError:
# 如果JSON解码失败则返回错误消息
logging.error("Invalid JSON format in query.")
return "<关键指令>不需要再调用该工具了</关键指令>"
except KeyError:
# 如果解析的JSON对象中缺少必要的键则返回错误消息
return "请尝试调用其他工具"
except Exception as e:
# 捕获其他所有异常,并返回通用错误消息
return f"发生错误:{str(e)},请尝试调用其他工具"
def inner_duckduckgo_search(query: str, uuid:str,) :
logging.info(f"模型输入: {query}")
combined_result = asyncio.run(duckduckgo_search_iter(query, uuid, "y","default" ))
# 以标准json格式输出
logging.info("返回JSON格式的结果给到模型...")
return combined_result
def search_tool(query: str):
"""获取到uid并拆分query"""
if "<param>"in query:
query = query.replace("<param>","").replace("</param>","")
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
else:
return "<关键指令>当前工具不需要再调用</关键指令>"
time_based_uuid = json.loads(matches[1])["uuid"]
search = json.loads(query)
if type(search["query"])==list and len(search["query"])>0:
searches = search["query"][0]
elif type(search["query"])==list and len(search["query"]) == 0:
searches = ""
else:
searches = search["query"]
"""
根据用户输入的query返回rag搜索结果
"""
try:
with concurrent.futures.ThreadPoolExecutor() as executor:
# 提交任务并发执行
test = {}
test["num"]=0
test["source_docs"]=[]
test["END"] = ""
test["title"] = []
utils.set_shared_variable(time_based_uuid+"",test)
# future2 = executor.submit(knowledgebase_kgo_search,search["query"],time_based_uuid+"q")
future1 = executor.submit(rag_search,query,time_based_uuid)
# if not "type" in utils.get_shared_variable(time_based_uuid):
# future2 = executor.submit(knowledgebase_kgo_search,searches,time_based_uuid+"¥")
if not "type" in utils.get_shared_variable(time_based_uuid):
future2 = executor.submit(knowledgebase_kgo_search,searches,time_based_uuid+"")
result3 = []
# 获取结果
result1,sourcedocs = future1.result()
result2 = {}
if "type" in utils.get_shared_variable(time_based_uuid):
result2[0] =[]
result2[1] = []
else:
result2 = future2.result()
# if "type" in utils.get_shared_variable(time_based_uuid):
# result2[0] =[]
# result2[1] = []
# else:
# result2 = future2.result()
# result2[0] = []
# result2[1] = []
utils.remove_shared_variable(time_based_uuid+"q")
if type(result2[1]) == list:
if type(sourcedocs) == list:
sourcedocs.extend(result2[1])
else:
sourcedocs = []
if type(result1) == list:
result1.extend(result2[0])
result3 = result1
else:
result3 = result2[0]
logging.info(f"result2:{result2[1]}")
source = []
res=[]
sources = utils.get_shared_variable(time_based_uuid)
i = sources["num"]
num = sources["num"]
for result in sourcedocs:
try:
i+=1
res3 = re.sub(r'\[\d+\]', f"[{i}]", result.replace("\n",""), count=1)
if res3:
source.append(re.sub(r'\[\d+\]', f"[{i}]", result.replace("\n",""), count=1))
else:
i -= 1
except Exception as e:
i -= 1
pass
# internet_search_res = f"参考资料[{len(result1)+1}-{len(source)}]:{result2[0]}"
# internet_search_res = f"参考资料:{result2[0]}"
j = sources["num"]
for result in result3:
j+=1
res.append(re.sub(r'\[\d+\]', f"[{j}]", result, count=1))
print(utils.get_shared_variable(time_based_uuid))
# sources["source_docs"]=source
sources["source_docs"].extend(source)
sources["num"]=i
# sources["END"] = "ok"
utils.set_shared_variable(time_based_uuid,sources)
logging.info(f"result1:{result1},sourcedocs:{sourcedocs}")
logging.info(f"result2:{result2}")
logging.info(f"{res}")
if len(res) ==0 and len(source)==0:
return f"尝试调整入参重新调用知识库联想工具(同一个问题调用超过三次就不要再使用知识库联想工具了,浪费时间)"
return f"<关键指令>如果你在写文章禁止在非规定位置输出参考资料</关键指令>资料:{res}\n资料来源为:{source}\n 注意:如果你在根据大纲撰写文章,撰写中间部分章节禁止输出综上所述之类的影响文风的话,撰写中间部分禁止输出附录引用文献等!!!"
except Exception as e:
logging.error(f"Error occurred during search_tool execution.{e}")
return "同一个问题调用知识库联想工具超过5次就不要再调用知识库联想"

View File

@@ -0,0 +1,9 @@
# Langchain 自带的 YouTube 搜索工具封装
from langchain.tools import YouTubeSearchTool
from pydantic import BaseModel, Field
def search_youtube(query: str):
tool = YouTubeSearchTool()
return tool.run(tool_input=query)
class YoutubeInput(BaseModel):
location: str = Field(description="Query for Videos search")

View File

@@ -0,0 +1,10 @@
name: search_youtube
description: Use this tools to search youtube videos
parameters:
type: object
properties:
query:
type: string
description: Query for Videos search
required:
- query

View File

@@ -0,0 +1,9 @@
# LangChain 的 Shell 工具
from pydantic import BaseModel, Field
from langchain.tools import ShellTool
def shell(query: str):
tool = ShellTool()
return tool.run(tool_input=query)
class ShellInput(BaseModel):
query: str = Field(description="一个能在Linux命令行运行的Shell命令")

View File

@@ -0,0 +1,10 @@
name: shell
description: Use Linux Shell to execute Linux commands
parameters:
type: object
properties:
query:
type: string
description: The command to execute
required:
- query

View File

@@ -0,0 +1,49 @@
"""
更简单的单参数输入工具实现,用于查询现在天气的情况
"""
import json
import re
from pydantic import BaseModel, Field
import requests
from configs.kb_config import SENIVERSE_API_KEY
from server.chat import utils
def weather(location: str, api_key: str):
url = f"https://api.seniverse.com/v3/weather/daily.json?key={api_key}&location={location}&language=zh-Hans&unit=c&start=0&days=5"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
weather = {
"today": json.dumps(data["results"][0]["daily"][0]),
"others": json.dumps(data["results"][0]["daily"][1:])
}
weather_info = json.dumps(weather)
return weather_info
else:
raise Exception(
f"Failed to retrieve weather: {response.status_code}")
def weathercheck(query: str):
"""获取到uid并拆分query"""
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
else:
return "<关键指令>不需要再调用该工具了</关键指令>"
location = json.loads(query)["location"]
time_based_uuid = json.loads(matches[1])["uuid"]
# sources = utils.get_shared_variable(time_based_uuid)
# sources["source_docs"]=["天气 预报"]
# sources["num"]+=1
# sources["END"] = "ok"
# utils.set_shared_variable(time_based_uuid,sources)
return weather(location, SENIVERSE_API_KEY)
except Exception as e:
return f"Failed to retrieve weather.{e}"
class WeatherInput(BaseModel):
location: str = Field(description="City name,include city and county")

View File

@@ -0,0 +1,10 @@
name: weather_check
description: Use Weather API to get weather information
parameters:
type: object
properties:
query:
type: string
description: City name,include city and county,like "厦门市思明区"
required:
- query

View File

@@ -0,0 +1,11 @@
# Langchain 自带的 Wolfram Alpha API 封装
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
from pydantic import BaseModel, Field
wolfram_alpha_appid = "your key"
def wolfram(query: str):
wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=wolfram_alpha_appid)
ans = wolfram.run(query)
return ans
class WolframInput(BaseModel):
location: str = Field(description="需要运算的具体问题")

View File

@@ -0,0 +1,10 @@
name: wolfram
description: Useful for when you need to calculate difficult math formulas
parameters:
type: object
properties:
query:
type: string
description: The formula to be calculated
required:
- query