Files
gangyan/langchain-chat/server/chat/search_engine_chat.py

592 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
import itertools
import random
from bs4 import BeautifulSoup
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.common.by import By
from langchain.utilities.bing_search import BingSearchAPIWrapper
import requests
from fake_useragent import UserAgent
# from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE,
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
from fastapi import Body
from sse_starlette import EventSourceResponse
from fastapi.concurrency import run_in_threadpool
from configs.kb_config import CHROME_DIR, EN_BASE_NAME
from server.agent.tools.duckduckgo_search import DuckduckgoInput
from server.chat import utils
from server.chat.ZhipuSearchAPI import ZhipuSearchAPIWrapper
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from typing import List, Optional, Dict
from server.chat.utils import History, get_similar_documents
from langchain.docstore.document import Document
import json
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from markdownify import markdownify
from .KgoSearchAPIWrapper import KgoSearchAPIWrapper
from server.chat.policy_fun_iast import get_llm_model_response
import re
from configs.basic_config import *
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
"title": "env info is not found",
"link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY,
bing_search_url=BING_SEARCH_URL)
return search.results(text, result_len)
# def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
# search = DuckDuckGoSearchAPIWrapper()
# return search.results(text, result_len)
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
search = DuckduckgoInput()
return search.results(text, result_len)
def kgo_search(text,origin_query, kgo_search_type="1000",**kwargs):
search = KgoSearchAPIWrapper()
return search.results(text, kgo_search_type,origin_query=origin_query)
def metaphor_search(
text: str,
result_len: int = SEARCH_ENGINE_TOP_K,
split_result: bool = False,
chunk_size: int = 500,
chunk_overlap: int = OVERLAP_SIZE,
) -> List[Dict]:
from metaphor_python import Metaphor
if not METAPHOR_API_KEY:
return []
client = Metaphor(METAPHOR_API_KEY)
search = client.search(text, num_results=result_len, use_autoprompt=True)
contents = search.get_contents().contents
for x in contents:
x.extract = markdownify(x.extract)
# metaphor 返回的内容都是长文本,需要分词再检索
if split_result:
docs = [Document(page_content=x.extract,
metadata={"link": x.url, "title": x.title})
for x in contents]
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
splitted_docs = text_splitter.split_documents(docs)
# 将切分好的文档放入临时向量库重新筛选出TOP_K个文档
if len(splitted_docs) > result_len:
normal = NormalizedLevenshtein()
for x in splitted_docs:
x.metadata["score"] = normal.similarity(text, x.page_content)
splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True)
splitted_docs = splitted_docs[:result_len]
docs = [{"snippet": x.page_content,
"link": x.metadata["link"],
"title": x.metadata["title"]}
for x in splitted_docs]
else:
docs = [{"snippet": x.extract,
"link": x.url,
"title": x.title}
for x in contents]
return docs
def zhipu_search(origin_query, **kwargs):
search = ZhipuSearchAPIWrapper()
return search.zhipu_search(origin_query)
SEARCH_ENGINES = {"bing": bing_search,
"duckduckgo": duckduckgo_search,
"metaphor": metaphor_search,
"kgo": kgo_search,
"zhipu_search": zhipu_search,
}
def search_result2docs(search_results):
docs = []
for result in search_results:
# 提取信息
title = result.get("title", "")
keywords = result.get("keywords", "")
snippet = result.get("snippet", "")
author = result.get("author", "")
resource_type = result.get("resource_type", "")
publish_year = result.get("publish_year", "")
# 创建page_content
page_content = f"【本资料的标题为:{title}\n"
page_content += f"【本资料的关键词为:{keywords}\n"
page_content += f"【本资料的摘要为:{snippet}\n"
page_content += f"【本资料的作者为:{author}\n"
page_content += f"【本资料的发布时间为:{publish_year}\n"
page_content += f"【本资料的资源类型为:{resource_type}\n"
# 创建metadata
metadata = {
"source": result.get("link", ""),
"filename": title,
"author": author,
"keywords": keywords,
"publish_year": publish_year,
"resource_type": resource_type,
}
doc = Document(page_content=page_content, metadata=metadata)
docs.append(doc)
return docs
# 定义替换函数
# def replace_ref(match):
# num = match.group(1)
# return f'^[{num}]^'
def ZhipuSearch_result2docs(search_results):
docs = []
for result in search_results:
# 提取信息
title = result.get("title", "")
cleaned_title = re.sub(r'(发布时间:.*?', '', title) # 展示的参考资料名称需过滤掉发布时间
link = result.get("link", "")
content = result.get("content", "")
# refer = result.get("refer", "")
# refer = re.sub(r'ref_(\d+)', replace_ref, refer)
# icon = result.get("icon", "")
media = result.get("media", "")
# 创建 page_content
page_content = (
f"【资料:{cleaned_title}\n"
f"内容如下:{content}\n"
f"资料链接:{link}\n"
f"发布媒体为:{media}\n"
)
# 创建metadata
metadata = {
"link": link,
"title": cleaned_title,
# "refer": refer,
"content": content,
# "icon": icon,
"media": media,
}
doc = Document(page_content=page_content, metadata=metadata)
docs.append(doc)
logger.info(f"Zhipu搜索召回资料:{doc.metadata}")
return docs
def Duckduckgo_result2docs(search_results):
docs = []
for result in search_results:
# 提取信息
title = result.get("title", "")
url = result.get("url", "")
content = result.get("body", "")
date = result.get("date", "")
image = result.get("image", "")
source = result.get("source", "")
# 创建page_content
page_content = f"【本资料的标题为:{title}\n"
page_content += f"链接为:{url}\n"
page_content += f"内容为:{content}\n"
page_content += f"发布时间为:{date}\n"
page_content += f"封面图为:{image}\n"
page_content += f"发布媒体为:{source}\n"
# 创建metadata
metadata = {
"url": url,
"title": title,
"date": date,
"content": content,
"image": image,
"source": source,
}
doc = Document(page_content=page_content, metadata=metadata)
docs.append(doc)
logger.info(f"Duckduckgo搜索召回资料:{doc.metadata}")
return docs
def extract_html(url,proxy):
try:
# url = url.replace("https", "http")
urlss = "http://%(ip)s:%(port)s" % {
"ip": proxy["sever"],
"port": proxy["port"],
}
proxies = {
"http": urlss,
"https": urlss,
}
headers = {'User-Agent': UserAgent().random}
response = requests.get(url,proxies=proxies, timeout=3,headers=headers)
response.raise_for_status()
soup = BeautifulSoup(response.content, 'html.parser')
res = {}
res ["html"] = soup.prettify()
res["title"] = soup.title.string if soup.title else ""
if res["title"]==None:
res["title"] = ""
return res
except requests.exceptions.RequestException as e:
res = {}
res ["html"] = ""
res["title"] = ""
return res
chrome_options = Options()
chrome_options.add_argument("--headless") # 无头模式
chrome_options.add_argument("--disable-gpu") # 禁用 GPU
chrome_options.add_argument("--ignore-certificate-errors") # 忽略 SSL 证书错误
chrome_options.add_argument("--no-sandbox") # 解决DevToolsActivePort文件不存在的报错
chrome_options.add_argument("--disable-dev-shm-usage") # 解决资源限制问题
chrome_options.binary_location = f"{CHROME_DIR}/chrome-linux64/chrome"#chrome可执行文件
service = Service(f'{CHROME_DIR}/chromedriver-linux64/chromedriver') # ChromeDriver
async def lookup_search_engine(
uid: str,
search_query: str,
origin_query: str,
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
split_result: bool = False,
kgo_search_type: str = "1000"
):
search_engine = SEARCH_ENGINES.get(search_engine_name)
if not search_engine:
raise ValueError(f"Unsupported search engine: {search_engine_name}")
if search_engine_name == "zhipu_search":
# 对于 zhipu_search只需要 origin_query
results = await run_in_threadpool(search_engine, search_query)
if len(results) > 0:
try:
sentences = [doc["title"] for doc in results]
sentences_page_content = [str(i+1)+":【"+doc["title"]+doc["content"]+"" for i,doc in enumerate(results)]
except Exception as e:
sentences = [doc["source"] for doc in results]
sentences_page_content = [str(i+1)+":【"+doc["source"]+doc["content"]+"" for i,doc in enumerate(results)]
res = get_llm_model_response(
strategy_name="default_similar",
llm_model_name=LLM_MODELS[0],
template_prompt_name="default_similar",
prompt_param_dict={"input": origin_query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")},
temperature=0.01,
max_tokens=512
)
try:
index =[]
if res == "":
index = []
else:
index = res.split(",")
index = [int(i)-1 for i in index]
results = get_similar_documents(index=index,sentences=sentences,query=origin_query, docs=results, top_k=top_k)
except Exception as e:
print(e)
results = get_similar_documents(index=[],sentences=sentences,query=origin_query, docs=results, top_k=top_k)
# 从中选择几个不同的代理
# 例如,选择 2 个不同的代理
if len(results)>0:
try:
url = f"https://sch.shanchendaili.com/api.html?action=get_ip&key=HU301b55318830279250RjH2&time=10&count={len(results)}&type=txt&province=215&only=0"
ip_list = requests.get(url=url).json()["list"]
with ThreadPoolExecutor() as executor:
html_contents = list(executor.map(extract_html, [result.get("link", "") for result in results],ip_list))
except Exception as e:
print(e)
html_contents = [{"html": "","title":""} for _ in range(len(results))]
info = {}
res_list = []
for result, html_content in zip(results, html_contents):
info.clear()
info["content"] = result.get("content", "")
info["icon"] = ""
info["index"] = 0
info["link"] = result.get("link", "")
info["media"] = "未知"
info["refer"] = result.get("positions", [0])[0]
info["title"] = html_content["title"] if ((result["title"].replace("...","")) in html_content["title"]) else result.get("title", "")
res_list.append(info.copy())
results = res_list
else:
return []
if "" in uid:
res =utils.get_shared_variable(uid.replace("",""))
results,titles=utils.remove_docs(res["title"], results)
res["title"].extend(titles)
utils.set_shared_variable(uid.replace("",""),res)
else:
res =utils.get_shared_variable(uid)
results,titles=utils.remove_docs(res["title"], results)
res["title"].extend(titles)
utils.set_shared_variable(uid,res)
docs = ZhipuSearch_result2docs(results)
else:
# 其他搜索引擎需要更多参数
results = await run_in_threadpool(
search_engine,
search_query,
origin_query=origin_query,
result_len=top_k,
split_result=split_result,
kgo_search_type=kgo_search_type
)
docs = search_result2docs(results)
return docs
async def search_engine_chat(uid: Optional[str]=Body(None, description="userID"),
query: str = Body(..., description="用户输入", examples=["你好"]),
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[1], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
split_result: bool = Body(False, description="是否对搜索结果进行拆分主要用于metaphor搜索引擎"),
kgo_search_type: str = Body("1000", description="kgo搜索引擎搜索类型默认1000"),
):
if search_engine_name == "bing" and not BING_SUBSCRIPTION_KEY:
return BaseResponse(code=404, msg=f"要使用Bing搜索引擎需要设置 `BING_SUBSCRIPTION_KEY`")
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
if search_engine_name == "bing" and not BING_SUBSCRIPTION_KEY:
return BaseResponse(code=404, msg=f"要使用Bing搜索引擎需要设置 `BING_SUBSCRIPTION_KEY`")
history = [History.from_data(h) for h in history]
user_queries = []
# 遍历历史消息并收集用户消息
for message in history:
if message.role == 'user':
user_queries.append(message.content)
if search_engine_name == "kgo":
metal_search_entity = get_llm_model_response(
strategy_name="KGO_NER",
llm_model_name=LLM_MODELS[0],
template_prompt_name="KGO_NER",
prompt_param_dict={"input": query, "history": user_queries},
temperature=TEMPERATURE,
max_tokens=512
)
logger.info(f"KGO_NER体识别结果{metal_search_entity}")
# 使用正则表达式来匹配和拆分多个 JSON 对象
json_strings = re.findall(r'\{.*?\}', metal_search_entity)
# 初始化搜索查询和类型
search_query = []
kgo_search_type = []
# 处理每个 JSON 对象
for json_string in json_strings:
try:
# 尝试解析 JSON
data = json.loads(json_string)
# 提取键值对
for key, value in data.items():
search_query.append(key)
if value in [" ", ""]: # 如果类型为空,则使用默认值
kgo_search_type = "10000"
else:
kgo_search_type.extend(value)
except json.JSONDecodeError:
# 处理 JSON 解码错误的情况,继续使用原始查询和默认类型
search_query = ','.join(search_query.append(query))
kgo_search_type = ','.join(kgo_search_type.append("1000,1002,1002,1003,1004,1005,1006,1007,1010,1011,1013,1015,3001,4001,4004,4005,6001"))
# 输出处理结果
search_query = ','.join(search_query)
kgo_search_type_list = ','.join(kgo_search_type).split(',')
search_type_map = {
"全部": "1000",
"期刊论文": "1002",
"学位论文": "1003",
"会议论文": "1004",
"政策": "1005",
"成果": "1006",
"科技成果": "1006",
"项目": "1007",
"报告": "1010",
"图书": "1011",
"外文期刊论文": "1013",
"外文资料": "1013",
"期刊": "1015",
"专利": "3001",
"新闻": "4001",
"数据集": "4004",
"视频": "4005",
"统计数据": "6001"
}
# 通过搜索类型映射,找到多个对应的值
matched_types = []
# 确保 search_type_map 不为空
if not search_type_map:
raise ValueError("search_type_map 不能为空")
for search_type in kgo_search_type_list:
if search_type in search_type_map:
matched_types.append(search_type_map[search_type])
# 如果没有匹配项,默认使用 "1000"
if not matched_types:
matched_types.append("1000")
# 处理 search_query 为空的情况
if search_query == '':
search_query = query
# 移除重复的类型并转换为字符串
kgo_search_type = ','.join(set(matched_types))
logger.info(f"KGO检索关键词:{search_query}KGO检索类型:{kgo_search_type}")
elif search_engine_name == "zhipu_search":
# search_query = get_llm_model_response(
# strategy_name="zhipu_search_rewrite",
# llm_model_name=LLM_MODELS[0],
# template_prompt_name="zhipu_search_rewrite",
# prompt_param_dict={"input": query, "year": datetime.now().strftime("%Y")},
# temperature=TEMPERATURE,
# max_tokens=512
# )
search_query = query
async def search_engine_chat_iterator(
uid: str,
search_query: str,
origin_query: str,
search_engine_name: str,
top_k: int,
history: Optional[List[History]],
model_name: str = LLM_MODELS[1],
prompt_name: str = prompt_name,
kgo_search_type: str = kgo_search_type
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback]
)
docs = await lookup_search_engine(uid,search_query,origin_query,search_engine_name, top_k, split_result=split_result,
kgo_search_type=kgo_search_type)
context = []
# prompt_template = get_prompt_template("search_engine_chat", prompt_name)
# input_msg = History(role="system", content=prompt_template).to_msg_template(False)
# chat_prompt = ChatPromptTemplate.from_messages([input_msg])
# chain = LLMChain(prompt=chat_prompt, llm=model, verbose = True)
# # Begin a task that runs in the background.
# task = asyncio.create_task(wrap_done(
# chain.acall({"context": context, "question": query}),
# callback.done),
# )
index = utils.get_shared_variable(uid)["num"]
if uid and "source_docs" in utils.get_shared_variable(uid):
for knowledge_name in EN_BASE_NAME:
if knowledge_name in utils.get_shared_variable(uid)["source_docs"]:
index += len(utils.get_shared_variable(uid)["source_docs"][knowledge_name])
k = 0
for doc in docs:
context.append(f"""\n\n资料[{index + k + 1}]内容:{doc.page_content} ) \n\n""")
k +=1
del k
# 搜索到的资料
if search_engine_name == "zhipu_search":
source_documents = [
f"""[{index + inum + 1}] [{doc.metadata["title"]}]({doc.metadata["link"]}) \n"""
if doc.metadata.get("link") else
f"""[{index + inum + 1}] [{doc.metadata["title"]}]\n"""
for inum, doc in enumerate(docs)
]
if len(source_documents) == 0: # 没有找到相关资料(不太可能)
# source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
source_documents = []
elif search_engine_name == "duckduckgo":
source_documents = [
f"""[{index + inum + 1}] [{doc.metadata["title"]}]({doc.metadata["url"]}) \n"""
if doc.metadata.get("url") else
f"""[{index + inum + 1}] [{doc.metadata["title"]}]\n"""
for inum, doc in enumerate(docs)
]
if len(source_documents) == 0: # 没有找到相关资料(不太可能)
source_documents.append(f"""未找到相关资料""")
else:
source_documents = [
f"""[{index+inum + 1}] [{doc.metadata["filename"]}]({doc.metadata["source"]}) \n"""
for inum, doc in enumerate(docs)
]
# if len(source_documents) == 1:
# source_documents = [f"""暂未找到相关资料。"""]
if len(source_documents) == 0: # 没有找到相关资料(不太可能)
# source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
source_documents = []
try:
res = utils.get_shared_variable(uid)
res["num"] = index + len(source_documents)
res["source_docs"].extend(source_documents)
# res["END"]="ok"
utils.set_shared_variable(uid, res)
del index
except:
pass
# if stream:
# async for token in callback.aiter():
# # Use server-sent-events to stream the response
# yield json.dumps({"answer": token}, ensure_ascii=False)
yield json.dumps({"answer": context,"docs": source_documents}, ensure_ascii=False)
# else:
# answer = ""
# async for token in callback.aiter():
# answer += token
# yield json.dumps({"answer": answer,
# "docs": source_documents},
# ensure_ascii=False)
# await task
return EventSourceResponse(search_engine_chat_iterator(uid=uid,
search_query=search_query,
origin_query=query,
search_engine_name=search_engine_name,
top_k=top_k,
history=history,
model_name=model_name,
prompt_name=prompt_name,
kgo_search_type=kgo_search_type),
)