[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
16
langchain-chat/server/agent/tools/__init__.py
Normal file
16
langchain-chat/server/agent/tools/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
## 导入所有的工具类
|
||||
# from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput
|
||||
# from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput
|
||||
# from .chat_with_Yi34B import chat_with_Yi34B, ChatWithYi34BInput
|
||||
# from .search_youtube import search_youtube, YoutubeInput
|
||||
from .calculate import calculate, CalculatorInput
|
||||
from .weather_check import weathercheck, WeatherInput
|
||||
from .shell import shell, ShellInput
|
||||
from .search_internet import search_internet, SearchInternetInput
|
||||
from .wolfram import wolfram, WolframInput
|
||||
from .arxiv import arxiv, ArxivInput
|
||||
from .knowledgebase_kgo_search import knowledgebase_kgo_search, KnowledgeKgoInput
|
||||
from .policy_knowledgebase_search import policy_knowledgebase_search, PolicyKnowledgeInput
|
||||
from .report_knowledgebase_search import report_knowledgebase_search, ReportKnowledgeInput
|
||||
from .rag_search import rag_search1, RagSearchInput
|
||||
from .duckduckgo_search import duckduckgo_search, DuckduckgoInput
|
||||
9
langchain-chat/server/agent/tools/arxiv.py
Normal file
9
langchain-chat/server/agent/tools/arxiv.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# LangChain 的 ArxivQueryRun 工具
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.tools.arxiv.tool import ArxivQueryRun
|
||||
def arxiv(query: str):
|
||||
tool = ArxivQueryRun()
|
||||
return tool.run(tool_input=query)
|
||||
|
||||
class ArxivInput(BaseModel):
|
||||
query: str = Field(description="The search query title")
|
||||
10
langchain-chat/server/agent/tools/arxiv.yaml
Normal file
10
langchain-chat/server/agent/tools/arxiv.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
name: arxiv
|
||||
description: A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: The search query title
|
||||
required:
|
||||
- query
|
||||
76
langchain-chat/server/agent/tools/calculate.py
Normal file
76
langchain-chat/server/agent/tools/calculate.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.chains import LLMMathChain
|
||||
from server.agent import model_container
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
_PROMPT_TEMPLATE = """
|
||||
将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。
|
||||
问题: ${{包含数学问题的问题。}}
|
||||
```text
|
||||
${{解决问题的单行数学表达式}}
|
||||
```
|
||||
...numexpr.evaluate(query)...
|
||||
```output
|
||||
${{运行代码的输出}}
|
||||
```
|
||||
答案: ${{答案}}
|
||||
|
||||
这是两个例子:
|
||||
|
||||
问题: 37593 * 67是多少?
|
||||
```text
|
||||
37593 * 67
|
||||
```
|
||||
...numexpr.evaluate("37593 * 67")...
|
||||
```output
|
||||
2518731
|
||||
|
||||
答案: 2518731
|
||||
|
||||
问题: 37593的五次方根是多少?
|
||||
```text
|
||||
37593**(1/5)
|
||||
```
|
||||
...numexpr.evaluate("37593**(1/5)")...
|
||||
```output
|
||||
8.222831614237718
|
||||
|
||||
答案: 8.222831614237718
|
||||
|
||||
|
||||
问题: 2的平方是多少?
|
||||
```text
|
||||
2 ** 2
|
||||
```
|
||||
...numexpr.evaluate("2 ** 2")...
|
||||
```output
|
||||
4
|
||||
|
||||
答案: 4
|
||||
|
||||
|
||||
现在,这是我的问题:
|
||||
问题: {question}
|
||||
"""
|
||||
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
class CalculatorInput(BaseModel):
|
||||
query: str = Field()
|
||||
|
||||
def calculate(query: str):
|
||||
model = model_container.MODEL
|
||||
llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT)
|
||||
ans = llm_math.run(query)
|
||||
return ans
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = calculate("2的三次方")
|
||||
print("答案:",result)
|
||||
|
||||
|
||||
|
||||
10
langchain-chat/server/agent/tools/calculate.yaml
Normal file
10
langchain-chat/server/agent/tools/calculate.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
name: calculate
|
||||
description: Useful for when you need to answer questions about simple calculations
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: The formula to be calculated
|
||||
required:
|
||||
- query
|
||||
43
langchain-chat/server/agent/tools/chat_with_Yi34B.py
Normal file
43
langchain-chat/server/agent/tools/chat_with_Yi34B.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from server.chat.chat import chat
|
||||
from server.chat.utils import History
|
||||
|
||||
|
||||
async def chat_with_Yi34B_iter(query: str,
|
||||
stream=False,
|
||||
model_name="qianfan-api",
|
||||
history: Union[int, List[History]] = None,
|
||||
conversation_id='',
|
||||
temperature=0.7,
|
||||
max_tokens=None,
|
||||
history_len=3,
|
||||
prompt_name="default"
|
||||
):
|
||||
response = await chat(query=query, history=history,
|
||||
history_len=history_len,
|
||||
conversation_id=conversation_id,
|
||||
stream=stream, model_name=model_name, temperature=temperature,
|
||||
max_tokens=max_tokens, prompt_name=prompt_name)
|
||||
|
||||
contents = ""
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents += data["text"]
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
def chat_with_Yi34B(query: str, model_name: str = "qianfan-api", conversation_id: str = '',
|
||||
history: Union[int, List[History]] = None):
|
||||
# 格式化查询字符串
|
||||
return asyncio.run(chat_with_Yi34B_iter(query, model_name=model_name, conversation_id=conversation_id,
|
||||
history=history))
|
||||
|
||||
|
||||
class ChatWithYi34BInput(BaseModel):
|
||||
location: str = Field(description="Query for any kind of chats and questions")
|
||||
18
langchain-chat/server/agent/tools/chat_with_Yi34B.yaml
Normal file
18
langchain-chat/server/agent/tools/chat_with_Yi34B.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
name: chat_with_Yi34B
|
||||
description: Use this tool to chat with human
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: Query for any kind of chat and questions
|
||||
model_name:
|
||||
type: string
|
||||
description:
|
||||
conversation_id:
|
||||
type: string
|
||||
description:
|
||||
required:
|
||||
- query
|
||||
- model_name
|
||||
- conversation_id
|
||||
34
langchain-chat/server/agent/tools/do_nothing.py
Normal file
34
langchain-chat/server/agent/tools/do_nothing.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import json
|
||||
import re
|
||||
import concurrent
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from langchain.tools import YouTubeSearchTool
|
||||
from pydantic import BaseModel, Field
|
||||
from server.chat import utils
|
||||
from server.chat.utils import doc_to_list, solve_knowledge_map,solve_mental_data,shared_variable
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from configs import kb_config
|
||||
|
||||
|
||||
def do_nothing(query: str):
|
||||
"""
|
||||
什么都不做不调用工具设置
|
||||
"""
|
||||
# 以下三行逻辑控制反问,如果不需要反问注释掉即可,但可能会带来的问题包括该agent将始终会使用工具
|
||||
|
||||
return f"\n不需要调用工具了"
|
||||
|
||||
def get_next_tip(query: str):
|
||||
"""
|
||||
什么都不做不调用工具设置
|
||||
"""
|
||||
# 以下三行逻辑控制反问,如果不需要反问注释掉即可,但可能会带来的问题包括该agent将始终会使用工具
|
||||
res = utils.get_shared_variable(query)
|
||||
res["END"] = "ok"
|
||||
utils.set_shared_variable(query,res)
|
||||
|
||||
return f"\n提示:你已经使用过环节跳转了,可以开始输出正文了"
|
||||
|
||||
class doNothingInput(BaseModel):
|
||||
query: str = Field(...,description="查询对象")
|
||||
175
langchain-chat/server/agent/tools/draw_plot.py
Normal file
175
langchain-chat/server/agent/tools/draw_plot.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import requests
|
||||
from matplotlib import pyplot as plt
|
||||
from pydantic import BaseModel, Field
|
||||
from configs.model_config import LLM_MODELS
|
||||
from server.chat import utils
|
||||
from server.chat.policy_fun_iast import get_llm_model_response
|
||||
from configs.kb_config import GENERATED_IMAGES_BASE_PATH, realistic_url,ink_url
|
||||
from matplotlib import font_manager
|
||||
|
||||
my_font = font_manager.FontProperties(fname="/usr/share/fonts/MicroSoft-YaHei/MSYH.TTC")
|
||||
def create_and_save_plot(query:str) -> str:
|
||||
try:
|
||||
query = query.replace(" ","").replace("'","\"")
|
||||
json_str ='{\n"data": {"XXX": XX, "XXX": XX, "XXX": X, "XXX": X},"title": "X","xlabel": "X","ylabel": "X","plot_type": "X"}'
|
||||
datas = {}
|
||||
try:
|
||||
match = re.search(r'<param>(.*?)</param>(\{.*\})', query.replace("\n", ""))
|
||||
if match:
|
||||
query = match.group(1).strip()
|
||||
datas = json.loads(query)
|
||||
else:
|
||||
print(f"Invalid JSON format in query:\n{query}")
|
||||
return"暂时无法画图"
|
||||
except:
|
||||
query = get_llm_model_response(
|
||||
strategy_name="query rewrite",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="check_plot",
|
||||
prompt_param_dict={"user_input": query,"json":json_str },
|
||||
temperature=0.01,
|
||||
max_tokens=512
|
||||
)
|
||||
re.search(r"```json\n(.*?)\n```", query.replace("\n", ""), re.DOTALL)
|
||||
query = match.group(1).strip()
|
||||
datas = json.loads(query)
|
||||
data = datas["data"]
|
||||
xlabel = datas["xlabel"]
|
||||
ylabel = datas["ylabel"]
|
||||
title = datas["title"]
|
||||
plot_type = datas["plot_type"]
|
||||
# 分析和汇总数据
|
||||
categories = list(data.keys())
|
||||
values = list(data.values())
|
||||
|
||||
# 创建图表
|
||||
plt.figure(figsize=(10, 6))
|
||||
|
||||
if plot_type == 'bar':
|
||||
plt.bar(categories, values, color='skyblue')
|
||||
elif plot_type == 'pie':
|
||||
plt.pie(values, labels=categories, autopct='%1.1f%%', startangle=140,textprops={'fontproperties': my_font})
|
||||
elif plot_type == 'line':
|
||||
plt.plot(categories, values, marker='o', linestyle='-')
|
||||
else:
|
||||
raise ValueError("Unsupported plot type. Choose from 'bar', 'pie', or 'line'.")
|
||||
|
||||
# 添加标题和标签
|
||||
plt.title(title,fontproperties=my_font)
|
||||
if plot_type != 'pie': # 饼状图不需要轴标签
|
||||
plt.xlabel(xlabel,fontproperties=my_font)
|
||||
plt.ylabel(ylabel,fontproperties=my_font)
|
||||
plt.xticks(fontproperties=my_font,rotation=45)
|
||||
plt.yticks(fontproperties=my_font)
|
||||
|
||||
namesid = uuid.uuid1()
|
||||
# 保存图表为图片文件
|
||||
file_path = f'{GENERATED_IMAGES_BASE_PATH}/plot{namesid}.png'
|
||||
absolute_path = os.path.abspath(file_path)
|
||||
# sources = utils.get_shared_variable(uuids)
|
||||
# sources["url"] = f"plot{namesid}.png"
|
||||
# sources["END"]="ok"
|
||||
# utils.set_shared_variable(uuids, sources)
|
||||
# plt.figure(figsize=(50, 60))
|
||||
# plt.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.15)
|
||||
plt.tight_layout()
|
||||
plt.savefig(absolute_path)
|
||||
plt.close()
|
||||
image_name = f"plot{namesid}.png"
|
||||
|
||||
return f"图片如下:"
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
return "不要再调用该工具了,无法提供该功能,你只能按工具描述提供有的入参,其他不支持"
|
||||
|
||||
|
||||
class drawPlotInput(BaseModel):
|
||||
query: str = Field(...,description="输入要画图的内容")
|
||||
|
||||
def draw_realistic_pic(query:str) -> str:
|
||||
try:
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
uuid = json.loads(matches[1])["uuid"]
|
||||
query = query.replace(matches[1],"")
|
||||
else:
|
||||
return"暂时无法画图"
|
||||
except:
|
||||
logging.error("Invalid JSON format in query.")
|
||||
|
||||
# 请求体数据
|
||||
data = {
|
||||
'prompt': json.loads(query)["query"]
|
||||
}
|
||||
|
||||
try:
|
||||
# 发起 POST 请求
|
||||
response = requests.post(realistic_url, json=data)
|
||||
|
||||
# 检查响应状态码
|
||||
if response.status_code == 200:
|
||||
# 解析 JSON 响应
|
||||
result = response.json()
|
||||
print("Image path:", result.get('image_path'))
|
||||
file_path = result.get('image_path')
|
||||
sources = utils.get_shared_variable(uuid)
|
||||
sources["url"] = file_path
|
||||
# sources["END"]="ok"
|
||||
utils.set_shared_variable(uuid, sources)
|
||||
return f"<关键指令>你必须以markdown格式的图片路径,,之后你也需要按照要求给出更多的markdown格式的数据表格</关键指令>"
|
||||
else:
|
||||
print("Failed to get response:", response.status_code)
|
||||
return "暂时无法画图"
|
||||
except requests.exceptions.RequestException as e:
|
||||
print("An error occurred:", e)
|
||||
return "暂时无法画图"
|
||||
|
||||
class drawRealisticInput(BaseModel):
|
||||
query: str = Field(...,description="输入要画图的内容")
|
||||
|
||||
def draw_ink_pic(query:str) -> str:
|
||||
get_llm_model_response()
|
||||
try:
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
uuid = json.loads(matches[1])["uuid"]
|
||||
query = query.replace(matches[1],"")
|
||||
else:
|
||||
return"暂时无法画图"
|
||||
except:
|
||||
logging.error("Invalid JSON format in query.")
|
||||
|
||||
# 请求体数据
|
||||
data = {
|
||||
'prompt': json.loads(query)["query"]
|
||||
}
|
||||
|
||||
try:
|
||||
# 发起 POST 请求
|
||||
response = requests.post(ink_url, json=data)
|
||||
|
||||
# 检查响应状态码
|
||||
if response.status_code == 200:
|
||||
# 解析 JSON 响应
|
||||
result = response.json()
|
||||
print("Image path:", result.get('image_path'))
|
||||
file_path = result.get('image_path')
|
||||
sources = utils.get_shared_variable(uuid)
|
||||
sources["url"] = file_path
|
||||
# sources["END"]="ok"
|
||||
utils.set_shared_variable(uuid, sources)
|
||||
return f"<关键指令>你必须以markdown格式的图片路径,,之后你也需要按照要求给出更多的markdown格式的数据表格</关键指令>"
|
||||
else:
|
||||
print("Failed to get response:", response.status_code)
|
||||
return "暂时无法画图"
|
||||
except requests.exceptions.RequestException as e:
|
||||
print("An error occurred:", e)
|
||||
return "暂时无法画图"
|
||||
|
||||
class drawInkInput(BaseModel):
|
||||
query: str = Field(...,description="输入要画图的内容")
|
||||
186
langchain-chat/server/agent/tools/duckduckgo_search.py
Normal file
186
langchain-chat/server/agent/tools/duckduckgo_search.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import asyncio
|
||||
import re
|
||||
import aiohttp
|
||||
import json
|
||||
import logging
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from server.chat import utils
|
||||
|
||||
# 配置日志记录器
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def duckduckgo_search_iter(query: str, uuid: str = "",time: str = "", resource_type: str = None, limit: int = 3):
|
||||
# 定义三个API的URL
|
||||
text_url = 'http://43.251.225.121/inspur/search_text'
|
||||
video_url = 'http://43.251.225.121/inspur/search_video'
|
||||
news_url = 'http://43.251.225.121/inspur/search_new'
|
||||
|
||||
payload = {
|
||||
"query": query,
|
||||
"time": time
|
||||
}
|
||||
|
||||
async def fetch(session, url, json_payload,limit):
|
||||
logger.info(f"从 {url} 获取数据,请求参数: {json_payload}")
|
||||
try:
|
||||
json_payload["limit"] = limit
|
||||
async with session.post(url, json=json_payload) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"向 {url} 请求失败,状态码 {response.status}")
|
||||
data = await response.json()
|
||||
logger.info(f"从 {url} 获取的资料数: {len(data) if isinstance(data, list) else '未知'}")
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"获取 {url} 数据时发生错误: {e}")
|
||||
return []
|
||||
|
||||
# 根据 resource_type 确定要请求的 API
|
||||
# 默认并发请求三个API
|
||||
# 视频只请求 video_url
|
||||
# 新闻只请求 news_url
|
||||
# 其他类型只请求 text_url
|
||||
async with aiohttp.ClientSession() as session:
|
||||
logger.info("发起请求duckduckgo...")
|
||||
|
||||
n = limit % 3
|
||||
limit1 = 0
|
||||
limit2 = 0
|
||||
limit3 = 0
|
||||
match n:
|
||||
case 0:
|
||||
limit1 = limit//3
|
||||
limit2 = limit1
|
||||
limit3 = limit1
|
||||
case 1:
|
||||
limit1 = limit//3 +1
|
||||
limit2 = limit//3
|
||||
limit3 = limit2
|
||||
case 2:
|
||||
limit1 = limit//3 +1
|
||||
limit2 = limit1
|
||||
limit2 = limit
|
||||
|
||||
if resource_type is None or not resource_type == 'video':
|
||||
text_task = asyncio.create_task(fetch(session, text_url, payload,limit1))
|
||||
video_task = asyncio.create_task(fetch(session, video_url, payload, limit3))
|
||||
news_task = asyncio.create_task(fetch(session, news_url, payload, limit2))
|
||||
text_result, video_result, news_result = await asyncio.gather(text_task, video_task, news_task)
|
||||
logger.info("合并结果...")
|
||||
|
||||
logger.info("合并结果完成")
|
||||
combined_result = {
|
||||
"text": text_result,
|
||||
"video": video_result,
|
||||
"news": news_result
|
||||
}
|
||||
|
||||
else:
|
||||
video_result = await fetch(session, video_url, payload, limit)
|
||||
combined_result = {
|
||||
"video": video_result
|
||||
}
|
||||
del limit1,limit2,limit3
|
||||
# elif resource_type == 'news':
|
||||
# news_result = await fetch(session, news_url, payload)
|
||||
# combined_result = {
|
||||
# "news": news_result
|
||||
# }
|
||||
|
||||
# else: # 其他类型
|
||||
# text_result = await fetch(session, text_url, payload)
|
||||
# combined_result = {
|
||||
# "text": text_result
|
||||
# }
|
||||
|
||||
logger.info("请求已完成")
|
||||
res = []
|
||||
source = []
|
||||
info = utils.get_shared_variable(uuid)
|
||||
index = info["num"]
|
||||
if "text" in combined_result:
|
||||
for item in combined_result["text"]:
|
||||
index += 1
|
||||
res.append(f'资料[{index}] 资料标题{item["title"]}({item["href"]}) 资料内容为: {item["body"]}')
|
||||
source.append(f'资料[{index}] [{item["title"]}]({item["href"]})')
|
||||
if "video" in combined_result:
|
||||
for item in combined_result["video"]:
|
||||
index += 1
|
||||
res.append(f'资料[{index}] 视频标题[{item["title"]}]({item["content"]}) 视频内容为: {item["description"]}')
|
||||
source.append(f'视频资料[{index}] [{item["title"]}]({item["content"]})')
|
||||
if "news" in combined_result:
|
||||
for item in combined_result["news"]:
|
||||
index += 1
|
||||
res.append(f'资料[{index}] 新闻标题[{item["title"]}]({item["url"]}) 新闻内容为: {item["body"]}')
|
||||
source.append(f'资料[{index}] [{item["title"]}]({item["url"]})')
|
||||
info["source_docs"].extend(source)
|
||||
utils.set_shared_variable(uuid, info)
|
||||
return res,source
|
||||
|
||||
|
||||
def duckduckgo_search(query: str, time: str = "", resource_type: str = None):
|
||||
logger.info(f"模型输入: {query}")
|
||||
# 对传入的 query 字段进行解析
|
||||
# 判断 query 是否包含 "}{"
|
||||
# if "}{" in query:
|
||||
# # 将 query 分割为两个JSON字符串
|
||||
# split_index = query.find("}{")
|
||||
# json_part1 = query[:split_index+1]
|
||||
# json_part2 = query[split_index+1:]
|
||||
|
||||
# try:
|
||||
# obj1 = json.loads(json_part1)
|
||||
# obj2 = json.loads(json_part2)
|
||||
|
||||
# # 提取 query, resource_type, time, uuid
|
||||
# parsed_query = obj1.get("query", "")
|
||||
# parsed_resource_type = obj1.get("resource_type", None)
|
||||
# parsed_time = obj1.get("time", time) # 如obj1未包含time则使用传入的默认值
|
||||
# parsed_uuid = obj2.get("uuid", "")
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
query = matches[0]
|
||||
else:
|
||||
return "<关键指令>不需要再调用该工具了</关键指令>"
|
||||
try:
|
||||
obj1= json.loads(query)
|
||||
parsed_query = obj1.get("query", "")
|
||||
parsed_limit = obj1.get("limit", 3)
|
||||
parsed_resource_type = obj1.get("resource_type", None)
|
||||
parsed_time = obj1.get("time", time) # 如obj1未包含time则使用传入的默认值
|
||||
parsed_uuid = json.loads(matches[1])["uuid"]
|
||||
# 将解析到的值覆盖原有的参数
|
||||
query = parsed_query if parsed_query else query
|
||||
resource_type = parsed_resource_type if parsed_resource_type else resource_type
|
||||
time = parsed_time if parsed_time else time
|
||||
|
||||
logger.info(f"解析完成,query: {query}, uuid: {parsed_uuid}, time: {time}, resource_type: {resource_type}, parsed_limit: {parsed_limit}")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析JSON出错: {e}")
|
||||
|
||||
# 在同步环境中运行异步函数
|
||||
combined_result = asyncio.run(duckduckgo_search_iter(query, parsed_uuid, time, resource_type, parsed_limit))
|
||||
# 以标准json格式输出
|
||||
logger.info("返回JSON格式的结果给到模型...")
|
||||
return combined_result
|
||||
class DuckduckgoInput(BaseModel):
|
||||
location: str = Field(description="网络搜索查询")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试调用
|
||||
# 1. 默认请求三个API
|
||||
# result_default = duckduckgo_search("粉末冶金", "m", "default")
|
||||
# print("duckduckgo输出(默认):\n", result_default)
|
||||
|
||||
# # 2. 只请求视频
|
||||
# result_video = duckduckgo_search("粉末冶金", "m", "video")
|
||||
# print("duckduckgo输出(视频):\n", result_video)
|
||||
|
||||
# # 3. 只请求新闻
|
||||
# result_news = duckduckgo_search("粉末冶金", "m", "news")
|
||||
# print("duckduckgo输出(新闻):\n", result_news)
|
||||
|
||||
# 4. 其它类型只请求文本
|
||||
result_other = duckduckgo_search("粉末冶金", "m", "other")
|
||||
print("duckduckgo输出(其他):\n", result_other)
|
||||
57
langchain-chat/server/agent/tools/get_statistical_data.py
Normal file
57
langchain-chat/server/agent/tools/get_statistical_data.py
Normal file
@@ -0,0 +1,57 @@
|
||||
|
||||
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
import requests
|
||||
|
||||
from configs.model_config import LLM_MODELS
|
||||
from server.chat import utils
|
||||
from server.chat.policy_fun_iast import get_llm_model_response
|
||||
|
||||
|
||||
def mysql_statistic(query:str) -> str:
|
||||
try:
|
||||
logging.info(f"\n🔍 统计工具查询query: \n{query}\n")
|
||||
matches = re.search(r'<param>(.*?)</param>(\{.*\})', query.replace("\n", ""))
|
||||
if matches:
|
||||
uuid = json.loads(matches.group(2))["uuid"]
|
||||
query = matches.group(1).strip()
|
||||
else:
|
||||
res1 = utils.get_shared_variable(uuid)
|
||||
# res1["END"]="ok"
|
||||
utils.set_shared_variable(uuid,res1)
|
||||
return"暂时无法查询"
|
||||
except:
|
||||
res = utils.get_shared_variable(uuid)
|
||||
# res["END"]="ok"
|
||||
utils.set_shared_variable(uuid,res)
|
||||
logging.error("Invalid JSON format in query.")
|
||||
return f"暂时无法查询"
|
||||
|
||||
question = json.loads(query)["query"]
|
||||
# question = get_llm_model_response(
|
||||
# strategy_name="default_code",
|
||||
# llm_model_name=LLM_MODELS[2],
|
||||
# template_prompt_name="sql_query_rewrite",
|
||||
# prompt_param_dict={"query": question,"time": datetime.now().strftime("%Y%m%d")},
|
||||
# temperature=0.01,
|
||||
# max_tokens=512
|
||||
# )
|
||||
logging.info(f"\n🔍 NL2SQL检索question: \n{question}\n")
|
||||
res = requests.post(
|
||||
url=f"http://127.0.0.1:6008/query",
|
||||
json={"question": question},
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
if res:
|
||||
data = res.json()["result"]
|
||||
if "'data': []" in data:
|
||||
return f"统计库未检索到数据,使用“联网思索”工具检索该请求:{question}\n"
|
||||
# temp = utils.get_shared_variable(uuid)
|
||||
# temp["END"]="ok"
|
||||
# utils.set_shared_variable(uuid,temp)
|
||||
else:
|
||||
return f"判断得到的数据是否准确,如果不准确,则使用“联网思索”工具检索。如果准确,则根据数据表格并使用图表绘制工具制图。\n 数据如下所示: \n{data}\n"
|
||||
170
langchain-chat/server/agent/tools/knowledgebase_kgo_search.py
Normal file
170
langchain-chat/server/agent/tools/knowledgebase_kgo_search.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Any, Union
|
||||
import concurrent
|
||||
from pydantic import BaseModel, Field
|
||||
from difflib import SequenceMatcher
|
||||
from configs import (VECTOR_SEARCH_TOP_K,
|
||||
SCORE_THRESHOLD,
|
||||
DEFAULT_POLICY_BASE)
|
||||
from server.agent.tools import search_internet
|
||||
from server.chat import utils
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.utils import BaseResponse
|
||||
|
||||
|
||||
class KnowledgeKgoInput(BaseModel):
|
||||
location: str = Field(description="Query for Internet search")
|
||||
|
||||
|
||||
|
||||
def preprocess_text(text: str) -> str:
|
||||
# 去除空格和特殊符号
|
||||
text = re.sub(r'[\s\W]', '', text.strip())
|
||||
return text
|
||||
|
||||
|
||||
def knowledge_temperature(a: str, b: str) -> float:
|
||||
# 使用difflib中的SequenceMatcher计算相似度
|
||||
return SequenceMatcher(None, a, b).ratio()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# def knowledgebase_kgo_iter(query: str,
|
||||
# fileName: List = [],
|
||||
# knowledge_base_name: str = DEFAULT_POLICY_BASE,
|
||||
# top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
# score_threshold: float = SCORE_THRESHOLD) -> BaseResponse | list[str] | Any:
|
||||
# kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
# if kb is None:
|
||||
# return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
# query = query.strip()
|
||||
|
||||
# docs = search_docs(fileName=fileName,
|
||||
# query=query,
|
||||
# knowledge_base_name=knowledge_base_name,
|
||||
# top_k=top_k,
|
||||
# score_threshold=score_threshold)
|
||||
|
||||
# # 预处理查询文本
|
||||
# processed_query = preprocess_text(query).replace("Observ","")
|
||||
# print("processed_query:", processed_query)
|
||||
# knowledge_docs = []
|
||||
# knowledge_content = []
|
||||
# # 知识库返回的文档与query的相似度
|
||||
# if docs:
|
||||
# for enum, doc in enumerate(docs):
|
||||
# filename = doc.metadata.get("title")
|
||||
# detail_url = "https://policy.ckcest.cn/detail/" + doc.metadata.get("primary_key") + ".html"
|
||||
# if filename:
|
||||
# text = f"""政策类资料[{enum + 1}]: [{filename}]({detail_url})\n"""
|
||||
# else:
|
||||
# text = f"""政策类资料[{enum + 1}]: [{"原文地址"}]({detail_url})\n"""
|
||||
# knowledge_docs.append(text)
|
||||
# # print("knowledge_docs:", knowledge_docs)
|
||||
# knowledge_content = [doc.page_content for doc in docs]
|
||||
# # print("knowledge_content:", knowledge_content)
|
||||
# # 计算知识库返回的文档与query的相似度
|
||||
# titles = [doc.metadata.get("title") for doc in docs]
|
||||
# print("titles:", titles)
|
||||
# def check_similarity_threshold(titles: List[str], query: str, knowledge_docs: List[str], knowledge_content: List[str]) -> Union[
|
||||
# List[str], None]:
|
||||
# # 用于记录是否存在相似度大于0.55的标题
|
||||
# has_similar_title = False
|
||||
# for title in titles:
|
||||
# processed_title = preprocess_text(title)
|
||||
# similarity = knowledge_temperature(processed_query, processed_title)
|
||||
# print("processed_title:", processed_title)
|
||||
# print("similarity:", similarity)
|
||||
# if similarity >= 0.55:
|
||||
# has_similar_title = True
|
||||
# break
|
||||
# # 如果存在相似度大于0.55的标题,则直接返回 knowledge_docs
|
||||
# if has_similar_title:
|
||||
# knowledge = knowledge_content + knowledge_docs
|
||||
# return knowledge
|
||||
# # 如果所有标题的相似度都不大于0.55,则返回 None
|
||||
# return None
|
||||
|
||||
# # 在原函数中使用新的函数进行相似度阈值的判断
|
||||
# similar_docs = check_similarity_threshold(titles, query, knowledge_docs, knowledge_content)
|
||||
# if similar_docs is None:
|
||||
# # 如果所有标题的相似度都不大于0.55,则执行搜索引擎查询
|
||||
# kgo_docs = search_internet(processed_query)
|
||||
# # print("kgo_docs", kgo_docs)
|
||||
# return kgo_docs
|
||||
# else:
|
||||
# kgo_docs = search_internet(processed_query)
|
||||
# # print("similar_docs", similar_docs)
|
||||
# # print("kgo_docs", kgo_docs)
|
||||
# similar_docs.extend(kgo_docs)
|
||||
# return similar_docs
|
||||
# else:
|
||||
# # 执行搜索引擎查询
|
||||
# kgo_docs = search_internet(query)
|
||||
# return kgo_docs
|
||||
|
||||
def knowledgebase_kgo_iter(query: str, uid: str) -> BaseResponse | list[str] | Any:
|
||||
kgo_docs = search_internet(query , uid)
|
||||
|
||||
return kgo_docs
|
||||
def knowledgebase_kgo_search(query: str) -> List[str]:
|
||||
try:
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
query = matches[0]
|
||||
parsed_query = json.loads(query)
|
||||
# 继续使用解析后的查询进行后续操作
|
||||
time_based_uuid = json.loads(matches[1])["uuid"]
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(knowledgebase_kgo_iter, parsed_query["query"], time_based_uuid)
|
||||
res = future.result()
|
||||
# res = knowledgebase_kgo_iter(parsed_query["query"],time_based_uuid)
|
||||
try:
|
||||
if type(res[0])==list and len(res[0])>0:
|
||||
doc_content = "资料内容"
|
||||
for doc in res[0]:
|
||||
doc_content += doc
|
||||
doc_content += "资料来源"
|
||||
for source in res[1]:
|
||||
doc_content += source
|
||||
return f"{doc_content}"
|
||||
elif type(res[1])==list and len(res[1])>0:
|
||||
doc_content += "资料来源"
|
||||
for source in res[1]:
|
||||
doc_content += source
|
||||
return "只有标题没有内容,标题为:{doc_content}"
|
||||
else:
|
||||
# return "<system>不要再调用工具了</system>"
|
||||
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
||||
except Exception as e:
|
||||
logging.error(f"Error occurred while processing query: {e}")
|
||||
# return "<system>不要再调用该工具了,根据已有资料或自身能力回答</system>"
|
||||
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
||||
else:
|
||||
logging.error("Invalid JSON format in query.")
|
||||
# return "<system>不要再调用该工具了</system>"
|
||||
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
||||
except json.JSONDecodeError:
|
||||
# 如果JSON解码失败,则返回错误消息
|
||||
logging.error("Invalid JSON format in query.")
|
||||
# return "<system>不要再调用该工具了</system>"
|
||||
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
||||
except KeyError:
|
||||
# 如果解析的JSON对象中缺少必要的键,则返回错误消息
|
||||
# return "<system>不要再调用该工具了</system>"
|
||||
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
||||
except Exception as e:
|
||||
# 捕获其他所有异常,并返回通用错误消息
|
||||
logging.error(f"Error occurred while processing query: {e}")
|
||||
# return f"<system>不要再调用该工具了</system>"
|
||||
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = knowledgebase_kgo_iter("《区块链和分布式记账技术标准体系建设指南》")
|
||||
print("检索结果:", result)
|
||||
113
langchain-chat/server/agent/tools/math.py
Normal file
113
langchain-chat/server/agent/tools/math.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs.model_config import LLM_MODELS
|
||||
from server.agent.tools.search_tool import search_tool
|
||||
from server.chat import utils
|
||||
from server.chat.policy_fun_iast import get_llm_model_response
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def math_count(query: str):
|
||||
try:
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
query = matches[0]
|
||||
parsed_query = json.loads(query)["query"]
|
||||
# 继续使用解析后的查询进行后续操作
|
||||
time_based_uuid = json.loads(matches[1])["uuid"]
|
||||
# temp = utils.get_shared_variable(time_based_uuid)
|
||||
# temp["END"] ="ok"
|
||||
# utils.set_shared_variable(time_based_uuid,temp)
|
||||
|
||||
tip = {}
|
||||
# tip["END"]="ok"
|
||||
tip["source_docs"]=[]
|
||||
tip["num"]=0
|
||||
tip["title"]=[]
|
||||
|
||||
utils.set_shared_variable(time_based_uuid+"q",tip)
|
||||
first_json = {
|
||||
"query": parsed_query,
|
||||
"knowledge_name": [],
|
||||
"keywords": []
|
||||
}
|
||||
second_json = {
|
||||
"uuid": time_based_uuid+"q"
|
||||
}
|
||||
math_doc = search_tool(json.dumps(first_json) + json.dumps(second_json))
|
||||
utils.remove_shared_variable(time_based_uuid+"q")
|
||||
res = get_llm_model_response(
|
||||
strategy_name="default_math",
|
||||
llm_model_name=LLM_MODELS[3],
|
||||
template_prompt_name="default_math",
|
||||
prompt_param_dict={"input": parsed_query, "math_doc": f"{math_doc}", "time": datetime.now().strftime("%Y%m%d")},
|
||||
temperature=0.01,
|
||||
max_tokens=512
|
||||
)
|
||||
return f"{res}"
|
||||
except Exception as e:
|
||||
logging.error(f"Error occurred while processing math query: {e}")
|
||||
return "<system>不要再调用该工具了</system>"
|
||||
|
||||
except Exception as e:
|
||||
return "<system>不要再调用该工具了</system>"
|
||||
|
||||
def code_count(query: str):
|
||||
try:
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
query = matches[0]
|
||||
parsed_query = json.loads(query)["query"]
|
||||
# 继续使用解析后的查询进行后续操作
|
||||
time_based_uuid = json.loads(matches[1])["uuid"]
|
||||
# temp = utils.get_shared_variable(time_based_uuid)
|
||||
# temp["END"] ="ok"
|
||||
# utils.set_shared_variable(time_based_uuid,temp)
|
||||
|
||||
tip = {}
|
||||
# tip["END"]="ok"
|
||||
tip["source_docs"]=[]
|
||||
tip["num"]=0
|
||||
tip["title"]=[]
|
||||
|
||||
utils.set_shared_variable(time_based_uuid+"q",tip)
|
||||
first_json = {
|
||||
"query": parsed_query,
|
||||
"knowledge_name": [],
|
||||
"keywords": []
|
||||
}
|
||||
second_json = {
|
||||
"uuid": time_based_uuid+"q"
|
||||
}
|
||||
code_doc = search_tool(json.dumps(first_json) + json.dumps(second_json))
|
||||
utils.remove_shared_variable(time_based_uuid+"q")
|
||||
res = get_llm_model_response(
|
||||
strategy_name="default_code",
|
||||
llm_model_name=LLM_MODELS[2],
|
||||
template_prompt_name="default_code",
|
||||
prompt_param_dict={"input": parsed_query, "code_doc": f"{code_doc}", "time": datetime.now().strftime("%Y%m%d")},
|
||||
temperature=0.01,
|
||||
max_tokens=512
|
||||
)
|
||||
res = res.replace("<think>","")
|
||||
return f"{res}"
|
||||
except Exception as e:
|
||||
logging.error(f"Error occurred while processing math query: {e}")
|
||||
return "<system>不要再调用该工具了</system>"
|
||||
|
||||
except Exception as e:
|
||||
return "<system>不要再调用该工具了</system>"
|
||||
|
||||
|
||||
|
||||
|
||||
class RagSearchInput(BaseModel):
|
||||
query: str = Field(...,description="查询对象")
|
||||
@@ -0,0 +1,42 @@
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS, LLM_MODELS
|
||||
import json
|
||||
import asyncio
|
||||
from typing import List, Tuple, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PolicyKnowledgeInput(BaseModel):
|
||||
location: str = Field(description="The policy related query to be searched")
|
||||
|
||||
|
||||
async def policy_knowledgebase_search_iter(query: str) -> tuple[str | Any, list[Any] | Any]:
|
||||
response = await knowledge_base_chat(query=query,
|
||||
fileName=None,
|
||||
knowledge_base_name_list = ["t_policy_total_bge_new_v1"],
|
||||
model_name=LLM_MODELS[0],
|
||||
temperature=0.01,
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
max_tokens=MAX_TOKENS,
|
||||
prompt_name="default",
|
||||
score_threshold=SCORE_THRESHOLD,
|
||||
stream=False)
|
||||
|
||||
contents = ""
|
||||
docs = []
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
print("data>>>>>", data)
|
||||
contents = data["answer"] if "answer" in data.keys() else contents
|
||||
docs = data["docs"] if "docs" in data.keys() else docs
|
||||
return contents, docs
|
||||
|
||||
|
||||
def policy_knowledgebase_search(query: str) -> tuple[str | Any, list[Any] | Any]:
|
||||
return asyncio.run(policy_knowledgebase_search_iter(query))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = policy_knowledgebase_search("大数据男女比例")
|
||||
print("答案:", result)
|
||||
108
langchain-chat/server/agent/tools/rag_search.py
Normal file
108
langchain-chat/server/agent/tools/rag_search.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import json
|
||||
import re
|
||||
import concurrent
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from langchain.tools import YouTubeSearchTool
|
||||
from pydantic import BaseModel, Field
|
||||
from server.chat import utils
|
||||
from server.chat.utils import doc_to_list, solve_knowledge_map,solve_mental_data,shared_variable
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from configs import kb_config
|
||||
|
||||
|
||||
def rag_search1(query: str):
|
||||
"""
|
||||
根据用户输入的query,返回rag搜索结果
|
||||
"""
|
||||
try:
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
query = matches[0]
|
||||
else:
|
||||
return "<关键指令>不需要再调用该工具了</关键指令>"
|
||||
time_based_uuid = json.loads(matches[1])["uuid"]
|
||||
search = json.loads(query)
|
||||
search_query = search["query"]
|
||||
search_keywords = []
|
||||
search_text = f"{search_query}"
|
||||
if type(search["keywords"]) == list:
|
||||
search_keywords = search["keywords"]
|
||||
for keyword in search_keywords:
|
||||
search_text += f" {keyword}"
|
||||
else:
|
||||
search_keywords = search["keywords"].split(",")
|
||||
for keyword in search_keywords:
|
||||
search_text += f" {keyword}"
|
||||
result = []
|
||||
source_docs = {}
|
||||
knownledge_name = []
|
||||
if type(search["knowledge_name"]) == list:
|
||||
knownledge_name=search["knowledge_name"]
|
||||
else:
|
||||
knownledge_name=search["knowledge_name"].split(",")
|
||||
for knownledge in knownledge_name:
|
||||
if not knownledge in kb_config.CH_BASE_NAME:
|
||||
knownledge_name.remove(knownledge)
|
||||
if len(knownledge_name)==0:
|
||||
result.append(f"没有找到匹配的知识库,请必须更换联网思索搜索更多知识库内容")
|
||||
return result
|
||||
# knownledge_name=kb_config.CH_BASE_NAME
|
||||
|
||||
knownledge_name=solve_knowledge_map(knownledge_name)
|
||||
num = 0
|
||||
for knownledge in knownledge_name:
|
||||
source_docs[knownledge] = []
|
||||
seen_docs = set()
|
||||
duplicate_indices = []
|
||||
doc_list = search_docs(usr_query=search_text,fileName= [],top_k=5,score_threshold=0.9,query=search_text, knowledge_base_name=knownledge)
|
||||
|
||||
|
||||
for inum,doc in enumerate(doc_list):
|
||||
solve_mental_data(knownledge,doc_list,doc=doc,seen_docs=seen_docs,duplicate_indices=duplicate_indices,knowledge=result,inum=inum)
|
||||
|
||||
# 从policydocs中删除重复的文档(从后往前删除以防止索引错位)
|
||||
for index in sorted(duplicate_indices, reverse=True):
|
||||
del doc_list[index]
|
||||
# 处理原文来源进入数组。使用开关语句明确各个条件分支
|
||||
match knownledge:
|
||||
# 属于政策库分支,入参为中文政策库名称
|
||||
case kb_config.DEFAULT_POLICY_BASE:
|
||||
doc_to_list(num,kb_config.DEFAULT_POLICY_BASE_NAME,doc_list,source_docs[knownledge])
|
||||
# 属于期刊论文库分支,入参为期刊论文库的中文名称
|
||||
case kb_config.DEFAULT_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.DEFAULT_JOURNAL_BASE_NAME,doc_list,source_docs[knownledge])
|
||||
# 属于报告库分支,入参为报告库中文名称
|
||||
case kb_config.DEFAULT_REPORT_BASE:
|
||||
doc_to_list(num,kb_config.DEFAULT_REPORT_BASE_NAME,doc_list,source_docs[knownledge])
|
||||
# 属于冶金行业新闻库分支,入参为冶金行业新闻库中文名称
|
||||
case kb_config.GY_NEWS_BASE:
|
||||
doc_to_list(num,kb_config.GY_NEWS_BASE_NAME,doc_list,source_docs[knownledge])
|
||||
# 属于冶金行业报告库分支,入参为冶金行业报告库中文名称
|
||||
case kb_config.GY_REPORT_BASE:
|
||||
doc_to_list(num,kb_config.GY_REPORT_BASE_NAME,doc_list,source_docs[knownledge])
|
||||
# 属于冶金专业知识库分支,入参为冶金专业知识库中文名称
|
||||
case kb_config.GY_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.GY_JOURNAL_BASE_NAME,doc_list,source_docs[knownledge])
|
||||
case _:
|
||||
print(f"输入了没有的知识库名称")
|
||||
return("输入了没有的知识库名称")
|
||||
num += len(source_docs[knownledge])
|
||||
# 构建缓存对象用于h'per_query_cache'用于传递给其他方法使用uuid获取本轮对话的文献来源
|
||||
del num
|
||||
source = utils.get_shared_variable(time_based_uuid)
|
||||
print(utils.get_shared_variable(time_based_uuid))
|
||||
source["source_docs"]=source_docs
|
||||
utils.set_shared_variable(time_based_uuid,source)
|
||||
if 0<len(result)<3:
|
||||
return f"当前资料:{result}\n<关键指令>搜索结果较少,更换知识库或联网思索重新搜索!!!</关键指令>"
|
||||
if len(result)==0:
|
||||
return "注意:【指令:更换知识库或联网思索继续搜索!!!】"
|
||||
except:
|
||||
# return "入参格式需为{\"knowledge_name\":[\"XXX\",\"XXX\"],\"query\":\"XXX\",\"keywords\":[\"XXX\", \"XXX\", \"XXX\", \"XXX\"]}检查输入参数如果没有缺少必要值,当前工具异常请换其他工具"
|
||||
return "当前工具异常!请换其他工具"
|
||||
|
||||
return f"当前资料:{result}\n<关键指令>总结此内容!!!</关键指令>"
|
||||
|
||||
class RagSearchInput(BaseModel):
|
||||
query: str = Field(...,description="查询对象")
|
||||
@@ -0,0 +1,42 @@
|
||||
from server.chat.report_chat import report_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS, LLM_MODELS, DEFAULT_REPORT_BASE
|
||||
import json
|
||||
import asyncio
|
||||
from typing import List, Tuple, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ReportKnowledgeInput(BaseModel):
|
||||
location: str = Field(description="The report related query to be searched")
|
||||
|
||||
|
||||
async def report_knowledgebase_search_iter(query: str) -> tuple[str | Any, list[Any] | Any]:
|
||||
response = await report_chat(query=query,
|
||||
fileName=None,
|
||||
knowledge_base_name=DEFAULT_REPORT_BASE,
|
||||
model_name=LLM_MODELS[0],
|
||||
temperature=0.01,
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
max_tokens=MAX_TOKENS,
|
||||
prompt_name="default",
|
||||
score_threshold=SCORE_THRESHOLD,
|
||||
stream=False)
|
||||
|
||||
contents = ""
|
||||
docs = []
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
print("data>>>>>", data)
|
||||
contents = data["answer"] if "answer" in data.keys() else contents
|
||||
docs = data["docs"] if "docs" in data.keys() else docs
|
||||
return contents, docs
|
||||
|
||||
|
||||
def report_knowledgebase_search(query: str) -> tuple[str | Any, list[Any] | Any]:
|
||||
return asyncio.run(report_knowledgebase_search_iter(query))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = report_knowledgebase_search("大数据男女比例")
|
||||
print("答案:", result)
|
||||
89
langchain-chat/server/agent/tools/search_internet.py
Normal file
89
langchain-chat/server/agent/tools/search_internet.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import json
|
||||
import asyncio
|
||||
import unicodedata
|
||||
|
||||
from server.chat.KgoSearchAPIWrapper import KgoSearchAPIWrapper
|
||||
from server.chat.search_engine_chat import search_engine_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, MAX_TOKENS
|
||||
from server.agent import model_container
|
||||
from pydantic import BaseModel, Field
|
||||
from configs import LLM_MODELS, TEMPERATURE
|
||||
from configs.basic_config import *
|
||||
|
||||
# def get_kgo_search_type(query: str = "全部"):
|
||||
# # 过滤掉所有非汉字的字符
|
||||
# query = ''.join(char for char in query if unicodedata.category(char) == 'Lo')
|
||||
# search_map = KgoSearchAPIWrapper().search_map
|
||||
|
||||
# if "论文" in query:
|
||||
# return "1001"
|
||||
# elif "外文" in query or "英文" in query:
|
||||
# return "1013"
|
||||
# elif "期刊" in query or "研究进展" in query:
|
||||
# return "1002"
|
||||
# else:
|
||||
# matched_types = [value for key, value in search_map.items() if key in query]
|
||||
# if matched_types:
|
||||
# return ','.join(matched_types)
|
||||
# else:
|
||||
# print("未找到匹配的搜索类型,返回默认值:1000")
|
||||
# return "1000"
|
||||
|
||||
@timing_decorator
|
||||
async def search_engine_iter(query: str , uid: str):
|
||||
response = await search_engine_chat(uid = uid,
|
||||
query=query,
|
||||
search_engine_name="zhipu_search",
|
||||
model_name=LLM_MODELS[1],
|
||||
temperature=TEMPERATURE, # Agent搜索互联网的时候,温度设为0.1
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
max_tokens=MAX_TOKENS,
|
||||
prompt_name="search",
|
||||
stream=False,
|
||||
kgo_search_type="1000")
|
||||
|
||||
contents = ""
|
||||
docs = []
|
||||
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents = data.get("answer", [])
|
||||
current_docs = data.get("docs", [])
|
||||
if current_docs:
|
||||
docs.extend(current_docs)
|
||||
else:
|
||||
logging.error("No docs found in the response")
|
||||
return docs
|
||||
# print("contents:", contents)
|
||||
# print("docs:", docs)
|
||||
# 回复搜索结果和搜索结果的总结
|
||||
return contents,docs
|
||||
|
||||
|
||||
def search_internet(query: str, uid: str):
|
||||
# filter_words = {
|
||||
# "统计数据", "视频", "数据集", "新闻", "专利", "期刊", "图书", "报告", "项目", "成果", "会议论文",
|
||||
# "政策", "外文期刊论文", "学位论文", "期刊论文", "全部论文", "原文", "全文", "pdf", "资料", " ",
|
||||
# "进展", "研究", "最新", "外文", "英文", "最新", "文件", "资料", "论文"
|
||||
# }
|
||||
# filtered_query = query
|
||||
# for word in filter_words:
|
||||
# if word in query:
|
||||
# filtered_query = filtered_query.replace(word, "")
|
||||
# print("filtered query:", filtered_query)
|
||||
|
||||
# kgo_search_type = get_kgo_search_type(query)
|
||||
# print("kgo_search_type:", kgo_search_type)
|
||||
|
||||
# 使用过滤后的查询字符串进行搜索
|
||||
return asyncio.run(search_engine_iter(query, uid))
|
||||
|
||||
|
||||
class SearchInternetInput(BaseModel):
|
||||
location: str = Field(description="Query for Internet search")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = search_internet("人工智能领域的政策")
|
||||
print("答案:", result)
|
||||
15
langchain-chat/server/agent/tools/search_internet.yaml
Normal file
15
langchain-chat/server/agent/tools/search_internet.yaml
Normal file
@@ -0,0 +1,15 @@
|
||||
name: search_internet
|
||||
description: Use this tool to surf internet and get information
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: Query for Internet search
|
||||
kgo_search_type:
|
||||
type: int
|
||||
description: the return value 'kgo_search_type' of the 'get_kgo_search_type'
|
||||
default: 1000
|
||||
required:
|
||||
- query
|
||||
- kgo_search_type
|
||||
@@ -0,0 +1,294 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import re
|
||||
import warnings
|
||||
from typing import Dict
|
||||
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from typing import List, Any, Optional
|
||||
from langchain.prompts import PromptTemplate
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
|
||||
import asyncio
|
||||
from server.agent import model_container
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
async def search_knowledge_base_iter(database: str, query: str) -> str:
|
||||
response = await knowledge_base_chat(query=query,
|
||||
knowledge_base_name=database,
|
||||
model_name=model_container.MODEL.model_name,
|
||||
temperature=0.01,
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
max_tokens=MAX_TOKENS,
|
||||
prompt_name="default",
|
||||
score_threshold=SCORE_THRESHOLD,
|
||||
stream=False)
|
||||
|
||||
contents = ""
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents += data["answer"]
|
||||
docs = data["docs"]
|
||||
return contents
|
||||
|
||||
|
||||
async def search_knowledge_multiple(queries) -> List[str]:
|
||||
# queries 应该是一个包含多个 (database, query) 元组的列表
|
||||
tasks = [search_knowledge_base_iter(database, query) for database, query in queries]
|
||||
results = await asyncio.gather(*tasks)
|
||||
# 结合每个查询结果,并在每个查询结果前添加一个自定义的消息
|
||||
combined_results = []
|
||||
for (database, _), result in zip(queries, results):
|
||||
message = f"\n查询到 {database} 知识库的相关信息:\n{result}"
|
||||
combined_results.append(message)
|
||||
|
||||
return combined_results
|
||||
|
||||
|
||||
def search_knowledge(queries) -> str:
|
||||
responses = asyncio.run(search_knowledge_multiple(queries))
|
||||
# 输出每个整合的查询结果
|
||||
contents = ""
|
||||
for response in responses:
|
||||
contents += response + "\n\n"
|
||||
return contents
|
||||
|
||||
|
||||
_PROMPT_TEMPLATE = """
|
||||
用户会提出一个需要你查询知识库的问题,你应该对问题进行理解和拆解,并在知识库中查询相关的内容。
|
||||
|
||||
对于每个知识库,你输出的内容应该是一个一行的字符串,这行字符串包含知识库名称和查询内容,中间用逗号隔开,不要有多余的文字和符号。你可以同时查询多个知识库,下面这个例子就是同时查询两个知识库的内容。
|
||||
|
||||
例子:
|
||||
|
||||
robotic,机器人男女比例是多少
|
||||
bigdata,大数据的就业情况如何
|
||||
|
||||
|
||||
这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能,你应该参考他们的功能来帮助你思考
|
||||
|
||||
|
||||
{database_names}
|
||||
|
||||
你的回答格式应该按照下面的内容,请注意```text 等标记都必须输出,这是我用来提取答案的标记。
|
||||
不要输出中文的逗号,不要输出引号。
|
||||
|
||||
Question: ${{用户的问题}}
|
||||
|
||||
```text
|
||||
${{知识库名称,查询问题,不要带有任何除了,之外的符号,比如不要输出中文的逗号,不要输出引号}}
|
||||
|
||||
```output
|
||||
数据库查询的结果
|
||||
|
||||
现在,我们开始作答
|
||||
问题: {question}
|
||||
"""
|
||||
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question", "database_names"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
class LLMKnowledgeChain(LLMChain):
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
"""[Deprecated] Prompt to use to translate to python if necessary."""
|
||||
database_names: Dict[str, str] = None
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
|
||||
"Please instantiate with llm_chain argument or using the from_llm "
|
||||
"class method."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
prompt = values.get("prompt", PROMPT)
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _evaluate_expression(self, queries) -> str:
|
||||
try:
|
||||
output = search_knowledge(queries)
|
||||
except Exception as e:
|
||||
output = "输入的信息有误或不存在知识库,错误信息如下:\n"
|
||||
return output + str(e)
|
||||
return output
|
||||
|
||||
def _process_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
run_manager: CallbackManagerForChainRun
|
||||
) -> Dict[str, str]:
|
||||
|
||||
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
|
||||
llm_output = llm_output.strip()
|
||||
# text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
expression = text_match.group(1).strip()
|
||||
cleaned_input_str = (expression.replace("\"", "").replace("“", "").
|
||||
replace("”", "").replace("```", "").strip())
|
||||
lines = cleaned_input_str.split("\n")
|
||||
# 使用逗号分割每一行,然后形成一个(数据库,查询)元组的列表
|
||||
|
||||
try:
|
||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
||||
except:
|
||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
||||
run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", verbose=self.verbose)
|
||||
output = self._evaluate_expression(queries)
|
||||
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
elif "Answer:" in llm_output:
|
||||
answer = llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
return {self.output_key: f"输入的格式不对:\n {llm_output}"}
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _aprocess_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> Dict[str, str]:
|
||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
|
||||
expression = text_match.group(1).strip()
|
||||
cleaned_input_str = (
|
||||
expression.replace("\"", "").replace("“", "").replace("”", "").replace("```", "").strip())
|
||||
lines = cleaned_input_str.split("\n")
|
||||
try:
|
||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
||||
except:
|
||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
||||
await run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue",
|
||||
verbose=self.verbose)
|
||||
|
||||
output = self._evaluate_expression(queries)
|
||||
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
elif "Answer:" in llm_output:
|
||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||
return {self.output_key: answer}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_run_manager.on_text(inputs[self.input_key])
|
||||
self.database_names = model_container.DATABASE
|
||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||
llm_output = self.llm_chain.predict(
|
||||
database_names=data_formatted_str,
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return self._process_llm_result(llm_output, _run_manager)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
await _run_manager.on_text(inputs[self.input_key])
|
||||
self.database_names = model_container.DATABASE
|
||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||
llm_output = await self.llm_chain.apredict(
|
||||
database_names=data_formatted_str,
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_knowledge_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> LLMKnowledgeChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
|
||||
def search_knowledgebase_complex(query: str):
|
||||
model = model_container.MODEL
|
||||
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
|
||||
ans = llm_knowledge.run(query)
|
||||
return ans
|
||||
|
||||
class KnowledgeSearchInput(BaseModel):
|
||||
location: str = Field(description="The query to be searched")
|
||||
|
||||
|
||||
class RagSearchInput(BaseModel):
|
||||
query: str = Field(description="查询对象")
|
||||
knowledge_name: str = Field(description="The name of the knowledge base to be searched,policy knowledge base name is t_policy_total_bge_new_v2, example: t_policy_total_bge_new_v2]")
|
||||
keywords: str = Field(description="The keywords to be searched example: age,child]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = search_knowledgebase_complex("机器人和大数据在代码教学上有什么区别")
|
||||
print(result)
|
||||
|
||||
# 这是一个正常的切割
|
||||
# queries = [
|
||||
# ("bigdata", "大数据专业的男女比例"),
|
||||
# ("robotic", "机器人专业的优势")
|
||||
# ]
|
||||
# result = search_knowledge(queries)
|
||||
# print(result)
|
||||
@@ -0,0 +1,10 @@
|
||||
name: search_knowledgebase_complex
|
||||
description: Use this tool to search local knowledgebase and get information
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: The query to be searched
|
||||
required:
|
||||
- query
|
||||
234
langchain-chat/server/agent/tools/search_knowledgebase_once.py
Normal file
234
langchain-chat/server/agent/tools/search_knowledgebase_once.py
Normal file
@@ -0,0 +1,234 @@
|
||||
from __future__ import annotations
|
||||
import re
|
||||
import warnings
|
||||
from typing import Dict
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from typing import List, Any, Optional
|
||||
from langchain.prompts import PromptTemplate
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
|
||||
|
||||
import asyncio
|
||||
from server.agent import model_container
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
async def search_knowledge_base_iter(database: str, query: str):
|
||||
response = await knowledge_base_chat(query=query,
|
||||
knowledge_base_name=database,
|
||||
model_name=model_container.MODEL.model_name,
|
||||
temperature=0.01,
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
max_tokens=MAX_TOKENS,
|
||||
prompt_name="knowledge_base_chat",
|
||||
score_threshold=SCORE_THRESHOLD,
|
||||
stream=False)
|
||||
|
||||
contents = ""
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents += data["answer"]
|
||||
docs = data["docs"]
|
||||
return contents
|
||||
|
||||
|
||||
_PROMPT_TEMPLATE = """
|
||||
用户会提出一个需要你查询知识库的问题,你应该按照我提供的思想进行思考
|
||||
Question: ${{用户的问题}}
|
||||
这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能:
|
||||
|
||||
{database_names}
|
||||
|
||||
你的回答格式应该按照下面的内容,请注意,格式内的```text 等标记都必须输出,这是我用来提取答案的标记。
|
||||
```text
|
||||
${{知识库的名称}}
|
||||
```
|
||||
```output
|
||||
数据库查询的结果
|
||||
```
|
||||
答案: ${{答案}}
|
||||
|
||||
现在,这是我的问题:
|
||||
问题: {question}
|
||||
|
||||
"""
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question", "database_names"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
class LLMKnowledgeChain(LLMChain):
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
"""[Deprecated] Prompt to use to translate to python if necessary."""
|
||||
database_names: Dict[str, str] = model_container.DATABASE
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
|
||||
"Please instantiate with llm_chain argument or using the from_llm "
|
||||
"class method."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
prompt = values.get("prompt", PROMPT)
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _evaluate_expression(self, dataset, query) -> str:
|
||||
try:
|
||||
output = asyncio.run(search_knowledge_base_iter(dataset, query))
|
||||
except Exception as e:
|
||||
output = "输入的信息有误或不存在知识库"
|
||||
return output
|
||||
return output
|
||||
|
||||
def _process_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
llm_input: str,
|
||||
run_manager: CallbackManagerForChainRun
|
||||
) -> Dict[str, str]:
|
||||
|
||||
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
database = text_match.group(1).strip()
|
||||
output = self._evaluate_expression(database, llm_input)
|
||||
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
elif "Answer:" in llm_output:
|
||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
return {self.output_key: f"输入的格式不对: {llm_output}"}
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _aprocess_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> Dict[str, str]:
|
||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
expression = text_match.group(1)
|
||||
output = self._evaluate_expression(expression)
|
||||
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
elif "Answer:" in llm_output:
|
||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||
return {self.output_key: answer}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_run_manager.on_text(inputs[self.input_key])
|
||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||
llm_output = self.llm_chain.predict(
|
||||
database_names=data_formatted_str,
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return self._process_llm_result(llm_output, inputs[self.input_key], _run_manager)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
await _run_manager.on_text(inputs[self.input_key])
|
||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||
llm_output = await self.llm_chain.apredict(
|
||||
database_names=data_formatted_str,
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_knowledge_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> LLMKnowledgeChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
|
||||
def search_knowledgebase_once(query: str):
|
||||
model = model_container.MODEL
|
||||
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
|
||||
ans = llm_knowledge.run(query)
|
||||
return ans
|
||||
|
||||
|
||||
class KnowledgeSearchInput(BaseModel):
|
||||
location: str = Field(description="The query to be searched")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = search_knowledgebase_once("大数据的男女比例")
|
||||
print(result)
|
||||
@@ -0,0 +1,32 @@
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
|
||||
import json
|
||||
import asyncio
|
||||
from server.agent import model_container
|
||||
|
||||
async def search_knowledge_base_iter(database: str, query: str) -> str:
|
||||
response = await knowledge_base_chat(query=query,
|
||||
knowledge_base_name=database,
|
||||
model_name=model_container.MODEL.model_name,
|
||||
temperature=0.01,
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
max_tokens=MAX_TOKENS,
|
||||
prompt_name="knowledge_base_chat",
|
||||
score_threshold=SCORE_THRESHOLD,
|
||||
stream=False)
|
||||
|
||||
contents = ""
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents = data["answer"]
|
||||
docs = data["docs"]
|
||||
return contents
|
||||
|
||||
def search_knowledgebase_simple(query: str):
|
||||
return asyncio.run(search_knowledge_base_iter(query))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = search_knowledgebase_simple("大数据男女比例")
|
||||
print("答案:",result)
|
||||
56
langchain-chat/server/agent/tools/search_picture.py
Normal file
56
langchain-chat/server/agent/tools/search_picture.py
Normal file
@@ -0,0 +1,56 @@
|
||||
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import List
|
||||
from urllib.parse import quote
|
||||
from server.agent.tools.search_tool import rag_search
|
||||
from server.chat import utils
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
|
||||
|
||||
def search_pic(query: str) -> List[str]:
|
||||
try:
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
query = matches[0]
|
||||
else:
|
||||
return "<关键指令>不需要再调用该工具了</关键指令>"
|
||||
uuid = json.loads(matches[1])["uuid"]
|
||||
tip = utils.get_shared_variable(uuid)
|
||||
# tip["END"] ="ok"
|
||||
temp = {}
|
||||
temp = json.loads(query)
|
||||
res = search_docs(usr_query=temp["query"],fileName= [],top_k=10,score_threshold=0.9,query=temp["query"], knowledge_base_name="p_meiyupic")
|
||||
if len(res)==0 and len(tip["source_docs"])==0:
|
||||
utils.set_shared_variable(uuid,tip)
|
||||
return "工具没有找到结果"
|
||||
# 遍历 res 中的每个元素
|
||||
result = ""
|
||||
for item in res:
|
||||
# 获取 source 的目录部分
|
||||
source_dir = os.path.splitext(item.metadata['source'])[0]
|
||||
# 获取 page_content
|
||||
page_content = item.page_content
|
||||
# 拼接字符串
|
||||
if item.metadata['source'] in tip["source_docs"]:
|
||||
continue
|
||||
else:
|
||||
tip["source_docs"].append(item.metadata['source'])
|
||||
page_content = quote(page_content.replace("http://127.0.0.1:8099/chat_web_backend", "http://127.0.0.1:8099/chat_web_backend"),safe='/:?=&#+')
|
||||
result += f'\n'
|
||||
|
||||
utils.set_shared_variable(uuid,tip)
|
||||
if len(result)>0:
|
||||
print(f"美术作品链接:{result}")
|
||||
return f"注意:以下链接是图片不是参考文献,以下链接不要放到引文小标的格式输出而是以图片格式输出,禁止转义后面链接的编码,这个链接不能带中文。图片如下:{result}"
|
||||
else:
|
||||
return "<关键指令>不需要再调用该工具了</关键指令>"
|
||||
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return f"Failed to get picture.{e}"
|
||||
|
||||
331
langchain-chat/server/agent/tools/search_tool.py
Normal file
331
langchain-chat/server/agent/tools/search_tool.py
Normal file
@@ -0,0 +1,331 @@
|
||||
import asyncio
|
||||
import concurrent
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from fastapi import logger
|
||||
|
||||
from configs import kb_config
|
||||
from configs.model_config import LLM_MODELS
|
||||
from server.agent.tools import duckduckgo_search
|
||||
from server.agent.tools.duckduckgo_search import duckduckgo_search_iter
|
||||
from server.agent.tools.knowledgebase_kgo_search import knowledgebase_kgo_iter
|
||||
from server.agent.tools.rag_search import rag_search1
|
||||
from server.chat import utils
|
||||
from server.chat.policy_fun_iast import get_llm_model_response
|
||||
from server.chat.utils import doc_to_list, get_similar_documents1, solve_knowledge_map, solve_mental_data
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
|
||||
|
||||
|
||||
def rag_search(query: str,uid):
|
||||
"""
|
||||
根据用户输入的query,返回rag搜索结果
|
||||
"""
|
||||
source_docs = []
|
||||
try:
|
||||
search = json.loads(query)
|
||||
logging.info(f'模型输入: {search["query"]}')
|
||||
original_query = search["query"]
|
||||
search_query = get_llm_model_response(
|
||||
strategy_name="rag_search_rewrite",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="rag_search_rewrite",
|
||||
prompt_param_dict={"input": search["query"], "year": datetime.now().strftime("%Y")},
|
||||
temperature=0.3,
|
||||
max_tokens=512
|
||||
)
|
||||
logging.info(f'模型改写: {search_query}')
|
||||
search_keywords = []
|
||||
search_text = f"{search_query}"
|
||||
# if type(search["keywords"]) == list:
|
||||
# search_keywords = search["keywords"]
|
||||
# for keyword in search_keywords:
|
||||
# search_text += f" {keyword}"
|
||||
# else:
|
||||
# search_keywords = search["keywords"].split(",")
|
||||
# for keyword in search_keywords:
|
||||
# search_text += f" {keyword}"
|
||||
self_database = utils.get_shared_variable(uid)
|
||||
result = []
|
||||
|
||||
knownledge_name = []
|
||||
if type(search["knowledge_name"]) == list:
|
||||
knownledge_name=search["knowledge_name"]
|
||||
else:
|
||||
knownledge_name=search["knowledge_name"].split(",")
|
||||
if "美术专业知识库" in knownledge_name:
|
||||
knownledge_name.remove("美术专业知识库")
|
||||
if "database" in self_database:
|
||||
self_database["database"]= self_database["database"].append("p_cafa0101011")
|
||||
else:
|
||||
self_database["database"] = ["p_cafa0101011"]
|
||||
# 添加个人知识库
|
||||
if "database" in self_database:
|
||||
knownledge_name.extend(self_database["database"])
|
||||
knownledge_name = [knownledge for knownledge in knownledge_name
|
||||
if (knownledge in kb_config.CH_BASE_NAME
|
||||
or knownledge in kb_config.EN_BASE_NAME
|
||||
or knownledge in getattr(kb_config, "YJ_BASE_NAME", [])
|
||||
or kb_config.SELF_KNOWLEDGE_BASE.match(knownledge)
|
||||
or knownledge == "coding")]
|
||||
if len(knownledge_name)==0:
|
||||
#result.append(f"没有找到匹配的知识库,请必须更换联网思索搜索更多知识库内容")
|
||||
return result,source_docs
|
||||
# knownledge_name=kb_config.CH_BASE_NAME
|
||||
|
||||
knownledge_name=solve_knowledge_map(knownledge_name)
|
||||
#knownledge_name = ["p_c88859a3d06e4265bd01d816ef2650d1"]
|
||||
num = 0
|
||||
temp=utils.get_shared_variable(uid)
|
||||
for knownledge in knownledge_name:
|
||||
seen_docs = set()
|
||||
duplicate_indices = []
|
||||
# 针对中国钢铁行业动态库增加日期范围过滤
|
||||
expr_param = ""
|
||||
if knownledge == kb_config.STEEL_KB:
|
||||
time_today = datetime.now().strftime("%Y-%m-%d")
|
||||
# 调用LLM生成日期表达式,模板沿用 get_policy_time
|
||||
try:
|
||||
expr_candidate = get_llm_model_response(
|
||||
strategy_name="get steel time",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="get_steel_time",
|
||||
prompt_param_dict={"query": original_query, "time": time_today},
|
||||
temperature=0.01,
|
||||
max_tokens=512
|
||||
).replace("None", "").strip()
|
||||
expr_param = expr_candidate if expr_candidate else ""
|
||||
except Exception as _:
|
||||
expr_param = ""
|
||||
|
||||
doc_list = search_docs(
|
||||
usr_query=original_query,
|
||||
fileName=[],
|
||||
top_k=20,
|
||||
score_threshold=1.0,
|
||||
query=search_text,
|
||||
knowledge_base_name=knownledge,
|
||||
expr=expr_param
|
||||
)
|
||||
|
||||
if len(doc_list)==0:
|
||||
return result,source_docs
|
||||
titles = temp["title"]
|
||||
doc_list,title = utils.remove_docs1(titles,doc_list)
|
||||
titles.extend(title)
|
||||
for inum,doc in enumerate(doc_list):
|
||||
solve_mental_data(knownledge,doc_list,doc=doc,seen_docs=seen_docs,duplicate_indices=duplicate_indices,knowledge=result,inum=inum)
|
||||
|
||||
# 从policydocs中删除重复的文档(从后往前删除以防止索引错位)
|
||||
for index in sorted(duplicate_indices, reverse=True):
|
||||
del doc_list[index]
|
||||
# 处理原文来源进入数组。使用开关语句明确各个条件分支
|
||||
match knownledge:
|
||||
# 属于政策库分支,入参为中文政策库名称
|
||||
case kb_config.DEFAULT_POLICY_BASE:
|
||||
doc_to_list(num,kb_config.DEFAULT_POLICY_BASE_NAME,doc_list,source_docs)
|
||||
# 属于期刊论文库分支,入参为期刊论文库的中文名称
|
||||
case kb_config.DEFAULT_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.DEFAULT_JOURNAL_BASE_NAME,doc_list,source_docs)
|
||||
# 属于报告库分支,入参为报告库中文名称
|
||||
case kb_config.DEFAULT_REPORT_BASE1:
|
||||
doc_to_list(num,kb_config.DEFAULT_REPORT_BASE_NAME,doc_list,source_docs)
|
||||
# 属于冶金行业新闻库分支,入参为冶金行业新闻库中文名称
|
||||
case kb_config.GY_NEWS_BASE:
|
||||
doc_to_list(num,kb_config.GY_NEWS_BASE_NAME,doc_list,source_docs)
|
||||
# 属于冶金行业报告库分支,入参为冶金行业报告库中文名称
|
||||
case kb_config.GY_REPORT_BASE:
|
||||
doc_to_list(num,kb_config.GY_REPORT_BASE_NAME,doc_list,source_docs)
|
||||
# 属于冶金专业知识库分支,入参为冶金专业知识库中文名称
|
||||
case kb_config.GY_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.GY_JOURNAL_BASE_NAME,doc_list,source_docs)
|
||||
# 新增冶金新闻库(2024年以及之前)
|
||||
case kb_config.YJ_NEWS_BASE:
|
||||
doc_to_list(num,kb_config.YJ_NEWS_BASE_NAME,doc_list,source_docs)
|
||||
# 新增冶金中文期刊库
|
||||
case kb_config.YJ_CH_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.YJ_CH_JOURNAL_BASE_NAME,doc_list,source_docs)
|
||||
# 新增冶金外文期刊库
|
||||
case kb_config.YJ_FOR_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.YJ_FOR_JOURNAL_BASE_NAME,doc_list,source_docs)
|
||||
# 新增冶金OA期刊库
|
||||
case kb_config.YJ_OA_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.YJ_OA_JOURNAL_BASE_NAME,doc_list,source_docs)
|
||||
# 新增冶金政策库
|
||||
case kb_config.YJ_POLICYS_BASE:
|
||||
doc_to_list(num,kb_config.YJ_POLICYS_BASE_NAME,doc_list,source_docs)
|
||||
# 新增中国钢铁行业动态库
|
||||
case kb_config.STEEL_KB:
|
||||
doc_to_list(num,kb_config.STEEL_KB_NAME,doc_list,source_docs)
|
||||
# 属于个人知识库分支
|
||||
case _ if kb_config.SELF_KNOWLEDGE_BASE.match(knownledge) or knownledge == "coding":
|
||||
doc_to_list(num,knownledge,doc_list,source_docs)
|
||||
case _:
|
||||
print(f"输入了没有的知识库名称")
|
||||
return "输入了没有的知识库名称",source_docs
|
||||
# num += len(source_docs[knownledge])
|
||||
# 构建缓存对象用于h'per_query_cache'用于传递给其他方法使用uuid获取本轮对话的文献来源
|
||||
# del num
|
||||
# source = utils.get_shared_variable(uid)
|
||||
# print(utils.get_shared_variable(uid))
|
||||
# source["source_docs"]=source_docs
|
||||
# utils.set_shared_variable(uid,source)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in rag_search: {e}")
|
||||
# return "入参格式需为{\"knowledge_name\":[\"XXX\",\"XXX\"],\"query\":\"XXX\",\"keywords\":[\"XXX\", \"XXX\", \"XXX\", \"XXX\"]}检查输入参数如果没有缺少必要值,当前工具异常请换其他工具"
|
||||
return "当前工具异常!请换其他工具",source_docs
|
||||
|
||||
return result,source_docs
|
||||
|
||||
|
||||
|
||||
|
||||
def knowledgebase_kgo_search(query: str, uid) -> List[str]:
|
||||
try:
|
||||
res = knowledgebase_kgo_iter(query,uid)
|
||||
try:
|
||||
if type(res[0])==list and type(res[1])==list:
|
||||
return res
|
||||
elif type(res[1])==list and len(res[1])>0:
|
||||
res[0]=[]
|
||||
return res
|
||||
else:
|
||||
temp = []
|
||||
temp[0]=[]
|
||||
temp[1]=[]
|
||||
return temp
|
||||
except Exception as e:
|
||||
temp = {}
|
||||
logging.error(f"No docs: {e}")
|
||||
temp[0]=[]
|
||||
temp[1]=[]
|
||||
return temp
|
||||
except json.JSONDecodeError:
|
||||
# 如果JSON解码失败,则返回错误消息
|
||||
logging.error("Invalid JSON format in query.")
|
||||
return "<关键指令>不需要再调用该工具了</关键指令>"
|
||||
except KeyError:
|
||||
# 如果解析的JSON对象中缺少必要的键,则返回错误消息
|
||||
return "请尝试调用其他工具"
|
||||
except Exception as e:
|
||||
# 捕获其他所有异常,并返回通用错误消息
|
||||
return f"发生错误:{str(e)},请尝试调用其他工具"
|
||||
|
||||
|
||||
def inner_duckduckgo_search(query: str, uuid:str,) :
|
||||
logging.info(f"模型输入: {query}")
|
||||
combined_result = asyncio.run(duckduckgo_search_iter(query, uuid, "y","default" ))
|
||||
# 以标准json格式输出
|
||||
logging.info("返回JSON格式的结果给到模型...")
|
||||
return combined_result
|
||||
|
||||
|
||||
def search_tool(query: str):
|
||||
"""获取到uid并拆分query"""
|
||||
if "<param>"in query:
|
||||
query = query.replace("<param>","").replace("</param>","")
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
query = matches[0]
|
||||
else:
|
||||
return "<关键指令>当前工具不需要再调用</关键指令>"
|
||||
time_based_uuid = json.loads(matches[1])["uuid"]
|
||||
search = json.loads(query)
|
||||
if type(search["query"])==list and len(search["query"])>0:
|
||||
searches = search["query"][0]
|
||||
elif type(search["query"])==list and len(search["query"]) == 0:
|
||||
searches = "无"
|
||||
else:
|
||||
searches = search["query"]
|
||||
"""
|
||||
根据用户输入的query,返回rag搜索结果
|
||||
"""
|
||||
|
||||
try:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# 提交任务并发执行
|
||||
|
||||
test = {}
|
||||
test["num"]=0
|
||||
test["source_docs"]=[]
|
||||
test["END"] = ""
|
||||
test["title"] = []
|
||||
utils.set_shared_variable(time_based_uuid+"¥",test)
|
||||
# future2 = executor.submit(knowledgebase_kgo_search,search["query"],time_based_uuid+"q")
|
||||
future1 = executor.submit(rag_search,query,time_based_uuid)
|
||||
# if not "type" in utils.get_shared_variable(time_based_uuid):
|
||||
# future2 = executor.submit(knowledgebase_kgo_search,searches,time_based_uuid+"¥")
|
||||
if not "type" in utils.get_shared_variable(time_based_uuid):
|
||||
future2 = executor.submit(knowledgebase_kgo_search,searches,time_based_uuid+"¥")
|
||||
result3 = []
|
||||
# 获取结果
|
||||
result1,sourcedocs = future1.result()
|
||||
result2 = {}
|
||||
if "type" in utils.get_shared_variable(time_based_uuid):
|
||||
result2[0] =[]
|
||||
result2[1] = []
|
||||
else:
|
||||
result2 = future2.result()
|
||||
# if "type" in utils.get_shared_variable(time_based_uuid):
|
||||
# result2[0] =[]
|
||||
# result2[1] = []
|
||||
# else:
|
||||
# result2 = future2.result()
|
||||
# result2[0] = []
|
||||
# result2[1] = []
|
||||
utils.remove_shared_variable(time_based_uuid+"q")
|
||||
if type(result2[1]) == list:
|
||||
if type(sourcedocs) == list:
|
||||
sourcedocs.extend(result2[1])
|
||||
else:
|
||||
sourcedocs = []
|
||||
if type(result1) == list:
|
||||
result1.extend(result2[0])
|
||||
result3 = result1
|
||||
else:
|
||||
result3 = result2[0]
|
||||
|
||||
logging.info(f"result2:{result2[1]}")
|
||||
source = []
|
||||
res=[]
|
||||
sources = utils.get_shared_variable(time_based_uuid)
|
||||
i = sources["num"]
|
||||
num = sources["num"]
|
||||
for result in sourcedocs:
|
||||
try:
|
||||
i+=1
|
||||
res3 = re.sub(r'\[\d+\]', f"[{i}]", result.replace("\n",""), count=1)
|
||||
if res3:
|
||||
source.append(re.sub(r'\[\d+\]', f"[{i}]", result.replace("\n",""), count=1))
|
||||
else:
|
||||
i -= 1
|
||||
except Exception as e:
|
||||
i -= 1
|
||||
pass
|
||||
# internet_search_res = f"参考资料[{len(result1)+1}-{len(source)}]:{result2[0]}"
|
||||
# internet_search_res = f"参考资料:{result2[0]}"
|
||||
j = sources["num"]
|
||||
for result in result3:
|
||||
j+=1
|
||||
res.append(re.sub(r'\[\d+\]', f"[{j}]", result, count=1))
|
||||
|
||||
print(utils.get_shared_variable(time_based_uuid))
|
||||
# sources["source_docs"]=source
|
||||
sources["source_docs"].extend(source)
|
||||
sources["num"]=i
|
||||
# sources["END"] = "ok"
|
||||
utils.set_shared_variable(time_based_uuid,sources)
|
||||
logging.info(f"result1:{result1},sourcedocs:{sourcedocs}")
|
||||
logging.info(f"result2:{result2}")
|
||||
logging.info(f"{res}")
|
||||
if len(res) ==0 and len(source)==0:
|
||||
return f"尝试调整入参重新调用知识库联想工具(同一个问题调用超过三次就不要再使用知识库联想工具了,浪费时间)"
|
||||
return f"<关键指令>如果你在写文章禁止在非规定位置输出参考资料</关键指令>资料:{res}\n资料来源为:{source}\n 注意:如果你在根据大纲撰写文章,撰写中间部分章节禁止输出综上所述之类的影响文风的话,撰写中间部分禁止输出附录引用文献等!!!"
|
||||
except Exception as e:
|
||||
logging.error(f"Error occurred during search_tool execution.{e}")
|
||||
return "同一个问题调用知识库联想工具超过5次就不要再调用知识库联想"
|
||||
|
||||
9
langchain-chat/server/agent/tools/search_youtube.py
Normal file
9
langchain-chat/server/agent/tools/search_youtube.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Langchain 自带的 YouTube 搜索工具封装
|
||||
from langchain.tools import YouTubeSearchTool
|
||||
from pydantic import BaseModel, Field
|
||||
def search_youtube(query: str):
|
||||
tool = YouTubeSearchTool()
|
||||
return tool.run(tool_input=query)
|
||||
|
||||
class YoutubeInput(BaseModel):
|
||||
location: str = Field(description="Query for Videos search")
|
||||
10
langchain-chat/server/agent/tools/search_youtube.yaml
Normal file
10
langchain-chat/server/agent/tools/search_youtube.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
name: search_youtube
|
||||
description: Use this tools to search youtube videos
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: Query for Videos search
|
||||
required:
|
||||
- query
|
||||
9
langchain-chat/server/agent/tools/shell.py
Normal file
9
langchain-chat/server/agent/tools/shell.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# LangChain 的 Shell 工具
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.tools import ShellTool
|
||||
def shell(query: str):
|
||||
tool = ShellTool()
|
||||
return tool.run(tool_input=query)
|
||||
|
||||
class ShellInput(BaseModel):
|
||||
query: str = Field(description="一个能在Linux命令行运行的Shell命令")
|
||||
10
langchain-chat/server/agent/tools/shell.yaml
Normal file
10
langchain-chat/server/agent/tools/shell.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
name: shell
|
||||
description: Use Linux Shell to execute Linux commands
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: The command to execute
|
||||
required:
|
||||
- query
|
||||
49
langchain-chat/server/agent/tools/weather_check.py
Normal file
49
langchain-chat/server/agent/tools/weather_check.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
更简单的单参数输入工具实现,用于查询现在天气的情况
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
import requests
|
||||
from configs.kb_config import SENIVERSE_API_KEY
|
||||
from server.chat import utils
|
||||
|
||||
|
||||
def weather(location: str, api_key: str):
|
||||
url = f"https://api.seniverse.com/v3/weather/daily.json?key={api_key}&location={location}&language=zh-Hans&unit=c&start=0&days=5"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
weather = {
|
||||
"today": json.dumps(data["results"][0]["daily"][0]),
|
||||
"others": json.dumps(data["results"][0]["daily"][1:])
|
||||
}
|
||||
weather_info = json.dumps(weather)
|
||||
return weather_info
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to retrieve weather: {response.status_code}")
|
||||
|
||||
|
||||
def weathercheck(query: str):
|
||||
"""获取到uid并拆分query"""
|
||||
try:
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
query = matches[0]
|
||||
else:
|
||||
return "<关键指令>不需要再调用该工具了</关键指令>"
|
||||
location = json.loads(query)["location"]
|
||||
time_based_uuid = json.loads(matches[1])["uuid"]
|
||||
# sources = utils.get_shared_variable(time_based_uuid)
|
||||
# sources["source_docs"]=["天气 预报"]
|
||||
# sources["num"]+=1
|
||||
# sources["END"] = "ok"
|
||||
# utils.set_shared_variable(time_based_uuid,sources)
|
||||
return weather(location, SENIVERSE_API_KEY)
|
||||
except Exception as e:
|
||||
return f"Failed to retrieve weather.{e}"
|
||||
|
||||
|
||||
class WeatherInput(BaseModel):
|
||||
location: str = Field(description="City name,include city and county")
|
||||
10
langchain-chat/server/agent/tools/weather_check.yaml
Normal file
10
langchain-chat/server/agent/tools/weather_check.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
name: weather_check
|
||||
description: Use Weather API to get weather information
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: City name,include city and county,like "厦门市思明区"
|
||||
required:
|
||||
- query
|
||||
11
langchain-chat/server/agent/tools/wolfram.py
Normal file
11
langchain-chat/server/agent/tools/wolfram.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# Langchain 自带的 Wolfram Alpha API 封装
|
||||
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
|
||||
from pydantic import BaseModel, Field
|
||||
wolfram_alpha_appid = "your key"
|
||||
def wolfram(query: str):
|
||||
wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=wolfram_alpha_appid)
|
||||
ans = wolfram.run(query)
|
||||
return ans
|
||||
|
||||
class WolframInput(BaseModel):
|
||||
location: str = Field(description="需要运算的具体问题")
|
||||
10
langchain-chat/server/agent/tools/wolfram.yaml
Normal file
10
langchain-chat/server/agent/tools/wolfram.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
name: wolfram
|
||||
description: Useful for when you need to calculate difficult math formulas
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: The formula to be calculated
|
||||
required:
|
||||
- query
|
||||
Reference in New Issue
Block a user