[全量] 初始化项目代码、配置、文档及Agent协同harness

This commit is contained in:
2026-04-02 11:36:05 +08:00
parent 0553309cdf
commit 87e571d9ec
1133 changed files with 221948 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
from .model_contain import *
from .callbacks import *
from .custom_template import *
from .tools import *

View File

@@ -0,0 +1,8 @@
class Agent(object):
def __init__(self):
self.step = ""
self.knowledge = ""
self.question = ""
self.res =""

View File

@@ -0,0 +1,101 @@
import uuid
from fastapi import Body
from langchain.memory import (
CombinedMemory,
ConversationBufferMemory,
ConversationSummaryMemory,
ConversationBufferWindowMemory
)
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN
from server.utils import wrap_done, get_ChatOpenAI
from langchain.chains import LLMChain, ConversationChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
import json
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional, Union
from server.chat.utils import History
from langchain.prompts import PromptTemplate
from server.utils import get_prompt_template, get_format_template
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
from datetime import datetime
from langchain_core.messages import SystemMessage
import time as t
from server.utils import replace_variables
from configs.basic_config import *
from configs.outline_config import outlines
async def agent_chat_new(
user_prompt_name: Optional[str] = Body(None, description="用户输入"),
query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
conversation_id: str = Body("", description="对话框ID"),
history: Union[int, List[History]] = Body([], description="历史对话"),
model_name: str = Body("default_model", description="LLM 模型名称。"),
temperature: float = Body(0.7, description="LLM 采样温度", ge=0.0, le=2.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量"),
prompt_template: str = Body("default", description="使用的prompt模板内容"),
stream: bool = Body(False, description="流式输出")
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
callbacks = [callback]
time = datetime.now().strftime("%Y年%m月%d")
message_id = str(uuid.uuid1())+"q"
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=callbacks,
)
history = [History.from_data(h) for h in history]
chat_prompt = PromptTemplate.from_template(prompt_template)
# 把history转成memory
buff_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="history", input_key="input")
if len(history)>0:
for message in history:
# 检查消息的角色
if message.role == 'user':
# 添加用户消息
buff_memory.chat_memory.add_user_message(message.content)
elif message.role == 'assistant':
# 添加AI消息
buff_memory.chat_memory.add_ai_message(message.content)
else:
buff_memory.chat_memory.add_user_message("")
buff_memory.chat_memory.add_ai_message("")
background_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="time", input_key="input")
message = SystemMessage(content = f'当前的时间是:{time}')
background_memory.chat_memory.add_message(message)
memory = CombinedMemory(memories=[background_memory, buff_memory])
chain = ConversationChain(llm=model, verbose=True, memory=memory, prompt=chat_prompt)
task = asyncio.create_task(wrap_done(
chain.acall({"input": query, "time": time}),
callback.done),
)
answer = ""
async for token in callback.aiter():
if stream:
yield json.dumps({"text": token}, ensure_ascii=False)
else:
answer += token
logger.info(f'solve_problem: {str(answer)}')
await task
if stream:
return
else:
yield json.dumps(
{"text": answer, "message_id": message_id},
ensure_ascii=False
)

View File

@@ -0,0 +1,161 @@
from __future__ import annotations
from uuid import UUID
from langchain.callbacks import AsyncIteratorCallbackHandler
import json
import asyncio
from typing import Any, Dict, List, Optional
from langchain.schema import AgentFinish, AgentAction
from langchain.schema.output import LLMResult
def dumps(obj: Dict) -> str:
return json.dumps(obj, ensure_ascii=False)
class Status:
start: int = 1
running: int = 2
complete: int = 3
agent_action: int = 4
agent_finish: int = 5
error: int = 6
tool_finish: int = 7
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
def __init__(self):
super().__init__()
self.queue = asyncio.Queue()
self.done = asyncio.Event()
self.cur_tool = {}
self.out = True
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
parent_run_id: UUID | None = None, tags: List[str] | None = None,
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
# 对于截断不能自理的大模型,我来帮他截断
stop_words = ["Observation:", "Thought","\"","", "\n","\t"]
for stop_word in stop_words:
index = input_str.find(stop_word)
if index != -1:
input_str = input_str[:index]
break
self.cur_tool = {
"tool_name": serialized["name"],
"input_str": input_str,
"output_str": "",
"status": Status.agent_action,
"run_id": run_id.hex,
"llm_token": "",
"final_answer": "",
"error": "",
}
# print("\nInput Str:",self.cur_tool["input_str"])
self.queue.put_nowait(dumps(self.cur_tool))
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
tags: List[str] | None = None, **kwargs: Any) -> None:
self.out = True ## 重置输出
self.cur_tool.update(
status=Status.tool_finish,
output_str=output.replace("Answer:", ""),
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
self.cur_tool.update(
status=Status.error,
error=str(error),
)
self.queue.put_nowait(dumps(self.cur_tool))
# async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
# if "Action" in token: ## 减少重复输出
# before_action = token.split("Action")[0]
# self.cur_tool.update(
# status=Status.running,
# llm_token=before_action + "\n",
# )
# self.queue.put_nowait(dumps(self.cur_tool))
#
# self.out = False
#
# if token and self.out:
# self.cur_tool.update(
# status=Status.running,
# llm_token=token,
# )
# self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
special_tokens = ["Action", "<|observation|>"]
for stoken in special_tokens:
if stoken in token:
before_action = token.split(stoken)[0]
self.cur_tool.update(
status=Status.running,
llm_token=before_action + "\n",
)
self.queue.put_nowait(dumps(self.cur_tool))
self.out = False
break
if token and self.out:
self.cur_tool.update(
status=Status.running,
llm_token=token,
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
self.cur_tool.update(
status=Status.start,
llm_token="",
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
self.cur_tool.update(
status=Status.start,
llm_token="",
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self.cur_tool.update(
status=Status.complete,
llm_token="\n",
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
self.cur_tool.update(
status=Status.error,
error=str(error),
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_agent_finish(
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
# 返回最终答案
self.cur_tool.update(
status=Status.agent_finish,
final_answer=finish.return_values["output"],
)
self.queue.put_nowait(dumps(self.cur_tool))
self.cur_tool = {}

View File

@@ -0,0 +1,228 @@
"""
This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo.
"""
from __future__ import annotations
import json
import logging
from typing import Any, List, Sequence, Tuple, Optional, Union
from pydantic.schema import model_schema
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.memory import ConversationBufferWindowMemory
from langchain.agents.agent import Agent
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
from langchain.agents.agent import AgentOutputParser
from langchain.output_parsers import OutputFixingParser
from langchain.pydantic_v1 import Field
from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate
from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
logger = logging.getLogger(__name__)
class StructuredChatOutputParserWithRetries(AgentOutputParser):
"""Output parser with retries for the structured chat agent."""
base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser)
"""The base parser to use."""
output_fixing_parser: Optional[OutputFixingParser] = None
"""The output fixing parser to use."""
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
special_tokens = ["Action:", "<|observation|>"]
first_index = min([text.find(token) if token in text else len(text) for token in special_tokens])
text = text[:first_index]
if "tool_call" in text:
action_end = text.find("```")
action = text[:action_end].strip()
params_str_start = text.find("(") + 1
params_str_end = text.rfind(")")
params_str = text[params_str_start:params_str_end]
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs}
action_json = {
"action": action,
"action_input": params
}
else:
action_json = {
"action": "Final Answer",
"action_input": text
}
action_str = f"""
Action:
```
{json.dumps(action_json, ensure_ascii=False)}
```"""
try:
if self.output_fixing_parser is not None:
parsed_obj: Union[
AgentAction, AgentFinish
] = self.output_fixing_parser.parse(action_str)
else:
parsed_obj = self.base_parser.parse(action_str)
return parsed_obj
except Exception as e:
raise OutputParserException(f"Could not parse LLM output: {text}") from e
@property
def _type(self) -> str:
return "structured_chat_ChatGLM3_6b_with_retries"
class StructuredGLM3ChatAgent(Agent):
"""Structured Chat Agent."""
output_parser: AgentOutputParser = Field(
default_factory=StructuredChatOutputParserWithRetries
)
"""Output parser for the agent."""
@property
def observation_prefix(self) -> str:
"""Prefix to append the ChatGLM3-6B observation with."""
return "Observation:"
@property
def llm_prefix(self) -> str:
"""Prefix to append the llm call with."""
return "Thought:"
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> str:
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
if not isinstance(agent_scratchpad, str):
raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad:
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}"
)
else:
return agent_scratchpad
@classmethod
def _get_default_output_parser(
cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
) -> AgentOutputParser:
return StructuredChatOutputParserWithRetries(llm=llm)
@property
def _stop(self) -> List[str]:
return ["<|observation|>"]
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prompt: str = None,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tools_json = []
tool_names = []
for tool in tools:
tool_schema = model_schema(tool.args_schema) if tool.args_schema else {}
simplified_config_langchain = {
"name": tool.name,
"description": tool.description,
"parameters": tool_schema.get("properties", {})
}
tools_json.append(simplified_config_langchain)
tool_names.append(tool.name)
formatted_tools = "\n".join([
f"{tool['name']}: {tool['description']}, args: {tool['parameters']}"
for tool in tools_json
])
formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}")
template = prompt.format(tool_names=tool_names,
tools=formatted_tools,
history="None",
input="{input}",
agent_scratchpad="{agent_scratchpad}")
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: str = None,
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
prompt = cls.create_prompt(
tools,
prompt=prompt,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser or cls._get_default_output_parser(llm=llm)
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
**kwargs,
)
@property
def _agent_type(self) -> str:
raise ValueError
def initialize_glm3_agent(
tools: Sequence[BaseTool],
llm: BaseLanguageModel,
prompt: str = None,
memory: Optional[ConversationBufferWindowMemory] = None,
agent_kwargs: Optional[dict] = None,
*,
tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AgentExecutor:
tags_ = list(tags) if tags else []
agent_kwargs = agent_kwargs or {}
agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools(
llm=llm,
tools=tools,
prompt=prompt,
**agent_kwargs
)
return AgentExecutor.from_agent_and_tools(
agent=agent_obj,
tools=tools,
memory=memory,
tags=tags_,
**kwargs,
)

View File

@@ -0,0 +1,162 @@
from __future__ import annotations
import json
import re
from langchain.agents import Tool, AgentOutputParser
from langchain.prompts import StringPromptTemplate
from typing import List
from langchain.schema import AgentAction, AgentFinish
from configs.basic_config import *
from configs import SUPPORT_AGENT_MODEL
from server.agent import model_container
from server.chat import utils
from server.chat.knowledge_base_name import KnowledgeBase
from collections import defaultdict
class CustomPromptTemplate(StringPromptTemplate):
template: str
tools: List[Tool]
# def format(self, **kwargs) -> str:
# intermediate_steps = kwargs.pop("intermediate_steps")
# thoughts = ""
# for action, observation in intermediate_steps:
# thoughts += action.log
# thoughts += f"\nObservation: {observation}\n"
# kwargs["agent_scratchpad"] = thoughts
# kwargs["tools"] = "\n".join([f"【工具名称】{tool.name}: 【工具描述】{tool.description}" for tool in self.tools])
# kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
# return self.template.format(**kwargs)
def format(self, **kwargs) -> str:
# 确保 intermediate_steps 存在
intermediate_steps = kwargs.pop("intermediate_steps", [])
thoughts = ""
for action, observation in intermediate_steps:
thoughts += action.log
thoughts += f"\nObservation: {observation}\n"
# 设置默认值以防止 KeyError
kwargs["agent_scratchpad"] = thoughts
kwargs["tools"] = "\n".join([f"【工具名称】{tool.name}: 【工具描述】{tool.description}" for tool in getattr(self, 'tools', [])])
kwargs["tool_names"] = ", ".join([tool.name for tool in getattr(self, 'tools', [])])
# 使用正则表达式替换位置字段为命名字段
def replace_positional(match):
index = match.group(1)
return f"{{arg{index}}}"
self.template = re.sub(r'\{(\d+)\}', replace_positional, self.template)
# 为所有占位符提供默认值
placeholders = re.findall(r'\{(\w+)\}', self.template)
for placeholder in placeholders:
if placeholder not in kwargs:
kwargs[placeholder] = f"<missing {placeholder}>"
# 确保所有占位符都有对应的值
try:
return self.template.format_map(kwargs)
except KeyError as e:
raise ValueError(f"Missing key in template formatting: {e}")
except ValueError as e:
raise ValueError(f"Format string contains positional fields: {e}")
class CustomOutputParser(AgentOutputParser):
begin: bool = False
knowledge_base_name: KnowledgeBase
time_based_uuid:str
def __init__(self, knowledge_base_name: KnowledgeBase,time_based_uuid:str):
super().__init__(knowledge_base_name = knowledge_base_name,time_based_uuid=time_based_uuid)
self.begin = True
self.knowledge_base_name = knowledge_base_name
self.time_based_uuid = time_based_uuid
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin:
self.begin = False
stop_words = ["Observation:"]
min_index = len(llm_output)
for stop_word in stop_words:
index = llm_output.find(stop_word)
if index != -1 and index < min_index:
min_index = index
llm_output = llm_output[:min_index]
if "Final Answer:" in llm_output:
self.begin = True
return AgentFinish(
return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()},
log=llm_output,
)
# print("llm_output>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",llm_output)
# logger.info(f"llm_output: {llm_output}")
parts = llm_output.split("Action:")
# print("parts:", parts)
if len(parts) < 2:
return AgentFinish(
# return_values={"output": f"以下内容如果不符合您的预期,可以试着换个方式问问我:\n\n {llm_output} "},
return_values={"output": f" {llm_output} "},
log=llm_output,
)
# action = parts[1].split("Action Input:")[0].strip()
match = re.search(r'Action:\s*(.+)', llm_output)
try:
action = match.group(1).strip().strip('\'').replace("【工具名称】","")
# print("parts[1]>>>>",parts[1])
# if "\nAction Input" not in parts:
action_input = parts[1].split("Action Input:")[1].strip()
action_input = action_input.replace("【调用结束】","")
action_input = action_input.replace("调用结束","")
except:
action_input = ""
# END = utils.get_shared_variable(self.time_based_uuid)
# END["END"] = "ok"
false_tokens = ["","None"]
# use_tools = ["联网思索","知识库联想","图表绘制","实景绘制","水墨画绘制"]
use_tools = ["联网思索","知识库联想","图表绘制","数学运算","代码专家","天气工具","美术作品获取","统计数据查询"]
# utils.set_shared_variable(self.time_based_uuid,END)
print("action_input: ",action_input)
try:
if "knowledge_name" in json.loads(action_input):
match = re.search(r'\{.*?\}', action_input,re.DOTALL)
query = match.group(0)
self.knowledge_base_name.name = json.loads(query)["knowledge_name"]
except Exception as e:
pass
try:
if not (any(false_token == action.strip("").replace("","").replace("","") for false_token in false_tokens)) and any(tool_temp in action for tool_temp in use_tools):
ans = AgentAction(
tool=action.replace("","").replace("",""),
tool_input=action_input.strip(" ").strip('"')+"{\"uuid\":\""+self.time_based_uuid+"\"}",
log=llm_output
)
return ans
elif action.strip("").replace("","").replace("","") == "环节跳转":
ans = AgentAction(
tool="环节跳转",
tool_input=self.time_based_uuid,
log=llm_output
)
return ans
else:
ans = AgentAction(
tool="无需调用工具",
tool_input=self.time_based_uuid,
log=llm_output
)
# END = utils.get_shared_variable(self.time_based_uuid)
# END["END"] = "ok"
# utils.set_shared_variable(self.time_based_uuid,END)
return ans
except:
# END = utils.get_shared_variable(self.time_based_uuid)
# END["END"] = "ok"
return AgentFinish(
return_values={"output": f"调用agent失败: `{llm_output}`"},
log=llm_output,
)

View File

@@ -0,0 +1,6 @@
class ModelContainer:
def __init__(self):
self.MODEL = None
self.DATABASE = None
model_container = ModelContainer()

View File

@@ -0,0 +1,16 @@
## 导入所有的工具类
# from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput
# from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput
# from .chat_with_Yi34B import chat_with_Yi34B, ChatWithYi34BInput
# from .search_youtube import search_youtube, YoutubeInput
from .calculate import calculate, CalculatorInput
from .weather_check import weathercheck, WeatherInput
from .shell import shell, ShellInput
from .search_internet import search_internet, SearchInternetInput
from .wolfram import wolfram, WolframInput
from .arxiv import arxiv, ArxivInput
from .knowledgebase_kgo_search import knowledgebase_kgo_search, KnowledgeKgoInput
from .policy_knowledgebase_search import policy_knowledgebase_search, PolicyKnowledgeInput
from .report_knowledgebase_search import report_knowledgebase_search, ReportKnowledgeInput
from .rag_search import rag_search1, RagSearchInput
from .duckduckgo_search import duckduckgo_search, DuckduckgoInput

View File

@@ -0,0 +1,9 @@
# LangChain 的 ArxivQueryRun 工具
from pydantic import BaseModel, Field
from langchain.tools.arxiv.tool import ArxivQueryRun
def arxiv(query: str):
tool = ArxivQueryRun()
return tool.run(tool_input=query)
class ArxivInput(BaseModel):
query: str = Field(description="The search query title")

View File

@@ -0,0 +1,10 @@
name: arxiv
description: A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.
parameters:
type: object
properties:
query:
type: string
description: The search query title
required:
- query

View File

@@ -0,0 +1,76 @@
from langchain.prompts import PromptTemplate
from langchain.chains import LLMMathChain
from server.agent import model_container
from pydantic import BaseModel, Field
_PROMPT_TEMPLATE = """
将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。
问题: ${{包含数学问题的问题。}}
```text
${{解决问题的单行数学表达式}}
```
...numexpr.evaluate(query)...
```output
${{运行代码的输出}}
```
答案: ${{答案}}
这是两个例子:
问题: 37593 * 67是多少
```text
37593 * 67
```
...numexpr.evaluate("37593 * 67")...
```output
2518731
答案: 2518731
问题: 37593的五次方根是多少
```text
37593**(1/5)
```
...numexpr.evaluate("37593**(1/5)")...
```output
8.222831614237718
答案: 8.222831614237718
问题: 2的平方是多少
```text
2 ** 2
```
...numexpr.evaluate("2 ** 2")...
```output
4
答案: 4
现在,这是我的问题:
问题: {question}
"""
PROMPT = PromptTemplate(
input_variables=["question"],
template=_PROMPT_TEMPLATE,
)
class CalculatorInput(BaseModel):
query: str = Field()
def calculate(query: str):
model = model_container.MODEL
llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_math.run(query)
return ans
if __name__ == "__main__":
result = calculate("2的三次方")
print("答案:",result)

View File

@@ -0,0 +1,10 @@
name: calculate
description: Useful for when you need to answer questions about simple calculations
parameters:
type: object
properties:
query:
type: string
description: The formula to be calculated
required:
- query

View File

@@ -0,0 +1,43 @@
import asyncio
import json
from typing import List, Union
from pydantic import BaseModel, Field
from server.chat.chat import chat
from server.chat.utils import History
async def chat_with_Yi34B_iter(query: str,
stream=False,
model_name="qianfan-api",
history: Union[int, List[History]] = None,
conversation_id='',
temperature=0.7,
max_tokens=None,
history_len=3,
prompt_name="default"
):
response = await chat(query=query, history=history,
history_len=history_len,
conversation_id=conversation_id,
stream=stream, model_name=model_name, temperature=temperature,
max_tokens=max_tokens, prompt_name=prompt_name)
contents = ""
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
contents += data["text"]
return contents
def chat_with_Yi34B(query: str, model_name: str = "qianfan-api", conversation_id: str = '',
history: Union[int, List[History]] = None):
# 格式化查询字符串
return asyncio.run(chat_with_Yi34B_iter(query, model_name=model_name, conversation_id=conversation_id,
history=history))
class ChatWithYi34BInput(BaseModel):
location: str = Field(description="Query for any kind of chats and questions")

View File

@@ -0,0 +1,18 @@
name: chat_with_Yi34B
description: Use this tool to chat with human
parameters:
type: object
properties:
query:
type: string
description: Query for any kind of chat and questions
model_name:
type: string
description:
conversation_id:
type: string
description:
required:
- query
- model_name
- conversation_id

View File

@@ -0,0 +1,34 @@
import json
import re
import concurrent
from fastapi.concurrency import run_in_threadpool
from langchain.tools import YouTubeSearchTool
from pydantic import BaseModel, Field
from server.chat import utils
from server.chat.utils import doc_to_list, solve_knowledge_map,solve_mental_data,shared_variable
from server.knowledge_base.kb_doc_api import search_docs
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from configs import kb_config
def do_nothing(query: str):
"""
什么都不做不调用工具设置
"""
# 以下三行逻辑控制反问如果不需要反问注释掉即可但可能会带来的问题包括该agent将始终会使用工具
return f"\n不需要调用工具了"
def get_next_tip(query: str):
"""
什么都不做不调用工具设置
"""
# 以下三行逻辑控制反问如果不需要反问注释掉即可但可能会带来的问题包括该agent将始终会使用工具
res = utils.get_shared_variable(query)
res["END"] = "ok"
utils.set_shared_variable(query,res)
return f"\n提示:你已经使用过环节跳转了,可以开始输出正文了"
class doNothingInput(BaseModel):
query: str = Field(...,description="查询对象")

View File

@@ -0,0 +1,175 @@
import json
import logging
import os
import re
import uuid
import requests
from matplotlib import pyplot as plt
from pydantic import BaseModel, Field
from configs.model_config import LLM_MODELS
from server.chat import utils
from server.chat.policy_fun_iast import get_llm_model_response
from configs.kb_config import GENERATED_IMAGES_BASE_PATH, realistic_url,ink_url
from matplotlib import font_manager
my_font = font_manager.FontProperties(fname="/usr/share/fonts/MicroSoft-YaHei/MSYH.TTC")
def create_and_save_plot(query:str) -> str:
try:
query = query.replace(" ","").replace("'","\"")
json_str ='{\n"data": {"XXX": XX, "XXX": XX, "XXX": X, "XXX": X},"title": "X","xlabel": "X","ylabel": "X","plot_type": "X"}'
datas = {}
try:
match = re.search(r'<param>(.*?)</param>(\{.*\})', query.replace("\n", ""))
if match:
query = match.group(1).strip()
datas = json.loads(query)
else:
print(f"Invalid JSON format in query:\n{query}")
return"暂时无法画图"
except:
query = get_llm_model_response(
strategy_name="query rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="check_plot",
prompt_param_dict={"user_input": query,"json":json_str },
temperature=0.01,
max_tokens=512
)
re.search(r"```json\n(.*?)\n```", query.replace("\n", ""), re.DOTALL)
query = match.group(1).strip()
datas = json.loads(query)
data = datas["data"]
xlabel = datas["xlabel"]
ylabel = datas["ylabel"]
title = datas["title"]
plot_type = datas["plot_type"]
# 分析和汇总数据
categories = list(data.keys())
values = list(data.values())
# 创建图表
plt.figure(figsize=(10, 6))
if plot_type == 'bar':
plt.bar(categories, values, color='skyblue')
elif plot_type == 'pie':
plt.pie(values, labels=categories, autopct='%1.1f%%', startangle=140,textprops={'fontproperties': my_font})
elif plot_type == 'line':
plt.plot(categories, values, marker='o', linestyle='-')
else:
raise ValueError("Unsupported plot type. Choose from 'bar', 'pie', or 'line'.")
# 添加标题和标签
plt.title(title,fontproperties=my_font)
if plot_type != 'pie': # 饼状图不需要轴标签
plt.xlabel(xlabel,fontproperties=my_font)
plt.ylabel(ylabel,fontproperties=my_font)
plt.xticks(fontproperties=my_font,rotation=45)
plt.yticks(fontproperties=my_font)
namesid = uuid.uuid1()
# 保存图表为图片文件
file_path = f'{GENERATED_IMAGES_BASE_PATH}/plot{namesid}.png'
absolute_path = os.path.abspath(file_path)
# sources = utils.get_shared_variable(uuids)
# sources["url"] = f"plot{namesid}.png"
# sources["END"]="ok"
# utils.set_shared_variable(uuids, sources)
# plt.figure(figsize=(50, 60))
# plt.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.15)
plt.tight_layout()
plt.savefig(absolute_path)
plt.close()
image_name = f"plot{namesid}.png"
return f"图片如下:![图片{title}](http://127.0.0.1:8099/chat_web_backend/get-image?file_name={image_name})"
except Exception as e:
logging.error(f"An error occurred: {e}")
return "不要再调用该工具了,无法提供该功能,你只能按工具描述提供有的入参,其他不支持"
class drawPlotInput(BaseModel):
query: str = Field(...,description="输入要画图的内容")
def draw_realistic_pic(query:str) -> str:
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
uuid = json.loads(matches[1])["uuid"]
query = query.replace(matches[1],"")
else:
return"暂时无法画图"
except:
logging.error("Invalid JSON format in query.")
# 请求体数据
data = {
'prompt': json.loads(query)["query"]
}
try:
# 发起 POST 请求
response = requests.post(realistic_url, json=data)
# 检查响应状态码
if response.status_code == 200:
# 解析 JSON 响应
result = response.json()
print("Image path:", result.get('image_path'))
file_path = result.get('image_path')
sources = utils.get_shared_variable(uuid)
sources["url"] = file_path
# sources["END"]="ok"
utils.set_shared_variable(uuid, sources)
return f"<关键指令>你必须以markdown格式的图片路径![图片](http://127.0.0.1:8099/chat_web_backend/get-image?file_name={file_path}),之后你也需要按照要求给出更多的markdown格式的数据表格</关键指令>"
else:
print("Failed to get response:", response.status_code)
return "暂时无法画图"
except requests.exceptions.RequestException as e:
print("An error occurred:", e)
return "暂时无法画图"
class drawRealisticInput(BaseModel):
query: str = Field(...,description="输入要画图的内容")
def draw_ink_pic(query:str) -> str:
get_llm_model_response()
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
uuid = json.loads(matches[1])["uuid"]
query = query.replace(matches[1],"")
else:
return"暂时无法画图"
except:
logging.error("Invalid JSON format in query.")
# 请求体数据
data = {
'prompt': json.loads(query)["query"]
}
try:
# 发起 POST 请求
response = requests.post(ink_url, json=data)
# 检查响应状态码
if response.status_code == 200:
# 解析 JSON 响应
result = response.json()
print("Image path:", result.get('image_path'))
file_path = result.get('image_path')
sources = utils.get_shared_variable(uuid)
sources["url"] = file_path
# sources["END"]="ok"
utils.set_shared_variable(uuid, sources)
return f"<关键指令>你必须以markdown格式的图片路径![图片](http://127.0.0.1:8099/chat_web_backend/get-image?file_name={file_path}),之后你也需要按照要求给出更多的markdown格式的数据表格</关键指令>"
else:
print("Failed to get response:", response.status_code)
return "暂时无法画图"
except requests.exceptions.RequestException as e:
print("An error occurred:", e)
return "暂时无法画图"
class drawInkInput(BaseModel):
query: str = Field(...,description="输入要画图的内容")

View File

@@ -0,0 +1,186 @@
import asyncio
import re
import aiohttp
import json
import logging
from pydantic import BaseModel, Field
from server.chat import utils
# 配置日志记录器
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)
async def duckduckgo_search_iter(query: str, uuid: str = "",time: str = "", resource_type: str = None, limit: int = 3):
# 定义三个API的URL
text_url = 'http://43.251.225.121/inspur/search_text'
video_url = 'http://43.251.225.121/inspur/search_video'
news_url = 'http://43.251.225.121/inspur/search_new'
payload = {
"query": query,
"time": time
}
async def fetch(session, url, json_payload,limit):
logger.info(f"{url} 获取数据,请求参数: {json_payload}")
try:
json_payload["limit"] = limit
async with session.post(url, json=json_payload) as response:
if response.status != 200:
logger.error(f"{url} 请求失败,状态码 {response.status}")
data = await response.json()
logger.info(f"{url} 获取的资料数: {len(data) if isinstance(data, list) else '未知'}")
return data
except Exception as e:
logger.error(f"获取 {url} 数据时发生错误: {e}")
return []
# 根据 resource_type 确定要请求的 API
# 默认并发请求三个API
# 视频只请求 video_url
# 新闻只请求 news_url
# 其他类型只请求 text_url
async with aiohttp.ClientSession() as session:
logger.info("发起请求duckduckgo...")
n = limit % 3
limit1 = 0
limit2 = 0
limit3 = 0
match n:
case 0:
limit1 = limit//3
limit2 = limit1
limit3 = limit1
case 1:
limit1 = limit//3 +1
limit2 = limit//3
limit3 = limit2
case 2:
limit1 = limit//3 +1
limit2 = limit1
limit2 = limit
if resource_type is None or not resource_type == 'video':
text_task = asyncio.create_task(fetch(session, text_url, payload,limit1))
video_task = asyncio.create_task(fetch(session, video_url, payload, limit3))
news_task = asyncio.create_task(fetch(session, news_url, payload, limit2))
text_result, video_result, news_result = await asyncio.gather(text_task, video_task, news_task)
logger.info("合并结果...")
logger.info("合并结果完成")
combined_result = {
"text": text_result,
"video": video_result,
"news": news_result
}
else:
video_result = await fetch(session, video_url, payload, limit)
combined_result = {
"video": video_result
}
del limit1,limit2,limit3
# elif resource_type == 'news':
# news_result = await fetch(session, news_url, payload)
# combined_result = {
# "news": news_result
# }
# else: # 其他类型
# text_result = await fetch(session, text_url, payload)
# combined_result = {
# "text": text_result
# }
logger.info("请求已完成")
res = []
source = []
info = utils.get_shared_variable(uuid)
index = info["num"]
if "text" in combined_result:
for item in combined_result["text"]:
index += 1
res.append(f'资料[{index}] 资料标题{item["title"]}({item["href"]}) 资料内容为: {item["body"]}')
source.append(f'资料[{index}] [{item["title"]}]({item["href"]})')
if "video" in combined_result:
for item in combined_result["video"]:
index += 1
res.append(f'资料[{index}] 视频标题[{item["title"]}]({item["content"]}) 视频内容为: {item["description"]}')
source.append(f'视频资料[{index}] [{item["title"]}]({item["content"]})')
if "news" in combined_result:
for item in combined_result["news"]:
index += 1
res.append(f'资料[{index}] 新闻标题[{item["title"]}]({item["url"]}) 新闻内容为: {item["body"]}')
source.append(f'资料[{index}] [{item["title"]}]({item["url"]})')
info["source_docs"].extend(source)
utils.set_shared_variable(uuid, info)
return res,source
def duckduckgo_search(query: str, time: str = "", resource_type: str = None):
logger.info(f"模型输入: {query}")
# 对传入的 query 字段进行解析
# 判断 query 是否包含 "}{"
# if "}{" in query:
# # 将 query 分割为两个JSON字符串
# split_index = query.find("}{")
# json_part1 = query[:split_index+1]
# json_part2 = query[split_index+1:]
# try:
# obj1 = json.loads(json_part1)
# obj2 = json.loads(json_part2)
# # 提取 query, resource_type, time, uuid
# parsed_query = obj1.get("query", "")
# parsed_resource_type = obj1.get("resource_type", None)
# parsed_time = obj1.get("time", time) # 如obj1未包含time则使用传入的默认值
# parsed_uuid = obj2.get("uuid", "")
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
else:
return "<关键指令>不需要再调用该工具了</关键指令>"
try:
obj1= json.loads(query)
parsed_query = obj1.get("query", "")
parsed_limit = obj1.get("limit", 3)
parsed_resource_type = obj1.get("resource_type", None)
parsed_time = obj1.get("time", time) # 如obj1未包含time则使用传入的默认值
parsed_uuid = json.loads(matches[1])["uuid"]
# 将解析到的值覆盖原有的参数
query = parsed_query if parsed_query else query
resource_type = parsed_resource_type if parsed_resource_type else resource_type
time = parsed_time if parsed_time else time
logger.info(f"解析完成query: {query}, uuid: {parsed_uuid}, time: {time}, resource_type: {resource_type}, parsed_limit: {parsed_limit}")
except json.JSONDecodeError as e:
logger.error(f"解析JSON出错: {e}")
# 在同步环境中运行异步函数
combined_result = asyncio.run(duckduckgo_search_iter(query, parsed_uuid, time, resource_type, parsed_limit))
# 以标准json格式输出
logger.info("返回JSON格式的结果给到模型...")
return combined_result
class DuckduckgoInput(BaseModel):
location: str = Field(description="网络搜索查询")
if __name__ == "__main__":
# 测试调用
# 1. 默认请求三个API
# result_default = duckduckgo_search("粉末冶金", "m", "default")
# print("duckduckgo输出(默认):\n", result_default)
# # 2. 只请求视频
# result_video = duckduckgo_search("粉末冶金", "m", "video")
# print("duckduckgo输出(视频):\n", result_video)
# # 3. 只请求新闻
# result_news = duckduckgo_search("粉末冶金", "m", "news")
# print("duckduckgo输出(新闻):\n", result_news)
# 4. 其它类型只请求文本
result_other = duckduckgo_search("粉末冶金", "m", "other")
print("duckduckgo输出(其他):\n", result_other)

View File

@@ -0,0 +1,57 @@
from datetime import datetime
import json
import logging
import re
import requests
from configs.model_config import LLM_MODELS
from server.chat import utils
from server.chat.policy_fun_iast import get_llm_model_response
def mysql_statistic(query:str) -> str:
try:
logging.info(f"\n🔍 统计工具查询query: \n{query}\n")
matches = re.search(r'<param>(.*?)</param>(\{.*\})', query.replace("\n", ""))
if matches:
uuid = json.loads(matches.group(2))["uuid"]
query = matches.group(1).strip()
else:
res1 = utils.get_shared_variable(uuid)
# res1["END"]="ok"
utils.set_shared_variable(uuid,res1)
return"暂时无法查询"
except:
res = utils.get_shared_variable(uuid)
# res["END"]="ok"
utils.set_shared_variable(uuid,res)
logging.error("Invalid JSON format in query.")
return f"暂时无法查询"
question = json.loads(query)["query"]
# question = get_llm_model_response(
# strategy_name="default_code",
# llm_model_name=LLM_MODELS[2],
# template_prompt_name="sql_query_rewrite",
# prompt_param_dict={"query": question,"time": datetime.now().strftime("%Y%m%d")},
# temperature=0.01,
# max_tokens=512
# )
logging.info(f"\n🔍 NL2SQL检索question: \n{question}\n")
res = requests.post(
url=f"http://127.0.0.1:6008/query",
json={"question": question},
headers={"Content-Type": "application/json"}
)
if res:
data = res.json()["result"]
if "'data': []" in data:
return f"统计库未检索到数据,使用“联网思索”工具检索该请求:{question}\n"
# temp = utils.get_shared_variable(uuid)
# temp["END"]="ok"
# utils.set_shared_variable(uuid,temp)
else:
return f"判断得到的数据是否准确,如果不准确,则使用“联网思索”工具检索。如果准确,则根据数据表格并使用图表绘制工具制图。\n 数据如下所示: \n{data}\n"

View File

@@ -0,0 +1,170 @@
import json
import logging
import re
from typing import List, Any, Union
import concurrent
from pydantic import BaseModel, Field
from difflib import SequenceMatcher
from configs import (VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
DEFAULT_POLICY_BASE)
from server.agent.tools import search_internet
from server.chat import utils
from server.knowledge_base.kb_doc_api import search_docs
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.utils import BaseResponse
class KnowledgeKgoInput(BaseModel):
location: str = Field(description="Query for Internet search")
def preprocess_text(text: str) -> str:
# 去除空格和特殊符号
text = re.sub(r'[\s\W]', '', text.strip())
return text
def knowledge_temperature(a: str, b: str) -> float:
# 使用difflib中的SequenceMatcher计算相似度
return SequenceMatcher(None, a, b).ratio()
# def knowledgebase_kgo_iter(query: str,
# fileName: List = [],
# knowledge_base_name: str = DEFAULT_POLICY_BASE,
# top_k: int = VECTOR_SEARCH_TOP_K,
# score_threshold: float = SCORE_THRESHOLD) -> BaseResponse | list[str] | Any:
# kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
# if kb is None:
# return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
# query = query.strip()
# docs = search_docs(fileName=fileName,
# query=query,
# knowledge_base_name=knowledge_base_name,
# top_k=top_k,
# score_threshold=score_threshold)
# # 预处理查询文本
# processed_query = preprocess_text(query).replace("Observ","")
# print("processed_query:", processed_query)
# knowledge_docs = []
# knowledge_content = []
# # 知识库返回的文档与query的相似度
# if docs:
# for enum, doc in enumerate(docs):
# filename = doc.metadata.get("title")
# detail_url = "https://policy.ckcest.cn/detail/" + doc.metadata.get("primary_key") + ".html"
# if filename:
# text = f"""政策类资料[{enum + 1}]: [{filename}]({detail_url})\n"""
# else:
# text = f"""政策类资料[{enum + 1}]: [{"原文地址"}]({detail_url})\n"""
# knowledge_docs.append(text)
# # print("knowledge_docs:", knowledge_docs)
# knowledge_content = [doc.page_content for doc in docs]
# # print("knowledge_content:", knowledge_content)
# # 计算知识库返回的文档与query的相似度
# titles = [doc.metadata.get("title") for doc in docs]
# print("titles:", titles)
# def check_similarity_threshold(titles: List[str], query: str, knowledge_docs: List[str], knowledge_content: List[str]) -> Union[
# List[str], None]:
# # 用于记录是否存在相似度大于0.55的标题
# has_similar_title = False
# for title in titles:
# processed_title = preprocess_text(title)
# similarity = knowledge_temperature(processed_query, processed_title)
# print("processed_title:", processed_title)
# print("similarity:", similarity)
# if similarity >= 0.55:
# has_similar_title = True
# break
# # 如果存在相似度大于0.55的标题,则直接返回 knowledge_docs
# if has_similar_title:
# knowledge = knowledge_content + knowledge_docs
# return knowledge
# # 如果所有标题的相似度都不大于0.55,则返回 None
# return None
# # 在原函数中使用新的函数进行相似度阈值的判断
# similar_docs = check_similarity_threshold(titles, query, knowledge_docs, knowledge_content)
# if similar_docs is None:
# # 如果所有标题的相似度都不大于0.55,则执行搜索引擎查询
# kgo_docs = search_internet(processed_query)
# # print("kgo_docs", kgo_docs)
# return kgo_docs
# else:
# kgo_docs = search_internet(processed_query)
# # print("similar_docs", similar_docs)
# # print("kgo_docs", kgo_docs)
# similar_docs.extend(kgo_docs)
# return similar_docs
# else:
# # 执行搜索引擎查询
# kgo_docs = search_internet(query)
# return kgo_docs
def knowledgebase_kgo_iter(query: str, uid: str) -> BaseResponse | list[str] | Any:
kgo_docs = search_internet(query , uid)
return kgo_docs
def knowledgebase_kgo_search(query: str) -> List[str]:
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
parsed_query = json.loads(query)
# 继续使用解析后的查询进行后续操作
time_based_uuid = json.loads(matches[1])["uuid"]
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(knowledgebase_kgo_iter, parsed_query["query"], time_based_uuid)
res = future.result()
# res = knowledgebase_kgo_iter(parsed_query["query"],time_based_uuid)
try:
if type(res[0])==list and len(res[0])>0:
doc_content = "资料内容"
for doc in res[0]:
doc_content += doc
doc_content += "资料来源"
for source in res[1]:
doc_content += source
return f"{doc_content}"
elif type(res[1])==list and len(res[1])>0:
doc_content += "资料来源"
for source in res[1]:
doc_content += source
return "只有标题没有内容,标题为:{doc_content}"
else:
# return "<system>不要再调用工具了</system>"
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
except Exception as e:
logging.error(f"Error occurred while processing query: {e}")
# return "<system>不要再调用该工具了,根据已有资料或自身能力回答</system>"
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
else:
logging.error("Invalid JSON format in query.")
# return "<system>不要再调用该工具了</system>"
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
except json.JSONDecodeError:
# 如果JSON解码失败则返回错误消息
logging.error("Invalid JSON format in query.")
# return "<system>不要再调用该工具了</system>"
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
except KeyError:
# 如果解析的JSON对象中缺少必要的键则返回错误消息
# return "<system>不要再调用该工具了</system>"
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
except Exception as e:
# 捕获其他所有异常,并返回通用错误消息
logging.error(f"Error occurred while processing query: {e}")
# return f"<system>不要再调用该工具了</system>"
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
if __name__ == "__main__":
result = knowledgebase_kgo_iter("《区块链和分布式记账技术标准体系建设指南》")
print("检索结果:", result)

View File

@@ -0,0 +1,113 @@
from datetime import datetime
import json
import logging
import re
from pydantic import BaseModel, Field
from configs.model_config import LLM_MODELS
from server.agent.tools.search_tool import search_tool
from server.chat import utils
from server.chat.policy_fun_iast import get_llm_model_response
def math_count(query: str):
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
parsed_query = json.loads(query)["query"]
# 继续使用解析后的查询进行后续操作
time_based_uuid = json.loads(matches[1])["uuid"]
# temp = utils.get_shared_variable(time_based_uuid)
# temp["END"] ="ok"
# utils.set_shared_variable(time_based_uuid,temp)
tip = {}
# tip["END"]="ok"
tip["source_docs"]=[]
tip["num"]=0
tip["title"]=[]
utils.set_shared_variable(time_based_uuid+"q",tip)
first_json = {
"query": parsed_query,
"knowledge_name": [],
"keywords": []
}
second_json = {
"uuid": time_based_uuid+"q"
}
math_doc = search_tool(json.dumps(first_json) + json.dumps(second_json))
utils.remove_shared_variable(time_based_uuid+"q")
res = get_llm_model_response(
strategy_name="default_math",
llm_model_name=LLM_MODELS[3],
template_prompt_name="default_math",
prompt_param_dict={"input": parsed_query, "math_doc": f"{math_doc}", "time": datetime.now().strftime("%Y%m%d")},
temperature=0.01,
max_tokens=512
)
return f"{res}"
except Exception as e:
logging.error(f"Error occurred while processing math query: {e}")
return "<system>不要再调用该工具了</system>"
except Exception as e:
return "<system>不要再调用该工具了</system>"
def code_count(query: str):
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
parsed_query = json.loads(query)["query"]
# 继续使用解析后的查询进行后续操作
time_based_uuid = json.loads(matches[1])["uuid"]
# temp = utils.get_shared_variable(time_based_uuid)
# temp["END"] ="ok"
# utils.set_shared_variable(time_based_uuid,temp)
tip = {}
# tip["END"]="ok"
tip["source_docs"]=[]
tip["num"]=0
tip["title"]=[]
utils.set_shared_variable(time_based_uuid+"q",tip)
first_json = {
"query": parsed_query,
"knowledge_name": [],
"keywords": []
}
second_json = {
"uuid": time_based_uuid+"q"
}
code_doc = search_tool(json.dumps(first_json) + json.dumps(second_json))
utils.remove_shared_variable(time_based_uuid+"q")
res = get_llm_model_response(
strategy_name="default_code",
llm_model_name=LLM_MODELS[2],
template_prompt_name="default_code",
prompt_param_dict={"input": parsed_query, "code_doc": f"{code_doc}", "time": datetime.now().strftime("%Y%m%d")},
temperature=0.01,
max_tokens=512
)
res = res.replace("<think>","")
return f"{res}"
except Exception as e:
logging.error(f"Error occurred while processing math query: {e}")
return "<system>不要再调用该工具了</system>"
except Exception as e:
return "<system>不要再调用该工具了</system>"
class RagSearchInput(BaseModel):
query: str = Field(...,description="查询对象")

View File

@@ -0,0 +1,42 @@
from server.chat.knowledge_base_chat import knowledge_base_chat
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS, LLM_MODELS
import json
import asyncio
from typing import List, Tuple, Any
from pydantic import BaseModel, Field
class PolicyKnowledgeInput(BaseModel):
location: str = Field(description="The policy related query to be searched")
async def policy_knowledgebase_search_iter(query: str) -> tuple[str | Any, list[Any] | Any]:
response = await knowledge_base_chat(query=query,
fileName=None,
knowledge_base_name_list = ["t_policy_total_bge_new_v1"],
model_name=LLM_MODELS[0],
temperature=0.01,
history=[],
top_k=VECTOR_SEARCH_TOP_K,
max_tokens=MAX_TOKENS,
prompt_name="default",
score_threshold=SCORE_THRESHOLD,
stream=False)
contents = ""
docs = []
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
print("data>>>>>", data)
contents = data["answer"] if "answer" in data.keys() else contents
docs = data["docs"] if "docs" in data.keys() else docs
return contents, docs
def policy_knowledgebase_search(query: str) -> tuple[str | Any, list[Any] | Any]:
return asyncio.run(policy_knowledgebase_search_iter(query))
if __name__ == "__main__":
result = policy_knowledgebase_search("大数据男女比例")
print("答案:", result)

View File

@@ -0,0 +1,108 @@
import json
import re
import concurrent
from fastapi.concurrency import run_in_threadpool
from langchain.tools import YouTubeSearchTool
from pydantic import BaseModel, Field
from server.chat import utils
from server.chat.utils import doc_to_list, solve_knowledge_map,solve_mental_data,shared_variable
from server.knowledge_base.kb_doc_api import search_docs
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from configs import kb_config
def rag_search1(query: str):
"""
根据用户输入的query返回rag搜索结果
"""
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
else:
return "<关键指令>不需要再调用该工具了</关键指令>"
time_based_uuid = json.loads(matches[1])["uuid"]
search = json.loads(query)
search_query = search["query"]
search_keywords = []
search_text = f"{search_query}"
if type(search["keywords"]) == list:
search_keywords = search["keywords"]
for keyword in search_keywords:
search_text += f" {keyword}"
else:
search_keywords = search["keywords"].split(",")
for keyword in search_keywords:
search_text += f" {keyword}"
result = []
source_docs = {}
knownledge_name = []
if type(search["knowledge_name"]) == list:
knownledge_name=search["knowledge_name"]
else:
knownledge_name=search["knowledge_name"].split(",")
for knownledge in knownledge_name:
if not knownledge in kb_config.CH_BASE_NAME:
knownledge_name.remove(knownledge)
if len(knownledge_name)==0:
result.append(f"没有找到匹配的知识库,请必须更换联网思索搜索更多知识库内容")
return result
# knownledge_name=kb_config.CH_BASE_NAME
knownledge_name=solve_knowledge_map(knownledge_name)
num = 0
for knownledge in knownledge_name:
source_docs[knownledge] = []
seen_docs = set()
duplicate_indices = []
doc_list = search_docs(usr_query=search_text,fileName= [],top_k=5,score_threshold=0.9,query=search_text, knowledge_base_name=knownledge)
for inum,doc in enumerate(doc_list):
solve_mental_data(knownledge,doc_list,doc=doc,seen_docs=seen_docs,duplicate_indices=duplicate_indices,knowledge=result,inum=inum)
# 从policydocs中删除重复的文档从后往前删除以防止索引错位
for index in sorted(duplicate_indices, reverse=True):
del doc_list[index]
# 处理原文来源进入数组。使用开关语句明确各个条件分支
match knownledge:
# 属于政策库分支,入参为中文政策库名称
case kb_config.DEFAULT_POLICY_BASE:
doc_to_list(num,kb_config.DEFAULT_POLICY_BASE_NAME,doc_list,source_docs[knownledge])
# 属于期刊论文库分支,入参为期刊论文库的中文名称
case kb_config.DEFAULT_JOURNAL_BASE:
doc_to_list(num,kb_config.DEFAULT_JOURNAL_BASE_NAME,doc_list,source_docs[knownledge])
# 属于报告库分支,入参为报告库中文名称
case kb_config.DEFAULT_REPORT_BASE:
doc_to_list(num,kb_config.DEFAULT_REPORT_BASE_NAME,doc_list,source_docs[knownledge])
# 属于冶金行业新闻库分支,入参为冶金行业新闻库中文名称
case kb_config.GY_NEWS_BASE:
doc_to_list(num,kb_config.GY_NEWS_BASE_NAME,doc_list,source_docs[knownledge])
# 属于冶金行业报告库分支,入参为冶金行业报告库中文名称
case kb_config.GY_REPORT_BASE:
doc_to_list(num,kb_config.GY_REPORT_BASE_NAME,doc_list,source_docs[knownledge])
# 属于冶金专业知识库分支,入参为冶金专业知识库中文名称
case kb_config.GY_JOURNAL_BASE:
doc_to_list(num,kb_config.GY_JOURNAL_BASE_NAME,doc_list,source_docs[knownledge])
case _:
print(f"输入了没有的知识库名称")
return("输入了没有的知识库名称")
num += len(source_docs[knownledge])
# 构建缓存对象用于h'per_query_cache'用于传递给其他方法使用uuid获取本轮对话的文献来源
del num
source = utils.get_shared_variable(time_based_uuid)
print(utils.get_shared_variable(time_based_uuid))
source["source_docs"]=source_docs
utils.set_shared_variable(time_based_uuid,source)
if 0<len(result)<3:
return f"当前资料:{result}\n<关键指令>搜索结果较少,更换知识库或联网思索重新搜索!!!</关键指令>"
if len(result)==0:
return "注意:【指令:更换知识库或联网思索继续搜索!!!】"
except:
# return "入参格式需为{\"knowledge_name\":[\"XXX\",\"XXX\"],\"query\":\"XXX\",\"keywords\":[\"XXX\", \"XXX\", \"XXX\", \"XXX\"]}检查输入参数如果没有缺少必要值,当前工具异常请换其他工具"
return "当前工具异常!请换其他工具"
return f"当前资料:{result}\n<关键指令>总结此内容!!!</关键指令>"
class RagSearchInput(BaseModel):
query: str = Field(...,description="查询对象")

View File

@@ -0,0 +1,42 @@
from server.chat.report_chat import report_chat
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS, LLM_MODELS, DEFAULT_REPORT_BASE
import json
import asyncio
from typing import List, Tuple, Any
from pydantic import BaseModel, Field
class ReportKnowledgeInput(BaseModel):
location: str = Field(description="The report related query to be searched")
async def report_knowledgebase_search_iter(query: str) -> tuple[str | Any, list[Any] | Any]:
response = await report_chat(query=query,
fileName=None,
knowledge_base_name=DEFAULT_REPORT_BASE,
model_name=LLM_MODELS[0],
temperature=0.01,
history=[],
top_k=VECTOR_SEARCH_TOP_K,
max_tokens=MAX_TOKENS,
prompt_name="default",
score_threshold=SCORE_THRESHOLD,
stream=False)
contents = ""
docs = []
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
print("data>>>>>", data)
contents = data["answer"] if "answer" in data.keys() else contents
docs = data["docs"] if "docs" in data.keys() else docs
return contents, docs
def report_knowledgebase_search(query: str) -> tuple[str | Any, list[Any] | Any]:
return asyncio.run(report_knowledgebase_search_iter(query))
if __name__ == "__main__":
result = report_knowledgebase_search("大数据男女比例")
print("答案:", result)

View File

@@ -0,0 +1,89 @@
import json
import asyncio
import unicodedata
from server.chat.KgoSearchAPIWrapper import KgoSearchAPIWrapper
from server.chat.search_engine_chat import search_engine_chat
from configs import VECTOR_SEARCH_TOP_K, MAX_TOKENS
from server.agent import model_container
from pydantic import BaseModel, Field
from configs import LLM_MODELS, TEMPERATURE
from configs.basic_config import *
# def get_kgo_search_type(query: str = "全部"):
# # 过滤掉所有非汉字的字符
# query = ''.join(char for char in query if unicodedata.category(char) == 'Lo')
# search_map = KgoSearchAPIWrapper().search_map
# if "论文" in query:
# return "1001"
# elif "外文" in query or "英文" in query:
# return "1013"
# elif "期刊" in query or "研究进展" in query:
# return "1002"
# else:
# matched_types = [value for key, value in search_map.items() if key in query]
# if matched_types:
# return ','.join(matched_types)
# else:
# print("未找到匹配的搜索类型返回默认值1000")
# return "1000"
@timing_decorator
async def search_engine_iter(query: str , uid: str):
response = await search_engine_chat(uid = uid,
query=query,
search_engine_name="zhipu_search",
model_name=LLM_MODELS[1],
temperature=TEMPERATURE, # Agent搜索互联网的时候温度设为0.1
history=[],
top_k=VECTOR_SEARCH_TOP_K,
max_tokens=MAX_TOKENS,
prompt_name="search",
stream=False,
kgo_search_type="1000")
contents = ""
docs = []
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
contents = data.get("answer", [])
current_docs = data.get("docs", [])
if current_docs:
docs.extend(current_docs)
else:
logging.error("No docs found in the response")
return docs
# print("contents", contents)
# print("docs", docs)
# 回复搜索结果和搜索结果的总结
return contents,docs
def search_internet(query: str, uid: str):
# filter_words = {
# "统计数据", "视频", "数据集", "新闻", "专利", "期刊", "图书", "报告", "项目", "成果", "会议论文",
# "政策", "外文期刊论文", "学位论文", "期刊论文", "全部论文", "原文", "全文", "pdf", "资料", " ",
# "进展", "研究", "最新", "外文", "英文", "最新", "文件", "资料", "论文"
# }
# filtered_query = query
# for word in filter_words:
# if word in query:
# filtered_query = filtered_query.replace(word, "")
# print("filtered query:", filtered_query)
# kgo_search_type = get_kgo_search_type(query)
# print("kgo_search_type:", kgo_search_type)
# 使用过滤后的查询字符串进行搜索
return asyncio.run(search_engine_iter(query, uid))
class SearchInternetInput(BaseModel):
location: str = Field(description="Query for Internet search")
if __name__ == "__main__":
result = search_internet("人工智能领域的政策")
print("答案:", result)

View File

@@ -0,0 +1,15 @@
name: search_internet
description: Use this tool to surf internet and get information
parameters:
type: object
properties:
query:
type: string
description: Query for Internet search
kgo_search_type:
type: int
description: the return value 'kgo_search_type' of the 'get_kgo_search_type'
default: 1000
required:
- query
- kgo_search_type

View File

@@ -0,0 +1,294 @@
from __future__ import annotations
import json
import re
import warnings
from typing import Dict
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun
from langchain.chains.llm import LLMChain
from langchain.pydantic_v1 import Extra, root_validator
from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from typing import List, Any, Optional
from langchain.prompts import PromptTemplate
from server.chat.knowledge_base_chat import knowledge_base_chat
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
import asyncio
from server.agent import model_container
from pydantic import BaseModel, Field
async def search_knowledge_base_iter(database: str, query: str) -> str:
response = await knowledge_base_chat(query=query,
knowledge_base_name=database,
model_name=model_container.MODEL.model_name,
temperature=0.01,
history=[],
top_k=VECTOR_SEARCH_TOP_K,
max_tokens=MAX_TOKENS,
prompt_name="default",
score_threshold=SCORE_THRESHOLD,
stream=False)
contents = ""
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
contents += data["answer"]
docs = data["docs"]
return contents
async def search_knowledge_multiple(queries) -> List[str]:
# queries 应该是一个包含多个 (database, query) 元组的列表
tasks = [search_knowledge_base_iter(database, query) for database, query in queries]
results = await asyncio.gather(*tasks)
# 结合每个查询结果,并在每个查询结果前添加一个自定义的消息
combined_results = []
for (database, _), result in zip(queries, results):
message = f"\n查询到 {database} 知识库的相关信息:\n{result}"
combined_results.append(message)
return combined_results
def search_knowledge(queries) -> str:
responses = asyncio.run(search_knowledge_multiple(queries))
# 输出每个整合的查询结果
contents = ""
for response in responses:
contents += response + "\n\n"
return contents
_PROMPT_TEMPLATE = """
用户会提出一个需要你查询知识库的问题,你应该对问题进行理解和拆解,并在知识库中查询相关的内容。
对于每个知识库,你输出的内容应该是一个一行的字符串,这行字符串包含知识库名称和查询内容,中间用逗号隔开,不要有多余的文字和符号。你可以同时查询多个知识库,下面这个例子就是同时查询两个知识库的内容。
例子:
robotic,机器人男女比例是多少
bigdata,大数据的就业情况如何
这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能,你应该参考他们的功能来帮助你思考
{database_names}
你的回答格式应该按照下面的内容,请注意```text 等标记都必须输出,这是我用来提取答案的标记。
不要输出中文的逗号,不要输出引号。
Question: ${{用户的问题}}
```text
${{知识库名称,查询问题,不要带有任何除了,之外的符号,比如不要输出中文的逗号,不要输出引号}}
```output
数据库查询的结果
现在,我们开始作答
问题: {question}
"""
PROMPT = PromptTemplate(
input_variables=["question", "database_names"],
template=_PROMPT_TEMPLATE,
)
class LLMKnowledgeChain(LLMChain):
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
prompt: BasePromptTemplate = PROMPT
"""[Deprecated] Prompt to use to translate to python if necessary."""
database_names: Dict[str, str] = None
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the from_llm "
"class method."
)
if "llm_chain" not in values and values["llm"] is not None:
prompt = values.get("prompt", PROMPT)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _evaluate_expression(self, queries) -> str:
try:
output = search_knowledge(queries)
except Exception as e:
output = "输入的信息有误或不存在知识库,错误信息如下:\n"
return output + str(e)
return output
def _process_llm_result(
self,
llm_output: str,
run_manager: CallbackManagerForChainRun
) -> Dict[str, str]:
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
# text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1).strip()
cleaned_input_str = (expression.replace("\"", "").replace("", "").
replace("", "").replace("```", "").strip())
lines = cleaned_input_str.split("\n")
# 使用逗号分割每一行,然后形成一个(数据库,查询)元组的列表
try:
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
except:
queries = [(line.split("")[0].strip(), line.split("")[1].strip()) for line in lines]
run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", verbose=self.verbose)
output = self._evaluate_expression(queries)
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = llm_output.split("Answer:")[-1]
else:
return {self.output_key: f"输入的格式不对:\n {llm_output}"}
return {self.output_key: answer}
async def _aprocess_llm_result(
self,
llm_output: str,
run_manager: AsyncCallbackManagerForChainRun,
) -> Dict[str, str]:
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1).strip()
cleaned_input_str = (
expression.replace("\"", "").replace("", "").replace("", "").replace("```", "").strip())
lines = cleaned_input_str.split("\n")
try:
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
except:
queries = [(line.split("")[0].strip(), line.split("")[1].strip()) for line in lines]
await run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue",
verbose=self.verbose)
output = self._evaluate_expression(queries)
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
raise ValueError(f"unknown format from LLM: {llm_output}")
return {self.output_key: answer}
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key])
self.database_names = model_container.DATABASE
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
llm_output = self.llm_chain.predict(
database_names=data_formatted_str,
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return self._process_llm_result(llm_output, _run_manager)
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
await _run_manager.on_text(inputs[self.input_key])
self.database_names = model_container.DATABASE
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
llm_output = await self.llm_chain.apredict(
database_names=data_formatted_str,
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
@property
def _chain_type(self) -> str:
return "llm_knowledge_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = PROMPT,
**kwargs: Any,
) -> LLMKnowledgeChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)
def search_knowledgebase_complex(query: str):
model = model_container.MODEL
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_knowledge.run(query)
return ans
class KnowledgeSearchInput(BaseModel):
location: str = Field(description="The query to be searched")
class RagSearchInput(BaseModel):
query: str = Field(description="查询对象")
knowledge_name: str = Field(description="The name of the knowledge base to be searched,policy knowledge base name is t_policy_total_bge_new_v2, example: t_policy_total_bge_new_v2]")
keywords: str = Field(description="The keywords to be searched example: age,child]")
if __name__ == "__main__":
result = search_knowledgebase_complex("机器人和大数据在代码教学上有什么区别")
print(result)
# 这是一个正常的切割
# queries = [
# ("bigdata", "大数据专业的男女比例"),
# ("robotic", "机器人专业的优势")
# ]
# result = search_knowledge(queries)
# print(result)

View File

@@ -0,0 +1,10 @@
name: search_knowledgebase_complex
description: Use this tool to search local knowledgebase and get information
parameters:
type: object
properties:
query:
type: string
description: The query to be searched
required:
- query

View File

@@ -0,0 +1,234 @@
from __future__ import annotations
import re
import warnings
from typing import Dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.llm import LLMChain
from langchain.pydantic_v1 import Extra, root_validator
from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from typing import List, Any, Optional
from langchain.prompts import PromptTemplate
import sys
import os
import json
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from server.chat.knowledge_base_chat import knowledge_base_chat
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
import asyncio
from server.agent import model_container
from pydantic import BaseModel, Field
async def search_knowledge_base_iter(database: str, query: str):
response = await knowledge_base_chat(query=query,
knowledge_base_name=database,
model_name=model_container.MODEL.model_name,
temperature=0.01,
history=[],
top_k=VECTOR_SEARCH_TOP_K,
max_tokens=MAX_TOKENS,
prompt_name="knowledge_base_chat",
score_threshold=SCORE_THRESHOLD,
stream=False)
contents = ""
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
contents += data["answer"]
docs = data["docs"]
return contents
_PROMPT_TEMPLATE = """
用户会提出一个需要你查询知识库的问题,你应该按照我提供的思想进行思考
Question: ${{用户的问题}}
这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能:
{database_names}
你的回答格式应该按照下面的内容,请注意,格式内的```text 等标记都必须输出,这是我用来提取答案的标记。
```text
${{知识库的名称}}
```
```output
数据库查询的结果
```
答案: ${{答案}}
现在,这是我的问题:
问题: {question}
"""
PROMPT = PromptTemplate(
input_variables=["question", "database_names"],
template=_PROMPT_TEMPLATE,
)
class LLMKnowledgeChain(LLMChain):
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
prompt: BasePromptTemplate = PROMPT
"""[Deprecated] Prompt to use to translate to python if necessary."""
database_names: Dict[str, str] = model_container.DATABASE
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the from_llm "
"class method."
)
if "llm_chain" not in values and values["llm"] is not None:
prompt = values.get("prompt", PROMPT)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _evaluate_expression(self, dataset, query) -> str:
try:
output = asyncio.run(search_knowledge_base_iter(dataset, query))
except Exception as e:
output = "输入的信息有误或不存在知识库"
return output
return output
def _process_llm_result(
self,
llm_output: str,
llm_input: str,
run_manager: CallbackManagerForChainRun
) -> Dict[str, str]:
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
database = text_match.group(1).strip()
output = self._evaluate_expression(database, llm_input)
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
return {self.output_key: f"输入的格式不对: {llm_output}"}
return {self.output_key: answer}
async def _aprocess_llm_result(
self,
llm_output: str,
run_manager: AsyncCallbackManagerForChainRun,
) -> Dict[str, str]:
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
raise ValueError(f"unknown format from LLM: {llm_output}")
return {self.output_key: answer}
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key])
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
llm_output = self.llm_chain.predict(
database_names=data_formatted_str,
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return self._process_llm_result(llm_output, inputs[self.input_key], _run_manager)
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
await _run_manager.on_text(inputs[self.input_key])
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
llm_output = await self.llm_chain.apredict(
database_names=data_formatted_str,
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
@property
def _chain_type(self) -> str:
return "llm_knowledge_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = PROMPT,
**kwargs: Any,
) -> LLMKnowledgeChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)
def search_knowledgebase_once(query: str):
model = model_container.MODEL
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_knowledge.run(query)
return ans
class KnowledgeSearchInput(BaseModel):
location: str = Field(description="The query to be searched")
if __name__ == "__main__":
result = search_knowledgebase_once("大数据的男女比例")
print(result)

View File

@@ -0,0 +1,32 @@
from server.chat.knowledge_base_chat import knowledge_base_chat
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
import json
import asyncio
from server.agent import model_container
async def search_knowledge_base_iter(database: str, query: str) -> str:
response = await knowledge_base_chat(query=query,
knowledge_base_name=database,
model_name=model_container.MODEL.model_name,
temperature=0.01,
history=[],
top_k=VECTOR_SEARCH_TOP_K,
max_tokens=MAX_TOKENS,
prompt_name="knowledge_base_chat",
score_threshold=SCORE_THRESHOLD,
stream=False)
contents = ""
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
contents = data["answer"]
docs = data["docs"]
return contents
def search_knowledgebase_simple(query: str):
return asyncio.run(search_knowledge_base_iter(query))
if __name__ == "__main__":
result = search_knowledgebase_simple("大数据男女比例")
print("答案:",result)

View File

@@ -0,0 +1,56 @@
import json
import os
import re
from typing import List
from urllib.parse import quote
from server.agent.tools.search_tool import rag_search
from server.chat import utils
from server.knowledge_base.kb_doc_api import search_docs
def search_pic(query: str) -> List[str]:
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
else:
return "<关键指令>不需要再调用该工具了</关键指令>"
uuid = json.loads(matches[1])["uuid"]
tip = utils.get_shared_variable(uuid)
# tip["END"] ="ok"
temp = {}
temp = json.loads(query)
res = search_docs(usr_query=temp["query"],fileName= [],top_k=10,score_threshold=0.9,query=temp["query"], knowledge_base_name="p_meiyupic")
if len(res)==0 and len(tip["source_docs"])==0:
utils.set_shared_variable(uuid,tip)
return "工具没有找到结果"
# 遍历 res 中的每个元素
result = ""
for item in res:
# 获取 source 的目录部分
source_dir = os.path.splitext(item.metadata['source'])[0]
# 获取 page_content
page_content = item.page_content
# 拼接字符串
if item.metadata['source'] in tip["source_docs"]:
continue
else:
tip["source_docs"].append(item.metadata['source'])
page_content = quote(page_content.replace("http://127.0.0.1:8099/chat_web_backend", "http://127.0.0.1:8099/chat_web_backend"),safe='/:?=&#+')
result += f'![{source_dir}]({page_content})\n'
utils.set_shared_variable(uuid,tip)
if len(result)>0:
print(f"美术作品链接:{result}")
return f"注意:以下链接是图片不是参考文献,以下链接不要放到引文小标的格式输出而是以图片格式输出,禁止转义后面链接的编码,这个链接不能带中文。图片如下:{result}"
else:
return "<关键指令>不需要再调用该工具了</关键指令>"
except Exception as e:
return f"Failed to get picture.{e}"

View File

@@ -0,0 +1,331 @@
import asyncio
import concurrent
from datetime import datetime
import json
import logging
import re
from typing import List
from fastapi import logger
from configs import kb_config
from configs.model_config import LLM_MODELS
from server.agent.tools import duckduckgo_search
from server.agent.tools.duckduckgo_search import duckduckgo_search_iter
from server.agent.tools.knowledgebase_kgo_search import knowledgebase_kgo_iter
from server.agent.tools.rag_search import rag_search1
from server.chat import utils
from server.chat.policy_fun_iast import get_llm_model_response
from server.chat.utils import doc_to_list, get_similar_documents1, solve_knowledge_map, solve_mental_data
from server.knowledge_base.kb_doc_api import search_docs
def rag_search(query: str,uid):
"""
根据用户输入的query返回rag搜索结果
"""
source_docs = []
try:
search = json.loads(query)
logging.info(f'模型输入: {search["query"]}')
original_query = search["query"]
search_query = get_llm_model_response(
strategy_name="rag_search_rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="rag_search_rewrite",
prompt_param_dict={"input": search["query"], "year": datetime.now().strftime("%Y")},
temperature=0.3,
max_tokens=512
)
logging.info(f'模型改写: {search_query}')
search_keywords = []
search_text = f"{search_query}"
# if type(search["keywords"]) == list:
# search_keywords = search["keywords"]
# for keyword in search_keywords:
# search_text += f" {keyword}"
# else:
# search_keywords = search["keywords"].split(",")
# for keyword in search_keywords:
# search_text += f" {keyword}"
self_database = utils.get_shared_variable(uid)
result = []
knownledge_name = []
if type(search["knowledge_name"]) == list:
knownledge_name=search["knowledge_name"]
else:
knownledge_name=search["knowledge_name"].split(",")
if "美术专业知识库" in knownledge_name:
knownledge_name.remove("美术专业知识库")
if "database" in self_database:
self_database["database"]= self_database["database"].append("p_cafa0101011")
else:
self_database["database"] = ["p_cafa0101011"]
# 添加个人知识库
if "database" in self_database:
knownledge_name.extend(self_database["database"])
knownledge_name = [knownledge for knownledge in knownledge_name
if (knownledge in kb_config.CH_BASE_NAME
or knownledge in kb_config.EN_BASE_NAME
or knownledge in getattr(kb_config, "YJ_BASE_NAME", [])
or kb_config.SELF_KNOWLEDGE_BASE.match(knownledge)
or knownledge == "coding")]
if len(knownledge_name)==0:
#result.append(f"没有找到匹配的知识库,请必须更换联网思索搜索更多知识库内容")
return result,source_docs
# knownledge_name=kb_config.CH_BASE_NAME
knownledge_name=solve_knowledge_map(knownledge_name)
#knownledge_name = ["p_c88859a3d06e4265bd01d816ef2650d1"]
num = 0
temp=utils.get_shared_variable(uid)
for knownledge in knownledge_name:
seen_docs = set()
duplicate_indices = []
# 针对中国钢铁行业动态库增加日期范围过滤
expr_param = ""
if knownledge == kb_config.STEEL_KB:
time_today = datetime.now().strftime("%Y-%m-%d")
# 调用LLM生成日期表达式模板沿用 get_policy_time
try:
expr_candidate = get_llm_model_response(
strategy_name="get steel time",
llm_model_name=LLM_MODELS[0],
template_prompt_name="get_steel_time",
prompt_param_dict={"query": original_query, "time": time_today},
temperature=0.01,
max_tokens=512
).replace("None", "").strip()
expr_param = expr_candidate if expr_candidate else ""
except Exception as _:
expr_param = ""
doc_list = search_docs(
usr_query=original_query,
fileName=[],
top_k=20,
score_threshold=1.0,
query=search_text,
knowledge_base_name=knownledge,
expr=expr_param
)
if len(doc_list)==0:
return result,source_docs
titles = temp["title"]
doc_list,title = utils.remove_docs1(titles,doc_list)
titles.extend(title)
for inum,doc in enumerate(doc_list):
solve_mental_data(knownledge,doc_list,doc=doc,seen_docs=seen_docs,duplicate_indices=duplicate_indices,knowledge=result,inum=inum)
# 从policydocs中删除重复的文档从后往前删除以防止索引错位
for index in sorted(duplicate_indices, reverse=True):
del doc_list[index]
# 处理原文来源进入数组。使用开关语句明确各个条件分支
match knownledge:
# 属于政策库分支,入参为中文政策库名称
case kb_config.DEFAULT_POLICY_BASE:
doc_to_list(num,kb_config.DEFAULT_POLICY_BASE_NAME,doc_list,source_docs)
# 属于期刊论文库分支,入参为期刊论文库的中文名称
case kb_config.DEFAULT_JOURNAL_BASE:
doc_to_list(num,kb_config.DEFAULT_JOURNAL_BASE_NAME,doc_list,source_docs)
# 属于报告库分支,入参为报告库中文名称
case kb_config.DEFAULT_REPORT_BASE1:
doc_to_list(num,kb_config.DEFAULT_REPORT_BASE_NAME,doc_list,source_docs)
# 属于冶金行业新闻库分支,入参为冶金行业新闻库中文名称
case kb_config.GY_NEWS_BASE:
doc_to_list(num,kb_config.GY_NEWS_BASE_NAME,doc_list,source_docs)
# 属于冶金行业报告库分支,入参为冶金行业报告库中文名称
case kb_config.GY_REPORT_BASE:
doc_to_list(num,kb_config.GY_REPORT_BASE_NAME,doc_list,source_docs)
# 属于冶金专业知识库分支,入参为冶金专业知识库中文名称
case kb_config.GY_JOURNAL_BASE:
doc_to_list(num,kb_config.GY_JOURNAL_BASE_NAME,doc_list,source_docs)
# 新增冶金新闻库2024年以及之前
case kb_config.YJ_NEWS_BASE:
doc_to_list(num,kb_config.YJ_NEWS_BASE_NAME,doc_list,source_docs)
# 新增冶金中文期刊库
case kb_config.YJ_CH_JOURNAL_BASE:
doc_to_list(num,kb_config.YJ_CH_JOURNAL_BASE_NAME,doc_list,source_docs)
# 新增冶金外文期刊库
case kb_config.YJ_FOR_JOURNAL_BASE:
doc_to_list(num,kb_config.YJ_FOR_JOURNAL_BASE_NAME,doc_list,source_docs)
# 新增冶金OA期刊库
case kb_config.YJ_OA_JOURNAL_BASE:
doc_to_list(num,kb_config.YJ_OA_JOURNAL_BASE_NAME,doc_list,source_docs)
# 新增冶金政策库
case kb_config.YJ_POLICYS_BASE:
doc_to_list(num,kb_config.YJ_POLICYS_BASE_NAME,doc_list,source_docs)
# 新增中国钢铁行业动态库
case kb_config.STEEL_KB:
doc_to_list(num,kb_config.STEEL_KB_NAME,doc_list,source_docs)
# 属于个人知识库分支
case _ if kb_config.SELF_KNOWLEDGE_BASE.match(knownledge) or knownledge == "coding":
doc_to_list(num,knownledge,doc_list,source_docs)
case _:
print(f"输入了没有的知识库名称")
return "输入了没有的知识库名称",source_docs
# num += len(source_docs[knownledge])
# 构建缓存对象用于h'per_query_cache'用于传递给其他方法使用uuid获取本轮对话的文献来源
# del num
# source = utils.get_shared_variable(uid)
# print(utils.get_shared_variable(uid))
# source["source_docs"]=source_docs
# utils.set_shared_variable(uid,source)
except Exception as e:
logging.error(f"Error in rag_search: {e}")
# return "入参格式需为{\"knowledge_name\":[\"XXX\",\"XXX\"],\"query\":\"XXX\",\"keywords\":[\"XXX\", \"XXX\", \"XXX\", \"XXX\"]}检查输入参数如果没有缺少必要值,当前工具异常请换其他工具"
return "当前工具异常!请换其他工具",source_docs
return result,source_docs
def knowledgebase_kgo_search(query: str, uid) -> List[str]:
try:
res = knowledgebase_kgo_iter(query,uid)
try:
if type(res[0])==list and type(res[1])==list:
return res
elif type(res[1])==list and len(res[1])>0:
res[0]=[]
return res
else:
temp = []
temp[0]=[]
temp[1]=[]
return temp
except Exception as e:
temp = {}
logging.error(f"No docs: {e}")
temp[0]=[]
temp[1]=[]
return temp
except json.JSONDecodeError:
# 如果JSON解码失败则返回错误消息
logging.error("Invalid JSON format in query.")
return "<关键指令>不需要再调用该工具了</关键指令>"
except KeyError:
# 如果解析的JSON对象中缺少必要的键则返回错误消息
return "请尝试调用其他工具"
except Exception as e:
# 捕获其他所有异常,并返回通用错误消息
return f"发生错误:{str(e)},请尝试调用其他工具"
def inner_duckduckgo_search(query: str, uuid:str,) :
logging.info(f"模型输入: {query}")
combined_result = asyncio.run(duckduckgo_search_iter(query, uuid, "y","default" ))
# 以标准json格式输出
logging.info("返回JSON格式的结果给到模型...")
return combined_result
def search_tool(query: str):
"""获取到uid并拆分query"""
if "<param>"in query:
query = query.replace("<param>","").replace("</param>","")
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
else:
return "<关键指令>当前工具不需要再调用</关键指令>"
time_based_uuid = json.loads(matches[1])["uuid"]
search = json.loads(query)
if type(search["query"])==list and len(search["query"])>0:
searches = search["query"][0]
elif type(search["query"])==list and len(search["query"]) == 0:
searches = ""
else:
searches = search["query"]
"""
根据用户输入的query返回rag搜索结果
"""
try:
with concurrent.futures.ThreadPoolExecutor() as executor:
# 提交任务并发执行
test = {}
test["num"]=0
test["source_docs"]=[]
test["END"] = ""
test["title"] = []
utils.set_shared_variable(time_based_uuid+"",test)
# future2 = executor.submit(knowledgebase_kgo_search,search["query"],time_based_uuid+"q")
future1 = executor.submit(rag_search,query,time_based_uuid)
# if not "type" in utils.get_shared_variable(time_based_uuid):
# future2 = executor.submit(knowledgebase_kgo_search,searches,time_based_uuid+"¥")
if not "type" in utils.get_shared_variable(time_based_uuid):
future2 = executor.submit(knowledgebase_kgo_search,searches,time_based_uuid+"")
result3 = []
# 获取结果
result1,sourcedocs = future1.result()
result2 = {}
if "type" in utils.get_shared_variable(time_based_uuid):
result2[0] =[]
result2[1] = []
else:
result2 = future2.result()
# if "type" in utils.get_shared_variable(time_based_uuid):
# result2[0] =[]
# result2[1] = []
# else:
# result2 = future2.result()
# result2[0] = []
# result2[1] = []
utils.remove_shared_variable(time_based_uuid+"q")
if type(result2[1]) == list:
if type(sourcedocs) == list:
sourcedocs.extend(result2[1])
else:
sourcedocs = []
if type(result1) == list:
result1.extend(result2[0])
result3 = result1
else:
result3 = result2[0]
logging.info(f"result2:{result2[1]}")
source = []
res=[]
sources = utils.get_shared_variable(time_based_uuid)
i = sources["num"]
num = sources["num"]
for result in sourcedocs:
try:
i+=1
res3 = re.sub(r'\[\d+\]', f"[{i}]", result.replace("\n",""), count=1)
if res3:
source.append(re.sub(r'\[\d+\]', f"[{i}]", result.replace("\n",""), count=1))
else:
i -= 1
except Exception as e:
i -= 1
pass
# internet_search_res = f"参考资料[{len(result1)+1}-{len(source)}]:{result2[0]}"
# internet_search_res = f"参考资料:{result2[0]}"
j = sources["num"]
for result in result3:
j+=1
res.append(re.sub(r'\[\d+\]', f"[{j}]", result, count=1))
print(utils.get_shared_variable(time_based_uuid))
# sources["source_docs"]=source
sources["source_docs"].extend(source)
sources["num"]=i
# sources["END"] = "ok"
utils.set_shared_variable(time_based_uuid,sources)
logging.info(f"result1:{result1},sourcedocs:{sourcedocs}")
logging.info(f"result2:{result2}")
logging.info(f"{res}")
if len(res) ==0 and len(source)==0:
return f"尝试调整入参重新调用知识库联想工具(同一个问题调用超过三次就不要再使用知识库联想工具了,浪费时间)"
return f"<关键指令>如果你在写文章禁止在非规定位置输出参考资料</关键指令>资料:{res}\n资料来源为:{source}\n 注意:如果你在根据大纲撰写文章,撰写中间部分章节禁止输出综上所述之类的影响文风的话,撰写中间部分禁止输出附录引用文献等!!!"
except Exception as e:
logging.error(f"Error occurred during search_tool execution.{e}")
return "同一个问题调用知识库联想工具超过5次就不要再调用知识库联想"

View File

@@ -0,0 +1,9 @@
# Langchain 自带的 YouTube 搜索工具封装
from langchain.tools import YouTubeSearchTool
from pydantic import BaseModel, Field
def search_youtube(query: str):
tool = YouTubeSearchTool()
return tool.run(tool_input=query)
class YoutubeInput(BaseModel):
location: str = Field(description="Query for Videos search")

View File

@@ -0,0 +1,10 @@
name: search_youtube
description: Use this tools to search youtube videos
parameters:
type: object
properties:
query:
type: string
description: Query for Videos search
required:
- query

View File

@@ -0,0 +1,9 @@
# LangChain 的 Shell 工具
from pydantic import BaseModel, Field
from langchain.tools import ShellTool
def shell(query: str):
tool = ShellTool()
return tool.run(tool_input=query)
class ShellInput(BaseModel):
query: str = Field(description="一个能在Linux命令行运行的Shell命令")

View File

@@ -0,0 +1,10 @@
name: shell
description: Use Linux Shell to execute Linux commands
parameters:
type: object
properties:
query:
type: string
description: The command to execute
required:
- query

View File

@@ -0,0 +1,49 @@
"""
更简单的单参数输入工具实现,用于查询现在天气的情况
"""
import json
import re
from pydantic import BaseModel, Field
import requests
from configs.kb_config import SENIVERSE_API_KEY
from server.chat import utils
def weather(location: str, api_key: str):
url = f"https://api.seniverse.com/v3/weather/daily.json?key={api_key}&location={location}&language=zh-Hans&unit=c&start=0&days=5"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
weather = {
"today": json.dumps(data["results"][0]["daily"][0]),
"others": json.dumps(data["results"][0]["daily"][1:])
}
weather_info = json.dumps(weather)
return weather_info
else:
raise Exception(
f"Failed to retrieve weather: {response.status_code}")
def weathercheck(query: str):
"""获取到uid并拆分query"""
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
else:
return "<关键指令>不需要再调用该工具了</关键指令>"
location = json.loads(query)["location"]
time_based_uuid = json.loads(matches[1])["uuid"]
# sources = utils.get_shared_variable(time_based_uuid)
# sources["source_docs"]=["天气 预报"]
# sources["num"]+=1
# sources["END"] = "ok"
# utils.set_shared_variable(time_based_uuid,sources)
return weather(location, SENIVERSE_API_KEY)
except Exception as e:
return f"Failed to retrieve weather.{e}"
class WeatherInput(BaseModel):
location: str = Field(description="City name,include city and county")

View File

@@ -0,0 +1,10 @@
name: weather_check
description: Use Weather API to get weather information
parameters:
type: object
properties:
query:
type: string
description: City name,include city and county,like "厦门市思明区"
required:
- query

View File

@@ -0,0 +1,11 @@
# Langchain 自带的 Wolfram Alpha API 封装
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
from pydantic import BaseModel, Field
wolfram_alpha_appid = "your key"
def wolfram(query: str):
wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=wolfram_alpha_appid)
ans = wolfram.run(query)
return ans
class WolframInput(BaseModel):
location: str = Field(description="需要运算的具体问题")

View File

@@ -0,0 +1,10 @@
name: wolfram
description: Useful for when you need to calculate difficult math formulas
parameters:
type: object
properties:
query:
type: string
description: The formula to be calculated
required:
- query

View File

@@ -0,0 +1,174 @@
from langchain.tools import Tool
from server.agent.tools import *
from server.agent.tools import rag_search1
from server.agent.tools import duckduckgo_search
from server.agent.tools.do_nothing import do_nothing, doNothingInput, get_next_tip
from server.agent.tools.draw_plot import create_and_save_plot, draw_ink_pic, draw_realistic_pic, drawInkInput, drawPlotInput, drawRealisticInput
from server.agent.tools.get_statistical_data import mysql_statistic
from server.agent.tools.math import code_count, math_count
from server.agent.tools.search_picture import search_pic
from server.agent.tools.search_tool import search_tool
# 请注意如果你是为了使用AgentLM在这里你应该使用英文版本。
tools = [
# Tool.from_function(
# func=calculate,
# name="calculate",
# description="Useful for when you need to answer questions about simple calculations",
# args_schema=CalculatorInput,
# ),
# Tool.from_function(
# func=arxiv,
# name="arxiv",
# description="A wrapper around Arxiv.org for the original English paper.",
# args_schema=ArxivInput,
# ),
# Tool.from_function(
# func=weathercheck,
# name="weather_check",
# description="",
# args_schema=WeatherInput,
# ),
# Tool.from_function(
# func=shell,
# name="shell",
# description="Use Shell to execute Linux commands",
# args_schema=ShellInput,
# ),
# Tool.from_function(
# func=search_knowledgebase_complex,
# name="search_knowledgebase_complex",
# description="Use this tool to check out local knowledgebase",
# args_schema=KnowledgeSearchInput,
# ),
# Tool.from_function(
# func=search_internet,
# name="search_internet",
# description="Use this tool to search the internet and retrieve information",
# args_schema=SearchInternetInput,
# ),
# Tool.from_function(
# func=wolfram,
# name="Wolfram",
# description="Useful for when you need to calculate difficult formulas",
# args_schema=WolframInput,
# ),
# Tool.from_function(
# func=search_youtube,
# name="search_youtube",
# description="use this tools to get videos",
# args_schema=YoutubeInput,
# ),
# Tool.from_function(
# func=chat_with_Yi34B,
# name="chat_with_Yi34B",
# description="Use this tool to chat",
# args_schema=ChatWithYi34BInput,
# ),
Tool.from_function(
func=search_tool,
name="知识库联想",
# description="\n【工具参数说明】\n 参数格式参数需以逗号分隔并以JSON格式提供\n 必需参数:\n 1. query查询内容。\n2.knowledge_name知识库名称例如 政策库、项目库、期刊论文库、冶金行业新闻库、冶金专业知识库、冶金行业报告库\\n3.keywords【使用指南】\n当需要利用文献资源辅助回答问题时请使用此工具。\n输入必须包含三个参数query、knowledge_name 和 keywords\nknowledge_name 必须是以下之一:政策库、期刊论文库、新闻库。\nkeywords 应尽可能多地涵盖相关领域,并确保与搜索主题高度相关。\n尽量提供中文关键词。\n示例\n{\"knowledge_name\":\"政策库\",\"keywords\": [\"technology\", \"era\"]\n确保所有参数以JSON格式提供以便工具能够正确解析和使用}",
description="\n【参数说明】:\n 参数格式参数必须使用JSON格式提供\n 必需参数:\n 1. query查询内容。\n2.knowledge_name知识库名称必须从【中国钢铁行业动态库、政策库、期刊论文库、冶金新闻库2024年以及之前、冶金中文期刊库、冶金外文期刊库、冶金OA期刊库、冶金行业新闻库、冶金专业知识库、冶金行业报告库、报告库、美术专业知识库】中选择\n3.keywords与搜索主题高度相关\n【使用指南】:\n当需要利用文献资源辅助回答问题时,请使用此工具。\n输入必须包含三个参数:{\"knowledge_name\":[\"XXX\",\"XXX\"],\"query\":\"XXX\",\"keywords\":[\"XXX\", \"XXX\", \"XXX\", \"XXX\"]}\n示例:\n{\"query\":\"人工智能的相关政策\",\"knowledge_name\":[\"政策库\"],\"keywords\": [\"人工智能\", \"国家级\"]}\n务必以JSON格式输入参数以便工具能够正确解析和使用。",
args_schema=RagSearchInput
),
Tool.from_function(
func=knowledgebase_kgo_search,
name="联网思索",
description="\n【参数说明】:参数格式以JSON格式提供参数为query\n query 为【user input】的原文例如\n{\"query\":\"人工智能是什么\"}\n注意禁止改写【user input】的内容",
# description="提示:联网思索工具仅供娱乐,名词解释,或知识库联想无法满足需求时才使用\n【参数说明】:参数格式以JSON格式提供参数为query\n query 为【user input】的原文例如\n{\"query\":\"人工智能是什么\"}\n注意禁止改写【user input】的内容",
args_schema=KnowledgeKgoInput
),
# Tool.from_function(
# func=duckduckgo_search,
# name="联网思索",
# description="【参数说明】:参数格式必须以JSON格式提供\n必需参数\n1.query查询内容\"粉末冶金\")。\n可选参数\n2.time查询时间范围例如 \"m\" 表示过去一个月,\"w\" 表示过去一周,不传则默认 \"d\"\n3.resource_type资源类型控制\n - \"default\" 或不传该参数:检索所有资源\n - \"video\":用户明确要求检索视频的时候,检索视频 - \"limit\"查询资料数量如果用户没有明确要求默认传3 \n\n【使用示例】:\n{\"query\":\"人工智能\",\"time\":\"m\",\"resource_type\":\"default\",\"limit\":3}\n确保参数以JSO-N格式提供以便正确解析和使用。",
# args_schema=DuckduckgoInput
# ),
# Tool.from_function(
# func=do_nothing,
# name="无需调用工具",
# description="不需要使用工具的时候调用本方法以获取指示",
# args_schema=doNothingInput
# ),
# Tool.from_function(
# func=get_next_tip,
# name="环节跳转",
# description="第一阶段结束获取第二阶段提示",
# args_schema=doNothingInput
# ),
Tool.from_function(
func=create_and_save_plot,
name="图表绘制",
description="注意本工具一次只能画一张图<关键指令>使用本工具后你必须按工具返回要求输出图片</关键指令>【参数说明】:参数以<param></param>标签包裹。必须提供以下参数格式: 必需参数:\n 1. data图表数据格式如下{\"XXX\": XX, \"XXX\": XX, \"XXX\": X, \"XXX\": X}其中XXX代表分类名称xx代表分类数据量。\n2.title图表标题3.xlabel横轴标题你按照分的几类属于哪个大类必须有\n4.ylabel纵轴标题你的数值数据是什么必须有\n 5.plot_type:图表类型必须从以下几个当中选一个作为输入【bar,pie,line】其中bar代表柱状图pie代表饼图line代表折线图你只能选一个作为输入\n【使用指南】:\n当需要生成图表时,请使用此工具。\n输入必须包含以下参数::<tool_input><param>{\n\"data\": {\"XXX\": XX, \"XXX\": XX, \"XXX\": X, \"XXX\": X},\"title\": \"X\",\"xlabel\": \"X\",\"ylabel\": \"X\",\"plot_type\": \"X\"}\n示例:{'\"data\": {\"Category A\": 23, \"Category B\": 17, \"Category C\": 35, \"Category D\": 29},\"title\": \"My Chart\",\"xlabel\": \"Category\",\"ylabel\": \"Value\",\"plot_type\": \"pie\"}</param></tool_input>\n请务必以json格式输入方便使用。\n",
args_schema=drawPlotInput
),
Tool.from_function(
func=math_count,
name="数学运算",
description="【参数说明】:参数格式以JSON格式提供参数为query\n query 为数学问题描述,例如:\n{\"query\":\"9.9和9.11谁大\"}\n",
args_schema=drawPlotInput
),
Tool.from_function(
func=code_count,
name="代码专家",
description="【参数说明】:参数格式以JSON格式提供参数为query\n query 为代码问题描述,例如:\n{\"query\":\"写一个es增量数据处理的脚本\"}\n",
args_schema=drawPlotInput
),
Tool.from_function(
func=weathercheck,
name="天气工具",
description="【参数说明】:注意仅支持三天内天气查询参数格式以JSON格式提供参数为location\n location 为查询天气的城市名称,例如:\n{\"location\":\"北京\"}\n",
args_schema=drawPlotInput
),
Tool.from_function(
func=search_pic,
name="美术作品获取",
description="【参数说明】:参数格式以JSON格式提供参数为query\n query 为查询美术作品的描述,务必是美术作品的描述,不要直接输入美术作品四个字而是如山水画写生画,草原,太阳等等,例如:\n{\"query\":\"山水画\"}\n需将该工具返回的图片链接以markdown格式给出",
args_schema=drawPlotInput
),
Tool.from_function(
func=mysql_statistic,
name="统计数据查询",
description="提示统计库数据最多只有2023年及之前的199几年的最多【参数说明】:参数格式以JSON格式提供并必须以<param></param>标签包裹。参数为query\n query 为详细的查询问题例如:\n<tool_input> <param>{\"query\":\"原油出口统计数据\"}</param> </tool_input>\n,需将返回数据做成数据表和图表",
args_schema=drawPlotInput
),
# Tool.from_function(
# func=draw_realistic_pic,
# name="实景绘制",
# description="【参数说明】:参数格式以JSON格式提供参数为query\n query 为图片需求描述,例如:\n{\"query\":\"画一张风景图\"}\n",
# args_schema=drawRealisticInput
# ),
# Tool.from_function(
# func=draw_ink_pic,
# name="水墨画绘制",
# description="【参数说明】:参数格式以JSON格式提供参数为query\n query 为图片需求描述,例如:\n{\"query\":\"画一个红苹果\"}\n",
# args_schema=drawInkInput
# ),
# Tool.from_function(
# func=rag_search1,
# name="rag_search",
# description="Use this tool to search for relevant information in the policy knowledge base, with the query is the user's questions ,with the knowledge base name being t_policy_total_bge_new_v2.",
# args_schema=RagSearchInput,
# ),
# Tool.from_function(
# func=policy_knowledgebase_search,
# name="policy_knowledgebase",
# description="Use this tool to search for policy knowledge base",
# args_schema=PolicyKnowledgeInput,
# ),
# Tool.from_function(
# func=report_knowledgebase_search,
# name="report_knowledgebase",
# description="Use this tool to search for report knowledge base",
# args_schema=ReportKnowledgeInput,
# ),
]
tool_names = [tool.name for tool in tools]
# 网络搜索工具
search_tool_names = [search_internet for tool in tools]