[全量] 初始化项目代码、配置、文档及Agent协同harness

This commit is contained in:
2026-04-02 11:36:05 +08:00
parent 0553309cdf
commit 87e571d9ec
1133 changed files with 221948 additions and 0 deletions

View File

@@ -0,0 +1,214 @@
import requests
import json
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.utils import api_address
from configs import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path, get_file_path
from pprint import pprint
api_base_url = api_address()
kb = "kb_for_api_test"
test_files = {
"wiki/Home.MD": get_file_path("samples", "wiki/Home.md"),
"wiki/开发环境部署.MD": get_file_path("samples", "wiki/开发环境部署.md"),
"test_files/test.txt": get_file_path("samples", "test_files/test.txt"),
}
print("\n\n直接url访问\n")
def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"):
if not Path(get_kb_path(kb)).exists():
return
url = api_base_url + api
print("\n测试知识库存在,需要删除")
r = requests.post(url, json=kb)
data = r.json()
pprint(data)
# check kb not exists anymore
url = api_base_url + "/knowledge_base/list_knowledge_bases"
print("\n获取知识库列表:")
r = requests.get(url)
data = r.json()
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list) and len(data["data"]) > 0
assert kb not in data["data"]
def test_create_kb(api="/knowledge_base/create_knowledge_base"):
url = api_base_url + api
print(f"\n尝试用空名称创建知识库:")
r = requests.post(url, json={"knowledge_base_name": " "})
data = r.json()
pprint(data)
assert data["code"] == 404
assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
print(f"\n创建新知识库: {kb}")
r = requests.post(url, json={"knowledge_base_name": kb})
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"已新增知识库 {kb}"
print(f"\n尝试创建同名知识库: {kb}")
r = requests.post(url, json={"knowledge_base_name": kb})
data = r.json()
pprint(data)
assert data["code"] == 404
assert data["msg"] == f"已存在同名知识库 {kb}"
def test_list_kbs(api="/knowledge_base/list_knowledge_bases"):
url = api_base_url + api
print("\n获取知识库列表:")
r = requests.get(url)
data = r.json()
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list) and len(data["data"]) > 0
assert kb in data["data"]
def test_upload_docs(api="/knowledge_base/upload_docs"):
url = api_base_url + api
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
print(f"\n上传知识文件")
data = {"knowledge_base_name": kb, "override": True}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
print(f"\n尝试重新上传知识文件, 不覆盖")
data = {"knowledge_base_name": kb, "override": False}
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == len(test_files)
print(f"\n尝试重新上传知识文件, 覆盖自定义docs")
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)}
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
def test_list_files(api="/knowledge_base/list_files"):
url = api_base_url + api
print("\n获取知识库中文件列表:")
r = requests.get(url, params={"knowledge_base_name": kb})
data = r.json()
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list)
for name in test_files:
assert name in data["data"]
def test_search_docs(api="/knowledge_base/search_docs"):
url = api_base_url + api
query = "介绍一下langchain-chatchat项目"
print("\n检索知识库:")
print(query)
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
data = r.json()
pprint(data)
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_update_info(api="/knowledge_base/update_info"):
url = api_base_url + api
print("\n更新知识库介绍")
r = requests.post(url, json={"knowledge_base_name": "samples", "kb_info": "你好"})
data = r.json()
pprint(data)
assert data["code"] == 200
def test_update_docs(api="/knowledge_base/update_docs"):
url = api_base_url + api
print(f"\n更新知识文件")
r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)})
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
def test_delete_docs(api="/knowledge_base/delete_docs"):
url = api_base_url + api
print(f"\n删除知识文件")
r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)})
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
url = api_base_url + "/knowledge_base/search_docs"
query = "介绍一下langchain-chatchat项目"
print("\n尝试检索删除后的检索知识库:")
print(query)
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
data = r.json()
pprint(data)
assert isinstance(data, list) and len(data) == 0
def test_recreate_vs(api="/knowledge_base/recreate_vector_store"):
url = api_base_url + api
print("\n重建知识库:")
r = requests.post(url, json={"knowledge_base_name": kb}, stream=True)
for chunk in r.iter_content(None):
data = json.loads(chunk[6:])
assert isinstance(data, dict)
assert data["code"] == 200
print(data["msg"])
url = api_base_url + "/knowledge_base/search_docs"
query = "本项目支持哪些文件格式?"
print("\n尝试检索重建后的检索知识库:")
print(query)
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
data = r.json()
pprint(data)
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_delete_kb_after(api="/knowledge_base/delete_knowledge_base"):
url = api_base_url + api
print("\n删除知识库")
r = requests.post(url, json=kb)
data = r.json()
pprint(data)
# check kb not exists anymore
url = api_base_url + "/knowledge_base/list_knowledge_bases"
print("\n获取知识库列表:")
r = requests.get(url)
data = r.json()
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list) and len(data["data"]) > 0
assert kb not in data["data"]

View File

@@ -0,0 +1,161 @@
import requests
import json
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.utils import api_address
from configs import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path, get_file_path
from webui_pages.utils import ApiRequest
from pprint import pprint
api_base_url = api_address()
api: ApiRequest = ApiRequest(api_base_url)
kb = "kb_for_api_test"
test_files = {
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"),
"README.MD": str(root_path / "README.MD"),
"test.txt": get_file_path("samples", "test.txt"),
}
print("\n\nApiRquest调用\n")
def test_delete_kb_before():
if not Path(get_kb_path(kb)).exists():
return
data = api.delete_knowledge_base(kb)
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list) and len(data["data"]) > 0
assert kb not in data["data"]
def test_create_kb():
print(f"\n尝试用空名称创建知识库:")
data = api.create_knowledge_base(" ")
pprint(data)
assert data["code"] == 404
assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
print(f"\n创建新知识库: {kb}")
data = api.create_knowledge_base(kb)
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"已新增知识库 {kb}"
print(f"\n尝试创建同名知识库: {kb}")
data = api.create_knowledge_base(kb)
pprint(data)
assert data["code"] == 404
assert data["msg"] == f"已存在同名知识库 {kb}"
def test_list_kbs():
data = api.list_knowledge_bases()
pprint(data)
assert isinstance(data, list) and len(data) > 0
assert kb in data
def test_upload_docs():
files = list(test_files.values())
print(f"\n上传知识文件")
data = {"knowledge_base_name": kb, "override": True}
data = api.upload_kb_docs(files, **data)
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
print(f"\n尝试重新上传知识文件, 不覆盖")
data = {"knowledge_base_name": kb, "override": False}
data = api.upload_kb_docs(files, **data)
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == len(test_files)
print(f"\n尝试重新上传知识文件, 覆盖自定义docs")
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
data = {"knowledge_base_name": kb, "override": True, "docs": docs}
data = api.upload_kb_docs(files, **data)
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
def test_list_files():
print("\n获取知识库中文件列表:")
data = api.list_kb_docs(knowledge_base_name=kb)
pprint(data)
assert isinstance(data, list)
for name in test_files:
assert name in data
def test_search_docs():
query = "介绍一下langchain-chatchat项目"
print("\n检索知识库:")
print(query)
data = api.search_kb_docs(query, kb)
pprint(data)
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_update_docs():
print(f"\n更新知识文件")
data = api.update_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
def test_delete_docs():
print(f"\n删除知识文件")
data = api.delete_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
query = "介绍一下langchain-chatchat项目"
print("\n尝试检索删除后的检索知识库:")
print(query)
data = api.search_kb_docs(query, kb)
pprint(data)
assert isinstance(data, list) and len(data) == 0
def test_recreate_vs():
print("\n重建知识库:")
r = api.recreate_vector_store(kb)
for data in r:
assert isinstance(data, dict)
assert data["code"] == 200
print(data["msg"])
query = "本项目支持哪些文件格式?"
print("\n尝试检索重建后的检索知识库:")
print(query)
data = api.search_kb_docs(query, kb)
pprint(data)
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_delete_kb_after():
print("\n删除知识库")
data = api.delete_knowledge_base(kb)
pprint(data)
# check kb not exists anymore
print("\n获取知识库列表:")
data = api.list_knowledge_bases()
pprint(data)
assert isinstance(data, list) and len(data) > 0
assert kb not in data

View File

@@ -0,0 +1,44 @@
import requests
import json
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.utils import api_address
api_base_url = api_address()
kb = "samples"
file_name = "/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/knowledge_base/samples/content/llm/大模型技术栈-实战与应用.md"
doc_ids = [
"357d580f-fdf7-495c-b58b-595a398284e8",
"c7338773-2e83-4671-b237-1ad20335b0f0",
"6da613d1-327d-466f-8c1a-b32e6f461f47"
]
def test_summary_file_to_vector_store(api="/knowledge_base/kb_summary_api/summary_file_to_vector_store"):
url = api_base_url + api
print("\n文件摘要:")
r = requests.post(url, json={"knowledge_base_name": kb,
"file_name": file_name
}, stream=True)
for chunk in r.iter_content(None):
data = json.loads(chunk[6:])
assert isinstance(data, dict)
assert data["code"] == 200
print(data["msg"])
def test_summary_doc_ids_to_vector_store(api="/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store"):
url = api_base_url + api
print("\n文件摘要:")
r = requests.post(url, json={"knowledge_base_name": kb,
"doc_ids": doc_ids
}, stream=True)
for chunk in r.iter_content(None):
data = json.loads(chunk[6:])
assert isinstance(data, dict)
assert data["code"] == 200
print(data)

View File

@@ -0,0 +1,70 @@
import requests
import json
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from configs.server_config import FSCHAT_MODEL_WORKERS
from server.utils import api_address, get_model_worker_config
from pprint import pprint
import random
from typing import List
def get_configured_models() -> List[str]:
model_workers = list(FSCHAT_MODEL_WORKERS)
if "default" in model_workers:
model_workers.remove("default")
return model_workers
api_base_url = api_address()
def get_running_models(api="/llm_model/list_models"):
url = api_base_url + api
r = requests.post(url)
if r.status_code == 200:
return r.json()["data"]
return []
def test_running_models(api="/llm_model/list_running_models"):
url = api_base_url + api
r = requests.post(url)
assert r.status_code == 200
print("\n获取当前正在运行的模型列表:")
pprint(r.json())
assert isinstance(r.json()["data"], list)
assert len(r.json()["data"]) > 0
# 不建议使用stop_model功能。按现在的实现停止了就只能手动再启动
# def test_stop_model(api="/llm_model/stop"):
# url = api_base_url + api
# r = requests.post(url, json={""})
def test_change_model(api="/llm_model/change_model"):
url = api_base_url + api
running_models = get_running_models()
assert len(running_models) > 0
model_workers = get_configured_models()
availabel_new_models = list(set(model_workers) - set(running_models))
assert len(availabel_new_models) > 0
print(availabel_new_models)
local_models = [x for x in running_models if not get_model_worker_config(x).get("online_api")]
model_name = random.choice(local_models)
new_model_name = random.choice(availabel_new_models)
print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}")
r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name})
assert r.status_code == 200
running_models = get_running_models()
assert new_model_name in running_models

View File

@@ -0,0 +1,47 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from webui_pages.utils import ApiRequest
import pytest
from pprint import pprint
from typing import List
api = ApiRequest()
def test_get_default_llm():
llm = api.get_default_llm_model()
print(llm)
assert isinstance(llm, tuple)
assert isinstance(llm[0], str) and isinstance(llm[1], bool)
def test_server_configs():
configs = api.get_server_configs()
pprint(configs, depth=2)
assert isinstance(configs, dict)
assert len(configs) > 0
def test_list_search_engines():
engines = api.list_search_engines()
pprint(engines)
assert isinstance(engines, list)
assert len(engines) > 0
@pytest.mark.parametrize("type", ["llm_chat", "agent_chat"])
def test_get_prompt_template(type):
print(f"prompt template for: {type}")
template = api.get_prompt_template(type=type)
print(template)
assert isinstance(template, str)
assert len(template) > 0

View File

@@ -0,0 +1,113 @@
import requests
import json
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent))
from configs import BING_SUBSCRIPTION_KEY
from server.utils import api_address
from pprint import pprint
api_base_url = api_address()
def dump_input(d, title):
print("\n")
print("=" * 30 + title + " input " + "="*30)
pprint(d)
def dump_output(r, title):
print("\n")
print("=" * 30 + title + " output" + "="*30)
for line in r.iter_content(None, decode_unicode=True):
print(line, end="", flush=True)
headers = {
'accept': 'application/json',
'Content-Type': 'application/json',
}
data = {
"query": "请用100字左右的文字介绍自己",
"history": [
{
"role": "user",
"content": "你好"
},
{
"role": "assistant",
"content": "你好,我是人工智能大模型"
}
],
"stream": True,
"temperature": 0.7,
}
def test_chat_chat(api="/chat/chat"):
url = f"{api_base_url}{api}"
dump_input(data, api)
response = requests.post(url, headers=headers, json=data, stream=True)
dump_output(response, api)
assert response.status_code == 200
def test_knowledge_chat(api="/chat/knowledge_base_chat"):
url = f"{api_base_url}{api}"
data = {
"query": "如何提问以获得高质量答案",
"knowledge_base_name": "samples",
"history": [
{
"role": "user",
"content": "你好"
},
{
"role": "assistant",
"content": "你好,我是 ChatGLM"
}
],
"stream": True
}
dump_input(data, api)
response = requests.post(url, headers=headers, json=data, stream=True)
print("\n")
print("=" * 30 + api + " output" + "="*30)
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line[6:])
if "answer" in data:
print(data["answer"], end="", flush=True)
pprint(data)
assert "docs" in data and len(data["docs"]) > 0
assert response.status_code == 200
def test_search_engine_chat(api="/chat/search_engine_chat"):
global data
data["query"] = "室温超导最新进展是什么样?"
url = f"{api_base_url}{api}"
for se in ["bing", "duckduckgo"]:
data["search_engine_name"] = se
dump_input(data, api + f" by {se}")
response = requests.post(url, json=data, stream=True)
if se == "bing" and not BING_SUBSCRIPTION_KEY:
data = response.json()
assert data["code"] == 404
assert data["msg"] == f"要使用Bing搜索引擎需要设置 `BING_SUBSCRIPTION_KEY`"
print("\n")
print("=" * 30 + api + f" by {se} output" + "="*30)
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line[6:])
if "answer" in data:
print(data["answer"], end="", flush=True)
assert "docs" in data and len(data["docs"]) > 0
pprint(data["docs"])
assert response.status_code == 200

View File

@@ -0,0 +1,81 @@
import requests
import json
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent))
from configs import BING_SUBSCRIPTION_KEY
from server.utils import api_address
from pprint import pprint
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
api_base_url = api_address()
def dump_input(d, title):
print("\n")
print("=" * 30 + title + " input " + "="*30)
pprint(d)
def dump_output(r, title):
print("\n")
print("=" * 30 + title + " output" + "="*30)
for line in r.iter_content(None, decode_unicode=True):
print(line, end="", flush=True)
headers = {
'accept': 'application/json',
'Content-Type': 'application/json',
}
def knowledge_chat(api="/chat/knowledge_base_chat"):
url = f"{api_base_url}{api}"
data = {
"query": "如何提问以获得高质量答案",
"knowledge_base_name": "samples",
"history": [
{
"role": "user",
"content": "你好"
},
{
"role": "assistant",
"content": "你好,我是 ChatGLM"
}
],
"stream": True
}
result = []
response = requests.post(url, headers=headers, json=data, stream=True)
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line[6:])
result.append(data)
return result
def test_thread():
threads = []
times = []
pool = ThreadPoolExecutor()
start = time.time()
for i in range(10):
t = pool.submit(knowledge_chat)
threads.append(t)
for r in as_completed(threads):
end = time.time()
times.append(end - start)
print("\nResult:\n")
pprint(r.result())
print("\nTime used:\n")
for x in times:
print(f"{x}")

View File

@@ -0,0 +1,53 @@
import os
from transformers import AutoTokenizer
import sys
sys.path.append("../..")
from configs import (
CHUNK_SIZE,
OVERLAP_SIZE
)
from server.knowledge_base.utils import make_text_splitter
def text(splitter_name):
from langchain import document_loaders
# 使用DocumentLoader读取文件
filepath = "../../knowledge_base/samples/content/test.txt"
loader = document_loaders.UnstructuredFileLoader(filepath, autodetect_encoding=True)
docs = loader.load()
text_splitter = make_text_splitter(splitter_name, CHUNK_SIZE, OVERLAP_SIZE)
if splitter_name == "MarkdownHeaderTextSplitter":
docs = text_splitter.split_text(docs[0].page_content)
for doc in docs:
if doc.metadata:
doc.metadata["source"] = os.path.basename(filepath)
else:
docs = text_splitter.split_documents(docs)
for doc in docs:
print(doc)
return docs
import pytest
from langchain.docstore.document import Document
@pytest.mark.parametrize("splitter_name",
[
"ChineseRecursiveTextSplitter",
"SpacyTextSplitter",
"RecursiveCharacterTextSplitter",
"MarkdownHeaderTextSplitter"
])
def test_different_splitter(splitter_name):
try:
docs = text(splitter_name)
assert isinstance(docs, list)
if len(docs)>0:
assert isinstance(docs[0], Document)
except Exception as e:
pytest.fail(f"test_different_splitter failed with {splitter_name}, error: {str(e)}")

View File

@@ -0,0 +1,10 @@
data_path = './人工智能发展月报.html'
from langchain_community.document_loaders import TextLoader
loader = TextLoader(data_path)
data = loader.load()
print(data)
from unstructured.partition.html import partition_html
rst = partition_html(text=data[0].page_content)
print("\n\n".join([str(el) for el in rst]))

View File

@@ -0,0 +1,21 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from pprint import pprint
test_files = {
"ocr_test.jpg": str(root_path / "tests" / "samples" / "ocr_test.jpg"),
}
def test_rapidocrloader():
img_path = test_files["ocr_test.jpg"]
from document_loaders import RapidOCRLoader
loader = RapidOCRLoader(img_path)
docs = loader.load()
pprint(docs)
assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str)

View File

@@ -0,0 +1,21 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from pprint import pprint
test_files = {
"ocr_test.pdf": str(root_path / "tests" / "samples" / "ocr_test.pdf"),
}
def test_rapidocrpdfloader():
pdf_path = test_files["ocr_test.pdf"]
from document_loaders import RapidOCRPDFLoader
loader = RapidOCRPDFLoader(pdf_path)
docs = loader.load()
pprint(docs)
assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,55 @@
import zipfile
from lxml import etree
import json
import os
def extract_docx_info(docx_path: str, output_json: str = 'docx_info.json'):
if not os.path.isfile(docx_path):
raise FileNotFoundError(f"未找到 DOCX 文件:{docx_path}")
tag_counts = {}
rels_counts = {}
media_files = []
with zipfile.ZipFile(docx_path) as docx:
file_list = docx.namelist()
# 统计 document.xml 中的元素标签
doc_xml = docx.read('word/document.xml')
doc_tree = etree.fromstring(doc_xml)
for elem in doc_tree.iter():
tag = etree.QName(elem).localname
tag_counts[tag] = tag_counts.get(tag, 0) + 1
# 统计文档关系document.xml.rels
rels_path = 'word/_rels/document.xml.rels'
if rels_path in file_list:
rels_xml = docx.read(rels_path)
rels_tree = etree.fromstring(rels_xml)
for rel in rels_tree.findall(
'.//{http://schemas.openxmlformats.org/package/2006/relationships}Relationship'
):
rel_type = rel.get('Type').split('/')[-1]
rels_counts[rel_type] = rels_counts.get(rel_type, 0) + 1
# 列出所有嵌入的媒体文件
media_files = [f for f in file_list if f.startswith('word/media/')]
# 汇总信息
info = {
'source_docx': os.path.basename(docx_path),
'elements': tag_counts,
'relationships': rels_counts,
'media_files': media_files,
}
# 写入 JSON
with open(output_json, 'w', encoding='utf-8') as f:
json.dump(info, f, ensure_ascii=False, indent=2)
print(f"已生成 JSON 文件:{output_json}")
if __name__ == '__main__':
# 读取当前目录下的 test.docx
extract_docx_info('./test.docx', 'docx_info.json')

View File

@@ -0,0 +1,35 @@
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
from server.knowledge_base.migrate import create_tables
from server.knowledge_base.utils import KnowledgeFile
kbService = FaissKBService("test")
test_kb_name = "test"
test_file_name = "README.md"
testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name)
search_content = "如何启动api服务"
def test_init():
create_tables()
def test_create_db():
assert kbService.create_kb()
def test_add_doc():
assert kbService.add_doc(testKnowledgeFile)
def test_search_db():
result = kbService.search_docs(search_content)
assert len(result) > 0
def test_delete_doc():
assert kbService.delete_doc(testKnowledgeFile)
def test_delete_db():
assert kbService.drop_kb()

View File

@@ -0,0 +1,4 @@
from pymilvus import MilvusClient
client = MilvusClient("http://127.0.0.1:19530")
print(client.list_collections())

View File

@@ -0,0 +1,31 @@
# from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
# from server.knowledge_base.kb_service.pg_kb_service import PGKBService
from server.knowledge_base.migrate import create_tables
from server.knowledge_base.utils import KnowledgeFile
kbService = MilvusKBService("test")
test_kb_name = "test"
test_file_name = "README.md"
testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name)
search_content = "如何启动api服务"
def test_init():
create_tables()
def test_create_db():
assert kbService.create_kb()
def test_add_doc():
assert kbService.add_doc(testKnowledgeFile)
def test_search_db():
result = kbService.search_docs(search_content)
assert len(result) > 0
def test_delete_doc():
assert kbService.delete_doc(testKnowledgeFile)

View File

@@ -0,0 +1,31 @@
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
from server.knowledge_base.kb_service.pg_kb_service import PGKBService
from server.knowledge_base.migrate import create_tables
from server.knowledge_base.utils import KnowledgeFile
kbService = PGKBService("test")
test_kb_name = "test"
test_file_name = "README.md"
testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name)
search_content = "如何启动api服务"
def test_init():
create_tables()
def test_create_db():
assert kbService.create_kb()
def test_add_doc():
assert kbService.add_doc(testKnowledgeFile)
def test_search_db():
result = kbService.search_docs(search_content)
assert len(result) > 0
def test_delete_doc():
assert kbService.delete_doc(testKnowledgeFile)

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.9 KiB

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,138 @@
import requests
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from tabulate import tabulate
# 模型配置
model_config = {
"Qwen1.5-32B-Chat": {
"model_name": "Qwen1.5-32B-Chat",
"api_base_url": "http://192.168.56.123:8821/v1",
"api_key": "fake",
}
}
# 选择要测试的模型
selected_model = "Qwen1.5-32B-Chat"
API_URL = model_config[selected_model]["api_base_url"] + "/completions"
API_KEY = model_config[selected_model]["api_key"]
MODEL_NAME = model_config[selected_model]["model_name"]
# 测试输入
TEST_PROMPT = "从前有座山"
# 测试每秒输出字符数
def test_output_speed(prompt, api_url, api_key, model_name, max_tokens):
headers = {"Authorization": f"Bearer {api_key}"}
data = {
"model": model_name,
"prompt": prompt,
"max_tokens": max_tokens
}
start_time = time.time()
response = requests.post(api_url, headers=headers, json=data)
end_time = time.time()
if response.status_code == 200:
output_text = response.json().get("choices", [{}])[0].get("text", "")
char_count = len(output_text)
elapsed_time = end_time - start_time
chars_per_second = char_count / elapsed_time if elapsed_time > 0 else 0
result = [
max_tokens,
char_count,
f"{elapsed_time:.2f} ",
f"{chars_per_second:.2f} "
]
return result
else:
print(f"错误: {response.status_code}, {response.text}")
return max_tokens, 0, 0, 0
# 测试最大并发请求数
def make_request(api_url, api_key, prompt, model_name, max_tokens):
headers = {"Authorization": f"Bearer {api_key}"}
data = {
"model": model_name,
"prompt": prompt,
"max_tokens": max_tokens
}
start_time = time.time()
response = requests.post(api_url, headers=headers, json=data)
end_time = time.time()
response_time = end_time - start_time
return response.status_code, response.text, response_time
def test_concurrent_requests(api_url, api_key, prompt, model_name, max_workers, max_tokens):
total_requests = max_workers
success_count = 0
total_response_time = 0
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(make_request, api_url, api_key, prompt, model_name, max_tokens) for _ in range(max_workers)]
for future in as_completed(futures):
status_code, _, response_time = future.result()
total_response_time += response_time
if status_code == 200:
success_count += 1
average_response_time = total_response_time / total_requests if total_requests > 0 else 0
throughput = success_count / total_response_time if total_response_time > 0 else 0
success_rate = (success_count / total_requests) * 100 if total_requests > 0 else 0
result = [
max_tokens,
max_workers,
total_requests,
f"{average_response_time:.2f}",
f"{throughput:.2f}",
success_count,
f"{success_rate:.2f}"
]
return result
# 动态测试最大并发请求数
def find_max_concurrent_requests(api_url, api_key, prompt, model_name, max_tokens, max_workers_list):
results = []
for max_workers in max_workers_list:
result = test_concurrent_requests(api_url, api_key, prompt, model_name, max_workers, max_tokens)
results.append(result)
return results
# 输出为制表符分隔的文本格式
def print_csv(headers, rows):
header_line = ",".join(headers)
print(header_line)
for row in rows:
row_line = ",".join(map(str, row))
print(row_line)
# 主程序
if __name__ == "__main__":
# 定义要测试的 max_tokens 列表
max_tokens_list = [512, 1024, 2048, 4096, 8192]
max_workers_list = [10, 20, 50, 100]
# max_tokens_list = [10, 20, 30]
# max_workers_list = [2, 4, 8, 16]
output_speed_results = []
for max_tokens in max_tokens_list:
result = test_output_speed(TEST_PROMPT, API_URL, API_KEY, MODEL_NAME, max_tokens)
output_speed_results.append(result)
# 输出每秒输出字符数的测试结果表格
output_speed_headers = ["max_tokens", "输出字符数", "耗时 (秒)", "每秒字符数"]
print("每秒输出字符数测试结果:")
print_csv(output_speed_headers, output_speed_results)
# 动态测试最大并发请求数
all_concurrency_results = []
for max_tokens in max_tokens_list:
results = find_max_concurrent_requests(API_URL, API_KEY, TEST_PROMPT, MODEL_NAME, max_tokens, max_workers_list)
all_concurrency_results.extend(results)
# 输出并发请求测试结果
concurrency_headers = ["max_tokens", "并发请求数", "请求个数", "平均响应时间(RT,单位秒)", "吞吐量 (QPS)", "请求成功个数", "请求成功率 (%)"]
print("\n测试并发结果:")
print_csv(concurrency_headers, all_concurrency_results)

View File

@@ -0,0 +1,162 @@
import aiohttp
import asyncio
import time
from tqdm import tqdm
import random
import sys
# 配置
# LLM_MODEL = "Qwen2-72B-Instruct"
# LLM_ENDPOINT = "http://192.168.56.123:8822/v1"
# LLM_MODEL = "deepseek-chat"
# LLM_ENDPOINT = "https://api.deepseek.com/v1"
LLM_MODEL = "qwen-max-2025-01-25"
LLM_ENDPOINT = "https://dashscope.aliyuncs.com/compatible-mode/v1"
TEMPERATURE = 0.7 # 确保每次返回的都不同
MAX_TOKENS = 8192
# 问题列表
questions = [
"为什么鸟儿会唱歌?", "为什么我们有季节?", "为什么星星会闪烁?", "为什么我们会打哈欠?",
"为什么太阳是热的?", "为什么猫会咕噜咕噜叫?", "为什么狗会吠?", "为什么鱼会游泳?",
"为什么我们有指纹?", "为什么我们会打喷嚏?", "为什么我们有眉毛?", "为什么我们有头发?",
"为什么我们有指甲?", "为什么我们有牙齿?", "为什么我们有骨头?", "为什么我们有肌肉?",
"为什么我们有血液?", "为什么我们有心脏?", "为什么我们有肺?", "为什么我们有大脑?",
"为什么我们有皮肤?", "为什么我们有耳朵?", "为什么我们有眼睛?", "为什么我们有鼻子?",
"为什么我们有嘴巴?", "为什么我们有舌头?", "为什么我们有胃?", "为什么我们有肠子?",
"为什么我们有肝脏?", "为什么我们有肾脏?", "为什么我们有膀胱?", "为什么我们有胰腺?",
"为什么我们有脾脏?", "为什么我们有胆囊?", "为什么我们有甲状腺?", "为什么我们有肾上腺?",
"为什么我们有垂体?", "为什么我们有下丘脑?", "为什么我们有胸腺?", "为什么我们有淋巴结?",
"为什么我们有脊髓?", "为什么我们有神经?", "为什么我们有循环系统?", "为什么我们有呼吸系统?",
"为什么我们有消化系统?", "为什么我们有免疫系统?"
]
def log_to_file(file, message):
"""将消息追加写入指定的文件"""
with open(file, 'a', encoding='utf-8') as f:
f.write(message + '\n')
async def fetch(session, url, file=None):
start_time = time.time()
question = random.choice(questions)
json_payload = {
"model": LLM_MODEL,
"messages": [
# {"role": "system", "content": "你的任务是学习和理解用户输入的文段分析其中的实体关系然后根据关系逻辑重新拟定你份合同合同字数要在4000字以上。"},
# {"role": "system", "content": "你的使命是翻译用户输入的文段。注意,一定要翻译完整。"},
{"role": "system", "content": "你的任务是用专业学术语言严谨的科学态度全面的数据支持完整详实地回答user的问题。"},
{"role": "user", "content": question}
],
"temperature": TEMPERATURE,
"max_tokens": MAX_TOKENS,
"stream": False
}
headers = {
"Content-Type": "application/json",
# "Authorization": "Bearer sk-dba93353b0cc447ba55245e4f048c779" # deepseek
"Authorization": "sk-8b498c0de2dc437aab8efa490d4021ba" # qwen
}
try:
async with session.post(url, json=json_payload, headers=headers) as response:
if response.status != 200:
print(f"错误: 收到响应码 {response.status}")
return 0, 0, 0
response_json = await response.json()
end_time = time.time()
request_time = end_time - start_time
completion_tokens = 0
if 'usage' in response_json:
usage = response_json['usage']
completion_tokens = usage.get('completion_tokens', 0)
prompt_tokens = usage.get('prompt_tokens', 0)
else:
print("警告: 响应中缺少 'usage' 字段。")
answer = ""
if 'choices' in response_json and len(response_json['choices']) > 0:
answer = response_json['choices'][0]['message']['content']
# completion_tokens = len(answer) / 1.5 # qwen
# completion_tokens = len(answer) * 0.6 # deepseek
else:
print("警告: 响应中缺少 'choices' 字段或内容为空。")
completion_tokens = 0
# 将输入输出写入文件(保持原有日志格式不变)
if file:
log_to_file(file, f"Q: {question}\nA: {answer}\n")
# 打印输入和输出的token数
# print(f"输入token数: {input_tokens}, 输出token数: {output_tokens}")
return prompt_tokens, completion_tokens, request_time
except Exception as e:
print(f"请求过程中发生异常: {e}")
return 0, 0, 0
async def bound_fetch(sem, session, url, pbar, file=None):
async with sem:
result = await fetch(session, url, file=file)
pbar.update(1)
return result
async def run(load_url, max_concurrent_requests, total_requests, output_file):
sem = asyncio.Semaphore(max_concurrent_requests)
timeout = aiohttp.ClientTimeout(total=6000, connect=6000, sock_read=6000, sock_connect=6000)
async with aiohttp.ClientSession(timeout=timeout) as session:
tasks = []
with tqdm(total=total_requests) as pbar:
for _ in range(total_requests):
task = asyncio.create_task(bound_fetch(sem, session, load_url, pbar, file=output_file))
tasks.append(task)
results = await asyncio.gather(*tasks)
# 聚合token数和响应时间
total_input_tokens = sum(result[0] for result in results)
total_output_tokens = sum(result[1] for result in results)
response_times = [result[2] for result in results]
return total_input_tokens, total_output_tokens, response_times
if __name__ == '__main__':
if len(sys.argv) != 3:
print("用法: python llm_test.py <C> <N>")
sys.exit(1)
try:
C = int(sys.argv[1]) # 最大并发数
N = int(sys.argv[2]) # 请求总数
except ValueError:
print("错误: C 和 N 必须是整数。")
sys.exit(1)
url = f'{LLM_ENDPOINT}/chat/completions'
output_file = 'A800_Qwen2.5-72Bint8_bench_.txt'
with open(output_file, 'w', encoding='utf-8') as f:
f.write('') # 清空文件内容
start_time = time.time()
total_input_tokens, total_output_tokens, response_times = asyncio.run(run(url, C, N, output_file))
end_time = time.time()
total_time = end_time - start_time
avg_time_per_request = sum(response_times) / len(response_times) if response_times else 0
tokens_per_second = (total_output_tokens) / total_time if total_time > 0 else 0
final_output = (
"最终表现:\n"
f" 输入token数 : {total_input_tokens:.2f}\n"
f" 输出token数 : {total_output_tokens:.2f}\n"
f" 并发数 : {C}\n"
f" 总请求数 : {N}\n"
f" 总耗时 : {total_time:.2f}\n"
f" 平均耗时 : {avg_time_per_request:.2f}\n"
f" 吞吐(QPS) : {tokens_per_second:.2f} tokens/s"
)
print(final_output)
log_to_file(output_file, final_output) # 将最终的表现也写入文件

View File

@@ -0,0 +1,138 @@
from pathlib import Path
from pprint import pprint
import os
import shutil
import sys
root_path = Path(__file__).parent.parent
sys.path.append(str(root_path))
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.knowledge_base.utils import get_kb_path, get_doc_path, KnowledgeFile
from server.knowledge_base.migrate import folder2db, prune_db_docs, prune_folder_files
# setup test knowledge base
kb_name = "test_kb_for_migrate"
test_files = {
"readme.md": str(root_path / "readme.md"),
}
kb_path = get_kb_path(kb_name)
doc_path = get_doc_path(kb_name)
if not os.path.isdir(doc_path):
os.makedirs(doc_path)
for k, v in test_files.items():
shutil.copy(v, os.path.join(doc_path, k))
def test_recreate_vs():
folder2db([kb_name], "recreate_vs")
kb = KBServiceFactory.get_service_by_name(kb_name)
assert kb and kb.exists()
files = kb.list_files()
print(files)
for name in test_files:
assert name in files
path = os.path.join(doc_path, name)
# list docs based on file name
docs = kb.list_docs(file_name=name)
assert len(docs) > 0
pprint(docs[0])
for doc in docs:
assert doc.metadata["source"] == name
# list docs base on metadata
docs = kb.list_docs(metadata={"source": name})
assert len(docs) > 0
for doc in docs:
assert doc.metadata["source"] == name
def test_increment():
kb = KBServiceFactory.get_service_by_name(kb_name)
kb.clear_vs()
assert kb.list_files() == []
assert kb.list_docs() == []
folder2db([kb_name], "increment")
files = kb.list_files()
print(files)
for f in test_files:
assert f in files
docs = kb.list_docs(file_name=f)
assert len(docs) > 0
pprint(docs[0])
for doc in docs:
assert doc.metadata["source"] == f
def test_prune_db():
del_file, keep_file = list(test_files)[:2]
os.remove(os.path.join(doc_path, del_file))
prune_db_docs([kb_name])
kb = KBServiceFactory.get_service_by_name(kb_name)
files = kb.list_files()
print(files)
assert del_file not in files
assert keep_file in files
docs = kb.list_docs(file_name=del_file)
assert len(docs) == 0
docs = kb.list_docs(file_name=keep_file)
assert len(docs) > 0
pprint(docs[0])
shutil.copy(test_files[del_file], os.path.join(doc_path, del_file))
def test_prune_folder():
del_file, keep_file = list(test_files)[:2]
kb = KBServiceFactory.get_service_by_name(kb_name)
# delete docs for file
kb.delete_doc(KnowledgeFile(del_file, kb_name))
files = kb.list_files()
print(files)
assert del_file not in files
assert keep_file in files
docs = kb.list_docs(file_name=del_file)
assert len(docs) == 0
docs = kb.list_docs(file_name=keep_file)
assert len(docs) > 0
docs = kb.list_docs(file_name=del_file)
assert len(docs) == 0
assert os.path.isfile(os.path.join(doc_path, del_file))
# prune folder
prune_folder_files([kb_name])
# check result
assert not os.path.isfile(os.path.join(doc_path, del_file))
assert os.path.isfile(os.path.join(doc_path, keep_file))
def test_drop_kb():
kb = KBServiceFactory.get_service_by_name(kb_name)
kb.drop_kb()
assert not kb.exists()
assert not os.path.isdir(kb_path)
kb = KBServiceFactory.get_service_by_name(kb_name)
assert kb is None

View File

@@ -0,0 +1,71 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent
sys.path.append(str(root_path))
from configs import ONLINE_LLM_MODEL
from server.model_workers.base import *
from server.utils import get_model_worker_config, list_config_llm_models
from pprint import pprint
import pytest
workers = []
for x in list_config_llm_models()["online"]:
if x in ONLINE_LLM_MODEL and x not in workers:
workers.append(x)
print(f"all workers to test: {workers}")
# workers = ["fangzhou-api"]
@pytest.mark.parametrize("worker", workers)
def test_chat(worker):
params = ApiChatParams(
messages = [
{"role": "user", "content": "你是谁"},
],
)
print(f"\nchat with {worker} \n")
if worker_class := get_model_worker_config(worker).get("worker_class"):
for x in worker_class().do_chat(params):
pprint(x)
assert isinstance(x, dict)
assert x["error_code"] == 0
@pytest.mark.parametrize("worker", workers)
def test_embeddings(worker):
params = ApiEmbeddingsParams(
texts = [
"LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。",
"一种利用 langchain 思想实现的基于本地知识库的问答应用,目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。",
],
worker_name=worker,
)
if worker_class := get_model_worker_config(worker).get("worker_class"):
if worker_class.can_embedding():
print(f"\embeddings with {worker} \n")
resp = worker_class().do_embeddings(params)
pprint(resp, depth=2)
assert resp["code"] == 200
assert "data" in resp
embeddings = resp["data"]
assert isinstance(embeddings, list) and len(embeddings) > 0
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
assert isinstance(embeddings[0][0], float)
print("向量长度:", len(embeddings[0]))
# @pytest.mark.parametrize("worker", workers)
# def test_completion(worker):
# params = ApiCompletionParams(prompt="五十六个民族")
# print(f"\completion with {worker} \n")
# worker_class = get_model_worker_config(worker)["worker_class"]
# resp = worker_class().do_completion(params)
# pprint(resp)

File diff suppressed because one or more lines are too long