592 lines
26 KiB
Python
592 lines
26 KiB
Python
from concurrent.futures import ThreadPoolExecutor
|
||
from datetime import datetime
|
||
import itertools
|
||
import random
|
||
from bs4 import BeautifulSoup
|
||
from selenium import webdriver
|
||
from selenium.webdriver.chrome.service import Service
|
||
from selenium.webdriver.chrome.options import Options
|
||
from selenium.webdriver.common.by import By
|
||
from langchain.utilities.bing_search import BingSearchAPIWrapper
|
||
import requests
|
||
from fake_useragent import UserAgent
|
||
# from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
|
||
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
|
||
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE,
|
||
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
|
||
from fastapi import Body
|
||
from sse_starlette import EventSourceResponse
|
||
from fastapi.concurrency import run_in_threadpool
|
||
from configs.kb_config import CHROME_DIR, EN_BASE_NAME
|
||
from server.agent.tools.duckduckgo_search import DuckduckgoInput
|
||
from server.chat import utils
|
||
from server.chat.ZhipuSearchAPI import ZhipuSearchAPIWrapper
|
||
from server.utils import wrap_done, get_ChatOpenAI
|
||
from server.utils import BaseResponse, get_prompt_template
|
||
from langchain.chains import LLMChain
|
||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||
from typing import AsyncIterable
|
||
import asyncio
|
||
from langchain.prompts.chat import ChatPromptTemplate
|
||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||
from typing import List, Optional, Dict
|
||
from server.chat.utils import History, get_similar_documents
|
||
from langchain.docstore.document import Document
|
||
import json
|
||
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
|
||
from markdownify import markdownify
|
||
from .KgoSearchAPIWrapper import KgoSearchAPIWrapper
|
||
from server.chat.policy_fun_iast import get_llm_model_response
|
||
import re
|
||
from configs.basic_config import *
|
||
|
||
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
|
||
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
|
||
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
|
||
"title": "env info is not found",
|
||
"link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
|
||
search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY,
|
||
bing_search_url=BING_SEARCH_URL)
|
||
return search.results(text, result_len)
|
||
|
||
|
||
# def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
|
||
# search = DuckDuckGoSearchAPIWrapper()
|
||
# return search.results(text, result_len)
|
||
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
|
||
search = DuckduckgoInput()
|
||
return search.results(text, result_len)
|
||
|
||
def kgo_search(text,origin_query, kgo_search_type="1000",**kwargs):
|
||
search = KgoSearchAPIWrapper()
|
||
return search.results(text, kgo_search_type,origin_query=origin_query)
|
||
|
||
|
||
def metaphor_search(
|
||
text: str,
|
||
result_len: int = SEARCH_ENGINE_TOP_K,
|
||
split_result: bool = False,
|
||
chunk_size: int = 500,
|
||
chunk_overlap: int = OVERLAP_SIZE,
|
||
) -> List[Dict]:
|
||
from metaphor_python import Metaphor
|
||
|
||
if not METAPHOR_API_KEY:
|
||
return []
|
||
|
||
client = Metaphor(METAPHOR_API_KEY)
|
||
search = client.search(text, num_results=result_len, use_autoprompt=True)
|
||
contents = search.get_contents().contents
|
||
for x in contents:
|
||
x.extract = markdownify(x.extract)
|
||
|
||
# metaphor 返回的内容都是长文本,需要分词再检索
|
||
if split_result:
|
||
docs = [Document(page_content=x.extract,
|
||
metadata={"link": x.url, "title": x.title})
|
||
for x in contents]
|
||
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
|
||
chunk_size=chunk_size,
|
||
chunk_overlap=chunk_overlap)
|
||
splitted_docs = text_splitter.split_documents(docs)
|
||
|
||
# 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档
|
||
if len(splitted_docs) > result_len:
|
||
normal = NormalizedLevenshtein()
|
||
for x in splitted_docs:
|
||
x.metadata["score"] = normal.similarity(text, x.page_content)
|
||
splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True)
|
||
splitted_docs = splitted_docs[:result_len]
|
||
|
||
docs = [{"snippet": x.page_content,
|
||
"link": x.metadata["link"],
|
||
"title": x.metadata["title"]}
|
||
for x in splitted_docs]
|
||
else:
|
||
docs = [{"snippet": x.extract,
|
||
"link": x.url,
|
||
"title": x.title}
|
||
for x in contents]
|
||
|
||
return docs
|
||
|
||
def zhipu_search(origin_query, **kwargs):
|
||
search = ZhipuSearchAPIWrapper()
|
||
return search.zhipu_search(origin_query)
|
||
|
||
SEARCH_ENGINES = {"bing": bing_search,
|
||
"duckduckgo": duckduckgo_search,
|
||
"metaphor": metaphor_search,
|
||
"kgo": kgo_search,
|
||
"zhipu_search": zhipu_search,
|
||
}
|
||
|
||
|
||
def search_result2docs(search_results):
|
||
docs = []
|
||
for result in search_results:
|
||
# 提取信息
|
||
title = result.get("title", "")
|
||
keywords = result.get("keywords", "")
|
||
snippet = result.get("snippet", "")
|
||
author = result.get("author", "")
|
||
resource_type = result.get("resource_type", "")
|
||
publish_year = result.get("publish_year", "")
|
||
# 创建page_content
|
||
page_content = f"【本资料的标题为:{title}】\n"
|
||
page_content += f"【本资料的关键词为:{keywords}】\n"
|
||
page_content += f"【本资料的摘要为:{snippet}】\n"
|
||
page_content += f"【本资料的作者为:{author}】\n"
|
||
page_content += f"【本资料的发布时间为:{publish_year}】\n"
|
||
page_content += f"【本资料的资源类型为:{resource_type}】\n"
|
||
# 创建metadata
|
||
metadata = {
|
||
"source": result.get("link", ""),
|
||
"filename": title,
|
||
"author": author,
|
||
"keywords": keywords,
|
||
"publish_year": publish_year,
|
||
"resource_type": resource_type,
|
||
}
|
||
doc = Document(page_content=page_content, metadata=metadata)
|
||
docs.append(doc)
|
||
return docs
|
||
|
||
# 定义替换函数
|
||
# def replace_ref(match):
|
||
# num = match.group(1)
|
||
# return f'^[{num}]^'
|
||
|
||
def ZhipuSearch_result2docs(search_results):
|
||
docs = []
|
||
for result in search_results:
|
||
# 提取信息
|
||
title = result.get("title", "")
|
||
cleaned_title = re.sub(r'(发布时间:.*?)', '', title) # 展示的参考资料名称需过滤掉发布时间
|
||
link = result.get("link", "")
|
||
content = result.get("content", "")
|
||
# refer = result.get("refer", "")
|
||
# refer = re.sub(r'ref_(\d+)', replace_ref, refer)
|
||
# icon = result.get("icon", "")
|
||
media = result.get("media", "")
|
||
# 创建 page_content
|
||
page_content = (
|
||
f"【资料:{cleaned_title}\n"
|
||
f"内容如下:{content}\n"
|
||
f"资料链接:{link}\n"
|
||
f"发布媒体为:{media}】\n"
|
||
)
|
||
# 创建metadata
|
||
metadata = {
|
||
"link": link,
|
||
"title": cleaned_title,
|
||
# "refer": refer,
|
||
"content": content,
|
||
# "icon": icon,
|
||
"media": media,
|
||
}
|
||
doc = Document(page_content=page_content, metadata=metadata)
|
||
docs.append(doc)
|
||
logger.info(f"Zhipu搜索召回资料:{doc.metadata}")
|
||
return docs
|
||
|
||
def Duckduckgo_result2docs(search_results):
|
||
docs = []
|
||
for result in search_results:
|
||
# 提取信息
|
||
title = result.get("title", "")
|
||
url = result.get("url", "")
|
||
content = result.get("body", "")
|
||
date = result.get("date", "")
|
||
image = result.get("image", "")
|
||
source = result.get("source", "")
|
||
# 创建page_content
|
||
page_content = f"【本资料的标题为:{title}\n"
|
||
page_content += f"链接为:{url}\n"
|
||
page_content += f"内容为:{content}\n"
|
||
page_content += f"发布时间为:{date}\n"
|
||
page_content += f"封面图为:{image}\n"
|
||
page_content += f"发布媒体为:{source}】\n"
|
||
# 创建metadata
|
||
metadata = {
|
||
"url": url,
|
||
"title": title,
|
||
"date": date,
|
||
"content": content,
|
||
"image": image,
|
||
"source": source,
|
||
}
|
||
doc = Document(page_content=page_content, metadata=metadata)
|
||
docs.append(doc)
|
||
logger.info(f"Duckduckgo搜索召回资料:{doc.metadata}")
|
||
return docs
|
||
|
||
def extract_html(url,proxy):
|
||
try:
|
||
# url = url.replace("https", "http")
|
||
urlss = "http://%(ip)s:%(port)s" % {
|
||
"ip": proxy["sever"],
|
||
"port": proxy["port"],
|
||
}
|
||
proxies = {
|
||
"http": urlss,
|
||
"https": urlss,
|
||
}
|
||
headers = {'User-Agent': UserAgent().random}
|
||
response = requests.get(url,proxies=proxies, timeout=3,headers=headers)
|
||
response.raise_for_status()
|
||
soup = BeautifulSoup(response.content, 'html.parser')
|
||
res = {}
|
||
res ["html"] = soup.prettify()
|
||
res["title"] = soup.title.string if soup.title else ""
|
||
if res["title"]==None:
|
||
res["title"] = ""
|
||
return res
|
||
except requests.exceptions.RequestException as e:
|
||
res = {}
|
||
res ["html"] = ""
|
||
res["title"] = ""
|
||
return res
|
||
|
||
|
||
chrome_options = Options()
|
||
chrome_options.add_argument("--headless") # 无头模式
|
||
chrome_options.add_argument("--disable-gpu") # 禁用 GPU
|
||
chrome_options.add_argument("--ignore-certificate-errors") # 忽略 SSL 证书错误
|
||
chrome_options.add_argument("--no-sandbox") # 解决DevToolsActivePort文件不存在的报错
|
||
chrome_options.add_argument("--disable-dev-shm-usage") # 解决资源限制问题
|
||
chrome_options.binary_location = f"{CHROME_DIR}/chrome-linux64/chrome"#chrome可执行文件
|
||
service = Service(f'{CHROME_DIR}/chromedriver-linux64/chromedriver') # ChromeDriver
|
||
async def lookup_search_engine(
|
||
uid: str,
|
||
search_query: str,
|
||
origin_query: str,
|
||
search_engine_name: str,
|
||
top_k: int = SEARCH_ENGINE_TOP_K,
|
||
split_result: bool = False,
|
||
kgo_search_type: str = "1000"
|
||
):
|
||
search_engine = SEARCH_ENGINES.get(search_engine_name)
|
||
if not search_engine:
|
||
raise ValueError(f"Unsupported search engine: {search_engine_name}")
|
||
|
||
if search_engine_name == "zhipu_search":
|
||
# 对于 zhipu_search,只需要 origin_query
|
||
results = await run_in_threadpool(search_engine, search_query)
|
||
if len(results) > 0:
|
||
try:
|
||
sentences = [doc["title"] for doc in results]
|
||
sentences_page_content = [str(i+1)+":【"+doc["title"]+doc["content"]+"】" for i,doc in enumerate(results)]
|
||
except Exception as e:
|
||
sentences = [doc["source"] for doc in results]
|
||
sentences_page_content = [str(i+1)+":【"+doc["source"]+doc["content"]+"】" for i,doc in enumerate(results)]
|
||
res = get_llm_model_response(
|
||
strategy_name="default_similar",
|
||
llm_model_name=LLM_MODELS[0],
|
||
template_prompt_name="default_similar",
|
||
prompt_param_dict={"input": origin_query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")},
|
||
temperature=0.01,
|
||
max_tokens=512
|
||
)
|
||
try:
|
||
index =[]
|
||
if res == "无":
|
||
index = []
|
||
else:
|
||
index = res.split(",")
|
||
index = [int(i)-1 for i in index]
|
||
results = get_similar_documents(index=index,sentences=sentences,query=origin_query, docs=results, top_k=top_k)
|
||
except Exception as e:
|
||
print(e)
|
||
results = get_similar_documents(index=[],sentences=sentences,query=origin_query, docs=results, top_k=top_k)
|
||
|
||
# 从中选择几个不同的代理
|
||
# 例如,选择 2 个不同的代理
|
||
if len(results)>0:
|
||
try:
|
||
url = f"https://sch.shanchendaili.com/api.html?action=get_ip&key=HU301b55318830279250RjH2&time=10&count={len(results)}&type=txt&province=215&only=0"
|
||
ip_list = requests.get(url=url).json()["list"]
|
||
with ThreadPoolExecutor() as executor:
|
||
html_contents = list(executor.map(extract_html, [result.get("link", "") for result in results],ip_list))
|
||
except Exception as e:
|
||
print(e)
|
||
html_contents = [{"html": "","title":""} for _ in range(len(results))]
|
||
info = {}
|
||
res_list = []
|
||
for result, html_content in zip(results, html_contents):
|
||
info.clear()
|
||
info["content"] = result.get("content", "")
|
||
info["icon"] = "无"
|
||
info["index"] = 0
|
||
info["link"] = result.get("link", "")
|
||
info["media"] = "未知"
|
||
info["refer"] = result.get("positions", [0])[0]
|
||
info["title"] = html_content["title"] if ((result["title"].replace("...","")) in html_content["title"]) else result.get("title", "")
|
||
res_list.append(info.copy())
|
||
results = res_list
|
||
else:
|
||
return []
|
||
if "¥" in uid:
|
||
res =utils.get_shared_variable(uid.replace("¥",""))
|
||
results,titles=utils.remove_docs(res["title"], results)
|
||
res["title"].extend(titles)
|
||
utils.set_shared_variable(uid.replace("¥",""),res)
|
||
else:
|
||
res =utils.get_shared_variable(uid)
|
||
results,titles=utils.remove_docs(res["title"], results)
|
||
res["title"].extend(titles)
|
||
utils.set_shared_variable(uid,res)
|
||
docs = ZhipuSearch_result2docs(results)
|
||
else:
|
||
# 其他搜索引擎需要更多参数
|
||
results = await run_in_threadpool(
|
||
search_engine,
|
||
search_query,
|
||
origin_query=origin_query,
|
||
result_len=top_k,
|
||
split_result=split_result,
|
||
kgo_search_type=kgo_search_type
|
||
)
|
||
docs = search_result2docs(results)
|
||
return docs
|
||
|
||
|
||
async def search_engine_chat(uid: Optional[str]=Body(None, description="userID"),
|
||
query: str = Body(..., description="用户输入", examples=["你好"]),
|
||
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||
history: List[History] = Body([],
|
||
description="历史对话",
|
||
examples=[[
|
||
{"role": "user",
|
||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||
{"role": "assistant",
|
||
"content": "虎头虎脑"}]]
|
||
),
|
||
stream: bool = Body(False, description="流式输出"),
|
||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||
prompt_name: str = Body("default",
|
||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||
split_result: bool = Body(False, description="是否对搜索结果进行拆分(主要用于metaphor搜索引擎)"),
|
||
kgo_search_type: str = Body("1000", description="kgo搜索引擎搜索类型,默认1000"),
|
||
):
|
||
if search_engine_name == "bing" and not BING_SUBSCRIPTION_KEY:
|
||
return BaseResponse(code=404, msg=f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`")
|
||
if search_engine_name not in SEARCH_ENGINES.keys():
|
||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||
|
||
if search_engine_name == "bing" and not BING_SUBSCRIPTION_KEY:
|
||
return BaseResponse(code=404, msg=f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`")
|
||
|
||
history = [History.from_data(h) for h in history]
|
||
user_queries = []
|
||
# 遍历历史消息并收集用户消息
|
||
for message in history:
|
||
if message.role == 'user':
|
||
user_queries.append(message.content)
|
||
if search_engine_name == "kgo":
|
||
metal_search_entity = get_llm_model_response(
|
||
strategy_name="KGO_NER",
|
||
llm_model_name=LLM_MODELS[0],
|
||
template_prompt_name="KGO_NER",
|
||
prompt_param_dict={"input": query, "history": user_queries},
|
||
temperature=TEMPERATURE,
|
||
max_tokens=512
|
||
)
|
||
logger.info(f"KGO_NER体识别结果:{metal_search_entity}")
|
||
# 使用正则表达式来匹配和拆分多个 JSON 对象
|
||
json_strings = re.findall(r'\{.*?\}', metal_search_entity)
|
||
# 初始化搜索查询和类型
|
||
search_query = []
|
||
kgo_search_type = []
|
||
|
||
# 处理每个 JSON 对象
|
||
for json_string in json_strings:
|
||
try:
|
||
# 尝试解析 JSON
|
||
data = json.loads(json_string)
|
||
|
||
# 提取键值对
|
||
for key, value in data.items():
|
||
search_query.append(key)
|
||
if value in [" ", ""]: # 如果类型为空,则使用默认值
|
||
kgo_search_type = "10000"
|
||
else:
|
||
kgo_search_type.extend(value)
|
||
|
||
except json.JSONDecodeError:
|
||
# 处理 JSON 解码错误的情况,继续使用原始查询和默认类型
|
||
search_query = ','.join(search_query.append(query))
|
||
kgo_search_type = ','.join(kgo_search_type.append("1000,1002,1002,1003,1004,1005,1006,1007,1010,1011,1013,1015,3001,4001,4004,4005,6001"))
|
||
# 输出处理结果
|
||
search_query = ','.join(search_query)
|
||
kgo_search_type_list = ','.join(kgo_search_type).split(',')
|
||
search_type_map = {
|
||
"全部": "1000",
|
||
"期刊论文": "1002",
|
||
"学位论文": "1003",
|
||
"会议论文": "1004",
|
||
"政策": "1005",
|
||
"成果": "1006",
|
||
"科技成果": "1006",
|
||
"项目": "1007",
|
||
"报告": "1010",
|
||
"图书": "1011",
|
||
"外文期刊论文": "1013",
|
||
"外文资料": "1013",
|
||
"期刊": "1015",
|
||
"专利": "3001",
|
||
"新闻": "4001",
|
||
"数据集": "4004",
|
||
"视频": "4005",
|
||
"统计数据": "6001"
|
||
}
|
||
# 通过搜索类型映射,找到多个对应的值
|
||
matched_types = []
|
||
|
||
# 确保 search_type_map 不为空
|
||
if not search_type_map:
|
||
raise ValueError("search_type_map 不能为空")
|
||
|
||
for search_type in kgo_search_type_list:
|
||
if search_type in search_type_map:
|
||
matched_types.append(search_type_map[search_type])
|
||
|
||
# 如果没有匹配项,默认使用 "1000"
|
||
if not matched_types:
|
||
matched_types.append("1000")
|
||
|
||
# 处理 search_query 为空的情况
|
||
if search_query == '':
|
||
search_query = query
|
||
|
||
# 移除重复的类型并转换为字符串
|
||
kgo_search_type = ','.join(set(matched_types))
|
||
|
||
logger.info(f"KGO检索关键词:{search_query},KGO检索类型:{kgo_search_type}")
|
||
elif search_engine_name == "zhipu_search":
|
||
# search_query = get_llm_model_response(
|
||
# strategy_name="zhipu_search_rewrite",
|
||
# llm_model_name=LLM_MODELS[0],
|
||
# template_prompt_name="zhipu_search_rewrite",
|
||
# prompt_param_dict={"input": query, "year": datetime.now().strftime("%Y")},
|
||
# temperature=TEMPERATURE,
|
||
# max_tokens=512
|
||
# )
|
||
search_query = query
|
||
async def search_engine_chat_iterator(
|
||
uid: str,
|
||
search_query: str,
|
||
origin_query: str,
|
||
search_engine_name: str,
|
||
top_k: int,
|
||
history: Optional[List[History]],
|
||
model_name: str = LLM_MODELS[0],
|
||
prompt_name: str = prompt_name,
|
||
kgo_search_type: str = kgo_search_type
|
||
) -> AsyncIterable[str]:
|
||
nonlocal max_tokens
|
||
callback = AsyncIteratorCallbackHandler()
|
||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
||
max_tokens = None
|
||
|
||
model = get_ChatOpenAI(
|
||
model_name=model_name,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
callbacks=[callback]
|
||
)
|
||
|
||
docs = await lookup_search_engine(uid,search_query,origin_query,search_engine_name, top_k, split_result=split_result,
|
||
kgo_search_type=kgo_search_type)
|
||
|
||
context = []
|
||
|
||
# prompt_template = get_prompt_template("search_engine_chat", prompt_name)
|
||
# input_msg = History(role="system", content=prompt_template).to_msg_template(False)
|
||
# chat_prompt = ChatPromptTemplate.from_messages([input_msg])
|
||
|
||
# chain = LLMChain(prompt=chat_prompt, llm=model, verbose = True)
|
||
|
||
# # Begin a task that runs in the background.
|
||
# task = asyncio.create_task(wrap_done(
|
||
# chain.acall({"context": context, "question": query}),
|
||
# callback.done),
|
||
# )
|
||
index = utils.get_shared_variable(uid)["num"]
|
||
if uid and "source_docs" in utils.get_shared_variable(uid):
|
||
for knowledge_name in EN_BASE_NAME:
|
||
if knowledge_name in utils.get_shared_variable(uid)["source_docs"]:
|
||
index += len(utils.get_shared_variable(uid)["source_docs"][knowledge_name])
|
||
|
||
k = 0
|
||
for doc in docs:
|
||
context.append(f"""\n\n资料[{index + k + 1}]内容:{doc.page_content} ) \n\n""")
|
||
k +=1
|
||
del k
|
||
# 搜索到的资料
|
||
if search_engine_name == "zhipu_search":
|
||
source_documents = [
|
||
f"""[{index + inum + 1}] [{doc.metadata["title"]}]({doc.metadata["link"]}) \n"""
|
||
if doc.metadata.get("link") else
|
||
f"""[{index + inum + 1}] [{doc.metadata["title"]}]\n"""
|
||
for inum, doc in enumerate(docs)
|
||
]
|
||
if len(source_documents) == 0: # 没有找到相关资料(不太可能)
|
||
# source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
|
||
source_documents = []
|
||
elif search_engine_name == "duckduckgo":
|
||
source_documents = [
|
||
f"""[{index + inum + 1}] [{doc.metadata["title"]}]({doc.metadata["url"]}) \n"""
|
||
if doc.metadata.get("url") else
|
||
f"""[{index + inum + 1}] [{doc.metadata["title"]}]\n"""
|
||
for inum, doc in enumerate(docs)
|
||
]
|
||
if len(source_documents) == 0: # 没有找到相关资料(不太可能)
|
||
source_documents.append(f"""未找到相关资料""")
|
||
else:
|
||
source_documents = [
|
||
f"""[{index+inum + 1}] [{doc.metadata["filename"]}]({doc.metadata["source"]}) \n"""
|
||
for inum, doc in enumerate(docs)
|
||
]
|
||
# if len(source_documents) == 1:
|
||
# source_documents = [f"""暂未找到相关资料。"""]
|
||
if len(source_documents) == 0: # 没有找到相关资料(不太可能)
|
||
# source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
|
||
source_documents = []
|
||
try:
|
||
res = utils.get_shared_variable(uid)
|
||
res["num"] = index + len(source_documents)
|
||
res["source_docs"].extend(source_documents)
|
||
# res["END"]="ok"
|
||
utils.set_shared_variable(uid, res)
|
||
del index
|
||
except:
|
||
pass
|
||
# if stream:
|
||
# async for token in callback.aiter():
|
||
# # Use server-sent-events to stream the response
|
||
# yield json.dumps({"answer": token}, ensure_ascii=False)
|
||
yield json.dumps({"answer": context,"docs": source_documents}, ensure_ascii=False)
|
||
# else:
|
||
# answer = ""
|
||
# async for token in callback.aiter():
|
||
# answer += token
|
||
# yield json.dumps({"answer": answer,
|
||
# "docs": source_documents},
|
||
# ensure_ascii=False)
|
||
# await task
|
||
|
||
return EventSourceResponse(search_engine_chat_iterator(uid=uid,
|
||
search_query=search_query,
|
||
origin_query=query,
|
||
search_engine_name=search_engine_name,
|
||
top_k=top_k,
|
||
history=history,
|
||
model_name=model_name,
|
||
prompt_name=prompt_name,
|
||
kgo_search_type=kgo_search_type),
|
||
)
|