[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
4
langchain-chat/server/agent/__init__.py
Normal file
4
langchain-chat/server/agent/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .model_contain import *
|
||||
from .callbacks import *
|
||||
from .custom_template import *
|
||||
from .tools import *
|
||||
8
langchain-chat/server/agent/agent.py
Normal file
8
langchain-chat/server/agent/agent.py
Normal file
@@ -0,0 +1,8 @@
|
||||
|
||||
|
||||
class Agent(object):
|
||||
def __init__(self):
|
||||
self.step = ""
|
||||
self.knowledge = ""
|
||||
self.question = ""
|
||||
self.res =""
|
||||
101
langchain-chat/server/agent/agent_chat.py
Normal file
101
langchain-chat/server/agent/agent_chat.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import uuid
|
||||
from fastapi import Body
|
||||
from langchain.memory import (
|
||||
CombinedMemory,
|
||||
ConversationBufferMemory,
|
||||
ConversationSummaryMemory,
|
||||
ConversationBufferWindowMemory
|
||||
)
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from langchain.chains import LLMChain, ConversationChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
import json
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List, Optional, Union
|
||||
from server.chat.utils import History
|
||||
from langchain.prompts import PromptTemplate
|
||||
from server.utils import get_prompt_template, get_format_template
|
||||
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||||
from server.db.repository import add_message_to_db
|
||||
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
|
||||
from datetime import datetime
|
||||
from langchain_core.messages import SystemMessage
|
||||
import time as t
|
||||
from server.utils import replace_variables
|
||||
from configs.basic_config import *
|
||||
from configs.outline_config import outlines
|
||||
|
||||
async def agent_chat_new(
|
||||
user_prompt_name: Optional[str] = Body(None, description="用户输入"),
|
||||
query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
conversation_id: str = Body("", description="对话框ID"),
|
||||
history: Union[int, List[History]] = Body([], description="历史对话"),
|
||||
model_name: str = Body("default_model", description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.7, description="LLM 采样温度", ge=0.0, le=2.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量"),
|
||||
prompt_template: str = Body("default", description="使用的prompt模板内容"),
|
||||
stream: bool = Body(False, description="流式输出")
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
callbacks = [callback]
|
||||
time = datetime.now().strftime("%Y年%m月%d日")
|
||||
message_id = str(uuid.uuid1())+"q"
|
||||
|
||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
||||
max_tokens = None
|
||||
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
|
||||
|
||||
history = [History.from_data(h) for h in history]
|
||||
chat_prompt = PromptTemplate.from_template(prompt_template)
|
||||
# 把history转成memory
|
||||
buff_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="history", input_key="input")
|
||||
if len(history)>0:
|
||||
for message in history:
|
||||
# 检查消息的角色
|
||||
if message.role == 'user':
|
||||
# 添加用户消息
|
||||
buff_memory.chat_memory.add_user_message(message.content)
|
||||
elif message.role == 'assistant':
|
||||
# 添加AI消息
|
||||
buff_memory.chat_memory.add_ai_message(message.content)
|
||||
else:
|
||||
buff_memory.chat_memory.add_user_message("无")
|
||||
buff_memory.chat_memory.add_ai_message("无")
|
||||
background_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="time", input_key="input")
|
||||
message = SystemMessage(content = f'当前的时间是:{time}')
|
||||
background_memory.chat_memory.add_message(message)
|
||||
memory = CombinedMemory(memories=[background_memory, buff_memory])
|
||||
chain = ConversationChain(llm=model, verbose=True, memory=memory, prompt=chat_prompt)
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"input": query, "time": time}),
|
||||
callback.done),
|
||||
)
|
||||
|
||||
answer = ""
|
||||
async for token in callback.aiter():
|
||||
if stream:
|
||||
yield json.dumps({"text": token}, ensure_ascii=False)
|
||||
else:
|
||||
answer += token
|
||||
logger.info(f'solve_problem: {str(answer)}')
|
||||
|
||||
await task
|
||||
if stream:
|
||||
return
|
||||
else:
|
||||
yield json.dumps(
|
||||
{"text": answer, "message_id": message_id},
|
||||
ensure_ascii=False
|
||||
)
|
||||
161
langchain-chat/server/agent/callbacks.py
Normal file
161
langchain-chat/server/agent/callbacks.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from __future__ import annotations
|
||||
from uuid import UUID
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.schema import AgentFinish, AgentAction
|
||||
from langchain.schema.output import LLMResult
|
||||
|
||||
|
||||
def dumps(obj: Dict) -> str:
|
||||
return json.dumps(obj, ensure_ascii=False)
|
||||
|
||||
|
||||
class Status:
|
||||
start: int = 1
|
||||
running: int = 2
|
||||
complete: int = 3
|
||||
agent_action: int = 4
|
||||
agent_finish: int = 5
|
||||
error: int = 6
|
||||
tool_finish: int = 7
|
||||
|
||||
|
||||
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.queue = asyncio.Queue()
|
||||
self.done = asyncio.Event()
|
||||
self.cur_tool = {}
|
||||
self.out = True
|
||||
|
||||
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
|
||||
parent_run_id: UUID | None = None, tags: List[str] | None = None,
|
||||
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
|
||||
|
||||
# 对于截断不能自理的大模型,我来帮他截断
|
||||
stop_words = ["Observation:", "Thought","\"","(", "\n","\t"]
|
||||
for stop_word in stop_words:
|
||||
index = input_str.find(stop_word)
|
||||
if index != -1:
|
||||
input_str = input_str[:index]
|
||||
break
|
||||
|
||||
self.cur_tool = {
|
||||
"tool_name": serialized["name"],
|
||||
"input_str": input_str,
|
||||
"output_str": "",
|
||||
"status": Status.agent_action,
|
||||
"run_id": run_id.hex,
|
||||
"llm_token": "",
|
||||
"final_answer": "",
|
||||
"error": "",
|
||||
}
|
||||
# print("\nInput Str:",self.cur_tool["input_str"])
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
|
||||
tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||
self.out = True ## 重置输出
|
||||
self.cur_tool.update(
|
||||
status=Status.tool_finish,
|
||||
output_str=output.replace("Answer:", ""),
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
|
||||
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.error,
|
||||
error=str(error),
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
# async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
# if "Action" in token: ## 减少重复输出
|
||||
# before_action = token.split("Action")[0]
|
||||
# self.cur_tool.update(
|
||||
# status=Status.running,
|
||||
# llm_token=before_action + "\n",
|
||||
# )
|
||||
# self.queue.put_nowait(dumps(self.cur_tool))
|
||||
#
|
||||
# self.out = False
|
||||
#
|
||||
# if token and self.out:
|
||||
# self.cur_tool.update(
|
||||
# status=Status.running,
|
||||
# llm_token=token,
|
||||
# )
|
||||
# self.queue.put_nowait(dumps(self.cur_tool))
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
special_tokens = ["Action", "<|observation|>"]
|
||||
for stoken in special_tokens:
|
||||
if stoken in token:
|
||||
before_action = token.split(stoken)[0]
|
||||
self.cur_tool.update(
|
||||
status=Status.running,
|
||||
llm_token=before_action + "\n",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
self.out = False
|
||||
break
|
||||
|
||||
if token and self.out:
|
||||
self.cur_tool.update(
|
||||
status=Status.running,
|
||||
llm_token=token,
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.start,
|
||||
llm_token="",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.start,
|
||||
llm_token="",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.complete,
|
||||
llm_token="\n",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.error,
|
||||
error=str(error),
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_agent_finish(
|
||||
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# 返回最终答案
|
||||
self.cur_tool.update(
|
||||
status=Status.agent_finish,
|
||||
final_answer=finish.return_values["output"],
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
self.cur_tool = {}
|
||||
228
langchain-chat/server/agent/custom_agent/ChatGLM3Agent.py
Normal file
228
langchain-chat/server/agent/custom_agent/ChatGLM3Agent.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, List, Sequence, Tuple, Optional, Union
|
||||
from pydantic.schema import model_schema
|
||||
|
||||
|
||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
||||
from langchain.memory import ConversationBufferWindowMemory
|
||||
from langchain.agents.agent import Agent
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
|
||||
from langchain.agents.agent import AgentOutputParser
|
||||
from langchain.output_parsers import OutputFixingParser
|
||||
from langchain.pydantic_v1 import Field
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
||||
"""Output parser with retries for the structured chat agent."""
|
||||
|
||||
base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser)
|
||||
"""The base parser to use."""
|
||||
output_fixing_parser: Optional[OutputFixingParser] = None
|
||||
"""The output fixing parser to use."""
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
special_tokens = ["Action:", "<|observation|>"]
|
||||
first_index = min([text.find(token) if token in text else len(text) for token in special_tokens])
|
||||
text = text[:first_index]
|
||||
if "tool_call" in text:
|
||||
action_end = text.find("```")
|
||||
action = text[:action_end].strip()
|
||||
params_str_start = text.find("(") + 1
|
||||
params_str_end = text.rfind(")")
|
||||
params_str = text[params_str_start:params_str_end]
|
||||
|
||||
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
|
||||
params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs}
|
||||
|
||||
action_json = {
|
||||
"action": action,
|
||||
"action_input": params
|
||||
}
|
||||
else:
|
||||
action_json = {
|
||||
"action": "Final Answer",
|
||||
"action_input": text
|
||||
}
|
||||
action_str = f"""
|
||||
Action:
|
||||
```
|
||||
{json.dumps(action_json, ensure_ascii=False)}
|
||||
```"""
|
||||
try:
|
||||
if self.output_fixing_parser is not None:
|
||||
parsed_obj: Union[
|
||||
AgentAction, AgentFinish
|
||||
] = self.output_fixing_parser.parse(action_str)
|
||||
else:
|
||||
parsed_obj = self.base_parser.parse(action_str)
|
||||
return parsed_obj
|
||||
except Exception as e:
|
||||
raise OutputParserException(f"Could not parse LLM output: {text}") from e
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "structured_chat_ChatGLM3_6b_with_retries"
|
||||
|
||||
|
||||
class StructuredGLM3ChatAgent(Agent):
|
||||
"""Structured Chat Agent."""
|
||||
|
||||
output_parser: AgentOutputParser = Field(
|
||||
default_factory=StructuredChatOutputParserWithRetries
|
||||
)
|
||||
"""Output parser for the agent."""
|
||||
|
||||
@property
|
||||
def observation_prefix(self) -> str:
|
||||
"""Prefix to append the ChatGLM3-6B observation with."""
|
||||
return "Observation:"
|
||||
|
||||
@property
|
||||
def llm_prefix(self) -> str:
|
||||
"""Prefix to append the llm call with."""
|
||||
return "Thought:"
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
|
||||
if not isinstance(agent_scratchpad, str):
|
||||
raise ValueError("agent_scratchpad should be of type string.")
|
||||
if agent_scratchpad:
|
||||
return (
|
||||
f"This was your previous work "
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
f"you return as final answer):\n{agent_scratchpad}"
|
||||
)
|
||||
else:
|
||||
return agent_scratchpad
|
||||
|
||||
@classmethod
|
||||
def _get_default_output_parser(
|
||||
cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
|
||||
) -> AgentOutputParser:
|
||||
return StructuredChatOutputParserWithRetries(llm=llm)
|
||||
|
||||
@property
|
||||
def _stop(self) -> List[str]:
|
||||
return ["<|observation|>"]
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prompt: str = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tools_json = []
|
||||
tool_names = []
|
||||
for tool in tools:
|
||||
tool_schema = model_schema(tool.args_schema) if tool.args_schema else {}
|
||||
simplified_config_langchain = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool_schema.get("properties", {})
|
||||
}
|
||||
tools_json.append(simplified_config_langchain)
|
||||
tool_names.append(tool.name)
|
||||
formatted_tools = "\n".join([
|
||||
f"{tool['name']}: {tool['description']}, args: {tool['parameters']}"
|
||||
for tool in tools_json
|
||||
])
|
||||
formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}")
|
||||
template = prompt.format(tool_names=tool_names,
|
||||
tools=formatted_tools,
|
||||
history="None",
|
||||
input="{input}",
|
||||
agent_scratchpad="{agent_scratchpad}")
|
||||
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
_memory_prompts = memory_prompts or []
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(template),
|
||||
*_memory_prompts,
|
||||
]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
prompt: str = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prompt=prompt,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser or cls._get_default_output_parser(llm=llm)
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
output_parser=_output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def _agent_type(self) -> str:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def initialize_glm3_agent(
|
||||
tools: Sequence[BaseTool],
|
||||
llm: BaseLanguageModel,
|
||||
prompt: str = None,
|
||||
memory: Optional[ConversationBufferWindowMemory] = None,
|
||||
agent_kwargs: Optional[dict] = None,
|
||||
*,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
tags_ = list(tags) if tags else []
|
||||
agent_kwargs = agent_kwargs or {}
|
||||
agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
prompt=prompt,
|
||||
**agent_kwargs
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent_obj,
|
||||
tools=tools,
|
||||
memory=memory,
|
||||
tags=tags_,
|
||||
**kwargs,
|
||||
)
|
||||
162
langchain-chat/server/agent/custom_template.py
Normal file
162
langchain-chat/server/agent/custom_template.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import re
|
||||
from langchain.agents import Tool, AgentOutputParser
|
||||
from langchain.prompts import StringPromptTemplate
|
||||
from typing import List
|
||||
from langchain.schema import AgentAction, AgentFinish
|
||||
from configs.basic_config import *
|
||||
from configs import SUPPORT_AGENT_MODEL
|
||||
from server.agent import model_container
|
||||
from server.chat import utils
|
||||
from server.chat.knowledge_base_name import KnowledgeBase
|
||||
from collections import defaultdict
|
||||
class CustomPromptTemplate(StringPromptTemplate):
|
||||
template: str
|
||||
tools: List[Tool]
|
||||
|
||||
# def format(self, **kwargs) -> str:
|
||||
# intermediate_steps = kwargs.pop("intermediate_steps")
|
||||
# thoughts = ""
|
||||
# for action, observation in intermediate_steps:
|
||||
# thoughts += action.log
|
||||
# thoughts += f"\nObservation: {observation}\n"
|
||||
# kwargs["agent_scratchpad"] = thoughts
|
||||
# kwargs["tools"] = "\n".join([f"【工具名称】{tool.name}: 【工具描述】{tool.description}" for tool in self.tools])
|
||||
# kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
|
||||
# return self.template.format(**kwargs)
|
||||
|
||||
|
||||
|
||||
def format(self, **kwargs) -> str:
|
||||
# 确保 intermediate_steps 存在
|
||||
intermediate_steps = kwargs.pop("intermediate_steps", [])
|
||||
thoughts = ""
|
||||
for action, observation in intermediate_steps:
|
||||
thoughts += action.log
|
||||
thoughts += f"\nObservation: {observation}\n"
|
||||
|
||||
# 设置默认值以防止 KeyError
|
||||
kwargs["agent_scratchpad"] = thoughts
|
||||
kwargs["tools"] = "\n".join([f"【工具名称】{tool.name}: 【工具描述】{tool.description}" for tool in getattr(self, 'tools', [])])
|
||||
kwargs["tool_names"] = ", ".join([tool.name for tool in getattr(self, 'tools', [])])
|
||||
|
||||
# 使用正则表达式替换位置字段为命名字段
|
||||
def replace_positional(match):
|
||||
index = match.group(1)
|
||||
return f"{{arg{index}}}"
|
||||
|
||||
self.template = re.sub(r'\{(\d+)\}', replace_positional, self.template)
|
||||
|
||||
# 为所有占位符提供默认值
|
||||
placeholders = re.findall(r'\{(\w+)\}', self.template)
|
||||
for placeholder in placeholders:
|
||||
if placeholder not in kwargs:
|
||||
kwargs[placeholder] = f"<missing {placeholder}>"
|
||||
|
||||
# 确保所有占位符都有对应的值
|
||||
try:
|
||||
return self.template.format_map(kwargs)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing key in template formatting: {e}")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Format string contains positional fields: {e}")
|
||||
|
||||
|
||||
|
||||
class CustomOutputParser(AgentOutputParser):
|
||||
begin: bool = False
|
||||
knowledge_base_name: KnowledgeBase
|
||||
time_based_uuid:str
|
||||
def __init__(self, knowledge_base_name: KnowledgeBase,time_based_uuid:str):
|
||||
super().__init__(knowledge_base_name = knowledge_base_name,time_based_uuid=time_based_uuid)
|
||||
self.begin = True
|
||||
self.knowledge_base_name = knowledge_base_name
|
||||
self.time_based_uuid = time_based_uuid
|
||||
|
||||
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
|
||||
if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin:
|
||||
self.begin = False
|
||||
stop_words = ["Observation:"]
|
||||
min_index = len(llm_output)
|
||||
for stop_word in stop_words:
|
||||
index = llm_output.find(stop_word)
|
||||
if index != -1 and index < min_index:
|
||||
min_index = index
|
||||
llm_output = llm_output[:min_index]
|
||||
|
||||
if "Final Answer:" in llm_output:
|
||||
self.begin = True
|
||||
return AgentFinish(
|
||||
return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()},
|
||||
log=llm_output,
|
||||
)
|
||||
# print("llm_output>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",llm_output)
|
||||
# logger.info(f"llm_output: {llm_output}")
|
||||
parts = llm_output.split("Action:")
|
||||
# print("parts:", parts)
|
||||
if len(parts) < 2:
|
||||
return AgentFinish(
|
||||
# return_values={"output": f"以下内容如果不符合您的预期,可以试着换个方式问问我:\n\n {llm_output} "},
|
||||
return_values={"output": f" {llm_output} "},
|
||||
log=llm_output,
|
||||
)
|
||||
|
||||
# action = parts[1].split("Action Input:")[0].strip()
|
||||
match = re.search(r'Action:\s*(.+)', llm_output)
|
||||
try:
|
||||
action = match.group(1).strip().strip('\'').replace("【工具名称】","")
|
||||
# print("parts[1]>>>>",parts[1])
|
||||
# if "\nAction Input" not in parts:
|
||||
action_input = parts[1].split("Action Input:")[1].strip()
|
||||
action_input = action_input.replace("【调用结束】","")
|
||||
action_input = action_input.replace("调用结束","")
|
||||
except:
|
||||
action_input = "无"
|
||||
# END = utils.get_shared_variable(self.time_based_uuid)
|
||||
# END["END"] = "ok"
|
||||
false_tokens = ["无","None"]
|
||||
# use_tools = ["联网思索","知识库联想","图表绘制","实景绘制","水墨画绘制"]
|
||||
use_tools = ["联网思索","知识库联想","图表绘制","数学运算","代码专家","天气工具","美术作品获取","统计数据查询"]
|
||||
|
||||
# utils.set_shared_variable(self.time_based_uuid,END)
|
||||
print("action_input: ",action_input)
|
||||
try:
|
||||
if "knowledge_name" in json.loads(action_input):
|
||||
match = re.search(r'\{.*?\}', action_input,re.DOTALL)
|
||||
query = match.group(0)
|
||||
self.knowledge_base_name.name = json.loads(query)["knowledge_name"]
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
if not (any(false_token == action.strip("").replace("【","").replace("】","") for false_token in false_tokens)) and any(tool_temp in action for tool_temp in use_tools):
|
||||
ans = AgentAction(
|
||||
tool=action.replace("【","").replace("】",""),
|
||||
tool_input=action_input.strip(" ").strip('"')+"{\"uuid\":\""+self.time_based_uuid+"\"}",
|
||||
log=llm_output
|
||||
)
|
||||
return ans
|
||||
elif action.strip("").replace("【","").replace("】","") == "环节跳转":
|
||||
ans = AgentAction(
|
||||
tool="环节跳转",
|
||||
tool_input=self.time_based_uuid,
|
||||
log=llm_output
|
||||
)
|
||||
return ans
|
||||
else:
|
||||
ans = AgentAction(
|
||||
tool="无需调用工具",
|
||||
tool_input=self.time_based_uuid,
|
||||
log=llm_output
|
||||
)
|
||||
# END = utils.get_shared_variable(self.time_based_uuid)
|
||||
# END["END"] = "ok"
|
||||
# utils.set_shared_variable(self.time_based_uuid,END)
|
||||
return ans
|
||||
except:
|
||||
# END = utils.get_shared_variable(self.time_based_uuid)
|
||||
# END["END"] = "ok"
|
||||
return AgentFinish(
|
||||
return_values={"output": f"调用agent失败: `{llm_output}`"},
|
||||
log=llm_output,
|
||||
)
|
||||
6
langchain-chat/server/agent/model_contain.py
Normal file
6
langchain-chat/server/agent/model_contain.py
Normal file
@@ -0,0 +1,6 @@
|
||||
class ModelContainer:
|
||||
def __init__(self):
|
||||
self.MODEL = None
|
||||
self.DATABASE = None
|
||||
|
||||
model_container = ModelContainer()
|
||||
16
langchain-chat/server/agent/tools/__init__.py
Normal file
16
langchain-chat/server/agent/tools/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
## 导入所有的工具类
|
||||
# from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput
|
||||
# from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput
|
||||
# from .chat_with_Yi34B import chat_with_Yi34B, ChatWithYi34BInput
|
||||
# from .search_youtube import search_youtube, YoutubeInput
|
||||
from .calculate import calculate, CalculatorInput
|
||||
from .weather_check import weathercheck, WeatherInput
|
||||
from .shell import shell, ShellInput
|
||||
from .search_internet import search_internet, SearchInternetInput
|
||||
from .wolfram import wolfram, WolframInput
|
||||
from .arxiv import arxiv, ArxivInput
|
||||
from .knowledgebase_kgo_search import knowledgebase_kgo_search, KnowledgeKgoInput
|
||||
from .policy_knowledgebase_search import policy_knowledgebase_search, PolicyKnowledgeInput
|
||||
from .report_knowledgebase_search import report_knowledgebase_search, ReportKnowledgeInput
|
||||
from .rag_search import rag_search1, RagSearchInput
|
||||
from .duckduckgo_search import duckduckgo_search, DuckduckgoInput
|
||||
9
langchain-chat/server/agent/tools/arxiv.py
Normal file
9
langchain-chat/server/agent/tools/arxiv.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# LangChain 的 ArxivQueryRun 工具
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.tools.arxiv.tool import ArxivQueryRun
|
||||
def arxiv(query: str):
|
||||
tool = ArxivQueryRun()
|
||||
return tool.run(tool_input=query)
|
||||
|
||||
class ArxivInput(BaseModel):
|
||||
query: str = Field(description="The search query title")
|
||||
10
langchain-chat/server/agent/tools/arxiv.yaml
Normal file
10
langchain-chat/server/agent/tools/arxiv.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
name: arxiv
|
||||
description: A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: The search query title
|
||||
required:
|
||||
- query
|
||||
76
langchain-chat/server/agent/tools/calculate.py
Normal file
76
langchain-chat/server/agent/tools/calculate.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.chains import LLMMathChain
|
||||
from server.agent import model_container
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
_PROMPT_TEMPLATE = """
|
||||
将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。
|
||||
问题: ${{包含数学问题的问题。}}
|
||||
```text
|
||||
${{解决问题的单行数学表达式}}
|
||||
```
|
||||
...numexpr.evaluate(query)...
|
||||
```output
|
||||
${{运行代码的输出}}
|
||||
```
|
||||
答案: ${{答案}}
|
||||
|
||||
这是两个例子:
|
||||
|
||||
问题: 37593 * 67是多少?
|
||||
```text
|
||||
37593 * 67
|
||||
```
|
||||
...numexpr.evaluate("37593 * 67")...
|
||||
```output
|
||||
2518731
|
||||
|
||||
答案: 2518731
|
||||
|
||||
问题: 37593的五次方根是多少?
|
||||
```text
|
||||
37593**(1/5)
|
||||
```
|
||||
...numexpr.evaluate("37593**(1/5)")...
|
||||
```output
|
||||
8.222831614237718
|
||||
|
||||
答案: 8.222831614237718
|
||||
|
||||
|
||||
问题: 2的平方是多少?
|
||||
```text
|
||||
2 ** 2
|
||||
```
|
||||
...numexpr.evaluate("2 ** 2")...
|
||||
```output
|
||||
4
|
||||
|
||||
答案: 4
|
||||
|
||||
|
||||
现在,这是我的问题:
|
||||
问题: {question}
|
||||
"""
|
||||
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
class CalculatorInput(BaseModel):
|
||||
query: str = Field()
|
||||
|
||||
def calculate(query: str):
|
||||
model = model_container.MODEL
|
||||
llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT)
|
||||
ans = llm_math.run(query)
|
||||
return ans
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = calculate("2的三次方")
|
||||
print("答案:",result)
|
||||
|
||||
|
||||
|
||||
10
langchain-chat/server/agent/tools/calculate.yaml
Normal file
10
langchain-chat/server/agent/tools/calculate.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
name: calculate
|
||||
description: Useful for when you need to answer questions about simple calculations
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: The formula to be calculated
|
||||
required:
|
||||
- query
|
||||
43
langchain-chat/server/agent/tools/chat_with_Yi34B.py
Normal file
43
langchain-chat/server/agent/tools/chat_with_Yi34B.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from server.chat.chat import chat
|
||||
from server.chat.utils import History
|
||||
|
||||
|
||||
async def chat_with_Yi34B_iter(query: str,
|
||||
stream=False,
|
||||
model_name="qianfan-api",
|
||||
history: Union[int, List[History]] = None,
|
||||
conversation_id='',
|
||||
temperature=0.7,
|
||||
max_tokens=None,
|
||||
history_len=3,
|
||||
prompt_name="default"
|
||||
):
|
||||
response = await chat(query=query, history=history,
|
||||
history_len=history_len,
|
||||
conversation_id=conversation_id,
|
||||
stream=stream, model_name=model_name, temperature=temperature,
|
||||
max_tokens=max_tokens, prompt_name=prompt_name)
|
||||
|
||||
contents = ""
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents += data["text"]
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
def chat_with_Yi34B(query: str, model_name: str = "qianfan-api", conversation_id: str = '',
|
||||
history: Union[int, List[History]] = None):
|
||||
# 格式化查询字符串
|
||||
return asyncio.run(chat_with_Yi34B_iter(query, model_name=model_name, conversation_id=conversation_id,
|
||||
history=history))
|
||||
|
||||
|
||||
class ChatWithYi34BInput(BaseModel):
|
||||
location: str = Field(description="Query for any kind of chats and questions")
|
||||
18
langchain-chat/server/agent/tools/chat_with_Yi34B.yaml
Normal file
18
langchain-chat/server/agent/tools/chat_with_Yi34B.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
name: chat_with_Yi34B
|
||||
description: Use this tool to chat with human
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: Query for any kind of chat and questions
|
||||
model_name:
|
||||
type: string
|
||||
description:
|
||||
conversation_id:
|
||||
type: string
|
||||
description:
|
||||
required:
|
||||
- query
|
||||
- model_name
|
||||
- conversation_id
|
||||
34
langchain-chat/server/agent/tools/do_nothing.py
Normal file
34
langchain-chat/server/agent/tools/do_nothing.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import json
|
||||
import re
|
||||
import concurrent
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from langchain.tools import YouTubeSearchTool
|
||||
from pydantic import BaseModel, Field
|
||||
from server.chat import utils
|
||||
from server.chat.utils import doc_to_list, solve_knowledge_map,solve_mental_data,shared_variable
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from configs import kb_config
|
||||
|
||||
|
||||
def do_nothing(query: str):
|
||||
"""
|
||||
什么都不做不调用工具设置
|
||||
"""
|
||||
# 以下三行逻辑控制反问,如果不需要反问注释掉即可,但可能会带来的问题包括该agent将始终会使用工具
|
||||
|
||||
return f"\n不需要调用工具了"
|
||||
|
||||
def get_next_tip(query: str):
|
||||
"""
|
||||
什么都不做不调用工具设置
|
||||
"""
|
||||
# 以下三行逻辑控制反问,如果不需要反问注释掉即可,但可能会带来的问题包括该agent将始终会使用工具
|
||||
res = utils.get_shared_variable(query)
|
||||
res["END"] = "ok"
|
||||
utils.set_shared_variable(query,res)
|
||||
|
||||
return f"\n提示:你已经使用过环节跳转了,可以开始输出正文了"
|
||||
|
||||
class doNothingInput(BaseModel):
|
||||
query: str = Field(...,description="查询对象")
|
||||
175
langchain-chat/server/agent/tools/draw_plot.py
Normal file
175
langchain-chat/server/agent/tools/draw_plot.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import requests
|
||||
from matplotlib import pyplot as plt
|
||||
from pydantic import BaseModel, Field
|
||||
from configs.model_config import LLM_MODELS
|
||||
from server.chat import utils
|
||||
from server.chat.policy_fun_iast import get_llm_model_response
|
||||
from configs.kb_config import GENERATED_IMAGES_BASE_PATH, realistic_url,ink_url
|
||||
from matplotlib import font_manager
|
||||
|
||||
my_font = font_manager.FontProperties(fname="/usr/share/fonts/MicroSoft-YaHei/MSYH.TTC")
|
||||
def create_and_save_plot(query:str) -> str:
|
||||
try:
|
||||
query = query.replace(" ","").replace("'","\"")
|
||||
json_str ='{\n"data": {"XXX": XX, "XXX": XX, "XXX": X, "XXX": X},"title": "X","xlabel": "X","ylabel": "X","plot_type": "X"}'
|
||||
datas = {}
|
||||
try:
|
||||
match = re.search(r'<param>(.*?)</param>(\{.*\})', query.replace("\n", ""))
|
||||
if match:
|
||||
query = match.group(1).strip()
|
||||
datas = json.loads(query)
|
||||
else:
|
||||
print(f"Invalid JSON format in query:\n{query}")
|
||||
return"暂时无法画图"
|
||||
except:
|
||||
query = get_llm_model_response(
|
||||
strategy_name="query rewrite",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="check_plot",
|
||||
prompt_param_dict={"user_input": query,"json":json_str },
|
||||
temperature=0.01,
|
||||
max_tokens=512
|
||||
)
|
||||
re.search(r"```json\n(.*?)\n```", query.replace("\n", ""), re.DOTALL)
|
||||
query = match.group(1).strip()
|
||||
datas = json.loads(query)
|
||||
data = datas["data"]
|
||||
xlabel = datas["xlabel"]
|
||||
ylabel = datas["ylabel"]
|
||||
title = datas["title"]
|
||||
plot_type = datas["plot_type"]
|
||||
# 分析和汇总数据
|
||||
categories = list(data.keys())
|
||||
values = list(data.values())
|
||||
|
||||
# 创建图表
|
||||
plt.figure(figsize=(10, 6))
|
||||
|
||||
if plot_type == 'bar':
|
||||
plt.bar(categories, values, color='skyblue')
|
||||
elif plot_type == 'pie':
|
||||
plt.pie(values, labels=categories, autopct='%1.1f%%', startangle=140,textprops={'fontproperties': my_font})
|
||||
elif plot_type == 'line':
|
||||
plt.plot(categories, values, marker='o', linestyle='-')
|
||||
else:
|
||||
raise ValueError("Unsupported plot type. Choose from 'bar', 'pie', or 'line'.")
|
||||
|
||||
# 添加标题和标签
|
||||
plt.title(title,fontproperties=my_font)
|
||||
if plot_type != 'pie': # 饼状图不需要轴标签
|
||||
plt.xlabel(xlabel,fontproperties=my_font)
|
||||
plt.ylabel(ylabel,fontproperties=my_font)
|
||||
plt.xticks(fontproperties=my_font,rotation=45)
|
||||
plt.yticks(fontproperties=my_font)
|
||||
|
||||
namesid = uuid.uuid1()
|
||||
# 保存图表为图片文件
|
||||
file_path = f'{GENERATED_IMAGES_BASE_PATH}/plot{namesid}.png'
|
||||
absolute_path = os.path.abspath(file_path)
|
||||
# sources = utils.get_shared_variable(uuids)
|
||||
# sources["url"] = f"plot{namesid}.png"
|
||||
# sources["END"]="ok"
|
||||
# utils.set_shared_variable(uuids, sources)
|
||||
# plt.figure(figsize=(50, 60))
|
||||
# plt.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.15)
|
||||
plt.tight_layout()
|
||||
plt.savefig(absolute_path)
|
||||
plt.close()
|
||||
image_name = f"plot{namesid}.png"
|
||||
|
||||
return f"图片如下:"
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
return "不要再调用该工具了,无法提供该功能,你只能按工具描述提供有的入参,其他不支持"
|
||||
|
||||
|
||||
class drawPlotInput(BaseModel):
|
||||
query: str = Field(...,description="输入要画图的内容")
|
||||
|
||||
def draw_realistic_pic(query:str) -> str:
|
||||
try:
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
uuid = json.loads(matches[1])["uuid"]
|
||||
query = query.replace(matches[1],"")
|
||||
else:
|
||||
return"暂时无法画图"
|
||||
except:
|
||||
logging.error("Invalid JSON format in query.")
|
||||
|
||||
# 请求体数据
|
||||
data = {
|
||||
'prompt': json.loads(query)["query"]
|
||||
}
|
||||
|
||||
try:
|
||||
# 发起 POST 请求
|
||||
response = requests.post(realistic_url, json=data)
|
||||
|
||||
# 检查响应状态码
|
||||
if response.status_code == 200:
|
||||
# 解析 JSON 响应
|
||||
result = response.json()
|
||||
print("Image path:", result.get('image_path'))
|
||||
file_path = result.get('image_path')
|
||||
sources = utils.get_shared_variable(uuid)
|
||||
sources["url"] = file_path
|
||||
# sources["END"]="ok"
|
||||
utils.set_shared_variable(uuid, sources)
|
||||
return f"<关键指令>你必须以markdown格式的图片路径,,之后你也需要按照要求给出更多的markdown格式的数据表格</关键指令>"
|
||||
else:
|
||||
print("Failed to get response:", response.status_code)
|
||||
return "暂时无法画图"
|
||||
except requests.exceptions.RequestException as e:
|
||||
print("An error occurred:", e)
|
||||
return "暂时无法画图"
|
||||
|
||||
class drawRealisticInput(BaseModel):
|
||||
query: str = Field(...,description="输入要画图的内容")
|
||||
|
||||
def draw_ink_pic(query:str) -> str:
|
||||
get_llm_model_response()
|
||||
try:
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
uuid = json.loads(matches[1])["uuid"]
|
||||
query = query.replace(matches[1],"")
|
||||
else:
|
||||
return"暂时无法画图"
|
||||
except:
|
||||
logging.error("Invalid JSON format in query.")
|
||||
|
||||
# 请求体数据
|
||||
data = {
|
||||
'prompt': json.loads(query)["query"]
|
||||
}
|
||||
|
||||
try:
|
||||
# 发起 POST 请求
|
||||
response = requests.post(ink_url, json=data)
|
||||
|
||||
# 检查响应状态码
|
||||
if response.status_code == 200:
|
||||
# 解析 JSON 响应
|
||||
result = response.json()
|
||||
print("Image path:", result.get('image_path'))
|
||||
file_path = result.get('image_path')
|
||||
sources = utils.get_shared_variable(uuid)
|
||||
sources["url"] = file_path
|
||||
# sources["END"]="ok"
|
||||
utils.set_shared_variable(uuid, sources)
|
||||
return f"<关键指令>你必须以markdown格式的图片路径,,之后你也需要按照要求给出更多的markdown格式的数据表格</关键指令>"
|
||||
else:
|
||||
print("Failed to get response:", response.status_code)
|
||||
return "暂时无法画图"
|
||||
except requests.exceptions.RequestException as e:
|
||||
print("An error occurred:", e)
|
||||
return "暂时无法画图"
|
||||
|
||||
class drawInkInput(BaseModel):
|
||||
query: str = Field(...,description="输入要画图的内容")
|
||||
186
langchain-chat/server/agent/tools/duckduckgo_search.py
Normal file
186
langchain-chat/server/agent/tools/duckduckgo_search.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import asyncio
|
||||
import re
|
||||
import aiohttp
|
||||
import json
|
||||
import logging
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from server.chat import utils
|
||||
|
||||
# 配置日志记录器
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def duckduckgo_search_iter(query: str, uuid: str = "",time: str = "", resource_type: str = None, limit: int = 3):
|
||||
# 定义三个API的URL
|
||||
text_url = 'http://43.251.225.121/inspur/search_text'
|
||||
video_url = 'http://43.251.225.121/inspur/search_video'
|
||||
news_url = 'http://43.251.225.121/inspur/search_new'
|
||||
|
||||
payload = {
|
||||
"query": query,
|
||||
"time": time
|
||||
}
|
||||
|
||||
async def fetch(session, url, json_payload,limit):
|
||||
logger.info(f"从 {url} 获取数据,请求参数: {json_payload}")
|
||||
try:
|
||||
json_payload["limit"] = limit
|
||||
async with session.post(url, json=json_payload) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"向 {url} 请求失败,状态码 {response.status}")
|
||||
data = await response.json()
|
||||
logger.info(f"从 {url} 获取的资料数: {len(data) if isinstance(data, list) else '未知'}")
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"获取 {url} 数据时发生错误: {e}")
|
||||
return []
|
||||
|
||||
# 根据 resource_type 确定要请求的 API
|
||||
# 默认并发请求三个API
|
||||
# 视频只请求 video_url
|
||||
# 新闻只请求 news_url
|
||||
# 其他类型只请求 text_url
|
||||
async with aiohttp.ClientSession() as session:
|
||||
logger.info("发起请求duckduckgo...")
|
||||
|
||||
n = limit % 3
|
||||
limit1 = 0
|
||||
limit2 = 0
|
||||
limit3 = 0
|
||||
match n:
|
||||
case 0:
|
||||
limit1 = limit//3
|
||||
limit2 = limit1
|
||||
limit3 = limit1
|
||||
case 1:
|
||||
limit1 = limit//3 +1
|
||||
limit2 = limit//3
|
||||
limit3 = limit2
|
||||
case 2:
|
||||
limit1 = limit//3 +1
|
||||
limit2 = limit1
|
||||
limit2 = limit
|
||||
|
||||
if resource_type is None or not resource_type == 'video':
|
||||
text_task = asyncio.create_task(fetch(session, text_url, payload,limit1))
|
||||
video_task = asyncio.create_task(fetch(session, video_url, payload, limit3))
|
||||
news_task = asyncio.create_task(fetch(session, news_url, payload, limit2))
|
||||
text_result, video_result, news_result = await asyncio.gather(text_task, video_task, news_task)
|
||||
logger.info("合并结果...")
|
||||
|
||||
logger.info("合并结果完成")
|
||||
combined_result = {
|
||||
"text": text_result,
|
||||
"video": video_result,
|
||||
"news": news_result
|
||||
}
|
||||
|
||||
else:
|
||||
video_result = await fetch(session, video_url, payload, limit)
|
||||
combined_result = {
|
||||
"video": video_result
|
||||
}
|
||||
del limit1,limit2,limit3
|
||||
# elif resource_type == 'news':
|
||||
# news_result = await fetch(session, news_url, payload)
|
||||
# combined_result = {
|
||||
# "news": news_result
|
||||
# }
|
||||
|
||||
# else: # 其他类型
|
||||
# text_result = await fetch(session, text_url, payload)
|
||||
# combined_result = {
|
||||
# "text": text_result
|
||||
# }
|
||||
|
||||
logger.info("请求已完成")
|
||||
res = []
|
||||
source = []
|
||||
info = utils.get_shared_variable(uuid)
|
||||
index = info["num"]
|
||||
if "text" in combined_result:
|
||||
for item in combined_result["text"]:
|
||||
index += 1
|
||||
res.append(f'资料[{index}] 资料标题{item["title"]}({item["href"]}) 资料内容为: {item["body"]}')
|
||||
source.append(f'资料[{index}] [{item["title"]}]({item["href"]})')
|
||||
if "video" in combined_result:
|
||||
for item in combined_result["video"]:
|
||||
index += 1
|
||||
res.append(f'资料[{index}] 视频标题[{item["title"]}]({item["content"]}) 视频内容为: {item["description"]}')
|
||||
source.append(f'视频资料[{index}] [{item["title"]}]({item["content"]})')
|
||||
if "news" in combined_result:
|
||||
for item in combined_result["news"]:
|
||||
index += 1
|
||||
res.append(f'资料[{index}] 新闻标题[{item["title"]}]({item["url"]}) 新闻内容为: {item["body"]}')
|
||||
source.append(f'资料[{index}] [{item["title"]}]({item["url"]})')
|
||||
info["source_docs"].extend(source)
|
||||
utils.set_shared_variable(uuid, info)
|
||||
return res,source
|
||||
|
||||
|
||||
def duckduckgo_search(query: str, time: str = "", resource_type: str = None):
|
||||
logger.info(f"模型输入: {query}")
|
||||
# 对传入的 query 字段进行解析
|
||||
# 判断 query 是否包含 "}{"
|
||||
# if "}{" in query:
|
||||
# # 将 query 分割为两个JSON字符串
|
||||
# split_index = query.find("}{")
|
||||
# json_part1 = query[:split_index+1]
|
||||
# json_part2 = query[split_index+1:]
|
||||
|
||||
# try:
|
||||
# obj1 = json.loads(json_part1)
|
||||
# obj2 = json.loads(json_part2)
|
||||
|
||||
# # 提取 query, resource_type, time, uuid
|
||||
# parsed_query = obj1.get("query", "")
|
||||
# parsed_resource_type = obj1.get("resource_type", None)
|
||||
# parsed_time = obj1.get("time", time) # 如obj1未包含time则使用传入的默认值
|
||||
# parsed_uuid = obj2.get("uuid", "")
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
query = matches[0]
|
||||
else:
|
||||
return "<关键指令>不需要再调用该工具了</关键指令>"
|
||||
try:
|
||||
obj1= json.loads(query)
|
||||
parsed_query = obj1.get("query", "")
|
||||
parsed_limit = obj1.get("limit", 3)
|
||||
parsed_resource_type = obj1.get("resource_type", None)
|
||||
parsed_time = obj1.get("time", time) # 如obj1未包含time则使用传入的默认值
|
||||
parsed_uuid = json.loads(matches[1])["uuid"]
|
||||
# 将解析到的值覆盖原有的参数
|
||||
query = parsed_query if parsed_query else query
|
||||
resource_type = parsed_resource_type if parsed_resource_type else resource_type
|
||||
time = parsed_time if parsed_time else time
|
||||
|
||||
logger.info(f"解析完成,query: {query}, uuid: {parsed_uuid}, time: {time}, resource_type: {resource_type}, parsed_limit: {parsed_limit}")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析JSON出错: {e}")
|
||||
|
||||
# 在同步环境中运行异步函数
|
||||
combined_result = asyncio.run(duckduckgo_search_iter(query, parsed_uuid, time, resource_type, parsed_limit))
|
||||
# 以标准json格式输出
|
||||
logger.info("返回JSON格式的结果给到模型...")
|
||||
return combined_result
|
||||
class DuckduckgoInput(BaseModel):
|
||||
location: str = Field(description="网络搜索查询")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试调用
|
||||
# 1. 默认请求三个API
|
||||
# result_default = duckduckgo_search("粉末冶金", "m", "default")
|
||||
# print("duckduckgo输出(默认):\n", result_default)
|
||||
|
||||
# # 2. 只请求视频
|
||||
# result_video = duckduckgo_search("粉末冶金", "m", "video")
|
||||
# print("duckduckgo输出(视频):\n", result_video)
|
||||
|
||||
# # 3. 只请求新闻
|
||||
# result_news = duckduckgo_search("粉末冶金", "m", "news")
|
||||
# print("duckduckgo输出(新闻):\n", result_news)
|
||||
|
||||
# 4. 其它类型只请求文本
|
||||
result_other = duckduckgo_search("粉末冶金", "m", "other")
|
||||
print("duckduckgo输出(其他):\n", result_other)
|
||||
57
langchain-chat/server/agent/tools/get_statistical_data.py
Normal file
57
langchain-chat/server/agent/tools/get_statistical_data.py
Normal file
@@ -0,0 +1,57 @@
|
||||
|
||||
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
import requests
|
||||
|
||||
from configs.model_config import LLM_MODELS
|
||||
from server.chat import utils
|
||||
from server.chat.policy_fun_iast import get_llm_model_response
|
||||
|
||||
|
||||
def mysql_statistic(query:str) -> str:
|
||||
try:
|
||||
logging.info(f"\n🔍 统计工具查询query: \n{query}\n")
|
||||
matches = re.search(r'<param>(.*?)</param>(\{.*\})', query.replace("\n", ""))
|
||||
if matches:
|
||||
uuid = json.loads(matches.group(2))["uuid"]
|
||||
query = matches.group(1).strip()
|
||||
else:
|
||||
res1 = utils.get_shared_variable(uuid)
|
||||
# res1["END"]="ok"
|
||||
utils.set_shared_variable(uuid,res1)
|
||||
return"暂时无法查询"
|
||||
except:
|
||||
res = utils.get_shared_variable(uuid)
|
||||
# res["END"]="ok"
|
||||
utils.set_shared_variable(uuid,res)
|
||||
logging.error("Invalid JSON format in query.")
|
||||
return f"暂时无法查询"
|
||||
|
||||
question = json.loads(query)["query"]
|
||||
# question = get_llm_model_response(
|
||||
# strategy_name="default_code",
|
||||
# llm_model_name=LLM_MODELS[2],
|
||||
# template_prompt_name="sql_query_rewrite",
|
||||
# prompt_param_dict={"query": question,"time": datetime.now().strftime("%Y%m%d")},
|
||||
# temperature=0.01,
|
||||
# max_tokens=512
|
||||
# )
|
||||
logging.info(f"\n🔍 NL2SQL检索question: \n{question}\n")
|
||||
res = requests.post(
|
||||
url=f"http://127.0.0.1:6008/query",
|
||||
json={"question": question},
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
if res:
|
||||
data = res.json()["result"]
|
||||
if "'data': []" in data:
|
||||
return f"统计库未检索到数据,使用“联网思索”工具检索该请求:{question}\n"
|
||||
# temp = utils.get_shared_variable(uuid)
|
||||
# temp["END"]="ok"
|
||||
# utils.set_shared_variable(uuid,temp)
|
||||
else:
|
||||
return f"判断得到的数据是否准确,如果不准确,则使用“联网思索”工具检索。如果准确,则根据数据表格并使用图表绘制工具制图。\n 数据如下所示: \n{data}\n"
|
||||
170
langchain-chat/server/agent/tools/knowledgebase_kgo_search.py
Normal file
170
langchain-chat/server/agent/tools/knowledgebase_kgo_search.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Any, Union
|
||||
import concurrent
|
||||
from pydantic import BaseModel, Field
|
||||
from difflib import SequenceMatcher
|
||||
from configs import (VECTOR_SEARCH_TOP_K,
|
||||
SCORE_THRESHOLD,
|
||||
DEFAULT_POLICY_BASE)
|
||||
from server.agent.tools import search_internet
|
||||
from server.chat import utils
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.utils import BaseResponse
|
||||
|
||||
|
||||
class KnowledgeKgoInput(BaseModel):
|
||||
location: str = Field(description="Query for Internet search")
|
||||
|
||||
|
||||
|
||||
def preprocess_text(text: str) -> str:
|
||||
# 去除空格和特殊符号
|
||||
text = re.sub(r'[\s\W]', '', text.strip())
|
||||
return text
|
||||
|
||||
|
||||
def knowledge_temperature(a: str, b: str) -> float:
|
||||
# 使用difflib中的SequenceMatcher计算相似度
|
||||
return SequenceMatcher(None, a, b).ratio()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# def knowledgebase_kgo_iter(query: str,
|
||||
# fileName: List = [],
|
||||
# knowledge_base_name: str = DEFAULT_POLICY_BASE,
|
||||
# top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
# score_threshold: float = SCORE_THRESHOLD) -> BaseResponse | list[str] | Any:
|
||||
# kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
# if kb is None:
|
||||
# return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
# query = query.strip()
|
||||
|
||||
# docs = search_docs(fileName=fileName,
|
||||
# query=query,
|
||||
# knowledge_base_name=knowledge_base_name,
|
||||
# top_k=top_k,
|
||||
# score_threshold=score_threshold)
|
||||
|
||||
# # 预处理查询文本
|
||||
# processed_query = preprocess_text(query).replace("Observ","")
|
||||
# print("processed_query:", processed_query)
|
||||
# knowledge_docs = []
|
||||
# knowledge_content = []
|
||||
# # 知识库返回的文档与query的相似度
|
||||
# if docs:
|
||||
# for enum, doc in enumerate(docs):
|
||||
# filename = doc.metadata.get("title")
|
||||
# detail_url = "https://policy.ckcest.cn/detail/" + doc.metadata.get("primary_key") + ".html"
|
||||
# if filename:
|
||||
# text = f"""政策类资料[{enum + 1}]: [{filename}]({detail_url})\n"""
|
||||
# else:
|
||||
# text = f"""政策类资料[{enum + 1}]: [{"原文地址"}]({detail_url})\n"""
|
||||
# knowledge_docs.append(text)
|
||||
# # print("knowledge_docs:", knowledge_docs)
|
||||
# knowledge_content = [doc.page_content for doc in docs]
|
||||
# # print("knowledge_content:", knowledge_content)
|
||||
# # 计算知识库返回的文档与query的相似度
|
||||
# titles = [doc.metadata.get("title") for doc in docs]
|
||||
# print("titles:", titles)
|
||||
# def check_similarity_threshold(titles: List[str], query: str, knowledge_docs: List[str], knowledge_content: List[str]) -> Union[
|
||||
# List[str], None]:
|
||||
# # 用于记录是否存在相似度大于0.55的标题
|
||||
# has_similar_title = False
|
||||
# for title in titles:
|
||||
# processed_title = preprocess_text(title)
|
||||
# similarity = knowledge_temperature(processed_query, processed_title)
|
||||
# print("processed_title:", processed_title)
|
||||
# print("similarity:", similarity)
|
||||
# if similarity >= 0.55:
|
||||
# has_similar_title = True
|
||||
# break
|
||||
# # 如果存在相似度大于0.55的标题,则直接返回 knowledge_docs
|
||||
# if has_similar_title:
|
||||
# knowledge = knowledge_content + knowledge_docs
|
||||
# return knowledge
|
||||
# # 如果所有标题的相似度都不大于0.55,则返回 None
|
||||
# return None
|
||||
|
||||
# # 在原函数中使用新的函数进行相似度阈值的判断
|
||||
# similar_docs = check_similarity_threshold(titles, query, knowledge_docs, knowledge_content)
|
||||
# if similar_docs is None:
|
||||
# # 如果所有标题的相似度都不大于0.55,则执行搜索引擎查询
|
||||
# kgo_docs = search_internet(processed_query)
|
||||
# # print("kgo_docs", kgo_docs)
|
||||
# return kgo_docs
|
||||
# else:
|
||||
# kgo_docs = search_internet(processed_query)
|
||||
# # print("similar_docs", similar_docs)
|
||||
# # print("kgo_docs", kgo_docs)
|
||||
# similar_docs.extend(kgo_docs)
|
||||
# return similar_docs
|
||||
# else:
|
||||
# # 执行搜索引擎查询
|
||||
# kgo_docs = search_internet(query)
|
||||
# return kgo_docs
|
||||
|
||||
def knowledgebase_kgo_iter(query: str, uid: str) -> BaseResponse | list[str] | Any:
|
||||
kgo_docs = search_internet(query , uid)
|
||||
|
||||
return kgo_docs
|
||||
def knowledgebase_kgo_search(query: str) -> List[str]:
|
||||
try:
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
query = matches[0]
|
||||
parsed_query = json.loads(query)
|
||||
# 继续使用解析后的查询进行后续操作
|
||||
time_based_uuid = json.loads(matches[1])["uuid"]
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(knowledgebase_kgo_iter, parsed_query["query"], time_based_uuid)
|
||||
res = future.result()
|
||||
# res = knowledgebase_kgo_iter(parsed_query["query"],time_based_uuid)
|
||||
try:
|
||||
if type(res[0])==list and len(res[0])>0:
|
||||
doc_content = "资料内容"
|
||||
for doc in res[0]:
|
||||
doc_content += doc
|
||||
doc_content += "资料来源"
|
||||
for source in res[1]:
|
||||
doc_content += source
|
||||
return f"{doc_content}"
|
||||
elif type(res[1])==list and len(res[1])>0:
|
||||
doc_content += "资料来源"
|
||||
for source in res[1]:
|
||||
doc_content += source
|
||||
return "只有标题没有内容,标题为:{doc_content}"
|
||||
else:
|
||||
# return "<system>不要再调用工具了</system>"
|
||||
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
||||
except Exception as e:
|
||||
logging.error(f"Error occurred while processing query: {e}")
|
||||
# return "<system>不要再调用该工具了,根据已有资料或自身能力回答</system>"
|
||||
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
||||
else:
|
||||
logging.error("Invalid JSON format in query.")
|
||||
# return "<system>不要再调用该工具了</system>"
|
||||
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
||||
except json.JSONDecodeError:
|
||||
# 如果JSON解码失败,则返回错误消息
|
||||
logging.error("Invalid JSON format in query.")
|
||||
# return "<system>不要再调用该工具了</system>"
|
||||
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
||||
except KeyError:
|
||||
# 如果解析的JSON对象中缺少必要的键,则返回错误消息
|
||||
# return "<system>不要再调用该工具了</system>"
|
||||
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
||||
except Exception as e:
|
||||
# 捕获其他所有异常,并返回通用错误消息
|
||||
logging.error(f"Error occurred while processing query: {e}")
|
||||
# return f"<system>不要再调用该工具了</system>"
|
||||
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = knowledgebase_kgo_iter("《区块链和分布式记账技术标准体系建设指南》")
|
||||
print("检索结果:", result)
|
||||
113
langchain-chat/server/agent/tools/math.py
Normal file
113
langchain-chat/server/agent/tools/math.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs.model_config import LLM_MODELS
|
||||
from server.agent.tools.search_tool import search_tool
|
||||
from server.chat import utils
|
||||
from server.chat.policy_fun_iast import get_llm_model_response
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def math_count(query: str):
|
||||
try:
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
query = matches[0]
|
||||
parsed_query = json.loads(query)["query"]
|
||||
# 继续使用解析后的查询进行后续操作
|
||||
time_based_uuid = json.loads(matches[1])["uuid"]
|
||||
# temp = utils.get_shared_variable(time_based_uuid)
|
||||
# temp["END"] ="ok"
|
||||
# utils.set_shared_variable(time_based_uuid,temp)
|
||||
|
||||
tip = {}
|
||||
# tip["END"]="ok"
|
||||
tip["source_docs"]=[]
|
||||
tip["num"]=0
|
||||
tip["title"]=[]
|
||||
|
||||
utils.set_shared_variable(time_based_uuid+"q",tip)
|
||||
first_json = {
|
||||
"query": parsed_query,
|
||||
"knowledge_name": [],
|
||||
"keywords": []
|
||||
}
|
||||
second_json = {
|
||||
"uuid": time_based_uuid+"q"
|
||||
}
|
||||
math_doc = search_tool(json.dumps(first_json) + json.dumps(second_json))
|
||||
utils.remove_shared_variable(time_based_uuid+"q")
|
||||
res = get_llm_model_response(
|
||||
strategy_name="default_math",
|
||||
llm_model_name=LLM_MODELS[3],
|
||||
template_prompt_name="default_math",
|
||||
prompt_param_dict={"input": parsed_query, "math_doc": f"{math_doc}", "time": datetime.now().strftime("%Y%m%d")},
|
||||
temperature=0.01,
|
||||
max_tokens=512
|
||||
)
|
||||
return f"{res}"
|
||||
except Exception as e:
|
||||
logging.error(f"Error occurred while processing math query: {e}")
|
||||
return "<system>不要再调用该工具了</system>"
|
||||
|
||||
except Exception as e:
|
||||
return "<system>不要再调用该工具了</system>"
|
||||
|
||||
def code_count(query: str):
|
||||
try:
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
query = matches[0]
|
||||
parsed_query = json.loads(query)["query"]
|
||||
# 继续使用解析后的查询进行后续操作
|
||||
time_based_uuid = json.loads(matches[1])["uuid"]
|
||||
# temp = utils.get_shared_variable(time_based_uuid)
|
||||
# temp["END"] ="ok"
|
||||
# utils.set_shared_variable(time_based_uuid,temp)
|
||||
|
||||
tip = {}
|
||||
# tip["END"]="ok"
|
||||
tip["source_docs"]=[]
|
||||
tip["num"]=0
|
||||
tip["title"]=[]
|
||||
|
||||
utils.set_shared_variable(time_based_uuid+"q",tip)
|
||||
first_json = {
|
||||
"query": parsed_query,
|
||||
"knowledge_name": [],
|
||||
"keywords": []
|
||||
}
|
||||
second_json = {
|
||||
"uuid": time_based_uuid+"q"
|
||||
}
|
||||
code_doc = search_tool(json.dumps(first_json) + json.dumps(second_json))
|
||||
utils.remove_shared_variable(time_based_uuid+"q")
|
||||
res = get_llm_model_response(
|
||||
strategy_name="default_code",
|
||||
llm_model_name=LLM_MODELS[2],
|
||||
template_prompt_name="default_code",
|
||||
prompt_param_dict={"input": parsed_query, "code_doc": f"{code_doc}", "time": datetime.now().strftime("%Y%m%d")},
|
||||
temperature=0.01,
|
||||
max_tokens=512
|
||||
)
|
||||
res = res.replace("<think>","")
|
||||
return f"{res}"
|
||||
except Exception as e:
|
||||
logging.error(f"Error occurred while processing math query: {e}")
|
||||
return "<system>不要再调用该工具了</system>"
|
||||
|
||||
except Exception as e:
|
||||
return "<system>不要再调用该工具了</system>"
|
||||
|
||||
|
||||
|
||||
|
||||
class RagSearchInput(BaseModel):
|
||||
query: str = Field(...,description="查询对象")
|
||||
@@ -0,0 +1,42 @@
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS, LLM_MODELS
|
||||
import json
|
||||
import asyncio
|
||||
from typing import List, Tuple, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PolicyKnowledgeInput(BaseModel):
|
||||
location: str = Field(description="The policy related query to be searched")
|
||||
|
||||
|
||||
async def policy_knowledgebase_search_iter(query: str) -> tuple[str | Any, list[Any] | Any]:
|
||||
response = await knowledge_base_chat(query=query,
|
||||
fileName=None,
|
||||
knowledge_base_name_list = ["t_policy_total_bge_new_v1"],
|
||||
model_name=LLM_MODELS[0],
|
||||
temperature=0.01,
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
max_tokens=MAX_TOKENS,
|
||||
prompt_name="default",
|
||||
score_threshold=SCORE_THRESHOLD,
|
||||
stream=False)
|
||||
|
||||
contents = ""
|
||||
docs = []
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
print("data>>>>>", data)
|
||||
contents = data["answer"] if "answer" in data.keys() else contents
|
||||
docs = data["docs"] if "docs" in data.keys() else docs
|
||||
return contents, docs
|
||||
|
||||
|
||||
def policy_knowledgebase_search(query: str) -> tuple[str | Any, list[Any] | Any]:
|
||||
return asyncio.run(policy_knowledgebase_search_iter(query))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = policy_knowledgebase_search("大数据男女比例")
|
||||
print("答案:", result)
|
||||
108
langchain-chat/server/agent/tools/rag_search.py
Normal file
108
langchain-chat/server/agent/tools/rag_search.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import json
|
||||
import re
|
||||
import concurrent
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from langchain.tools import YouTubeSearchTool
|
||||
from pydantic import BaseModel, Field
|
||||
from server.chat import utils
|
||||
from server.chat.utils import doc_to_list, solve_knowledge_map,solve_mental_data,shared_variable
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from configs import kb_config
|
||||
|
||||
|
||||
def rag_search1(query: str):
|
||||
"""
|
||||
根据用户输入的query,返回rag搜索结果
|
||||
"""
|
||||
try:
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
query = matches[0]
|
||||
else:
|
||||
return "<关键指令>不需要再调用该工具了</关键指令>"
|
||||
time_based_uuid = json.loads(matches[1])["uuid"]
|
||||
search = json.loads(query)
|
||||
search_query = search["query"]
|
||||
search_keywords = []
|
||||
search_text = f"{search_query}"
|
||||
if type(search["keywords"]) == list:
|
||||
search_keywords = search["keywords"]
|
||||
for keyword in search_keywords:
|
||||
search_text += f" {keyword}"
|
||||
else:
|
||||
search_keywords = search["keywords"].split(",")
|
||||
for keyword in search_keywords:
|
||||
search_text += f" {keyword}"
|
||||
result = []
|
||||
source_docs = {}
|
||||
knownledge_name = []
|
||||
if type(search["knowledge_name"]) == list:
|
||||
knownledge_name=search["knowledge_name"]
|
||||
else:
|
||||
knownledge_name=search["knowledge_name"].split(",")
|
||||
for knownledge in knownledge_name:
|
||||
if not knownledge in kb_config.CH_BASE_NAME:
|
||||
knownledge_name.remove(knownledge)
|
||||
if len(knownledge_name)==0:
|
||||
result.append(f"没有找到匹配的知识库,请必须更换联网思索搜索更多知识库内容")
|
||||
return result
|
||||
# knownledge_name=kb_config.CH_BASE_NAME
|
||||
|
||||
knownledge_name=solve_knowledge_map(knownledge_name)
|
||||
num = 0
|
||||
for knownledge in knownledge_name:
|
||||
source_docs[knownledge] = []
|
||||
seen_docs = set()
|
||||
duplicate_indices = []
|
||||
doc_list = search_docs(usr_query=search_text,fileName= [],top_k=5,score_threshold=0.9,query=search_text, knowledge_base_name=knownledge)
|
||||
|
||||
|
||||
for inum,doc in enumerate(doc_list):
|
||||
solve_mental_data(knownledge,doc_list,doc=doc,seen_docs=seen_docs,duplicate_indices=duplicate_indices,knowledge=result,inum=inum)
|
||||
|
||||
# 从policydocs中删除重复的文档(从后往前删除以防止索引错位)
|
||||
for index in sorted(duplicate_indices, reverse=True):
|
||||
del doc_list[index]
|
||||
# 处理原文来源进入数组。使用开关语句明确各个条件分支
|
||||
match knownledge:
|
||||
# 属于政策库分支,入参为中文政策库名称
|
||||
case kb_config.DEFAULT_POLICY_BASE:
|
||||
doc_to_list(num,kb_config.DEFAULT_POLICY_BASE_NAME,doc_list,source_docs[knownledge])
|
||||
# 属于期刊论文库分支,入参为期刊论文库的中文名称
|
||||
case kb_config.DEFAULT_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.DEFAULT_JOURNAL_BASE_NAME,doc_list,source_docs[knownledge])
|
||||
# 属于报告库分支,入参为报告库中文名称
|
||||
case kb_config.DEFAULT_REPORT_BASE:
|
||||
doc_to_list(num,kb_config.DEFAULT_REPORT_BASE_NAME,doc_list,source_docs[knownledge])
|
||||
# 属于冶金行业新闻库分支,入参为冶金行业新闻库中文名称
|
||||
case kb_config.GY_NEWS_BASE:
|
||||
doc_to_list(num,kb_config.GY_NEWS_BASE_NAME,doc_list,source_docs[knownledge])
|
||||
# 属于冶金行业报告库分支,入参为冶金行业报告库中文名称
|
||||
case kb_config.GY_REPORT_BASE:
|
||||
doc_to_list(num,kb_config.GY_REPORT_BASE_NAME,doc_list,source_docs[knownledge])
|
||||
# 属于冶金专业知识库分支,入参为冶金专业知识库中文名称
|
||||
case kb_config.GY_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.GY_JOURNAL_BASE_NAME,doc_list,source_docs[knownledge])
|
||||
case _:
|
||||
print(f"输入了没有的知识库名称")
|
||||
return("输入了没有的知识库名称")
|
||||
num += len(source_docs[knownledge])
|
||||
# 构建缓存对象用于h'per_query_cache'用于传递给其他方法使用uuid获取本轮对话的文献来源
|
||||
del num
|
||||
source = utils.get_shared_variable(time_based_uuid)
|
||||
print(utils.get_shared_variable(time_based_uuid))
|
||||
source["source_docs"]=source_docs
|
||||
utils.set_shared_variable(time_based_uuid,source)
|
||||
if 0<len(result)<3:
|
||||
return f"当前资料:{result}\n<关键指令>搜索结果较少,更换知识库或联网思索重新搜索!!!</关键指令>"
|
||||
if len(result)==0:
|
||||
return "注意:【指令:更换知识库或联网思索继续搜索!!!】"
|
||||
except:
|
||||
# return "入参格式需为{\"knowledge_name\":[\"XXX\",\"XXX\"],\"query\":\"XXX\",\"keywords\":[\"XXX\", \"XXX\", \"XXX\", \"XXX\"]}检查输入参数如果没有缺少必要值,当前工具异常请换其他工具"
|
||||
return "当前工具异常!请换其他工具"
|
||||
|
||||
return f"当前资料:{result}\n<关键指令>总结此内容!!!</关键指令>"
|
||||
|
||||
class RagSearchInput(BaseModel):
|
||||
query: str = Field(...,description="查询对象")
|
||||
@@ -0,0 +1,42 @@
|
||||
from server.chat.report_chat import report_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS, LLM_MODELS, DEFAULT_REPORT_BASE
|
||||
import json
|
||||
import asyncio
|
||||
from typing import List, Tuple, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ReportKnowledgeInput(BaseModel):
|
||||
location: str = Field(description="The report related query to be searched")
|
||||
|
||||
|
||||
async def report_knowledgebase_search_iter(query: str) -> tuple[str | Any, list[Any] | Any]:
|
||||
response = await report_chat(query=query,
|
||||
fileName=None,
|
||||
knowledge_base_name=DEFAULT_REPORT_BASE,
|
||||
model_name=LLM_MODELS[0],
|
||||
temperature=0.01,
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
max_tokens=MAX_TOKENS,
|
||||
prompt_name="default",
|
||||
score_threshold=SCORE_THRESHOLD,
|
||||
stream=False)
|
||||
|
||||
contents = ""
|
||||
docs = []
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
print("data>>>>>", data)
|
||||
contents = data["answer"] if "answer" in data.keys() else contents
|
||||
docs = data["docs"] if "docs" in data.keys() else docs
|
||||
return contents, docs
|
||||
|
||||
|
||||
def report_knowledgebase_search(query: str) -> tuple[str | Any, list[Any] | Any]:
|
||||
return asyncio.run(report_knowledgebase_search_iter(query))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = report_knowledgebase_search("大数据男女比例")
|
||||
print("答案:", result)
|
||||
89
langchain-chat/server/agent/tools/search_internet.py
Normal file
89
langchain-chat/server/agent/tools/search_internet.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import json
|
||||
import asyncio
|
||||
import unicodedata
|
||||
|
||||
from server.chat.KgoSearchAPIWrapper import KgoSearchAPIWrapper
|
||||
from server.chat.search_engine_chat import search_engine_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, MAX_TOKENS
|
||||
from server.agent import model_container
|
||||
from pydantic import BaseModel, Field
|
||||
from configs import LLM_MODELS, TEMPERATURE
|
||||
from configs.basic_config import *
|
||||
|
||||
# def get_kgo_search_type(query: str = "全部"):
|
||||
# # 过滤掉所有非汉字的字符
|
||||
# query = ''.join(char for char in query if unicodedata.category(char) == 'Lo')
|
||||
# search_map = KgoSearchAPIWrapper().search_map
|
||||
|
||||
# if "论文" in query:
|
||||
# return "1001"
|
||||
# elif "外文" in query or "英文" in query:
|
||||
# return "1013"
|
||||
# elif "期刊" in query or "研究进展" in query:
|
||||
# return "1002"
|
||||
# else:
|
||||
# matched_types = [value for key, value in search_map.items() if key in query]
|
||||
# if matched_types:
|
||||
# return ','.join(matched_types)
|
||||
# else:
|
||||
# print("未找到匹配的搜索类型,返回默认值:1000")
|
||||
# return "1000"
|
||||
|
||||
@timing_decorator
|
||||
async def search_engine_iter(query: str , uid: str):
|
||||
response = await search_engine_chat(uid = uid,
|
||||
query=query,
|
||||
search_engine_name="zhipu_search",
|
||||
model_name=LLM_MODELS[1],
|
||||
temperature=TEMPERATURE, # Agent搜索互联网的时候,温度设为0.1
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
max_tokens=MAX_TOKENS,
|
||||
prompt_name="search",
|
||||
stream=False,
|
||||
kgo_search_type="1000")
|
||||
|
||||
contents = ""
|
||||
docs = []
|
||||
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents = data.get("answer", [])
|
||||
current_docs = data.get("docs", [])
|
||||
if current_docs:
|
||||
docs.extend(current_docs)
|
||||
else:
|
||||
logging.error("No docs found in the response")
|
||||
return docs
|
||||
# print("contents:", contents)
|
||||
# print("docs:", docs)
|
||||
# 回复搜索结果和搜索结果的总结
|
||||
return contents,docs
|
||||
|
||||
|
||||
def search_internet(query: str, uid: str):
|
||||
# filter_words = {
|
||||
# "统计数据", "视频", "数据集", "新闻", "专利", "期刊", "图书", "报告", "项目", "成果", "会议论文",
|
||||
# "政策", "外文期刊论文", "学位论文", "期刊论文", "全部论文", "原文", "全文", "pdf", "资料", " ",
|
||||
# "进展", "研究", "最新", "外文", "英文", "最新", "文件", "资料", "论文"
|
||||
# }
|
||||
# filtered_query = query
|
||||
# for word in filter_words:
|
||||
# if word in query:
|
||||
# filtered_query = filtered_query.replace(word, "")
|
||||
# print("filtered query:", filtered_query)
|
||||
|
||||
# kgo_search_type = get_kgo_search_type(query)
|
||||
# print("kgo_search_type:", kgo_search_type)
|
||||
|
||||
# 使用过滤后的查询字符串进行搜索
|
||||
return asyncio.run(search_engine_iter(query, uid))
|
||||
|
||||
|
||||
class SearchInternetInput(BaseModel):
|
||||
location: str = Field(description="Query for Internet search")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = search_internet("人工智能领域的政策")
|
||||
print("答案:", result)
|
||||
15
langchain-chat/server/agent/tools/search_internet.yaml
Normal file
15
langchain-chat/server/agent/tools/search_internet.yaml
Normal file
@@ -0,0 +1,15 @@
|
||||
name: search_internet
|
||||
description: Use this tool to surf internet and get information
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: Query for Internet search
|
||||
kgo_search_type:
|
||||
type: int
|
||||
description: the return value 'kgo_search_type' of the 'get_kgo_search_type'
|
||||
default: 1000
|
||||
required:
|
||||
- query
|
||||
- kgo_search_type
|
||||
@@ -0,0 +1,294 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import re
|
||||
import warnings
|
||||
from typing import Dict
|
||||
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from typing import List, Any, Optional
|
||||
from langchain.prompts import PromptTemplate
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
|
||||
import asyncio
|
||||
from server.agent import model_container
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
async def search_knowledge_base_iter(database: str, query: str) -> str:
|
||||
response = await knowledge_base_chat(query=query,
|
||||
knowledge_base_name=database,
|
||||
model_name=model_container.MODEL.model_name,
|
||||
temperature=0.01,
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
max_tokens=MAX_TOKENS,
|
||||
prompt_name="default",
|
||||
score_threshold=SCORE_THRESHOLD,
|
||||
stream=False)
|
||||
|
||||
contents = ""
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents += data["answer"]
|
||||
docs = data["docs"]
|
||||
return contents
|
||||
|
||||
|
||||
async def search_knowledge_multiple(queries) -> List[str]:
|
||||
# queries 应该是一个包含多个 (database, query) 元组的列表
|
||||
tasks = [search_knowledge_base_iter(database, query) for database, query in queries]
|
||||
results = await asyncio.gather(*tasks)
|
||||
# 结合每个查询结果,并在每个查询结果前添加一个自定义的消息
|
||||
combined_results = []
|
||||
for (database, _), result in zip(queries, results):
|
||||
message = f"\n查询到 {database} 知识库的相关信息:\n{result}"
|
||||
combined_results.append(message)
|
||||
|
||||
return combined_results
|
||||
|
||||
|
||||
def search_knowledge(queries) -> str:
|
||||
responses = asyncio.run(search_knowledge_multiple(queries))
|
||||
# 输出每个整合的查询结果
|
||||
contents = ""
|
||||
for response in responses:
|
||||
contents += response + "\n\n"
|
||||
return contents
|
||||
|
||||
|
||||
_PROMPT_TEMPLATE = """
|
||||
用户会提出一个需要你查询知识库的问题,你应该对问题进行理解和拆解,并在知识库中查询相关的内容。
|
||||
|
||||
对于每个知识库,你输出的内容应该是一个一行的字符串,这行字符串包含知识库名称和查询内容,中间用逗号隔开,不要有多余的文字和符号。你可以同时查询多个知识库,下面这个例子就是同时查询两个知识库的内容。
|
||||
|
||||
例子:
|
||||
|
||||
robotic,机器人男女比例是多少
|
||||
bigdata,大数据的就业情况如何
|
||||
|
||||
|
||||
这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能,你应该参考他们的功能来帮助你思考
|
||||
|
||||
|
||||
{database_names}
|
||||
|
||||
你的回答格式应该按照下面的内容,请注意```text 等标记都必须输出,这是我用来提取答案的标记。
|
||||
不要输出中文的逗号,不要输出引号。
|
||||
|
||||
Question: ${{用户的问题}}
|
||||
|
||||
```text
|
||||
${{知识库名称,查询问题,不要带有任何除了,之外的符号,比如不要输出中文的逗号,不要输出引号}}
|
||||
|
||||
```output
|
||||
数据库查询的结果
|
||||
|
||||
现在,我们开始作答
|
||||
问题: {question}
|
||||
"""
|
||||
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question", "database_names"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
class LLMKnowledgeChain(LLMChain):
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
"""[Deprecated] Prompt to use to translate to python if necessary."""
|
||||
database_names: Dict[str, str] = None
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
|
||||
"Please instantiate with llm_chain argument or using the from_llm "
|
||||
"class method."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
prompt = values.get("prompt", PROMPT)
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _evaluate_expression(self, queries) -> str:
|
||||
try:
|
||||
output = search_knowledge(queries)
|
||||
except Exception as e:
|
||||
output = "输入的信息有误或不存在知识库,错误信息如下:\n"
|
||||
return output + str(e)
|
||||
return output
|
||||
|
||||
def _process_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
run_manager: CallbackManagerForChainRun
|
||||
) -> Dict[str, str]:
|
||||
|
||||
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
|
||||
llm_output = llm_output.strip()
|
||||
# text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
expression = text_match.group(1).strip()
|
||||
cleaned_input_str = (expression.replace("\"", "").replace("“", "").
|
||||
replace("”", "").replace("```", "").strip())
|
||||
lines = cleaned_input_str.split("\n")
|
||||
# 使用逗号分割每一行,然后形成一个(数据库,查询)元组的列表
|
||||
|
||||
try:
|
||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
||||
except:
|
||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
||||
run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", verbose=self.verbose)
|
||||
output = self._evaluate_expression(queries)
|
||||
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
elif "Answer:" in llm_output:
|
||||
answer = llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
return {self.output_key: f"输入的格式不对:\n {llm_output}"}
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _aprocess_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> Dict[str, str]:
|
||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
|
||||
expression = text_match.group(1).strip()
|
||||
cleaned_input_str = (
|
||||
expression.replace("\"", "").replace("“", "").replace("”", "").replace("```", "").strip())
|
||||
lines = cleaned_input_str.split("\n")
|
||||
try:
|
||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
||||
except:
|
||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
||||
await run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue",
|
||||
verbose=self.verbose)
|
||||
|
||||
output = self._evaluate_expression(queries)
|
||||
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
elif "Answer:" in llm_output:
|
||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||
return {self.output_key: answer}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_run_manager.on_text(inputs[self.input_key])
|
||||
self.database_names = model_container.DATABASE
|
||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||
llm_output = self.llm_chain.predict(
|
||||
database_names=data_formatted_str,
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return self._process_llm_result(llm_output, _run_manager)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
await _run_manager.on_text(inputs[self.input_key])
|
||||
self.database_names = model_container.DATABASE
|
||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||
llm_output = await self.llm_chain.apredict(
|
||||
database_names=data_formatted_str,
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_knowledge_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> LLMKnowledgeChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
|
||||
def search_knowledgebase_complex(query: str):
|
||||
model = model_container.MODEL
|
||||
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
|
||||
ans = llm_knowledge.run(query)
|
||||
return ans
|
||||
|
||||
class KnowledgeSearchInput(BaseModel):
|
||||
location: str = Field(description="The query to be searched")
|
||||
|
||||
|
||||
class RagSearchInput(BaseModel):
|
||||
query: str = Field(description="查询对象")
|
||||
knowledge_name: str = Field(description="The name of the knowledge base to be searched,policy knowledge base name is t_policy_total_bge_new_v2, example: t_policy_total_bge_new_v2]")
|
||||
keywords: str = Field(description="The keywords to be searched example: age,child]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = search_knowledgebase_complex("机器人和大数据在代码教学上有什么区别")
|
||||
print(result)
|
||||
|
||||
# 这是一个正常的切割
|
||||
# queries = [
|
||||
# ("bigdata", "大数据专业的男女比例"),
|
||||
# ("robotic", "机器人专业的优势")
|
||||
# ]
|
||||
# result = search_knowledge(queries)
|
||||
# print(result)
|
||||
@@ -0,0 +1,10 @@
|
||||
name: search_knowledgebase_complex
|
||||
description: Use this tool to search local knowledgebase and get information
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: The query to be searched
|
||||
required:
|
||||
- query
|
||||
234
langchain-chat/server/agent/tools/search_knowledgebase_once.py
Normal file
234
langchain-chat/server/agent/tools/search_knowledgebase_once.py
Normal file
@@ -0,0 +1,234 @@
|
||||
from __future__ import annotations
|
||||
import re
|
||||
import warnings
|
||||
from typing import Dict
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from typing import List, Any, Optional
|
||||
from langchain.prompts import PromptTemplate
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
|
||||
|
||||
import asyncio
|
||||
from server.agent import model_container
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
async def search_knowledge_base_iter(database: str, query: str):
|
||||
response = await knowledge_base_chat(query=query,
|
||||
knowledge_base_name=database,
|
||||
model_name=model_container.MODEL.model_name,
|
||||
temperature=0.01,
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
max_tokens=MAX_TOKENS,
|
||||
prompt_name="knowledge_base_chat",
|
||||
score_threshold=SCORE_THRESHOLD,
|
||||
stream=False)
|
||||
|
||||
contents = ""
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents += data["answer"]
|
||||
docs = data["docs"]
|
||||
return contents
|
||||
|
||||
|
||||
_PROMPT_TEMPLATE = """
|
||||
用户会提出一个需要你查询知识库的问题,你应该按照我提供的思想进行思考
|
||||
Question: ${{用户的问题}}
|
||||
这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能:
|
||||
|
||||
{database_names}
|
||||
|
||||
你的回答格式应该按照下面的内容,请注意,格式内的```text 等标记都必须输出,这是我用来提取答案的标记。
|
||||
```text
|
||||
${{知识库的名称}}
|
||||
```
|
||||
```output
|
||||
数据库查询的结果
|
||||
```
|
||||
答案: ${{答案}}
|
||||
|
||||
现在,这是我的问题:
|
||||
问题: {question}
|
||||
|
||||
"""
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question", "database_names"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
class LLMKnowledgeChain(LLMChain):
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
"""[Deprecated] Prompt to use to translate to python if necessary."""
|
||||
database_names: Dict[str, str] = model_container.DATABASE
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
|
||||
"Please instantiate with llm_chain argument or using the from_llm "
|
||||
"class method."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
prompt = values.get("prompt", PROMPT)
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _evaluate_expression(self, dataset, query) -> str:
|
||||
try:
|
||||
output = asyncio.run(search_knowledge_base_iter(dataset, query))
|
||||
except Exception as e:
|
||||
output = "输入的信息有误或不存在知识库"
|
||||
return output
|
||||
return output
|
||||
|
||||
def _process_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
llm_input: str,
|
||||
run_manager: CallbackManagerForChainRun
|
||||
) -> Dict[str, str]:
|
||||
|
||||
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
database = text_match.group(1).strip()
|
||||
output = self._evaluate_expression(database, llm_input)
|
||||
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
elif "Answer:" in llm_output:
|
||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
return {self.output_key: f"输入的格式不对: {llm_output}"}
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _aprocess_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> Dict[str, str]:
|
||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
expression = text_match.group(1)
|
||||
output = self._evaluate_expression(expression)
|
||||
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
elif "Answer:" in llm_output:
|
||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||
return {self.output_key: answer}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_run_manager.on_text(inputs[self.input_key])
|
||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||
llm_output = self.llm_chain.predict(
|
||||
database_names=data_formatted_str,
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return self._process_llm_result(llm_output, inputs[self.input_key], _run_manager)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
await _run_manager.on_text(inputs[self.input_key])
|
||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||
llm_output = await self.llm_chain.apredict(
|
||||
database_names=data_formatted_str,
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_knowledge_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> LLMKnowledgeChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
|
||||
def search_knowledgebase_once(query: str):
|
||||
model = model_container.MODEL
|
||||
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
|
||||
ans = llm_knowledge.run(query)
|
||||
return ans
|
||||
|
||||
|
||||
class KnowledgeSearchInput(BaseModel):
|
||||
location: str = Field(description="The query to be searched")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = search_knowledgebase_once("大数据的男女比例")
|
||||
print(result)
|
||||
@@ -0,0 +1,32 @@
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
|
||||
import json
|
||||
import asyncio
|
||||
from server.agent import model_container
|
||||
|
||||
async def search_knowledge_base_iter(database: str, query: str) -> str:
|
||||
response = await knowledge_base_chat(query=query,
|
||||
knowledge_base_name=database,
|
||||
model_name=model_container.MODEL.model_name,
|
||||
temperature=0.01,
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
max_tokens=MAX_TOKENS,
|
||||
prompt_name="knowledge_base_chat",
|
||||
score_threshold=SCORE_THRESHOLD,
|
||||
stream=False)
|
||||
|
||||
contents = ""
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents = data["answer"]
|
||||
docs = data["docs"]
|
||||
return contents
|
||||
|
||||
def search_knowledgebase_simple(query: str):
|
||||
return asyncio.run(search_knowledge_base_iter(query))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = search_knowledgebase_simple("大数据男女比例")
|
||||
print("答案:",result)
|
||||
56
langchain-chat/server/agent/tools/search_picture.py
Normal file
56
langchain-chat/server/agent/tools/search_picture.py
Normal file
@@ -0,0 +1,56 @@
|
||||
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import List
|
||||
from urllib.parse import quote
|
||||
from server.agent.tools.search_tool import rag_search
|
||||
from server.chat import utils
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
|
||||
|
||||
def search_pic(query: str) -> List[str]:
|
||||
try:
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
query = matches[0]
|
||||
else:
|
||||
return "<关键指令>不需要再调用该工具了</关键指令>"
|
||||
uuid = json.loads(matches[1])["uuid"]
|
||||
tip = utils.get_shared_variable(uuid)
|
||||
# tip["END"] ="ok"
|
||||
temp = {}
|
||||
temp = json.loads(query)
|
||||
res = search_docs(usr_query=temp["query"],fileName= [],top_k=10,score_threshold=0.9,query=temp["query"], knowledge_base_name="p_meiyupic")
|
||||
if len(res)==0 and len(tip["source_docs"])==0:
|
||||
utils.set_shared_variable(uuid,tip)
|
||||
return "工具没有找到结果"
|
||||
# 遍历 res 中的每个元素
|
||||
result = ""
|
||||
for item in res:
|
||||
# 获取 source 的目录部分
|
||||
source_dir = os.path.splitext(item.metadata['source'])[0]
|
||||
# 获取 page_content
|
||||
page_content = item.page_content
|
||||
# 拼接字符串
|
||||
if item.metadata['source'] in tip["source_docs"]:
|
||||
continue
|
||||
else:
|
||||
tip["source_docs"].append(item.metadata['source'])
|
||||
page_content = quote(page_content.replace("http://127.0.0.1:8099/chat_web_backend", "http://127.0.0.1:8099/chat_web_backend"),safe='/:?=&#+')
|
||||
result += f'\n'
|
||||
|
||||
utils.set_shared_variable(uuid,tip)
|
||||
if len(result)>0:
|
||||
print(f"美术作品链接:{result}")
|
||||
return f"注意:以下链接是图片不是参考文献,以下链接不要放到引文小标的格式输出而是以图片格式输出,禁止转义后面链接的编码,这个链接不能带中文。图片如下:{result}"
|
||||
else:
|
||||
return "<关键指令>不需要再调用该工具了</关键指令>"
|
||||
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return f"Failed to get picture.{e}"
|
||||
|
||||
331
langchain-chat/server/agent/tools/search_tool.py
Normal file
331
langchain-chat/server/agent/tools/search_tool.py
Normal file
@@ -0,0 +1,331 @@
|
||||
import asyncio
|
||||
import concurrent
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from fastapi import logger
|
||||
|
||||
from configs import kb_config
|
||||
from configs.model_config import LLM_MODELS
|
||||
from server.agent.tools import duckduckgo_search
|
||||
from server.agent.tools.duckduckgo_search import duckduckgo_search_iter
|
||||
from server.agent.tools.knowledgebase_kgo_search import knowledgebase_kgo_iter
|
||||
from server.agent.tools.rag_search import rag_search1
|
||||
from server.chat import utils
|
||||
from server.chat.policy_fun_iast import get_llm_model_response
|
||||
from server.chat.utils import doc_to_list, get_similar_documents1, solve_knowledge_map, solve_mental_data
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
|
||||
|
||||
|
||||
def rag_search(query: str,uid):
|
||||
"""
|
||||
根据用户输入的query,返回rag搜索结果
|
||||
"""
|
||||
source_docs = []
|
||||
try:
|
||||
search = json.loads(query)
|
||||
logging.info(f'模型输入: {search["query"]}')
|
||||
original_query = search["query"]
|
||||
search_query = get_llm_model_response(
|
||||
strategy_name="rag_search_rewrite",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="rag_search_rewrite",
|
||||
prompt_param_dict={"input": search["query"], "year": datetime.now().strftime("%Y")},
|
||||
temperature=0.3,
|
||||
max_tokens=512
|
||||
)
|
||||
logging.info(f'模型改写: {search_query}')
|
||||
search_keywords = []
|
||||
search_text = f"{search_query}"
|
||||
# if type(search["keywords"]) == list:
|
||||
# search_keywords = search["keywords"]
|
||||
# for keyword in search_keywords:
|
||||
# search_text += f" {keyword}"
|
||||
# else:
|
||||
# search_keywords = search["keywords"].split(",")
|
||||
# for keyword in search_keywords:
|
||||
# search_text += f" {keyword}"
|
||||
self_database = utils.get_shared_variable(uid)
|
||||
result = []
|
||||
|
||||
knownledge_name = []
|
||||
if type(search["knowledge_name"]) == list:
|
||||
knownledge_name=search["knowledge_name"]
|
||||
else:
|
||||
knownledge_name=search["knowledge_name"].split(",")
|
||||
if "美术专业知识库" in knownledge_name:
|
||||
knownledge_name.remove("美术专业知识库")
|
||||
if "database" in self_database:
|
||||
self_database["database"]= self_database["database"].append("p_cafa0101011")
|
||||
else:
|
||||
self_database["database"] = ["p_cafa0101011"]
|
||||
# 添加个人知识库
|
||||
if "database" in self_database:
|
||||
knownledge_name.extend(self_database["database"])
|
||||
knownledge_name = [knownledge for knownledge in knownledge_name
|
||||
if (knownledge in kb_config.CH_BASE_NAME
|
||||
or knownledge in kb_config.EN_BASE_NAME
|
||||
or knownledge in getattr(kb_config, "YJ_BASE_NAME", [])
|
||||
or kb_config.SELF_KNOWLEDGE_BASE.match(knownledge)
|
||||
or knownledge == "coding")]
|
||||
if len(knownledge_name)==0:
|
||||
#result.append(f"没有找到匹配的知识库,请必须更换联网思索搜索更多知识库内容")
|
||||
return result,source_docs
|
||||
# knownledge_name=kb_config.CH_BASE_NAME
|
||||
|
||||
knownledge_name=solve_knowledge_map(knownledge_name)
|
||||
#knownledge_name = ["p_c88859a3d06e4265bd01d816ef2650d1"]
|
||||
num = 0
|
||||
temp=utils.get_shared_variable(uid)
|
||||
for knownledge in knownledge_name:
|
||||
seen_docs = set()
|
||||
duplicate_indices = []
|
||||
# 针对中国钢铁行业动态库增加日期范围过滤
|
||||
expr_param = ""
|
||||
if knownledge == kb_config.STEEL_KB:
|
||||
time_today = datetime.now().strftime("%Y-%m-%d")
|
||||
# 调用LLM生成日期表达式,模板沿用 get_policy_time
|
||||
try:
|
||||
expr_candidate = get_llm_model_response(
|
||||
strategy_name="get steel time",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="get_steel_time",
|
||||
prompt_param_dict={"query": original_query, "time": time_today},
|
||||
temperature=0.01,
|
||||
max_tokens=512
|
||||
).replace("None", "").strip()
|
||||
expr_param = expr_candidate if expr_candidate else ""
|
||||
except Exception as _:
|
||||
expr_param = ""
|
||||
|
||||
doc_list = search_docs(
|
||||
usr_query=original_query,
|
||||
fileName=[],
|
||||
top_k=20,
|
||||
score_threshold=1.0,
|
||||
query=search_text,
|
||||
knowledge_base_name=knownledge,
|
||||
expr=expr_param
|
||||
)
|
||||
|
||||
if len(doc_list)==0:
|
||||
return result,source_docs
|
||||
titles = temp["title"]
|
||||
doc_list,title = utils.remove_docs1(titles,doc_list)
|
||||
titles.extend(title)
|
||||
for inum,doc in enumerate(doc_list):
|
||||
solve_mental_data(knownledge,doc_list,doc=doc,seen_docs=seen_docs,duplicate_indices=duplicate_indices,knowledge=result,inum=inum)
|
||||
|
||||
# 从policydocs中删除重复的文档(从后往前删除以防止索引错位)
|
||||
for index in sorted(duplicate_indices, reverse=True):
|
||||
del doc_list[index]
|
||||
# 处理原文来源进入数组。使用开关语句明确各个条件分支
|
||||
match knownledge:
|
||||
# 属于政策库分支,入参为中文政策库名称
|
||||
case kb_config.DEFAULT_POLICY_BASE:
|
||||
doc_to_list(num,kb_config.DEFAULT_POLICY_BASE_NAME,doc_list,source_docs)
|
||||
# 属于期刊论文库分支,入参为期刊论文库的中文名称
|
||||
case kb_config.DEFAULT_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.DEFAULT_JOURNAL_BASE_NAME,doc_list,source_docs)
|
||||
# 属于报告库分支,入参为报告库中文名称
|
||||
case kb_config.DEFAULT_REPORT_BASE1:
|
||||
doc_to_list(num,kb_config.DEFAULT_REPORT_BASE_NAME,doc_list,source_docs)
|
||||
# 属于冶金行业新闻库分支,入参为冶金行业新闻库中文名称
|
||||
case kb_config.GY_NEWS_BASE:
|
||||
doc_to_list(num,kb_config.GY_NEWS_BASE_NAME,doc_list,source_docs)
|
||||
# 属于冶金行业报告库分支,入参为冶金行业报告库中文名称
|
||||
case kb_config.GY_REPORT_BASE:
|
||||
doc_to_list(num,kb_config.GY_REPORT_BASE_NAME,doc_list,source_docs)
|
||||
# 属于冶金专业知识库分支,入参为冶金专业知识库中文名称
|
||||
case kb_config.GY_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.GY_JOURNAL_BASE_NAME,doc_list,source_docs)
|
||||
# 新增冶金新闻库(2024年以及之前)
|
||||
case kb_config.YJ_NEWS_BASE:
|
||||
doc_to_list(num,kb_config.YJ_NEWS_BASE_NAME,doc_list,source_docs)
|
||||
# 新增冶金中文期刊库
|
||||
case kb_config.YJ_CH_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.YJ_CH_JOURNAL_BASE_NAME,doc_list,source_docs)
|
||||
# 新增冶金外文期刊库
|
||||
case kb_config.YJ_FOR_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.YJ_FOR_JOURNAL_BASE_NAME,doc_list,source_docs)
|
||||
# 新增冶金OA期刊库
|
||||
case kb_config.YJ_OA_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.YJ_OA_JOURNAL_BASE_NAME,doc_list,source_docs)
|
||||
# 新增冶金政策库
|
||||
case kb_config.YJ_POLICYS_BASE:
|
||||
doc_to_list(num,kb_config.YJ_POLICYS_BASE_NAME,doc_list,source_docs)
|
||||
# 新增中国钢铁行业动态库
|
||||
case kb_config.STEEL_KB:
|
||||
doc_to_list(num,kb_config.STEEL_KB_NAME,doc_list,source_docs)
|
||||
# 属于个人知识库分支
|
||||
case _ if kb_config.SELF_KNOWLEDGE_BASE.match(knownledge) or knownledge == "coding":
|
||||
doc_to_list(num,knownledge,doc_list,source_docs)
|
||||
case _:
|
||||
print(f"输入了没有的知识库名称")
|
||||
return "输入了没有的知识库名称",source_docs
|
||||
# num += len(source_docs[knownledge])
|
||||
# 构建缓存对象用于h'per_query_cache'用于传递给其他方法使用uuid获取本轮对话的文献来源
|
||||
# del num
|
||||
# source = utils.get_shared_variable(uid)
|
||||
# print(utils.get_shared_variable(uid))
|
||||
# source["source_docs"]=source_docs
|
||||
# utils.set_shared_variable(uid,source)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in rag_search: {e}")
|
||||
# return "入参格式需为{\"knowledge_name\":[\"XXX\",\"XXX\"],\"query\":\"XXX\",\"keywords\":[\"XXX\", \"XXX\", \"XXX\", \"XXX\"]}检查输入参数如果没有缺少必要值,当前工具异常请换其他工具"
|
||||
return "当前工具异常!请换其他工具",source_docs
|
||||
|
||||
return result,source_docs
|
||||
|
||||
|
||||
|
||||
|
||||
def knowledgebase_kgo_search(query: str, uid) -> List[str]:
|
||||
try:
|
||||
res = knowledgebase_kgo_iter(query,uid)
|
||||
try:
|
||||
if type(res[0])==list and type(res[1])==list:
|
||||
return res
|
||||
elif type(res[1])==list and len(res[1])>0:
|
||||
res[0]=[]
|
||||
return res
|
||||
else:
|
||||
temp = []
|
||||
temp[0]=[]
|
||||
temp[1]=[]
|
||||
return temp
|
||||
except Exception as e:
|
||||
temp = {}
|
||||
logging.error(f"No docs: {e}")
|
||||
temp[0]=[]
|
||||
temp[1]=[]
|
||||
return temp
|
||||
except json.JSONDecodeError:
|
||||
# 如果JSON解码失败,则返回错误消息
|
||||
logging.error("Invalid JSON format in query.")
|
||||
return "<关键指令>不需要再调用该工具了</关键指令>"
|
||||
except KeyError:
|
||||
# 如果解析的JSON对象中缺少必要的键,则返回错误消息
|
||||
return "请尝试调用其他工具"
|
||||
except Exception as e:
|
||||
# 捕获其他所有异常,并返回通用错误消息
|
||||
return f"发生错误:{str(e)},请尝试调用其他工具"
|
||||
|
||||
|
||||
def inner_duckduckgo_search(query: str, uuid:str,) :
|
||||
logging.info(f"模型输入: {query}")
|
||||
combined_result = asyncio.run(duckduckgo_search_iter(query, uuid, "y","default" ))
|
||||
# 以标准json格式输出
|
||||
logging.info("返回JSON格式的结果给到模型...")
|
||||
return combined_result
|
||||
|
||||
|
||||
def search_tool(query: str):
|
||||
"""获取到uid并拆分query"""
|
||||
if "<param>"in query:
|
||||
query = query.replace("<param>","").replace("</param>","")
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
query = matches[0]
|
||||
else:
|
||||
return "<关键指令>当前工具不需要再调用</关键指令>"
|
||||
time_based_uuid = json.loads(matches[1])["uuid"]
|
||||
search = json.loads(query)
|
||||
if type(search["query"])==list and len(search["query"])>0:
|
||||
searches = search["query"][0]
|
||||
elif type(search["query"])==list and len(search["query"]) == 0:
|
||||
searches = "无"
|
||||
else:
|
||||
searches = search["query"]
|
||||
"""
|
||||
根据用户输入的query,返回rag搜索结果
|
||||
"""
|
||||
|
||||
try:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# 提交任务并发执行
|
||||
|
||||
test = {}
|
||||
test["num"]=0
|
||||
test["source_docs"]=[]
|
||||
test["END"] = ""
|
||||
test["title"] = []
|
||||
utils.set_shared_variable(time_based_uuid+"¥",test)
|
||||
# future2 = executor.submit(knowledgebase_kgo_search,search["query"],time_based_uuid+"q")
|
||||
future1 = executor.submit(rag_search,query,time_based_uuid)
|
||||
# if not "type" in utils.get_shared_variable(time_based_uuid):
|
||||
# future2 = executor.submit(knowledgebase_kgo_search,searches,time_based_uuid+"¥")
|
||||
if not "type" in utils.get_shared_variable(time_based_uuid):
|
||||
future2 = executor.submit(knowledgebase_kgo_search,searches,time_based_uuid+"¥")
|
||||
result3 = []
|
||||
# 获取结果
|
||||
result1,sourcedocs = future1.result()
|
||||
result2 = {}
|
||||
if "type" in utils.get_shared_variable(time_based_uuid):
|
||||
result2[0] =[]
|
||||
result2[1] = []
|
||||
else:
|
||||
result2 = future2.result()
|
||||
# if "type" in utils.get_shared_variable(time_based_uuid):
|
||||
# result2[0] =[]
|
||||
# result2[1] = []
|
||||
# else:
|
||||
# result2 = future2.result()
|
||||
# result2[0] = []
|
||||
# result2[1] = []
|
||||
utils.remove_shared_variable(time_based_uuid+"q")
|
||||
if type(result2[1]) == list:
|
||||
if type(sourcedocs) == list:
|
||||
sourcedocs.extend(result2[1])
|
||||
else:
|
||||
sourcedocs = []
|
||||
if type(result1) == list:
|
||||
result1.extend(result2[0])
|
||||
result3 = result1
|
||||
else:
|
||||
result3 = result2[0]
|
||||
|
||||
logging.info(f"result2:{result2[1]}")
|
||||
source = []
|
||||
res=[]
|
||||
sources = utils.get_shared_variable(time_based_uuid)
|
||||
i = sources["num"]
|
||||
num = sources["num"]
|
||||
for result in sourcedocs:
|
||||
try:
|
||||
i+=1
|
||||
res3 = re.sub(r'\[\d+\]', f"[{i}]", result.replace("\n",""), count=1)
|
||||
if res3:
|
||||
source.append(re.sub(r'\[\d+\]', f"[{i}]", result.replace("\n",""), count=1))
|
||||
else:
|
||||
i -= 1
|
||||
except Exception as e:
|
||||
i -= 1
|
||||
pass
|
||||
# internet_search_res = f"参考资料[{len(result1)+1}-{len(source)}]:{result2[0]}"
|
||||
# internet_search_res = f"参考资料:{result2[0]}"
|
||||
j = sources["num"]
|
||||
for result in result3:
|
||||
j+=1
|
||||
res.append(re.sub(r'\[\d+\]', f"[{j}]", result, count=1))
|
||||
|
||||
print(utils.get_shared_variable(time_based_uuid))
|
||||
# sources["source_docs"]=source
|
||||
sources["source_docs"].extend(source)
|
||||
sources["num"]=i
|
||||
# sources["END"] = "ok"
|
||||
utils.set_shared_variable(time_based_uuid,sources)
|
||||
logging.info(f"result1:{result1},sourcedocs:{sourcedocs}")
|
||||
logging.info(f"result2:{result2}")
|
||||
logging.info(f"{res}")
|
||||
if len(res) ==0 and len(source)==0:
|
||||
return f"尝试调整入参重新调用知识库联想工具(同一个问题调用超过三次就不要再使用知识库联想工具了,浪费时间)"
|
||||
return f"<关键指令>如果你在写文章禁止在非规定位置输出参考资料</关键指令>资料:{res}\n资料来源为:{source}\n 注意:如果你在根据大纲撰写文章,撰写中间部分章节禁止输出综上所述之类的影响文风的话,撰写中间部分禁止输出附录引用文献等!!!"
|
||||
except Exception as e:
|
||||
logging.error(f"Error occurred during search_tool execution.{e}")
|
||||
return "同一个问题调用知识库联想工具超过5次就不要再调用知识库联想"
|
||||
|
||||
9
langchain-chat/server/agent/tools/search_youtube.py
Normal file
9
langchain-chat/server/agent/tools/search_youtube.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Langchain 自带的 YouTube 搜索工具封装
|
||||
from langchain.tools import YouTubeSearchTool
|
||||
from pydantic import BaseModel, Field
|
||||
def search_youtube(query: str):
|
||||
tool = YouTubeSearchTool()
|
||||
return tool.run(tool_input=query)
|
||||
|
||||
class YoutubeInput(BaseModel):
|
||||
location: str = Field(description="Query for Videos search")
|
||||
10
langchain-chat/server/agent/tools/search_youtube.yaml
Normal file
10
langchain-chat/server/agent/tools/search_youtube.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
name: search_youtube
|
||||
description: Use this tools to search youtube videos
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: Query for Videos search
|
||||
required:
|
||||
- query
|
||||
9
langchain-chat/server/agent/tools/shell.py
Normal file
9
langchain-chat/server/agent/tools/shell.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# LangChain 的 Shell 工具
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.tools import ShellTool
|
||||
def shell(query: str):
|
||||
tool = ShellTool()
|
||||
return tool.run(tool_input=query)
|
||||
|
||||
class ShellInput(BaseModel):
|
||||
query: str = Field(description="一个能在Linux命令行运行的Shell命令")
|
||||
10
langchain-chat/server/agent/tools/shell.yaml
Normal file
10
langchain-chat/server/agent/tools/shell.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
name: shell
|
||||
description: Use Linux Shell to execute Linux commands
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: The command to execute
|
||||
required:
|
||||
- query
|
||||
49
langchain-chat/server/agent/tools/weather_check.py
Normal file
49
langchain-chat/server/agent/tools/weather_check.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
更简单的单参数输入工具实现,用于查询现在天气的情况
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
import requests
|
||||
from configs.kb_config import SENIVERSE_API_KEY
|
||||
from server.chat import utils
|
||||
|
||||
|
||||
def weather(location: str, api_key: str):
|
||||
url = f"https://api.seniverse.com/v3/weather/daily.json?key={api_key}&location={location}&language=zh-Hans&unit=c&start=0&days=5"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
weather = {
|
||||
"today": json.dumps(data["results"][0]["daily"][0]),
|
||||
"others": json.dumps(data["results"][0]["daily"][1:])
|
||||
}
|
||||
weather_info = json.dumps(weather)
|
||||
return weather_info
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to retrieve weather: {response.status_code}")
|
||||
|
||||
|
||||
def weathercheck(query: str):
|
||||
"""获取到uid并拆分query"""
|
||||
try:
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
query = matches[0]
|
||||
else:
|
||||
return "<关键指令>不需要再调用该工具了</关键指令>"
|
||||
location = json.loads(query)["location"]
|
||||
time_based_uuid = json.loads(matches[1])["uuid"]
|
||||
# sources = utils.get_shared_variable(time_based_uuid)
|
||||
# sources["source_docs"]=["天气 预报"]
|
||||
# sources["num"]+=1
|
||||
# sources["END"] = "ok"
|
||||
# utils.set_shared_variable(time_based_uuid,sources)
|
||||
return weather(location, SENIVERSE_API_KEY)
|
||||
except Exception as e:
|
||||
return f"Failed to retrieve weather.{e}"
|
||||
|
||||
|
||||
class WeatherInput(BaseModel):
|
||||
location: str = Field(description="City name,include city and county")
|
||||
10
langchain-chat/server/agent/tools/weather_check.yaml
Normal file
10
langchain-chat/server/agent/tools/weather_check.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
name: weather_check
|
||||
description: Use Weather API to get weather information
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: City name,include city and county,like "厦门市思明区"
|
||||
required:
|
||||
- query
|
||||
11
langchain-chat/server/agent/tools/wolfram.py
Normal file
11
langchain-chat/server/agent/tools/wolfram.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# Langchain 自带的 Wolfram Alpha API 封装
|
||||
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
|
||||
from pydantic import BaseModel, Field
|
||||
wolfram_alpha_appid = "your key"
|
||||
def wolfram(query: str):
|
||||
wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=wolfram_alpha_appid)
|
||||
ans = wolfram.run(query)
|
||||
return ans
|
||||
|
||||
class WolframInput(BaseModel):
|
||||
location: str = Field(description="需要运算的具体问题")
|
||||
10
langchain-chat/server/agent/tools/wolfram.yaml
Normal file
10
langchain-chat/server/agent/tools/wolfram.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
name: wolfram
|
||||
description: Useful for when you need to calculate difficult math formulas
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: The formula to be calculated
|
||||
required:
|
||||
- query
|
||||
174
langchain-chat/server/agent/tools_select.py
Normal file
174
langchain-chat/server/agent/tools_select.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from langchain.tools import Tool
|
||||
from server.agent.tools import *
|
||||
from server.agent.tools import rag_search1
|
||||
from server.agent.tools import duckduckgo_search
|
||||
from server.agent.tools.do_nothing import do_nothing, doNothingInput, get_next_tip
|
||||
from server.agent.tools.draw_plot import create_and_save_plot, draw_ink_pic, draw_realistic_pic, drawInkInput, drawPlotInput, drawRealisticInput
|
||||
from server.agent.tools.get_statistical_data import mysql_statistic
|
||||
from server.agent.tools.math import code_count, math_count
|
||||
from server.agent.tools.search_picture import search_pic
|
||||
from server.agent.tools.search_tool import search_tool
|
||||
|
||||
# 请注意,如果你是为了使用AgentLM,在这里,你应该使用英文版本。
|
||||
|
||||
tools = [
|
||||
# Tool.from_function(
|
||||
# func=calculate,
|
||||
# name="calculate",
|
||||
# description="Useful for when you need to answer questions about simple calculations",
|
||||
# args_schema=CalculatorInput,
|
||||
# ),
|
||||
# Tool.from_function(
|
||||
# func=arxiv,
|
||||
# name="arxiv",
|
||||
# description="A wrapper around Arxiv.org for the original English paper.",
|
||||
# args_schema=ArxivInput,
|
||||
# ),
|
||||
# Tool.from_function(
|
||||
# func=weathercheck,
|
||||
# name="weather_check",
|
||||
# description="",
|
||||
# args_schema=WeatherInput,
|
||||
# ),
|
||||
# Tool.from_function(
|
||||
# func=shell,
|
||||
# name="shell",
|
||||
# description="Use Shell to execute Linux commands",
|
||||
# args_schema=ShellInput,
|
||||
# ),
|
||||
# Tool.from_function(
|
||||
# func=search_knowledgebase_complex,
|
||||
# name="search_knowledgebase_complex",
|
||||
# description="Use this tool to check out local knowledgebase",
|
||||
# args_schema=KnowledgeSearchInput,
|
||||
# ),
|
||||
# Tool.from_function(
|
||||
# func=search_internet,
|
||||
# name="search_internet",
|
||||
# description="Use this tool to search the internet and retrieve information",
|
||||
# args_schema=SearchInternetInput,
|
||||
# ),
|
||||
# Tool.from_function(
|
||||
# func=wolfram,
|
||||
# name="Wolfram",
|
||||
# description="Useful for when you need to calculate difficult formulas",
|
||||
# args_schema=WolframInput,
|
||||
# ),
|
||||
# Tool.from_function(
|
||||
# func=search_youtube,
|
||||
# name="search_youtube",
|
||||
# description="use this tools to get videos",
|
||||
# args_schema=YoutubeInput,
|
||||
# ),
|
||||
# Tool.from_function(
|
||||
# func=chat_with_Yi34B,
|
||||
# name="chat_with_Yi34B",
|
||||
# description="Use this tool to chat",
|
||||
# args_schema=ChatWithYi34BInput,
|
||||
# ),
|
||||
|
||||
Tool.from_function(
|
||||
func=search_tool,
|
||||
name="知识库联想",
|
||||
# description="\n【工具参数说明】:\n 参数格式:参数需以逗号分隔,并以JSON格式提供\n 必需参数:\n 1. query:查询内容。\n2.knowledge_name:知识库名称,例如 政策库、项目库、期刊论文库、冶金行业新闻库、冶金专业知识库、冶金行业报告库\\n3.keywords【使用指南】:\n当需要利用文献资源辅助回答问题时,请使用此工具。\n输入必须包含三个参数:query、knowledge_name 和 keywords\nknowledge_name 必须是以下之一:政策库、期刊论文库、新闻库。\nkeywords 应尽可能多地涵盖相关领域,并确保与搜索主题高度相关。\n尽量提供中文关键词。\n示例:\n{\"knowledge_name\":\"政策库\",\"keywords\": [\"technology\", \"era\"]\n确保所有参数以JSON格式提供,以便工具能够正确解析和使用}",
|
||||
description="\n【参数说明】:\n 参数格式:参数必须使用JSON格式提供\n 必需参数:\n 1. query:查询内容。\n2.knowledge_name:知识库名称:必须从【中国钢铁行业动态库、政策库、期刊论文库、冶金新闻库(2024年以及之前)、冶金中文期刊库、冶金外文期刊库、冶金OA期刊库、冶金行业新闻库、冶金专业知识库、冶金行业报告库、报告库、美术专业知识库】中选择\n3.keywords:与搜索主题高度相关\n【使用指南】:\n当需要利用文献资源辅助回答问题时,请使用此工具。\n输入必须包含三个参数:{\"knowledge_name\":[\"XXX\",\"XXX\"],\"query\":\"XXX\",\"keywords\":[\"XXX\", \"XXX\", \"XXX\", \"XXX\"]}\n示例:\n{\"query\":\"人工智能的相关政策\",\"knowledge_name\":[\"政策库\"],\"keywords\": [\"人工智能\", \"国家级\"]}\n务必以JSON格式输入参数,以便工具能够正确解析和使用。",
|
||||
args_schema=RagSearchInput
|
||||
),
|
||||
Tool.from_function(
|
||||
func=knowledgebase_kgo_search,
|
||||
name="联网思索",
|
||||
description="\n【参数说明】:参数格式:以JSON格式提供,参数为:query\n query 为【user input】的原文,例如:\n{\"query\":\"人工智能是什么\"}\n注意!禁止改写【user input】的内容",
|
||||
# description="提示:联网思索工具仅供娱乐,名词解释,或知识库联想无法满足需求时才使用\n【参数说明】:参数格式:以JSON格式提供,参数为:query\n query 为【user input】的原文,例如:\n{\"query\":\"人工智能是什么\"}\n注意!禁止改写【user input】的内容",
|
||||
args_schema=KnowledgeKgoInput
|
||||
),
|
||||
# Tool.from_function(
|
||||
# func=duckduckgo_search,
|
||||
# name="联网思索",
|
||||
# description="【参数说明】:参数格式:必须以JSON格式提供\n必需参数:\n1.query:查询内容(如:\"粉末冶金\")。\n可选参数:\n2.time:查询时间范围(例如 \"m\" 表示过去一个月,\"w\" 表示过去一周,不传则默认 \"d\")\n3.resource_type:资源类型控制\n - \"default\" 或不传该参数:检索所有资源\n - \"video\":用户明确要求检索视频的时候,检索视频 - \"limit\":查询资料数量,如果用户没有明确要求,默认传3 \n\n【使用示例】:\n{\"query\":\"人工智能\",\"time\":\"m\",\"resource_type\":\"default\",\"limit\":3}\n确保参数以JSO-N格式提供,以便正确解析和使用。",
|
||||
# args_schema=DuckduckgoInput
|
||||
# ),
|
||||
# Tool.from_function(
|
||||
# func=do_nothing,
|
||||
# name="无需调用工具",
|
||||
# description="不需要使用工具的时候调用本方法以获取指示",
|
||||
# args_schema=doNothingInput
|
||||
# ),
|
||||
# Tool.from_function(
|
||||
# func=get_next_tip,
|
||||
# name="环节跳转",
|
||||
# description="第一阶段结束获取第二阶段提示",
|
||||
# args_schema=doNothingInput
|
||||
# ),
|
||||
Tool.from_function(
|
||||
func=create_and_save_plot,
|
||||
name="图表绘制",
|
||||
description="注意本工具一次只能画一张图<关键指令>使用本工具后你必须按工具返回要求输出图片</关键指令>【参数说明】:参数以<param></param>标签包裹。必须提供以下参数格式: 必需参数:\n 1. data:图表数据格式如下{\"XXX\": XX, \"XXX\": XX, \"XXX\": X, \"XXX\": X}其中XXX代表分类名称,xx代表分类数据量。\n2.title:图表标题3.xlabel:横轴标题(你按照分的几类属于哪个大类)必须有\n4.ylabel:纵轴标题(你的数值数据是什么)必须有\n 5.plot_type:图表类型必须从以下几个当中选一个作为输入【bar,pie,line】其中bar代表柱状图pie代表饼图line代表折线图你只能选一个作为输入,\n【使用指南】:\n当需要生成图表时,请使用此工具。\n输入必须包含以下参数::<tool_input><param>{\n\"data\": {\"XXX\": XX, \"XXX\": XX, \"XXX\": X, \"XXX\": X},\"title\": \"X\",\"xlabel\": \"X\",\"ylabel\": \"X\",\"plot_type\": \"X\"}\n示例:{'\"data\": {\"Category A\": 23, \"Category B\": 17, \"Category C\": 35, \"Category D\": 29},\"title\": \"My Chart\",\"xlabel\": \"Category\",\"ylabel\": \"Value\",\"plot_type\": \"pie\"}</param></tool_input>\n请务必以json格式输入方便使用。\n",
|
||||
args_schema=drawPlotInput
|
||||
),
|
||||
|
||||
Tool.from_function(
|
||||
func=math_count,
|
||||
name="数学运算",
|
||||
description="【参数说明】:参数格式:以JSON格式提供,参数为:query\n query 为数学问题描述,例如:\n{\"query\":\"9.9和9.11谁大\"}\n",
|
||||
args_schema=drawPlotInput
|
||||
),
|
||||
Tool.from_function(
|
||||
func=code_count,
|
||||
name="代码专家",
|
||||
description="【参数说明】:参数格式:以JSON格式提供,参数为:query\n query 为代码问题描述,例如:\n{\"query\":\"写一个es增量数据处理的脚本\"}\n",
|
||||
args_schema=drawPlotInput
|
||||
),
|
||||
Tool.from_function(
|
||||
func=weathercheck,
|
||||
name="天气工具",
|
||||
description="【参数说明】:注意:仅支持三天内天气查询,参数格式:以JSON格式提供,参数为:location\n location 为查询天气的城市名称,例如:\n{\"location\":\"北京\"}\n",
|
||||
args_schema=drawPlotInput
|
||||
),
|
||||
Tool.from_function(
|
||||
func=search_pic,
|
||||
name="美术作品获取",
|
||||
description="【参数说明】:参数格式:以JSON格式提供,参数为:query\n query 为查询美术作品的描述,务必是美术作品的描述,不要直接输入美术作品四个字而是如山水画写生画,草原,太阳等等,例如:\n{\"query\":\"山水画\"}\n,需将该工具返回的图片链接以markdown格式给出",
|
||||
args_schema=drawPlotInput
|
||||
),
|
||||
Tool.from_function(
|
||||
func=mysql_statistic,
|
||||
name="统计数据查询",
|
||||
description="提示统计库数据最多只有2023年及之前的199几年的最多【参数说明】:参数格式:以JSON格式提供,并必须以<param></param>标签包裹。参数为:query\n query 为详细的查询问题例如:\n<tool_input> <param>{\"query\":\"原油出口统计数据\"}</param> </tool_input>\n,需将返回数据做成数据表和图表",
|
||||
args_schema=drawPlotInput
|
||||
),
|
||||
# Tool.from_function(
|
||||
# func=draw_realistic_pic,
|
||||
# name="实景绘制",
|
||||
# description="【参数说明】:参数格式:以JSON格式提供,参数为:query\n query 为图片需求描述,例如:\n{\"query\":\"画一张风景图\"}\n",
|
||||
# args_schema=drawRealisticInput
|
||||
# ),
|
||||
# Tool.from_function(
|
||||
# func=draw_ink_pic,
|
||||
# name="水墨画绘制",
|
||||
# description="【参数说明】:参数格式:以JSON格式提供,参数为:query\n query 为图片需求描述,例如:\n{\"query\":\"画一个红苹果\"}\n",
|
||||
# args_schema=drawInkInput
|
||||
# ),
|
||||
|
||||
# Tool.from_function(
|
||||
# func=rag_search1,
|
||||
# name="rag_search",
|
||||
# description="Use this tool to search for relevant information in the policy knowledge base, with the query is the user's questions ,with the knowledge base name being t_policy_total_bge_new_v2.",
|
||||
# args_schema=RagSearchInput,
|
||||
# ),
|
||||
# Tool.from_function(
|
||||
# func=policy_knowledgebase_search,
|
||||
# name="policy_knowledgebase",
|
||||
# description="Use this tool to search for policy knowledge base",
|
||||
# args_schema=PolicyKnowledgeInput,
|
||||
# ),
|
||||
# Tool.from_function(
|
||||
# func=report_knowledgebase_search,
|
||||
# name="report_knowledgebase",
|
||||
# description="Use this tool to search for report knowledge base",
|
||||
# args_schema=ReportKnowledgeInput,
|
||||
# ),
|
||||
]
|
||||
|
||||
tool_names = [tool.name for tool in tools]
|
||||
# 网络搜索工具
|
||||
search_tool_names = [search_internet for tool in tools]
|
||||
Reference in New Issue
Block a user