[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
0
langchain-chat/webui_pages/__init__.py
Normal file
0
langchain-chat/webui_pages/__init__.py
Normal file
0
langchain-chat/webui_pages/dialogue/__init__.py
Normal file
0
langchain-chat/webui_pages/dialogue/__init__.py
Normal file
620
langchain-chat/webui_pages/dialogue/dialogue.py
Normal file
620
langchain-chat/webui_pages/dialogue/dialogue.py
Normal file
@@ -0,0 +1,620 @@
|
||||
import streamlit as st
|
||||
from webui_pages.utils import *
|
||||
from streamlit_chatbox import *
|
||||
from streamlit_modal import Modal
|
||||
from datetime import datetime
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, LLM_MODELS,
|
||||
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
|
||||
from server.knowledge_base.utils import LOADER_DICT
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
|
||||
chat_box = ChatBox(
|
||||
assistant_avatar=os.path.join(
|
||||
"img",
|
||||
"chatchat_icon_blue_square_v2.png"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
|
||||
'''
|
||||
返回消息历史。
|
||||
content_in_expander控制是否返回expander元素中的内容,一般导出的时候可以选上,传入LLM的history不需要
|
||||
'''
|
||||
|
||||
def filter(msg):
|
||||
content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]]
|
||||
if not content_in_expander:
|
||||
content = [x for x in content if not x._in_expander]
|
||||
content = [x.content for x in content]
|
||||
|
||||
return {
|
||||
"role": msg["role"],
|
||||
"content": "\n\n".join(content),
|
||||
}
|
||||
|
||||
return chat_box.filter_history(history_len=history_len, filter=filter)
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def upload_temp_docs(files, _api: ApiRequest) -> str:
|
||||
'''
|
||||
将文件上传到临时目录,用于文件对话
|
||||
返回临时向量库ID
|
||||
'''
|
||||
return _api.upload_temp_docs(files).get("data", {}).get("id")
|
||||
|
||||
|
||||
def parse_command(text: str, modal: Modal) -> bool:
|
||||
'''
|
||||
检查用户是否输入了自定义命令,当前支持:
|
||||
/new {session_name}。如果未提供名称,默认为“会话X”
|
||||
/del {session_name}。如果未提供名称,在会话数量>1的情况下,删除当前会话。
|
||||
/clear {session_name}。如果未提供名称,默认清除当前会话
|
||||
/help。查看命令帮助
|
||||
返回值:输入的是命令返回True,否则返回False
|
||||
'''
|
||||
if m := re.match(r"/([^\s]+)\s*(.*)", text):
|
||||
cmd, name = m.groups()
|
||||
name = name.strip()
|
||||
conv_names = chat_box.get_chat_names()
|
||||
if cmd == "help":
|
||||
modal.open()
|
||||
elif cmd == "new":
|
||||
if not name:
|
||||
i = 1
|
||||
while True:
|
||||
name = f"会话{i}"
|
||||
if name not in conv_names:
|
||||
break
|
||||
i += 1
|
||||
if name in st.session_state["conversation_ids"]:
|
||||
st.error(f"该会话名称 “{name}” 已存在")
|
||||
time.sleep(1)
|
||||
else:
|
||||
st.session_state["conversation_ids"][name] = uuid.uuid4().hex
|
||||
st.session_state["cur_conv_name"] = name
|
||||
elif cmd == "del":
|
||||
name = name or st.session_state.get("cur_conv_name")
|
||||
if len(conv_names) == 1:
|
||||
st.error("这是最后一个会话,无法删除")
|
||||
time.sleep(1)
|
||||
elif not name or name not in st.session_state["conversation_ids"]:
|
||||
st.error(f"无效的会话名称:“{name}”")
|
||||
time.sleep(1)
|
||||
else:
|
||||
st.session_state["conversation_ids"].pop(name, None)
|
||||
chat_box.del_chat_name(name)
|
||||
st.session_state["cur_conv_name"] = ""
|
||||
elif cmd == "clear":
|
||||
chat_box.reset_history(name=name or None)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
st.session_state.setdefault("conversation_ids", {})
|
||||
st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex)
|
||||
st.session_state.setdefault("file_chat_id", None)
|
||||
default_model = api.get_default_llm_model()[0]
|
||||
|
||||
if not chat_box.chat_inited:
|
||||
st.toast(
|
||||
f"欢迎使用 知冶大模型! \n\n"
|
||||
f"当前运行的模型`{default_model}`, 您可以开始提问了."
|
||||
)
|
||||
chat_box.init_session()
|
||||
|
||||
# 弹出自定义命令帮助信息
|
||||
modal = Modal("自定义命令", key="cmd_help", max_width="500")
|
||||
if modal.is_open():
|
||||
with modal.container():
|
||||
cmds = [x for x in parse_command.__doc__.split("\n") if x.strip().startswith("/")]
|
||||
st.write("\n\n".join(cmds))
|
||||
|
||||
with st.sidebar:
|
||||
# 多会话
|
||||
conv_names = list(st.session_state["conversation_ids"].keys())
|
||||
index = 0
|
||||
if st.session_state.get("cur_conv_name") in conv_names:
|
||||
index = conv_names.index(st.session_state.get("cur_conv_name"))
|
||||
conversation_name = st.selectbox("当前会话:", conv_names, index=index)
|
||||
chat_box.use_chat_name(conversation_name)
|
||||
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
||||
|
||||
def on_mode_change():
|
||||
mode = st.session_state.dialogue_mode
|
||||
text = f"已切换到 {mode} 模式。"
|
||||
if mode == "知识库问答-旧" or mode == "知识库问答":
|
||||
cur_kb = st.session_state.get("selected_kb")
|
||||
if cur_kb:
|
||||
text = f"{text} 当前知识库: `{cur_kb}`。"
|
||||
st.toast(text)
|
||||
|
||||
dialogue_modes = ["LLM 对话",
|
||||
"知识库问答",
|
||||
"知识库问答-旧",
|
||||
"文件对话",
|
||||
"搜索引擎问答",
|
||||
"自定义Agent问答",
|
||||
"翻译",
|
||||
"智能大纲生成",
|
||||
"智能大纲补全",
|
||||
"个人知识库问答"
|
||||
]
|
||||
dialogue_mode = st.selectbox("请选择对话模式:",
|
||||
dialogue_modes,
|
||||
index=0,
|
||||
on_change=on_mode_change,
|
||||
key="dialogue_mode",
|
||||
)
|
||||
|
||||
def on_llm_change():
|
||||
if llm_model:
|
||||
config = api.get_model_config(llm_model)
|
||||
if not config.get("online_api"): # 只有本地model_worker可以切换模型
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
st.session_state["cur_llm_model"] = st.session_state.llm_model
|
||||
|
||||
def llm_model_format_func(x):
|
||||
if x in running_models:
|
||||
return f"{x} (Running)"
|
||||
return x
|
||||
|
||||
running_models = list(api.list_running_models())
|
||||
available_models = []
|
||||
config_models = api.list_config_models()
|
||||
if not is_lite:
|
||||
for k, v in config_models.get("local", {}).items():
|
||||
if (v.get("model_path_exists")
|
||||
and k not in running_models):
|
||||
available_models.append(k)
|
||||
for k, v in config_models.get("online", {}).items():
|
||||
if not v.get("provider") and k not in running_models and k in LLM_MODELS:
|
||||
available_models.append(k)
|
||||
llm_models = running_models + available_models
|
||||
cur_llm_model = st.session_state.get("cur_llm_model", default_model)
|
||||
if cur_llm_model in llm_models:
|
||||
index = llm_models.index(cur_llm_model)
|
||||
else:
|
||||
index = 0
|
||||
llm_model = st.selectbox("选择LLM模型:",
|
||||
llm_models,
|
||||
index,
|
||||
format_func=llm_model_format_func,
|
||||
on_change=on_llm_change,
|
||||
key="llm_model",
|
||||
)
|
||||
if (st.session_state.get("prev_llm_model") != llm_model
|
||||
and not is_lite
|
||||
and not llm_model in config_models.get("online", {})
|
||||
and not llm_model in config_models.get("langchain", {})
|
||||
and llm_model not in running_models):
|
||||
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
||||
prev_model = st.session_state.get("prev_llm_model")
|
||||
r = api.change_llm_model(prev_model, llm_model)
|
||||
if msg := check_error_msg(r):
|
||||
st.error(msg)
|
||||
elif msg := check_success_msg(r):
|
||||
st.success(msg)
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
|
||||
index_prompt = {
|
||||
"LLM 对话": "llm_chat",
|
||||
"翻译":"llm_chat",
|
||||
"智能大纲生成": "llm_chat",
|
||||
"智能大纲补全": "llm_chat",
|
||||
"自定义Agent问答": "agent_chat",
|
||||
"搜索引擎问答": "search_engine_chat",
|
||||
"知识库问答": "knowledge_base_chat",
|
||||
"知识库问答-旧": "knowledge_base_chat",
|
||||
"文件对话": "knowledge_base_chat",
|
||||
"个人知识库问答": "knowledge_base_chat",
|
||||
}
|
||||
# prompt_templates_kb_list = ['iast_policy_chat'] + list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys())
|
||||
prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys())
|
||||
prompt_template_name = prompt_templates_kb_list[0]
|
||||
if "prompt_template_select" not in st.session_state:
|
||||
st.session_state.prompt_template_select = prompt_templates_kb_list[0]
|
||||
|
||||
def prompt_change():
|
||||
text = f"已切换为 {prompt_template_name} 模板。"
|
||||
st.toast(text)
|
||||
|
||||
prompt_template_select = st.selectbox(
|
||||
"请选择Prompt模板:",
|
||||
prompt_templates_kb_list,
|
||||
index=0,
|
||||
on_change=prompt_change,
|
||||
key="prompt_template_select",
|
||||
)
|
||||
prompt_template_name = st.session_state.prompt_template_select
|
||||
temperature = st.slider("Temperature:", 0.0, 2.0, TEMPERATURE, 0.05)
|
||||
history_len = st.number_input("历史对话轮数:", 0, 20, HISTORY_LEN)
|
||||
|
||||
def on_kb_change():
|
||||
st.toast(f"已加载知识库: {st.session_state.selected_kb}")
|
||||
|
||||
if dialogue_mode in ["知识库问答-旧", "知识库问答"]:
|
||||
with st.expander("知识库配置", True):
|
||||
kb_list = api.list_knowledge_bases()
|
||||
index = 0
|
||||
if DEFAULT_KNOWLEDGE_BASE in kb_list:
|
||||
index = kb_list.index(DEFAULT_KNOWLEDGE_BASE)
|
||||
selected_kb = st.selectbox(
|
||||
"请选择知识库:",
|
||||
kb_list,
|
||||
index=index,
|
||||
on_change=on_kb_change,
|
||||
key="selected_kb",
|
||||
)
|
||||
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
||||
|
||||
elif dialogue_mode in ["个人知识库问答"]:
|
||||
with st.expander("知识库配置", expanded=True):
|
||||
# 获取知识库列表
|
||||
kb_list = api.list_knowledge_bases()
|
||||
|
||||
# 设置默认选中的知识库索引
|
||||
index = 0
|
||||
if DEFAULT_KNOWLEDGE_BASE in kb_list:
|
||||
index = kb_list.index(DEFAULT_KNOWLEDGE_BASE)
|
||||
|
||||
# 创建知识库选择框
|
||||
selected_kb = st.selectbox(
|
||||
"请选择知识库:",
|
||||
kb_list,
|
||||
index=index,
|
||||
on_change=on_kb_change,
|
||||
key="selected_kb",
|
||||
)
|
||||
|
||||
# 根据选择的知识库获取文件列表
|
||||
file_list = api.list_kb_docs(selected_kb)
|
||||
|
||||
# 设置默认选中的文件索引
|
||||
f_index = 0
|
||||
if file_list:
|
||||
f_index = 0 # 或根据需要设置其他逻辑
|
||||
|
||||
# 创建文件多选框
|
||||
selected_files = st.multiselect(
|
||||
"请选择文件:",
|
||||
file_list,
|
||||
default=st.session_state.get("selected_files", []),
|
||||
on_change=lambda: st.session_state.setdefault("selected_files", selected_files),
|
||||
key="selected_files",
|
||||
)
|
||||
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
||||
|
||||
## Bge 模型会超过1
|
||||
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
||||
elif dialogue_mode == "文件对话":
|
||||
with st.expander("文件对话配置", True):
|
||||
files = st.file_uploader("上传知识文件:",
|
||||
[i for ls in LOADER_DICT.values() for i in ls],
|
||||
accept_multiple_files=True,
|
||||
)
|
||||
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
||||
|
||||
## Bge 模型会超过1
|
||||
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
||||
if st.button("开始上传", disabled=len(files) == 0):
|
||||
st.session_state["file_chat_id"] = upload_temp_docs(files, api)
|
||||
elif dialogue_mode == "搜索引擎问答":
|
||||
search_engine_list = api.list_search_engines()
|
||||
if DEFAULT_SEARCH_ENGINE in search_engine_list:
|
||||
index = search_engine_list.index(DEFAULT_SEARCH_ENGINE)
|
||||
else:
|
||||
index = search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0
|
||||
with st.expander("搜索引擎配置", True):
|
||||
search_engine = st.selectbox(
|
||||
label="请选择搜索引擎",
|
||||
options=search_engine_list,
|
||||
index=index,
|
||||
)
|
||||
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)
|
||||
|
||||
# Display chat messages from history on app rerun
|
||||
chat_box.output_messages()
|
||||
|
||||
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
|
||||
|
||||
def on_feedback(
|
||||
feedback,
|
||||
message_id: str = "",
|
||||
history_index: int = -1,
|
||||
):
|
||||
reason = feedback["text"]
|
||||
score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
|
||||
api.chat_feedback(message_id=message_id,
|
||||
score=score_int,
|
||||
reason=reason)
|
||||
st.session_state["need_rerun"] = True
|
||||
|
||||
feedback_kwargs = {
|
||||
"feedback_type": "thumbs",
|
||||
"optional_text_label": "欢迎反馈您打分的理由",
|
||||
}
|
||||
|
||||
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
|
||||
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
|
||||
st.rerun()
|
||||
else:
|
||||
history = get_messages_history(history_len)
|
||||
chat_box.user_say(prompt)
|
||||
if dialogue_mode == "LLM 对话":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
message_id = str(uuid.uuid1())+"q"
|
||||
r = api.chat_chat(prompt,
|
||||
history=history,
|
||||
conversation_id=conversation_id,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
break
|
||||
text += t.get("text", "")
|
||||
chat_box.update_msg(text)
|
||||
message_id = t.get("message_id", "")
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
}
|
||||
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
|
||||
chat_box.show_feedback(**feedback_kwargs,
|
||||
key=str(uuid.uuid1())+"q",
|
||||
on_submit=on_feedback,
|
||||
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
||||
elif dialogue_mode == "翻译":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
message_id = str(uuid.uuid1())+"q"
|
||||
r = api.chat_translate(
|
||||
prompt,
|
||||
conversation_id=conversation_id,
|
||||
to_lang="zh-cn"
|
||||
)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
break
|
||||
text += t.get("text", "")
|
||||
chat_box.update_msg(text)
|
||||
message_id = t.get("message_id", "")
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
}
|
||||
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
|
||||
chat_box.show_feedback(**feedback_kwargs,
|
||||
key=str(uuid.uuid1())+"q",
|
||||
on_submit=on_feedback,
|
||||
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
||||
elif dialogue_mode == "智能大纲生成":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
message_id = str(uuid.uuid1())+"q"
|
||||
r = api.chat_outlines(prompt,
|
||||
history=history,
|
||||
conversation_id=conversation_id,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
break
|
||||
text += t.get("text", "")
|
||||
chat_box.update_msg(text)
|
||||
message_id = t.get("message_id", "")
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
}
|
||||
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
|
||||
chat_box.show_feedback(**feedback_kwargs,
|
||||
key=str(uuid.uuid1())+"q",
|
||||
on_submit=on_feedback,
|
||||
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
||||
elif dialogue_mode == "智能大纲补全":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
message_id = str(uuid.uuid1())+"q"
|
||||
r = api.finsh_outlines(prompt,
|
||||
history=history,
|
||||
conversation_id=conversation_id,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
break
|
||||
text += t.get("text", "")
|
||||
chat_box.update_msg(text)
|
||||
message_id = t.get("message_id", "")
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
}
|
||||
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
|
||||
chat_box.show_feedback(**feedback_kwargs,
|
||||
key=str(uuid.uuid1())+"q",
|
||||
on_submit=on_feedback,
|
||||
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
||||
elif dialogue_mode == "自定义Agent问答":
|
||||
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
|
||||
chat_box.ai_say([
|
||||
f"正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!</span>\n\n\n",
|
||||
Markdown("...", in_expander=True, title="思考过程", state="complete"),
|
||||
|
||||
])
|
||||
else:
|
||||
chat_box.ai_say([
|
||||
f"正在思考...",
|
||||
Markdown("...", in_expander=True, title="思考过程", state="complete"),
|
||||
|
||||
])
|
||||
text = ""
|
||||
ans = ""
|
||||
for d in api.agent_chat(prompt,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature,
|
||||
):
|
||||
try:
|
||||
d = json.loads(d)
|
||||
except:
|
||||
pass
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
if chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=1)
|
||||
if chunk := d.get("final_answer"):
|
||||
ans += chunk
|
||||
chat_box.update_msg(ans, element_index=0)
|
||||
if chunk := d.get("tools"):
|
||||
text += "\n\n".join(d.get("tools", []))
|
||||
chat_box.update_msg(text, element_index=1)
|
||||
chat_box.update_msg(ans, element_index=0, streaming=False)
|
||||
chat_box.update_msg(text, element_index=1, streaming=False)
|
||||
elif dialogue_mode in ["知识库问答-旧", "知识库问答"]:
|
||||
chat_box.ai_say([
|
||||
f"正在查询知识库 `{selected_kb}` ...",
|
||||
Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"),
|
||||
])
|
||||
text = ""
|
||||
if dialogue_mode == "知识库问答-旧":
|
||||
for d in api.knowledge_base_chat_old(prompt,
|
||||
knowledge_base_name=selected_kb,
|
||||
top_k=kb_top_k,
|
||||
score_threshold=score_threshold,
|
||||
conversation_id=conversation_id,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
elif chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
elif dialogue_mode == "知识库问答":
|
||||
# print("knowledge_base_name_list:",selected_kb)
|
||||
for d in api.knowledge_base_chat(prompt,
|
||||
knowledge_base_name_list=[selected_kb],
|
||||
top_k=kb_top_k,
|
||||
score_threshold=score_threshold,
|
||||
conversation_id=conversation_id,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
elif chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||
|
||||
elif dialogue_mode == "个人知识库问答":
|
||||
chat_box.ai_say([
|
||||
f"正在查询知识库 `{selected_kb}`中的文件: `{selected_files}` ...",
|
||||
Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"),
|
||||
])
|
||||
text = ""
|
||||
# print("knowledge_base_name_list:",selected_kb)
|
||||
for d in api.self_kb_chat(
|
||||
prompt,
|
||||
fileNames=[selected_files],
|
||||
knowledge_base_name_list=[selected_kb],
|
||||
history=history
|
||||
):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
elif chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
|
||||
elif dialogue_mode == "文件对话":
|
||||
if st.session_state["file_chat_id"] is None:
|
||||
st.error("请先上传文件再进行对话")
|
||||
st.stop()
|
||||
chat_box.ai_say([
|
||||
f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
|
||||
Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
|
||||
])
|
||||
text = ""
|
||||
for d in api.file_chat(prompt,
|
||||
knowledge_id=st.session_state["file_chat_id"],
|
||||
top_k=kb_top_k,
|
||||
score_threshold=score_threshold,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
elif chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||
elif dialogue_mode == "搜索引擎问答":
|
||||
chat_box.ai_say([
|
||||
f"正在执行 `{search_engine}` 搜索...",
|
||||
Markdown("...", in_expander=True, title="网络搜索结果", state="complete"),
|
||||
])
|
||||
text = ""
|
||||
for d in api.search_engine_chat(prompt,
|
||||
search_engine_name=search_engine,
|
||||
top_k=se_top_k,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature,
|
||||
split_result=se_top_k > 1):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
elif chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||
|
||||
if st.session_state.get("need_rerun"):
|
||||
st.session_state["need_rerun"] = False
|
||||
st.rerun()
|
||||
|
||||
now = datetime.now()
|
||||
with st.sidebar:
|
||||
|
||||
cols = st.columns(2)
|
||||
export_btn = cols[0]
|
||||
if cols[1].button(
|
||||
"清空对话",
|
||||
use_container_width=True,
|
||||
):
|
||||
chat_box.reset_history()
|
||||
st.rerun()
|
||||
|
||||
export_btn.download_button(
|
||||
"导出记录",
|
||||
"".join(chat_box.export2md()),
|
||||
file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md",
|
||||
mime="text/markdown",
|
||||
use_container_width=True,
|
||||
)
|
||||
1
langchain-chat/webui_pages/knowledge_base/__init__.py
Normal file
1
langchain-chat/webui_pages/knowledge_base/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# 知识库页面模块
|
||||
367
langchain-chat/webui_pages/knowledge_base/knowledge_base.py
Normal file
367
langchain-chat/webui_pages/knowledge_base/knowledge_base.py
Normal file
@@ -0,0 +1,367 @@
|
||||
import streamlit as st
|
||||
from webui_pages.utils import *
|
||||
from st_aggrid import AgGrid, JsCode
|
||||
from st_aggrid.grid_options_builder import GridOptionsBuilder
|
||||
import pandas as pd
|
||||
from server.knowledge_base.utils import get_file_path, LOADER_DICT
|
||||
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
|
||||
from typing import Literal, Dict, Tuple
|
||||
from configs import (kbs_config,
|
||||
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
|
||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
||||
from server.utils import list_embed_models, list_online_embed_models
|
||||
import os
|
||||
|
||||
|
||||
cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""")
|
||||
|
||||
|
||||
def config_aggrid(
|
||||
df: pd.DataFrame,
|
||||
columns: Dict[Tuple[str, str], Dict] = {},
|
||||
selection_mode: Literal["single", "multiple", "disabled"] = "single",
|
||||
use_checkbox: bool = False,
|
||||
) -> GridOptionsBuilder:
|
||||
gb = GridOptionsBuilder.from_dataframe(df)
|
||||
gb.configure_column("No", width=40)
|
||||
for (col, header), kw in columns.items():
|
||||
gb.configure_column(col, header, wrapHeaderText=True, **kw)
|
||||
gb.configure_selection(
|
||||
selection_mode=selection_mode,
|
||||
use_checkbox=use_checkbox,
|
||||
pre_selected_rows=st.session_state.get("selected_rows", [0]),
|
||||
)
|
||||
gb.configure_pagination(
|
||||
enabled=True,
|
||||
paginationAutoPageSize=False,
|
||||
paginationPageSize=10
|
||||
)
|
||||
return gb
|
||||
|
||||
|
||||
def file_exists(kb: str, selected_rows: List) -> Tuple[str, str]:
|
||||
"""
|
||||
check whether a doc file exists in local knowledge base folder.
|
||||
return the file's name and path if it exists.
|
||||
"""
|
||||
if selected_rows:
|
||||
file_name = selected_rows[0]["file_name"]
|
||||
file_path = get_file_path(kb, file_name)
|
||||
if os.path.isfile(file_path):
|
||||
return file_name, file_path
|
||||
return "", ""
|
||||
|
||||
|
||||
def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
||||
try:
|
||||
kb_list = {x["kb_name"]: x for x in get_kb_details()}
|
||||
except Exception as e:
|
||||
st.error(
|
||||
"获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。")
|
||||
st.stop()
|
||||
kb_names = list(kb_list.keys())
|
||||
|
||||
if "selected_kb_name" in st.session_state and st.session_state["selected_kb_name"] in kb_names:
|
||||
selected_kb_index = kb_names.index(st.session_state["selected_kb_name"])
|
||||
else:
|
||||
selected_kb_index = kb_names.index("samples")
|
||||
|
||||
if "selected_kb_info" not in st.session_state:
|
||||
st.session_state["selected_kb_info"] = ""
|
||||
|
||||
def format_selected_kb(kb_name: str) -> str:
|
||||
if kb := kb_list.get(kb_name):
|
||||
return f"{kb_name} ({kb['vs_type']} @ {kb['embed_model']})"
|
||||
else:
|
||||
return kb_name
|
||||
|
||||
selected_kb = st.selectbox(
|
||||
"请选择或新建知识库:",
|
||||
kb_names + ["新建知识库"],
|
||||
format_func=format_selected_kb,
|
||||
index=selected_kb_index
|
||||
)
|
||||
|
||||
if selected_kb == "新建知识库":
|
||||
with st.form("新建知识库"):
|
||||
|
||||
kb_name = st.text_input(
|
||||
"新建知识库名称",
|
||||
placeholder="新知识库名称,不支持中文命名",
|
||||
key="kb_name",
|
||||
)
|
||||
kb_info = st.text_input(
|
||||
"知识库简介",
|
||||
placeholder="知识库简介,方便Agent查找",
|
||||
key="kb_info",
|
||||
)
|
||||
|
||||
cols = st.columns(2)
|
||||
|
||||
vs_types = list(kbs_config.keys())
|
||||
vs_type = cols[0].selectbox(
|
||||
"向量库类型",
|
||||
vs_types,
|
||||
index=vs_types.index(DEFAULT_VS_TYPE),
|
||||
key="vs_type",
|
||||
)
|
||||
|
||||
if is_lite:
|
||||
embed_models = list_online_embed_models()
|
||||
else:
|
||||
embed_models = list_embed_models() + list_online_embed_models()
|
||||
|
||||
embed_model = cols[1].selectbox(
|
||||
"Embedding 模型",
|
||||
embed_models,
|
||||
index=embed_models.index(EMBEDDING_MODEL),
|
||||
key="embed_model",
|
||||
)
|
||||
|
||||
submit_create_kb = st.form_submit_button(
|
||||
"新建",
|
||||
# disabled=not bool(kb_name),
|
||||
use_container_width=True,
|
||||
)
|
||||
|
||||
if submit_create_kb:
|
||||
if not kb_name or not kb_name.strip():
|
||||
st.error(f"知识库名称不能为空!")
|
||||
elif kb_name in kb_list:
|
||||
st.error(f"名为 {kb_name} 的知识库已经存在!")
|
||||
else:
|
||||
ret = api.create_knowledge_base(
|
||||
knowledge_base_name=kb_name,
|
||||
vector_store_type=vs_type,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
st.toast(ret.get("msg", " "))
|
||||
st.session_state["selected_kb_name"] = kb_name
|
||||
st.session_state["selected_kb_info"] = kb_info
|
||||
st.rerun()
|
||||
|
||||
elif selected_kb:
|
||||
kb = selected_kb
|
||||
st.session_state["selected_kb_info"] = kb_list[kb]['kb_info']
|
||||
# 上传文件
|
||||
files = st.file_uploader("上传知识文件:",
|
||||
[i for ls in LOADER_DICT.values() for i in ls],
|
||||
accept_multiple_files=True,
|
||||
)
|
||||
kb_info = st.text_area("请输入知识库介绍:", value=st.session_state["selected_kb_info"], max_chars=None,
|
||||
key=None,
|
||||
help=None, on_change=None, args=None, kwargs=None)
|
||||
|
||||
if kb_info != st.session_state["selected_kb_info"]:
|
||||
st.session_state["selected_kb_info"] = kb_info
|
||||
api.update_kb_info(kb, kb_info)
|
||||
|
||||
# with st.sidebar:
|
||||
with st.expander(
|
||||
"文件处理配置",
|
||||
expanded=True,
|
||||
):
|
||||
cols = st.columns(3)
|
||||
chunk_size = cols[0].number_input("单段文本最大长度:", 1, 1000, CHUNK_SIZE)
|
||||
chunk_overlap = cols[1].number_input("相邻文本重合长度:", 0, chunk_size, OVERLAP_SIZE)
|
||||
cols[2].write("")
|
||||
cols[2].write("")
|
||||
zh_title_enhance = cols[2].checkbox("开启中文标题加强", ZH_TITLE_ENHANCE)
|
||||
|
||||
if st.button(
|
||||
"添加文件到知识库",
|
||||
# use_container_width=True,
|
||||
disabled=len(files) == 0,
|
||||
):
|
||||
ret = api.upload_kb_docs(files,
|
||||
knowledge_base_name=kb,
|
||||
override=True,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
zh_title_enhance=zh_title_enhance)
|
||||
if msg := check_success_msg(ret):
|
||||
st.toast(msg, icon="✔")
|
||||
elif msg := check_error_msg(ret):
|
||||
st.toast(msg, icon="✖")
|
||||
|
||||
st.divider()
|
||||
|
||||
# 知识库详情
|
||||
# st.info("请选择文件,点击按钮进行操作。")
|
||||
doc_details = pd.DataFrame(get_kb_file_details(kb))
|
||||
selected_rows = []
|
||||
if not len(doc_details):
|
||||
st.info(f"知识库 `{kb}` 中暂无文件")
|
||||
else:
|
||||
st.write(f"知识库 `{kb}` 中已有文件:")
|
||||
st.info("知识库中包含源文件与向量库,请从下表中选择文件后操作")
|
||||
doc_details.drop(columns=["kb_name"], inplace=True)
|
||||
doc_details = doc_details[[
|
||||
"No", "file_name", "document_loader", "text_splitter", "docs_count", "in_folder", "in_db",
|
||||
]]
|
||||
doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×")
|
||||
doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×")
|
||||
gb = config_aggrid(
|
||||
doc_details,
|
||||
{
|
||||
("No", "序号"): {},
|
||||
("file_name", "文档名称"): {},
|
||||
# ("file_ext", "文档类型"): {},
|
||||
# ("file_version", "文档版本"): {},
|
||||
("document_loader", "文档加载器"): {},
|
||||
("docs_count", "文档数量"): {},
|
||||
("text_splitter", "分词器"): {},
|
||||
# ("create_time", "创建时间"): {},
|
||||
("in_folder", "源文件"): {"cellRenderer": cell_renderer},
|
||||
("in_db", "向量库"): {"cellRenderer": cell_renderer},
|
||||
},
|
||||
"multiple",
|
||||
)
|
||||
|
||||
doc_grid = AgGrid(
|
||||
doc_details,
|
||||
gb.build(),
|
||||
columns_auto_size_mode="FIT_CONTENTS",
|
||||
theme="alpine",
|
||||
custom_css={
|
||||
"#gridToolBar": {"display": "none"},
|
||||
},
|
||||
allow_unsafe_jscode=True,
|
||||
enable_enterprise_modules=False
|
||||
)
|
||||
|
||||
selected_rows = doc_grid.get("selected_rows", [])
|
||||
|
||||
cols = st.columns(4)
|
||||
file_name, file_path = file_exists(kb, selected_rows)
|
||||
if file_path:
|
||||
with open(file_path, "rb") as fp:
|
||||
cols[0].download_button(
|
||||
"下载选中文档",
|
||||
fp,
|
||||
file_name=file_name,
|
||||
use_container_width=True, )
|
||||
else:
|
||||
cols[0].download_button(
|
||||
"下载选中文档",
|
||||
"",
|
||||
disabled=True,
|
||||
use_container_width=True, )
|
||||
|
||||
st.write()
|
||||
# 将文件分词并加载到向量库中
|
||||
if cols[1].button(
|
||||
"重新添加至向量库" if selected_rows and (
|
||||
pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库",
|
||||
disabled=not file_exists(kb, selected_rows)[0],
|
||||
use_container_width=True,
|
||||
):
|
||||
file_names = [row["file_name"] for row in selected_rows]
|
||||
api.update_kb_docs(kb,
|
||||
file_names=file_names,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
zh_title_enhance=zh_title_enhance)
|
||||
st.rerun()
|
||||
|
||||
# 将文件从向量库中删除,但不删除文件本身。
|
||||
if cols[2].button(
|
||||
"从向量库删除",
|
||||
disabled=not (selected_rows and selected_rows[0]["in_db"]),
|
||||
use_container_width=True,
|
||||
):
|
||||
file_names = [row["file_name"] for row in selected_rows]
|
||||
api.delete_kb_docs(kb, file_names=file_names)
|
||||
st.rerun()
|
||||
|
||||
if cols[3].button(
|
||||
"从知识库中删除",
|
||||
type="primary",
|
||||
use_container_width=True,
|
||||
):
|
||||
file_names = [row["file_name"] for row in selected_rows]
|
||||
api.delete_kb_docs(kb, file_names=file_names, delete_content=True)
|
||||
st.rerun()
|
||||
|
||||
st.divider()
|
||||
|
||||
cols = st.columns(3)
|
||||
|
||||
if cols[0].button(
|
||||
"依据源文件重建向量库",
|
||||
help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。",
|
||||
use_container_width=True,
|
||||
type="primary",
|
||||
):
|
||||
with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"):
|
||||
empty = st.empty()
|
||||
empty.progress(0.0, "")
|
||||
for d in api.recreate_vector_store(kb,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
zh_title_enhance=zh_title_enhance):
|
||||
if msg := check_error_msg(d):
|
||||
st.toast(msg)
|
||||
else:
|
||||
empty.progress(d["finished"] / d["total"], d["msg"])
|
||||
st.rerun()
|
||||
|
||||
# if cols[2].button(
|
||||
# "删除知识库",
|
||||
# use_container_width=True,
|
||||
# ):
|
||||
# ret = api.delete_knowledge_base(kb)
|
||||
# st.toast(ret.get("msg", " "))
|
||||
# time.sleep(1)
|
||||
# st.rerun()
|
||||
|
||||
# with st.sidebar:
|
||||
# keyword = st.text_input("查询关键字")
|
||||
# top_k = st.slider("匹配条数", 1, 100, 3)
|
||||
|
||||
# st.write("文件内文档列表。双击进行修改,在删除列填入 Y 可删除对应行。")
|
||||
# docs = []
|
||||
# df = pd.DataFrame([], columns=["seq", "id", "content", "source"])
|
||||
# if selected_rows:
|
||||
# file_name = selected_rows[0]["file_name"]
|
||||
# docs = api.search_kb_docs(knowledge_base_name=selected_kb, file_name=file_name)
|
||||
# data = [
|
||||
# {"seq": i + 1, "id": x["id"], "page_content": x["page_content"], "source": x["metadata"].get("source"),
|
||||
# "type": x["type"],
|
||||
# "metadata": json.dumps(x["metadata"], ensure_ascii=False),
|
||||
# "to_del": "",
|
||||
# } for i, x in enumerate(docs)]
|
||||
# df = pd.DataFrame(data)
|
||||
|
||||
# gb = GridOptionsBuilder.from_dataframe(df)
|
||||
# gb.configure_columns(["id", "source", "type", "metadata"], hide=True)
|
||||
# gb.configure_column("seq", "No.", width=50)
|
||||
# gb.configure_column("page_content", "内容", editable=True, autoHeight=True, wrapText=True, flex=1,
|
||||
# cellEditor="agLargeTextCellEditor", cellEditorPopup=True)
|
||||
# gb.configure_column("to_del", "删除", editable=True, width=50, wrapHeaderText=True,
|
||||
# cellEditor="agCheckboxCellEditor", cellRender="agCheckboxCellRenderer")
|
||||
# gb.configure_selection()
|
||||
# edit_docs = AgGrid(df, gb.build())
|
||||
|
||||
# if st.button("保存更改"):
|
||||
# origin_docs = {
|
||||
# x["id"]: {"page_content": x["page_content"], "type": x["type"], "metadata": x["metadata"]} for x in
|
||||
# docs}
|
||||
# changed_docs = []
|
||||
# for index, row in edit_docs.data.iterrows():
|
||||
# origin_doc = origin_docs[row["id"]]
|
||||
# if row["page_content"] != origin_doc["page_content"]:
|
||||
# if row["to_del"] not in ["Y", "y", 1]:
|
||||
# changed_docs.append({
|
||||
# "page_content": row["page_content"],
|
||||
# "type": row["type"],
|
||||
# "metadata": json.loads(row["metadata"]),
|
||||
# })
|
||||
|
||||
# if changed_docs:
|
||||
# if api.update_kb_docs(knowledge_base_name=selected_kb,
|
||||
# file_names=[file_name],
|
||||
# docs={file_name: changed_docs}):
|
||||
# st.toast("更新文档成功")
|
||||
# else:
|
||||
# st.toast("更新文档失败")
|
||||
1
langchain-chat/webui_pages/model_config/__init__.py
Normal file
1
langchain-chat/webui_pages/model_config/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# 模型配置页面模块
|
||||
4
langchain-chat/webui_pages/model_config/model_config.py
Normal file
4
langchain-chat/webui_pages/model_config/model_config.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from webui_pages.utils import *
|
||||
|
||||
def model_config_page(api: ApiRequest):
|
||||
pass
|
||||
1221
langchain-chat/webui_pages/utils.py
Normal file
1221
langchain-chat/webui_pages/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user