Compare commits

...

2 Commits

17 changed files with 176 additions and 53 deletions

View File

@@ -247,6 +247,7 @@ public class SmartChatController extends BaseController {
}
talkDto.setKnowledgeBaseNameList(knowledgeBaseNameList);
talkDto.setWebSearch(smartChatQueryDto.getWebSearch());
talkDto.setQuery(chatMessages.getContent().replaceAll("\n", ""));
talkDto.setStream(true);

View File

@@ -35,4 +35,7 @@ public class SmartChatQueryDto {
/* 用户请求类型 **/
private Integer chatType;
/* 是否开启联网搜索 **/
private Boolean webSearch;
}

View File

@@ -46,6 +46,10 @@ public class SmartChatSelfDto {
/* 用户请求类型 **/
private Integer chatType;
/* 是否开启联网搜索 **/
@JsonProperty("web_search")
private Boolean webSearch;
public String toJsonString() {
StringBuffer str = new StringBuffer();
extracted(knowledgeBaseNameList, str);
@@ -63,7 +67,7 @@ public class SmartChatSelfDto {
", \"fileNames\":" + fileNameList +
", \"quote\":\"" + ReplaceUtils.replaceHiddenChars(quote) + '\"' +
", \"prompt_name\":\"" + promptName + '\"' +
// ", \"use_model_self_response\":\"" + "False" + '\"' +
", \"web_search\":" + (webSearch != null && webSearch ? "true" : "false") +
'}';
}

View File

@@ -9,16 +9,27 @@
<!-- 文字窗口-->
<div>
<div class="tool-bar">
<div class="label">
<!-- <img src="../assets/images/writing/start.png">
<div>AI写作助手</div>-->
</div>
<div class="label"></div>
<div class="clean" @click="cleanChat">
<img src="../assets/images/writing/brush.png">
<div>清除对话</div>
</div>
</div>
<div class="search-scope" v-if="selectedFile">
<div class="scope-label">检索范围</div>
<div class="scope-option" :class="{ active: searchScope === 'file' }" @click="searchScope = 'file'"
:title="'仅在「' + selectedFile.fileName + '」中检索'">
<span class="scope-icon">📄</span>
<span class="scope-name">{{ selectedFile.fileName }}</span>
</div>
<div class="scope-option" :class="{ active: searchScope === 'kb' }" @click="searchScope = 'kb'"
:title="'在「' + selectedFile.folderName + '」知识库的所有文件中检索'">
<span class="scope-icon">📁</span>
<span class="scope-name">{{ selectedFile.folderName }}</span>
</div>
</div>
<div class="text-box">
<div class="quote-box" v-if="quoteMsg">
<div class="vertical-line"></div>
@@ -34,10 +45,17 @@
@input="handleInput"
@keydown.enter="keyDown"
placeholder="请输入你想提的问题字数不能超过1000字"/>
<div>
<div class="text-box-bottom">
<div class="web-search-toggle" :class="{ active: webSearchEnabled }" @click="webSearchEnabled = !webSearchEnabled"
:title="webSearchEnabled ? '联网搜索已开启,点击关闭' : '开启联网搜索,从互联网获取最新信息'">
<span class="ws-icon">🌐</span>
<span class="ws-text">联网搜索</span>
</div>
<div class="send-btn">
<img v-if="textarea&&!sendStatus" style="width: 38px" src="../assets/images/writing/send-blue.png" @click="send('','0')">
<img v-if="!textarea&&!sendStatus" src="../assets/images/writing/send-gray.png">
<img v-if="sendStatus" src="../assets/images/chat/stopChat.png" @click="handleStop"></img>
<img v-if="sendStatus" src="../assets/images/chat/stopChat.png" @click="handleStop">
</div>
</div>
</div>
@@ -76,6 +94,8 @@ const clearQuote = () => {
}
//const title = inject('aiboxTitle');
const searchScope = ref<'file' | 'kb'>('file');
const webSearchEnabled = ref(false);
const textarea = ref("");
const firstChat = ref(true);
const sendStatus = ref(false);
@@ -198,12 +218,13 @@ const getFetchChatAPIProcess = async (type: string) => {
headers: headers,
signal: controller.signal,
body: JSON.stringify({
fileNames: [selectedFile.value?.embeddingId],
fileNames: searchScope.value === 'file' ? [selectedFile.value?.embeddingId] : [],
conversationId: conversationId.value,
promptName: "default",
knowledgeBaseIdList: [selectedFile.value?.folderId],
chatType: type,
quote: quoteMsg.value
quote: quoteMsg.value,
webSearch: webSearchEnabled.value
}),
}
);
@@ -372,6 +393,7 @@ const loadChatHistory = async () => {
// 监听文件切换,重新加载对话历史
watch(() => selectedFile.value?.fileId, () => {
searchScope.value = 'file';
loadChatHistory();
});
@@ -434,7 +456,7 @@ const handleStop = async () => {
<style lang="less" scoped>
.message-content {
height: calc(100% - 290px);
height: calc(100% - 320px);
overflow-y: auto;
padding: 20px;
@@ -505,28 +527,93 @@ const handleStop = async () => {
}
}
.search-scope {
display: flex;
align-items: center;
padding: 4px 12px;
gap: 6px;
.scope-label {
font-size: 13px;
color: #333;
flex-shrink: 0;
}
.scope-option {
display: flex;
align-items: center;
gap: 5px;
padding: 6px 14px;
border-radius: 14px;
border: 1px solid #E0E0E0;
cursor: pointer;
font-size: 13px;
color: #666;
transition: all 0.2s;
max-width: 45%;
overflow: hidden;
&:hover { border-color: #004EA0; color: #004EA0; }
&.active { border-color: #004EA0; color: #fff; background: #004EA0; }
.scope-icon { font-size: 14px; flex-shrink: 0; }
.scope-name { overflow: hidden; text-overflow: ellipsis; white-space: nowrap; }
}
}
.text-box {
//width: 100%;
height: 190px;
background: #FFFFFF;
border-radius: 8px;
border: 1px solid #D5DDFF;
margin: 12px 20px 12px 20px;
display: flex;
flex-direction: column;
.box-textarea {
outline: none;
border: none;
resize: none;
width: 100%;
height: calc(100% - 54px);
flex: 1;
padding: 16px;
line-height: 24px;
border-radius: 8px;
}
.text-box-bottom {
display: flex;
justify-content: space-between;
align-items: center;
padding: 6px 12px;
.web-search-toggle {
display: flex;
align-items: center;
gap: 5px;
padding: 6px 14px;
border-radius: 14px;
border: 1px solid #E0E0E0;
cursor: pointer;
font-size: 13px;
color: #999;
transition: all 0.2s;
user-select: none;
&:hover { border-color: #10a37f; color: #10a37f; }
&.active { border-color: #10a37f; color: #fff; background: #10a37f; }
.ws-icon { font-size: 14px; }
.ws-text { font-size: 13px; }
}
.send-btn {
img { width: 38px; cursor: pointer; }
}
}
img {
cursor:pointer;
float: right;
margin-right: 16px;
cursor: pointer;
}
.quote-box {

View File

@@ -34,7 +34,7 @@ async def search_engine_iter(query: str , uid: str):
response = await search_engine_chat(uid = uid,
query=query,
search_engine_name="zhipu_search",
model_name=LLM_MODELS[1],
model_name=LLM_MODELS[0],
temperature=TEMPERATURE, # Agent搜索互联网的时候温度设为0.1
history=[],
top_k=VECTOR_SEARCH_TOP_K,

View File

@@ -71,10 +71,7 @@ class ZhipuSearchAPIWrapper:
logging.info(f"Zhipu检索内容:{search_query}")
url = "http://ywk3hvt4d:01Jp2V1tR9PdTsYSz919779Rb9_@134.122.191.214/search"
if "天气" in search_query:
engines = "google"
else:
engines = "baidu"
engines = "duckduckgo,bing"
data = {
"format":"json",
"q":search_query,

View File

@@ -62,7 +62,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
if prompt_name == "Search Summary":
model = get_ChatOpenAI(
model_name=LLM_MODELS[1],
model_name=LLM_MODELS[0],
temperature=temperature,
max_tokens=max_tokens,
callbacks=callbacks,

View File

@@ -46,7 +46,7 @@ async def chat_comparison_test(
executor.submit(
get_llm_model_response,
strategy_name="query rewrite",
llm_model_name=LLM_MODELS[1],
llm_model_name=LLM_MODELS[0],
template_prompt_name="extract_key_points",
prompt_param_dict={"time": datetime.now().strftime("%Y%m%d"), "context": context, "content": content},
temperature=0.01,

View File

@@ -44,7 +44,7 @@ async def gen_abstract(
try:
article_abstract = get_llm_model_response(
strategy_name="gen_abstract",
llm_model_name=LLM_MODELS[1],
llm_model_name=LLM_MODELS[0],
template_prompt_name="gen_abstract",
prompt_param_dict={
"context": context_summary, # 使用摘要或原文

View File

@@ -44,7 +44,7 @@ async def gen_conclusion(
try:
article_conclusion = get_llm_model_response(
strategy_name="gen_conclusion",
llm_model_name=LLM_MODELS[1],
llm_model_name=LLM_MODELS[0],
template_prompt_name="gen_conclusion",
prompt_param_dict={
"context": context_summary, # 使用摘要或原文

View File

@@ -43,7 +43,7 @@ async def gen_keywords(
try:
article_keywords = get_llm_model_response(
strategy_name="gen_keywords",
llm_model_name=LLM_MODELS[1],
llm_model_name=LLM_MODELS[0],
template_prompt_name="gen_keywords",
prompt_param_dict={
"context": context_summary, # 使用摘要或原文

View File

@@ -44,7 +44,7 @@ async def gen_paragraph(
try:
article_paragraph = get_llm_model_response(
strategy_name="gen_paragraph",
llm_model_name=LLM_MODELS[1],
llm_model_name=LLM_MODELS[0],
template_prompt_name="gen_paragraph",
prompt_param_dict={
"context": context_summary, # 使用摘要或原文

View File

@@ -19,7 +19,7 @@ async def gen_title(
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"}]]
),
model_name: Optional[str] = Body(LLM_MODELS[1], description="LLM 模型名称。"),
model_name: Optional[str] = Body(LLM_MODELS[0], description="LLM 模型名称。"),
):
"""
根据一轮对话历史生成简洁标题\n
@@ -32,7 +32,7 @@ async def gen_title(
if model_name == "R1-70B":
model_name = DEEPSEEK_MODELS[1]
elif model_name == "QIANWEN":
model_name = LLM_MODELS[1]
model_name = LLM_MODELS[0]
else:
model_name = model_name

View File

@@ -138,7 +138,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
if prompt_name == "policy_chat":
model_name = LLM_MODELS[1]
model_name = LLM_MODELS[0]
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,

View File

@@ -364,7 +364,7 @@ async def search_engine_chat(uid: Optional[str]=Body(None, description="userID")
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[1], description="LLM 模型名称。"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",
@@ -483,7 +483,7 @@ async def search_engine_chat(uid: Optional[str]=Body(None, description="userID")
search_engine_name: str,
top_k: int,
history: Optional[List[History]],
model_name: str = LLM_MODELS[1],
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
kgo_search_type: str = kgo_search_type
) -> AsyncIterable[str]:

View File

@@ -50,6 +50,7 @@ async def self_kb_chat(
"content": "虎头虎脑"}]]
),
stream: bool = Body(True, description="流式输出"),
web_search: bool = Body(False, description="是否开启联网搜索"),
):
"""
个人知识库对话api\n
@@ -72,7 +73,7 @@ async def self_kb_chat(
async def knowledge_base_chat_iterator(
query: str,
model_name: str = LLM_MODELS[0],
model_name1: str = LLM_MODELS[1],
model_name1: str = LLM_MODELS[0],
prompt_name: str = "self_default",
) -> AsyncIterable[str]:
nonlocal fileNames, history
@@ -149,7 +150,29 @@ async def self_kb_chat(
except Exception as e:
logger.error(f"个人知识库问答路由错误: {self_kb_route}", exc_info=True)
docs = []
logger.info(f"个人知识库问答source_documents: {docs}")
logger.info(f"个人知识库问答source_documents: {len(docs)}")
# 联网搜索
web_search_context = ""
web_search_results = [] # 保存搜索结果供后面引用
if web_search:
try:
from server.chat.ZhipuSearchAPI import ZhipuSearchAPIWrapper
searcher = ZhipuSearchAPIWrapper()
web_results = searcher.zhipu_search(search_query)
web_search_results = web_results[:5] if web_results else []
if web_results:
web_parts = []
for i, r in enumerate(web_results[:5], 1):
title = r.get("title", "")
content = r.get("content", "")[:300]
url = r.get("url", "")
web_parts.append(f"[{i}] {title}\n{content}\n来源: {url}")
web_search_context = "\n\n【联网搜索结果】\n" + "\n\n".join(web_parts)
logger.info(f"联网搜索获取到 {len(web_results)} 条结果")
except Exception as e:
logger.error(f"联网搜索失败: {e}")
# if SELF_USE_RERANKER:
# reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
# print("-----------------model path------------------")
@@ -184,25 +207,28 @@ async def self_kb_chat(
if '0' in self_kb_route:
context = "\n".join([doc.page_content for doc in docs]).strip("xa0")
logger.info(f"个人知识库问答 context 长度:{len(context)}")
# context_70 = context if len(context)<30000 else TextRank(context,num_sentences=70)
context = context[:40000] if len(context)>40000 else context
logger.info(f"截取后个人知识库问答 context 长度:{len(context)}")
context = context[:30000] if len(context)>30000 else context
if web_search_context:
context += web_search_context
logger.info(f"最终 context 长度:{len(context)}")
if history:
history = history if len(history) < 20000 else TextRank(history,num_sentences=1)
# logger.info(f"个人知识库问答 context 长度超过 30000使用 TextRank 算法进行降维得到 context 长度:{len(context)}")
chain = LLMChain(prompt=chat_prompt, llm=model1, verbose=True)
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query, "history": history, "quote": quote, "fileName":fileNames}),
callback.done),
)
elif '1' in self_kb_route:
# 联网搜索结果作为额外文档加入
if web_search_context:
from langchain.docstore.document import Document as LCDocument
docs.append(LCDocument(page_content=web_search_context, metadata={"source": "web_search"}))
chain = load_qa_chain(
model,
chain_type="stuff",
prompt=chat_prompt,
verbose=True
)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"input_documents": docs, "question": query, "history": history, "quote": quote, "fileName":fileNames}),
callback.done),
@@ -235,14 +261,18 @@ async def self_kb_chat(
yield json.dumps(response, ensure_ascii=False)
await task
source_documents = []
if len(docs) == 0: # 没有找到相关文档
if len(docs) == 0 and not web_search_context:
source_documents.append(f"""暂未从本篇文献中找到答案,该回答为大模型自身能力解答!""")
else:
# 去除文件扩展名
# fileNames_without_ext = [name.rsplit('.', 1)[0] for name in fileNames]
# 连接文件名(如果有多个文件名)
# joined_fileNames = ', '.join(fileNames_without_ext)
if len(docs) > 0:
source_documents.append(f"""[{len(source_documents) + 1}] [{docs[0].metadata.get("source")}]()\n""")
# 联网搜索结果链接
if web_search_results:
for r in web_search_results:
title = r.get("title", "").replace("\n", "")
url = r.get("url", "")
if title and url:
source_documents.append(f"""[{len(source_documents) + 1}] [{title}]({url})\n""")
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
return EventSourceResponse(knowledge_base_chat_iterator(query))

View File

@@ -116,15 +116,16 @@ def search_self_docs(
if top_k > 50:
data = docs
else:
# Milvus 已通过 expr 过滤了 source无需再按 fileNames 二次过滤
# Milvus 的 source 可能是原始文件名,而 fileNames 是 embeddingId格式不一致
data = [
DocumentWithVSId(
**{k: v for k, v in x[0].dict().items() if k != 'page_content'}, # 排除原有的 page_content
**{k: v for k, v in x[0].dict().items() if k != 'page_content'},
score=x[1],
id=x[0].metadata.get("id"),
page_content=f"【^[{index +1}]^ {x[0].page_content}" # 拼接索引和page_content
page_content=f"【^[{index +1}]^ {x[0].page_content}"
)
for index, x in enumerate(docs) # 使用enumerate来获取索引
if x[0].metadata.get("source") in flat_fileNames
for index, x in enumerate(docs)
]
else:
logger.warning(f"未找到知识库服务: {knowledge_base_name}")