[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
214
langchain-chat/tests/api/test_kb_api.py
Normal file
214
langchain-chat/tests/api/test_kb_api.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from server.utils import api_address
|
||||
from configs import VECTOR_SEARCH_TOP_K
|
||||
from server.knowledge_base.utils import get_kb_path, get_file_path
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
|
||||
|
||||
kb = "kb_for_api_test"
|
||||
test_files = {
|
||||
"wiki/Home.MD": get_file_path("samples", "wiki/Home.md"),
|
||||
"wiki/开发环境部署.MD": get_file_path("samples", "wiki/开发环境部署.md"),
|
||||
"test_files/test.txt": get_file_path("samples", "test_files/test.txt"),
|
||||
}
|
||||
|
||||
print("\n\n直接url访问\n")
|
||||
|
||||
|
||||
def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"):
|
||||
if not Path(get_kb_path(kb)).exists():
|
||||
return
|
||||
|
||||
url = api_base_url + api
|
||||
print("\n测试知识库存在,需要删除")
|
||||
r = requests.post(url, json=kb)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
|
||||
# check kb not exists anymore
|
||||
url = api_base_url + "/knowledge_base/list_knowledge_bases"
|
||||
print("\n获取知识库列表:")
|
||||
r = requests.get(url)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert isinstance(data["data"], list) and len(data["data"]) > 0
|
||||
assert kb not in data["data"]
|
||||
|
||||
|
||||
def test_create_kb(api="/knowledge_base/create_knowledge_base"):
|
||||
url = api_base_url + api
|
||||
|
||||
print(f"\n尝试用空名称创建知识库:")
|
||||
r = requests.post(url, json={"knowledge_base_name": " "})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
|
||||
|
||||
print(f"\n创建新知识库: {kb}")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"已新增知识库 {kb}"
|
||||
|
||||
print(f"\n尝试创建同名知识库: {kb}")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == f"已存在同名知识库 {kb}"
|
||||
|
||||
|
||||
def test_list_kbs(api="/knowledge_base/list_knowledge_bases"):
|
||||
url = api_base_url + api
|
||||
print("\n获取知识库列表:")
|
||||
r = requests.get(url)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert isinstance(data["data"], list) and len(data["data"]) > 0
|
||||
assert kb in data["data"]
|
||||
|
||||
|
||||
def test_upload_docs(api="/knowledge_base/upload_docs"):
|
||||
url = api_base_url + api
|
||||
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
|
||||
|
||||
print(f"\n上传知识文件")
|
||||
data = {"knowledge_base_name": kb, "override": True}
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
print(f"\n尝试重新上传知识文件, 不覆盖")
|
||||
data = {"knowledge_base_name": kb, "override": False}
|
||||
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == len(test_files)
|
||||
|
||||
print(f"\n尝试重新上传知识文件, 覆盖,自定义docs")
|
||||
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
|
||||
data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)}
|
||||
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
|
||||
def test_list_files(api="/knowledge_base/list_files"):
|
||||
url = api_base_url + api
|
||||
print("\n获取知识库中文件列表:")
|
||||
r = requests.get(url, params={"knowledge_base_name": kb})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert isinstance(data["data"], list)
|
||||
for name in test_files:
|
||||
assert name in data["data"]
|
||||
|
||||
|
||||
def test_search_docs(api="/knowledge_base/search_docs"):
|
||||
url = api_base_url + api
|
||||
query = "介绍一下langchain-chatchat项目"
|
||||
print("\n检索知识库:")
|
||||
print(query)
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||
|
||||
|
||||
def test_update_info(api="/knowledge_base/update_info"):
|
||||
url = api_base_url + api
|
||||
print("\n更新知识库介绍")
|
||||
r = requests.post(url, json={"knowledge_base_name": "samples", "kb_info": "你好"})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
|
||||
def test_update_docs(api="/knowledge_base/update_docs"):
|
||||
url = api_base_url + api
|
||||
|
||||
print(f"\n更新知识文件")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
|
||||
def test_delete_docs(api="/knowledge_base/delete_docs"):
|
||||
url = api_base_url + api
|
||||
|
||||
print(f"\n删除知识文件")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
url = api_base_url + "/knowledge_base/search_docs"
|
||||
query = "介绍一下langchain-chatchat项目"
|
||||
print("\n尝试检索删除后的检索知识库:")
|
||||
print(query)
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == 0
|
||||
|
||||
|
||||
def test_recreate_vs(api="/knowledge_base/recreate_vector_store"):
|
||||
url = api_base_url + api
|
||||
print("\n重建知识库:")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb}, stream=True)
|
||||
for chunk in r.iter_content(None):
|
||||
data = json.loads(chunk[6:])
|
||||
assert isinstance(data, dict)
|
||||
assert data["code"] == 200
|
||||
print(data["msg"])
|
||||
|
||||
url = api_base_url + "/knowledge_base/search_docs"
|
||||
query = "本项目支持哪些文件格式?"
|
||||
print("\n尝试检索重建后的检索知识库:")
|
||||
print(query)
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||
|
||||
|
||||
def test_delete_kb_after(api="/knowledge_base/delete_knowledge_base"):
|
||||
url = api_base_url + api
|
||||
print("\n删除知识库")
|
||||
r = requests.post(url, json=kb)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
|
||||
# check kb not exists anymore
|
||||
url = api_base_url + "/knowledge_base/list_knowledge_bases"
|
||||
print("\n获取知识库列表:")
|
||||
r = requests.get(url)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert isinstance(data["data"], list) and len(data["data"]) > 0
|
||||
assert kb not in data["data"]
|
||||
161
langchain-chat/tests/api/test_kb_api_request.py
Normal file
161
langchain-chat/tests/api/test_kb_api_request.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from server.utils import api_address
|
||||
from configs import VECTOR_SEARCH_TOP_K
|
||||
from server.knowledge_base.utils import get_kb_path, get_file_path
|
||||
from webui_pages.utils import ApiRequest
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
api: ApiRequest = ApiRequest(api_base_url)
|
||||
|
||||
|
||||
kb = "kb_for_api_test"
|
||||
test_files = {
|
||||
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"),
|
||||
"README.MD": str(root_path / "README.MD"),
|
||||
"test.txt": get_file_path("samples", "test.txt"),
|
||||
}
|
||||
|
||||
print("\n\nApiRquest调用\n")
|
||||
|
||||
|
||||
def test_delete_kb_before():
|
||||
if not Path(get_kb_path(kb)).exists():
|
||||
return
|
||||
|
||||
data = api.delete_knowledge_base(kb)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert isinstance(data["data"], list) and len(data["data"]) > 0
|
||||
assert kb not in data["data"]
|
||||
|
||||
|
||||
def test_create_kb():
|
||||
print(f"\n尝试用空名称创建知识库:")
|
||||
data = api.create_knowledge_base(" ")
|
||||
pprint(data)
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
|
||||
|
||||
print(f"\n创建新知识库: {kb}")
|
||||
data = api.create_knowledge_base(kb)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"已新增知识库 {kb}"
|
||||
|
||||
print(f"\n尝试创建同名知识库: {kb}")
|
||||
data = api.create_knowledge_base(kb)
|
||||
pprint(data)
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == f"已存在同名知识库 {kb}"
|
||||
|
||||
|
||||
def test_list_kbs():
|
||||
data = api.list_knowledge_bases()
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) > 0
|
||||
assert kb in data
|
||||
|
||||
|
||||
def test_upload_docs():
|
||||
files = list(test_files.values())
|
||||
|
||||
print(f"\n上传知识文件")
|
||||
data = {"knowledge_base_name": kb, "override": True}
|
||||
data = api.upload_kb_docs(files, **data)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
print(f"\n尝试重新上传知识文件, 不覆盖")
|
||||
data = {"knowledge_base_name": kb, "override": False}
|
||||
data = api.upload_kb_docs(files, **data)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == len(test_files)
|
||||
|
||||
print(f"\n尝试重新上传知识文件, 覆盖,自定义docs")
|
||||
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
|
||||
data = {"knowledge_base_name": kb, "override": True, "docs": docs}
|
||||
data = api.upload_kb_docs(files, **data)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
|
||||
def test_list_files():
|
||||
print("\n获取知识库中文件列表:")
|
||||
data = api.list_kb_docs(knowledge_base_name=kb)
|
||||
pprint(data)
|
||||
assert isinstance(data, list)
|
||||
for name in test_files:
|
||||
assert name in data
|
||||
|
||||
|
||||
def test_search_docs():
|
||||
query = "介绍一下langchain-chatchat项目"
|
||||
print("\n检索知识库:")
|
||||
print(query)
|
||||
data = api.search_kb_docs(query, kb)
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||
|
||||
|
||||
def test_update_docs():
|
||||
print(f"\n更新知识文件")
|
||||
data = api.update_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
|
||||
def test_delete_docs():
|
||||
print(f"\n删除知识文件")
|
||||
data = api.delete_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
query = "介绍一下langchain-chatchat项目"
|
||||
print("\n尝试检索删除后的检索知识库:")
|
||||
print(query)
|
||||
data = api.search_kb_docs(query, kb)
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == 0
|
||||
|
||||
|
||||
def test_recreate_vs():
|
||||
print("\n重建知识库:")
|
||||
r = api.recreate_vector_store(kb)
|
||||
for data in r:
|
||||
assert isinstance(data, dict)
|
||||
assert data["code"] == 200
|
||||
print(data["msg"])
|
||||
|
||||
query = "本项目支持哪些文件格式?"
|
||||
print("\n尝试检索重建后的检索知识库:")
|
||||
print(query)
|
||||
data = api.search_kb_docs(query, kb)
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||
|
||||
|
||||
def test_delete_kb_after():
|
||||
print("\n删除知识库")
|
||||
data = api.delete_knowledge_base(kb)
|
||||
pprint(data)
|
||||
|
||||
# check kb not exists anymore
|
||||
print("\n获取知识库列表:")
|
||||
data = api.list_knowledge_bases()
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) > 0
|
||||
assert kb not in data
|
||||
44
langchain-chat/tests/api/test_kb_summary_api.py
Normal file
44
langchain-chat/tests/api/test_kb_summary_api.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from server.utils import api_address
|
||||
|
||||
api_base_url = api_address()
|
||||
|
||||
kb = "samples"
|
||||
file_name = "/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/knowledge_base/samples/content/llm/大模型技术栈-实战与应用.md"
|
||||
doc_ids = [
|
||||
"357d580f-fdf7-495c-b58b-595a398284e8",
|
||||
"c7338773-2e83-4671-b237-1ad20335b0f0",
|
||||
"6da613d1-327d-466f-8c1a-b32e6f461f47"
|
||||
]
|
||||
|
||||
|
||||
def test_summary_file_to_vector_store(api="/knowledge_base/kb_summary_api/summary_file_to_vector_store"):
|
||||
url = api_base_url + api
|
||||
print("\n文件摘要:")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb,
|
||||
"file_name": file_name
|
||||
}, stream=True)
|
||||
for chunk in r.iter_content(None):
|
||||
data = json.loads(chunk[6:])
|
||||
assert isinstance(data, dict)
|
||||
assert data["code"] == 200
|
||||
print(data["msg"])
|
||||
|
||||
|
||||
def test_summary_doc_ids_to_vector_store(api="/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store"):
|
||||
url = api_base_url + api
|
||||
print("\n文件摘要:")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb,
|
||||
"doc_ids": doc_ids
|
||||
}, stream=True)
|
||||
for chunk in r.iter_content(None):
|
||||
data = json.loads(chunk[6:])
|
||||
assert isinstance(data, dict)
|
||||
assert data["code"] == 200
|
||||
print(data)
|
||||
70
langchain-chat/tests/api/test_llm_api.py
Normal file
70
langchain-chat/tests/api/test_llm_api.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
from server.utils import api_address, get_model_worker_config
|
||||
|
||||
from pprint import pprint
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
|
||||
def get_configured_models() -> List[str]:
|
||||
model_workers = list(FSCHAT_MODEL_WORKERS)
|
||||
if "default" in model_workers:
|
||||
model_workers.remove("default")
|
||||
return model_workers
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
|
||||
|
||||
def get_running_models(api="/llm_model/list_models"):
|
||||
url = api_base_url + api
|
||||
r = requests.post(url)
|
||||
if r.status_code == 200:
|
||||
return r.json()["data"]
|
||||
return []
|
||||
|
||||
|
||||
def test_running_models(api="/llm_model/list_running_models"):
|
||||
url = api_base_url + api
|
||||
r = requests.post(url)
|
||||
assert r.status_code == 200
|
||||
print("\n获取当前正在运行的模型列表:")
|
||||
pprint(r.json())
|
||||
assert isinstance(r.json()["data"], list)
|
||||
assert len(r.json()["data"]) > 0
|
||||
|
||||
|
||||
# 不建议使用stop_model功能。按现在的实现,停止了就只能手动再启动
|
||||
# def test_stop_model(api="/llm_model/stop"):
|
||||
# url = api_base_url + api
|
||||
# r = requests.post(url, json={""})
|
||||
|
||||
|
||||
def test_change_model(api="/llm_model/change_model"):
|
||||
url = api_base_url + api
|
||||
|
||||
running_models = get_running_models()
|
||||
assert len(running_models) > 0
|
||||
|
||||
model_workers = get_configured_models()
|
||||
|
||||
availabel_new_models = list(set(model_workers) - set(running_models))
|
||||
assert len(availabel_new_models) > 0
|
||||
print(availabel_new_models)
|
||||
|
||||
local_models = [x for x in running_models if not get_model_worker_config(x).get("online_api")]
|
||||
model_name = random.choice(local_models)
|
||||
new_model_name = random.choice(availabel_new_models)
|
||||
print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}")
|
||||
r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name})
|
||||
assert r.status_code == 200
|
||||
|
||||
running_models = get_running_models()
|
||||
assert new_model_name in running_models
|
||||
47
langchain-chat/tests/api/test_server_state_api.py
Normal file
47
langchain-chat/tests/api/test_server_state_api.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
|
||||
from webui_pages.utils import ApiRequest
|
||||
|
||||
import pytest
|
||||
from pprint import pprint
|
||||
from typing import List
|
||||
|
||||
|
||||
api = ApiRequest()
|
||||
|
||||
|
||||
def test_get_default_llm():
|
||||
llm = api.get_default_llm_model()
|
||||
|
||||
print(llm)
|
||||
assert isinstance(llm, tuple)
|
||||
assert isinstance(llm[0], str) and isinstance(llm[1], bool)
|
||||
|
||||
|
||||
def test_server_configs():
|
||||
configs = api.get_server_configs()
|
||||
pprint(configs, depth=2)
|
||||
|
||||
assert isinstance(configs, dict)
|
||||
assert len(configs) > 0
|
||||
|
||||
|
||||
def test_list_search_engines():
|
||||
engines = api.list_search_engines()
|
||||
pprint(engines)
|
||||
|
||||
assert isinstance(engines, list)
|
||||
assert len(engines) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type", ["llm_chat", "agent_chat"])
|
||||
def test_get_prompt_template(type):
|
||||
print(f"prompt template for: {type}")
|
||||
template = api.get_prompt_template(type=type)
|
||||
|
||||
print(template)
|
||||
assert isinstance(template, str)
|
||||
assert len(template) > 0
|
||||
113
langchain-chat/tests/api/test_stream_chat_api.py
Normal file
113
langchain-chat/tests/api/test_stream_chat_api.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
from configs import BING_SUBSCRIPTION_KEY
|
||||
from server.utils import api_address
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
|
||||
|
||||
def dump_input(d, title):
|
||||
print("\n")
|
||||
print("=" * 30 + title + " input " + "="*30)
|
||||
pprint(d)
|
||||
|
||||
|
||||
def dump_output(r, title):
|
||||
print("\n")
|
||||
print("=" * 30 + title + " output" + "="*30)
|
||||
for line in r.iter_content(None, decode_unicode=True):
|
||||
print(line, end="", flush=True)
|
||||
|
||||
|
||||
headers = {
|
||||
'accept': 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
data = {
|
||||
"query": "请用100字左右的文字介绍自己",
|
||||
"history": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "你好,我是人工智能大模型"
|
||||
}
|
||||
],
|
||||
"stream": True,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
|
||||
|
||||
def test_chat_chat(api="/chat/chat"):
|
||||
url = f"{api_base_url}{api}"
|
||||
dump_input(data, api)
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
dump_output(response, api)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_knowledge_chat(api="/chat/knowledge_base_chat"):
|
||||
url = f"{api_base_url}{api}"
|
||||
data = {
|
||||
"query": "如何提问以获得高质量答案",
|
||||
"knowledge_base_name": "samples",
|
||||
"history": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "你好,我是 ChatGLM"
|
||||
}
|
||||
],
|
||||
"stream": True
|
||||
}
|
||||
dump_input(data, api)
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
print("\n")
|
||||
print("=" * 30 + api + " output" + "="*30)
|
||||
for line in response.iter_content(None, decode_unicode=True):
|
||||
data = json.loads(line[6:])
|
||||
if "answer" in data:
|
||||
print(data["answer"], end="", flush=True)
|
||||
pprint(data)
|
||||
assert "docs" in data and len(data["docs"]) > 0
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_search_engine_chat(api="/chat/search_engine_chat"):
|
||||
global data
|
||||
|
||||
data["query"] = "室温超导最新进展是什么样?"
|
||||
|
||||
url = f"{api_base_url}{api}"
|
||||
for se in ["bing", "duckduckgo"]:
|
||||
data["search_engine_name"] = se
|
||||
dump_input(data, api + f" by {se}")
|
||||
response = requests.post(url, json=data, stream=True)
|
||||
if se == "bing" and not BING_SUBSCRIPTION_KEY:
|
||||
data = response.json()
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`"
|
||||
|
||||
print("\n")
|
||||
print("=" * 30 + api + f" by {se} output" + "="*30)
|
||||
for line in response.iter_content(None, decode_unicode=True):
|
||||
data = json.loads(line[6:])
|
||||
if "answer" in data:
|
||||
print(data["answer"], end="", flush=True)
|
||||
assert "docs" in data and len(data["docs"]) > 0
|
||||
pprint(data["docs"])
|
||||
assert response.status_code == 200
|
||||
|
||||
81
langchain-chat/tests/api/test_stream_chat_api_thread.py
Normal file
81
langchain-chat/tests/api/test_stream_chat_api_thread.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
from configs import BING_SUBSCRIPTION_KEY
|
||||
from server.utils import api_address
|
||||
|
||||
from pprint import pprint
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import time
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
|
||||
|
||||
def dump_input(d, title):
|
||||
print("\n")
|
||||
print("=" * 30 + title + " input " + "="*30)
|
||||
pprint(d)
|
||||
|
||||
|
||||
def dump_output(r, title):
|
||||
print("\n")
|
||||
print("=" * 30 + title + " output" + "="*30)
|
||||
for line in r.iter_content(None, decode_unicode=True):
|
||||
print(line, end="", flush=True)
|
||||
|
||||
|
||||
headers = {
|
||||
'accept': 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
|
||||
def knowledge_chat(api="/chat/knowledge_base_chat"):
|
||||
url = f"{api_base_url}{api}"
|
||||
data = {
|
||||
"query": "如何提问以获得高质量答案",
|
||||
"knowledge_base_name": "samples",
|
||||
"history": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "你好,我是 ChatGLM"
|
||||
}
|
||||
],
|
||||
"stream": True
|
||||
}
|
||||
result = []
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
|
||||
for line in response.iter_content(None, decode_unicode=True):
|
||||
data = json.loads(line[6:])
|
||||
result.append(data)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def test_thread():
|
||||
threads = []
|
||||
times = []
|
||||
pool = ThreadPoolExecutor()
|
||||
start = time.time()
|
||||
for i in range(10):
|
||||
t = pool.submit(knowledge_chat)
|
||||
threads.append(t)
|
||||
|
||||
for r in as_completed(threads):
|
||||
end = time.time()
|
||||
times.append(end - start)
|
||||
print("\nResult:\n")
|
||||
pprint(r.result())
|
||||
|
||||
print("\nTime used:\n")
|
||||
for x in times:
|
||||
print(f"{x}")
|
||||
@@ -0,0 +1,53 @@
|
||||
import os
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
import sys
|
||||
|
||||
sys.path.append("../..")
|
||||
from configs import (
|
||||
CHUNK_SIZE,
|
||||
OVERLAP_SIZE
|
||||
)
|
||||
|
||||
from server.knowledge_base.utils import make_text_splitter
|
||||
|
||||
def text(splitter_name):
|
||||
from langchain import document_loaders
|
||||
|
||||
# 使用DocumentLoader读取文件
|
||||
filepath = "../../knowledge_base/samples/content/test.txt"
|
||||
loader = document_loaders.UnstructuredFileLoader(filepath, autodetect_encoding=True)
|
||||
docs = loader.load()
|
||||
text_splitter = make_text_splitter(splitter_name, CHUNK_SIZE, OVERLAP_SIZE)
|
||||
if splitter_name == "MarkdownHeaderTextSplitter":
|
||||
docs = text_splitter.split_text(docs[0].page_content)
|
||||
for doc in docs:
|
||||
if doc.metadata:
|
||||
doc.metadata["source"] = os.path.basename(filepath)
|
||||
else:
|
||||
docs = text_splitter.split_documents(docs)
|
||||
for doc in docs:
|
||||
print(doc)
|
||||
return docs
|
||||
|
||||
|
||||
|
||||
|
||||
import pytest
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
@pytest.mark.parametrize("splitter_name",
|
||||
[
|
||||
"ChineseRecursiveTextSplitter",
|
||||
"SpacyTextSplitter",
|
||||
"RecursiveCharacterTextSplitter",
|
||||
"MarkdownHeaderTextSplitter"
|
||||
])
|
||||
def test_different_splitter(splitter_name):
|
||||
try:
|
||||
docs = text(splitter_name)
|
||||
assert isinstance(docs, list)
|
||||
if len(docs)>0:
|
||||
assert isinstance(docs[0], Document)
|
||||
except Exception as e:
|
||||
pytest.fail(f"test_different_splitter failed with {splitter_name}, error: {str(e)}")
|
||||
10
langchain-chat/tests/document_loader/test_html.py
Normal file
10
langchain-chat/tests/document_loader/test_html.py
Normal file
@@ -0,0 +1,10 @@
|
||||
data_path = './人工智能发展月报.html'
|
||||
from langchain_community.document_loaders import TextLoader
|
||||
|
||||
loader = TextLoader(data_path)
|
||||
data = loader.load()
|
||||
print(data)
|
||||
|
||||
from unstructured.partition.html import partition_html
|
||||
rst = partition_html(text=data[0].page_content)
|
||||
print("\n\n".join([str(el) for el in rst]))
|
||||
21
langchain-chat/tests/document_loader/test_imgloader.py
Normal file
21
langchain-chat/tests/document_loader/test_imgloader.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from pprint import pprint
|
||||
|
||||
test_files = {
|
||||
"ocr_test.jpg": str(root_path / "tests" / "samples" / "ocr_test.jpg"),
|
||||
}
|
||||
|
||||
def test_rapidocrloader():
|
||||
img_path = test_files["ocr_test.jpg"]
|
||||
from document_loaders import RapidOCRLoader
|
||||
|
||||
loader = RapidOCRLoader(img_path)
|
||||
docs = loader.load()
|
||||
pprint(docs)
|
||||
assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str)
|
||||
|
||||
|
||||
21
langchain-chat/tests/document_loader/test_pdfloader.py
Normal file
21
langchain-chat/tests/document_loader/test_pdfloader.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from pprint import pprint
|
||||
|
||||
test_files = {
|
||||
"ocr_test.pdf": str(root_path / "tests" / "samples" / "ocr_test.pdf"),
|
||||
}
|
||||
|
||||
def test_rapidocrpdfloader():
|
||||
pdf_path = test_files["ocr_test.pdf"]
|
||||
from document_loaders import RapidOCRPDFLoader
|
||||
|
||||
loader = RapidOCRPDFLoader(pdf_path)
|
||||
docs = loader.load()
|
||||
pprint(docs)
|
||||
assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str)
|
||||
|
||||
|
||||
106
langchain-chat/tests/document_loader/人工智能发展月报.html
Normal file
106
langchain-chat/tests/document_loader/人工智能发展月报.html
Normal file
File diff suppressed because one or more lines are too long
55
langchain-chat/tests/docx_parser.py
Normal file
55
langchain-chat/tests/docx_parser.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import zipfile
|
||||
from lxml import etree
|
||||
import json
|
||||
import os
|
||||
|
||||
def extract_docx_info(docx_path: str, output_json: str = 'docx_info.json'):
|
||||
if not os.path.isfile(docx_path):
|
||||
raise FileNotFoundError(f"未找到 DOCX 文件:{docx_path}")
|
||||
|
||||
tag_counts = {}
|
||||
rels_counts = {}
|
||||
media_files = []
|
||||
|
||||
with zipfile.ZipFile(docx_path) as docx:
|
||||
file_list = docx.namelist()
|
||||
|
||||
# 统计 document.xml 中的元素标签
|
||||
doc_xml = docx.read('word/document.xml')
|
||||
doc_tree = etree.fromstring(doc_xml)
|
||||
for elem in doc_tree.iter():
|
||||
tag = etree.QName(elem).localname
|
||||
tag_counts[tag] = tag_counts.get(tag, 0) + 1
|
||||
|
||||
# 统计文档关系(document.xml.rels)
|
||||
rels_path = 'word/_rels/document.xml.rels'
|
||||
if rels_path in file_list:
|
||||
rels_xml = docx.read(rels_path)
|
||||
rels_tree = etree.fromstring(rels_xml)
|
||||
for rel in rels_tree.findall(
|
||||
'.//{http://schemas.openxmlformats.org/package/2006/relationships}Relationship'
|
||||
):
|
||||
rel_type = rel.get('Type').split('/')[-1]
|
||||
rels_counts[rel_type] = rels_counts.get(rel_type, 0) + 1
|
||||
|
||||
# 列出所有嵌入的媒体文件
|
||||
media_files = [f for f in file_list if f.startswith('word/media/')]
|
||||
|
||||
# 汇总信息
|
||||
info = {
|
||||
'source_docx': os.path.basename(docx_path),
|
||||
'elements': tag_counts,
|
||||
'relationships': rels_counts,
|
||||
'media_files': media_files,
|
||||
}
|
||||
|
||||
# 写入 JSON
|
||||
with open(output_json, 'w', encoding='utf-8') as f:
|
||||
json.dump(info, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"已生成 JSON 文件:{output_json}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 读取当前目录下的 test.docx
|
||||
extract_docx_info('./test.docx', 'docx_info.json')
|
||||
0
langchain-chat/tests/kb_vector_db/__init__.py
Normal file
0
langchain-chat/tests/kb_vector_db/__init__.py
Normal file
35
langchain-chat/tests/kb_vector_db/test_faiss_kb.py
Normal file
35
langchain-chat/tests/kb_vector_db/test_faiss_kb.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||
from server.knowledge_base.migrate import create_tables
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
|
||||
|
||||
kbService = FaissKBService("test")
|
||||
test_kb_name = "test"
|
||||
test_file_name = "README.md"
|
||||
testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name)
|
||||
search_content = "如何启动api服务"
|
||||
|
||||
|
||||
def test_init():
|
||||
create_tables()
|
||||
|
||||
|
||||
def test_create_db():
|
||||
assert kbService.create_kb()
|
||||
|
||||
|
||||
def test_add_doc():
|
||||
assert kbService.add_doc(testKnowledgeFile)
|
||||
|
||||
|
||||
def test_search_db():
|
||||
result = kbService.search_docs(search_content)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_delete_doc():
|
||||
assert kbService.delete_doc(testKnowledgeFile)
|
||||
|
||||
|
||||
def test_delete_db():
|
||||
assert kbService.drop_kb()
|
||||
4
langchain-chat/tests/kb_vector_db/test_milvus_2026.py
Normal file
4
langchain-chat/tests/kb_vector_db/test_milvus_2026.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
client = MilvusClient("http://127.0.0.1:19530")
|
||||
print(client.list_collections())
|
||||
31
langchain-chat/tests/kb_vector_db/test_milvus_db.py
Normal file
31
langchain-chat/tests/kb_vector_db/test_milvus_db.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||
# from server.knowledge_base.kb_service.pg_kb_service import PGKBService
|
||||
from server.knowledge_base.migrate import create_tables
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
|
||||
kbService = MilvusKBService("test")
|
||||
|
||||
test_kb_name = "test"
|
||||
test_file_name = "README.md"
|
||||
testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name)
|
||||
search_content = "如何启动api服务"
|
||||
|
||||
def test_init():
|
||||
create_tables()
|
||||
|
||||
|
||||
def test_create_db():
|
||||
assert kbService.create_kb()
|
||||
|
||||
|
||||
def test_add_doc():
|
||||
assert kbService.add_doc(testKnowledgeFile)
|
||||
|
||||
|
||||
def test_search_db():
|
||||
result = kbService.search_docs(search_content)
|
||||
assert len(result) > 0
|
||||
def test_delete_doc():
|
||||
assert kbService.delete_doc(testKnowledgeFile)
|
||||
|
||||
31
langchain-chat/tests/kb_vector_db/test_pg_db.py
Normal file
31
langchain-chat/tests/kb_vector_db/test_pg_db.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||
from server.knowledge_base.kb_service.pg_kb_service import PGKBService
|
||||
from server.knowledge_base.migrate import create_tables
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
|
||||
kbService = PGKBService("test")
|
||||
|
||||
test_kb_name = "test"
|
||||
test_file_name = "README.md"
|
||||
testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name)
|
||||
search_content = "如何启动api服务"
|
||||
|
||||
|
||||
def test_init():
|
||||
create_tables()
|
||||
|
||||
|
||||
def test_create_db():
|
||||
assert kbService.create_kb()
|
||||
|
||||
|
||||
def test_add_doc():
|
||||
assert kbService.add_doc(testKnowledgeFile)
|
||||
|
||||
|
||||
def test_search_db():
|
||||
result = kbService.search_docs(search_content)
|
||||
assert len(result) > 0
|
||||
def test_delete_doc():
|
||||
assert kbService.delete_doc(testKnowledgeFile)
|
||||
|
||||
BIN
langchain-chat/tests/samples/ocr_test.jpg
Normal file
BIN
langchain-chat/tests/samples/ocr_test.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.9 KiB |
BIN
langchain-chat/tests/samples/ocr_test.pdf
Normal file
BIN
langchain-chat/tests/samples/ocr_test.pdf
Normal file
Binary file not shown.
79
langchain-chat/tests/test_Qwen72B_32B.py
Normal file
79
langchain-chat/tests/test_Qwen72B_32B.py
Normal file
File diff suppressed because one or more lines are too long
138
langchain-chat/tests/test_concurrency_csv.py
Normal file
138
langchain-chat/tests/test_concurrency_csv.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import requests
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from tabulate import tabulate
|
||||
|
||||
# 模型配置
|
||||
model_config = {
|
||||
"Qwen1.5-32B-Chat": {
|
||||
"model_name": "Qwen1.5-32B-Chat",
|
||||
"api_base_url": "http://192.168.56.123:8821/v1",
|
||||
"api_key": "fake",
|
||||
}
|
||||
}
|
||||
|
||||
# 选择要测试的模型
|
||||
selected_model = "Qwen1.5-32B-Chat"
|
||||
API_URL = model_config[selected_model]["api_base_url"] + "/completions"
|
||||
API_KEY = model_config[selected_model]["api_key"]
|
||||
MODEL_NAME = model_config[selected_model]["model_name"]
|
||||
|
||||
# 测试输入
|
||||
TEST_PROMPT = "从前有座山"
|
||||
|
||||
# 测试每秒输出字符数
|
||||
def test_output_speed(prompt, api_url, api_key, model_name, max_tokens):
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
data = {
|
||||
"model": model_name,
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
end_time = time.time()
|
||||
|
||||
if response.status_code == 200:
|
||||
output_text = response.json().get("choices", [{}])[0].get("text", "")
|
||||
char_count = len(output_text)
|
||||
elapsed_time = end_time - start_time
|
||||
chars_per_second = char_count / elapsed_time if elapsed_time > 0 else 0
|
||||
result = [
|
||||
max_tokens,
|
||||
char_count,
|
||||
f"{elapsed_time:.2f} ",
|
||||
f"{chars_per_second:.2f} "
|
||||
]
|
||||
return result
|
||||
else:
|
||||
print(f"错误: {response.status_code}, {response.text}")
|
||||
return max_tokens, 0, 0, 0
|
||||
|
||||
# 测试最大并发请求数
|
||||
def make_request(api_url, api_key, prompt, model_name, max_tokens):
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
data = {
|
||||
"model": model_name,
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens
|
||||
}
|
||||
start_time = time.time()
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
end_time = time.time()
|
||||
response_time = end_time - start_time
|
||||
return response.status_code, response.text, response_time
|
||||
|
||||
def test_concurrent_requests(api_url, api_key, prompt, model_name, max_workers, max_tokens):
|
||||
total_requests = max_workers
|
||||
success_count = 0
|
||||
total_response_time = 0
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(make_request, api_url, api_key, prompt, model_name, max_tokens) for _ in range(max_workers)]
|
||||
for future in as_completed(futures):
|
||||
status_code, _, response_time = future.result()
|
||||
total_response_time += response_time
|
||||
if status_code == 200:
|
||||
success_count += 1
|
||||
|
||||
average_response_time = total_response_time / total_requests if total_requests > 0 else 0
|
||||
throughput = success_count / total_response_time if total_response_time > 0 else 0
|
||||
success_rate = (success_count / total_requests) * 100 if total_requests > 0 else 0
|
||||
|
||||
result = [
|
||||
max_tokens,
|
||||
max_workers,
|
||||
total_requests,
|
||||
f"{average_response_time:.2f}",
|
||||
f"{throughput:.2f}",
|
||||
success_count,
|
||||
f"{success_rate:.2f}"
|
||||
]
|
||||
return result
|
||||
|
||||
# 动态测试最大并发请求数
|
||||
def find_max_concurrent_requests(api_url, api_key, prompt, model_name, max_tokens, max_workers_list):
|
||||
results = []
|
||||
for max_workers in max_workers_list:
|
||||
result = test_concurrent_requests(api_url, api_key, prompt, model_name, max_workers, max_tokens)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
# 输出为制表符分隔的文本格式
|
||||
def print_csv(headers, rows):
|
||||
header_line = ",".join(headers)
|
||||
print(header_line)
|
||||
for row in rows:
|
||||
row_line = ",".join(map(str, row))
|
||||
print(row_line)
|
||||
|
||||
# 主程序
|
||||
if __name__ == "__main__":
|
||||
# 定义要测试的 max_tokens 列表
|
||||
max_tokens_list = [512, 1024, 2048, 4096, 8192]
|
||||
max_workers_list = [10, 20, 50, 100]
|
||||
# max_tokens_list = [10, 20, 30]
|
||||
# max_workers_list = [2, 4, 8, 16]
|
||||
|
||||
output_speed_results = []
|
||||
for max_tokens in max_tokens_list:
|
||||
result = test_output_speed(TEST_PROMPT, API_URL, API_KEY, MODEL_NAME, max_tokens)
|
||||
output_speed_results.append(result)
|
||||
|
||||
# 输出每秒输出字符数的测试结果表格
|
||||
output_speed_headers = ["max_tokens", "输出字符数", "耗时 (秒)", "每秒字符数"]
|
||||
print("每秒输出字符数测试结果:")
|
||||
print_csv(output_speed_headers, output_speed_results)
|
||||
|
||||
# 动态测试最大并发请求数
|
||||
all_concurrency_results = []
|
||||
for max_tokens in max_tokens_list:
|
||||
results = find_max_concurrent_requests(API_URL, API_KEY, TEST_PROMPT, MODEL_NAME, max_tokens, max_workers_list)
|
||||
all_concurrency_results.extend(results)
|
||||
|
||||
# 输出并发请求测试结果
|
||||
concurrency_headers = ["max_tokens", "并发请求数", "请求个数", "平均响应时间(RT,单位秒)", "吞吐量 (QPS)", "请求成功个数", "请求成功率 (%)"]
|
||||
print("\n测试并发结果:")
|
||||
print_csv(concurrency_headers, all_concurrency_results)
|
||||
162
langchain-chat/tests/test_llm_bench_qa.py
Normal file
162
langchain-chat/tests/test_llm_bench_qa.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
import sys
|
||||
|
||||
# 配置
|
||||
# LLM_MODEL = "Qwen2-72B-Instruct"
|
||||
# LLM_ENDPOINT = "http://192.168.56.123:8822/v1"
|
||||
|
||||
# LLM_MODEL = "deepseek-chat"
|
||||
# LLM_ENDPOINT = "https://api.deepseek.com/v1"
|
||||
|
||||
LLM_MODEL = "qwen-max-2025-01-25"
|
||||
LLM_ENDPOINT = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
|
||||
TEMPERATURE = 0.7 # 确保每次返回的都不同
|
||||
MAX_TOKENS = 8192
|
||||
|
||||
# 问题列表
|
||||
questions = [
|
||||
"为什么鸟儿会唱歌?", "为什么我们有季节?", "为什么星星会闪烁?", "为什么我们会打哈欠?",
|
||||
"为什么太阳是热的?", "为什么猫会咕噜咕噜叫?", "为什么狗会吠?", "为什么鱼会游泳?",
|
||||
"为什么我们有指纹?", "为什么我们会打喷嚏?", "为什么我们有眉毛?", "为什么我们有头发?",
|
||||
"为什么我们有指甲?", "为什么我们有牙齿?", "为什么我们有骨头?", "为什么我们有肌肉?",
|
||||
"为什么我们有血液?", "为什么我们有心脏?", "为什么我们有肺?", "为什么我们有大脑?",
|
||||
"为什么我们有皮肤?", "为什么我们有耳朵?", "为什么我们有眼睛?", "为什么我们有鼻子?",
|
||||
"为什么我们有嘴巴?", "为什么我们有舌头?", "为什么我们有胃?", "为什么我们有肠子?",
|
||||
"为什么我们有肝脏?", "为什么我们有肾脏?", "为什么我们有膀胱?", "为什么我们有胰腺?",
|
||||
"为什么我们有脾脏?", "为什么我们有胆囊?", "为什么我们有甲状腺?", "为什么我们有肾上腺?",
|
||||
"为什么我们有垂体?", "为什么我们有下丘脑?", "为什么我们有胸腺?", "为什么我们有淋巴结?",
|
||||
"为什么我们有脊髓?", "为什么我们有神经?", "为什么我们有循环系统?", "为什么我们有呼吸系统?",
|
||||
"为什么我们有消化系统?", "为什么我们有免疫系统?"
|
||||
]
|
||||
def log_to_file(file, message):
|
||||
"""将消息追加写入指定的文件"""
|
||||
with open(file, 'a', encoding='utf-8') as f:
|
||||
f.write(message + '\n')
|
||||
|
||||
async def fetch(session, url, file=None):
|
||||
start_time = time.time()
|
||||
|
||||
question = random.choice(questions)
|
||||
json_payload = {
|
||||
"model": LLM_MODEL,
|
||||
"messages": [
|
||||
# {"role": "system", "content": "你的任务是学习和理解用户输入的文段,分析其中的实体关系,然后根据关系逻辑重新拟定你份合同,合同字数要在4000字以上。"},
|
||||
# {"role": "system", "content": "你的使命是翻译用户输入的文段。注意,一定要翻译完整。"},
|
||||
{"role": "system", "content": "你的任务是用专业学术语言,严谨的科学态度,全面的数据支持,完整详实地回答user的问题。"},
|
||||
{"role": "user", "content": question}
|
||||
],
|
||||
"temperature": TEMPERATURE,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
"stream": False
|
||||
}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
# "Authorization": "Bearer sk-dba93353b0cc447ba55245e4f048c779" # deepseek
|
||||
"Authorization": "sk-8b498c0de2dc437aab8efa490d4021ba" # qwen
|
||||
}
|
||||
try:
|
||||
async with session.post(url, json=json_payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
print(f"错误: 收到响应码 {response.status}")
|
||||
return 0, 0, 0
|
||||
|
||||
response_json = await response.json()
|
||||
end_time = time.time()
|
||||
request_time = end_time - start_time
|
||||
|
||||
completion_tokens = 0
|
||||
if 'usage' in response_json:
|
||||
usage = response_json['usage']
|
||||
completion_tokens = usage.get('completion_tokens', 0)
|
||||
prompt_tokens = usage.get('prompt_tokens', 0)
|
||||
else:
|
||||
print("警告: 响应中缺少 'usage' 字段。")
|
||||
|
||||
answer = ""
|
||||
if 'choices' in response_json and len(response_json['choices']) > 0:
|
||||
answer = response_json['choices'][0]['message']['content']
|
||||
# completion_tokens = len(answer) / 1.5 # qwen
|
||||
# completion_tokens = len(answer) * 0.6 # deepseek
|
||||
else:
|
||||
print("警告: 响应中缺少 'choices' 字段或内容为空。")
|
||||
completion_tokens = 0
|
||||
|
||||
# 将输入输出写入文件(保持原有日志格式不变)
|
||||
if file:
|
||||
log_to_file(file, f"Q: {question}\nA: {answer}\n")
|
||||
|
||||
# 打印输入和输出的token数
|
||||
# print(f"输入token数: {input_tokens}, 输出token数: {output_tokens}")
|
||||
return prompt_tokens, completion_tokens, request_time
|
||||
except Exception as e:
|
||||
print(f"请求过程中发生异常: {e}")
|
||||
return 0, 0, 0
|
||||
|
||||
async def bound_fetch(sem, session, url, pbar, file=None):
|
||||
async with sem:
|
||||
result = await fetch(session, url, file=file)
|
||||
pbar.update(1)
|
||||
return result
|
||||
|
||||
async def run(load_url, max_concurrent_requests, total_requests, output_file):
|
||||
sem = asyncio.Semaphore(max_concurrent_requests)
|
||||
timeout = aiohttp.ClientTimeout(total=6000, connect=6000, sock_read=6000, sock_connect=6000)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
tasks = []
|
||||
with tqdm(total=total_requests) as pbar:
|
||||
for _ in range(total_requests):
|
||||
task = asyncio.create_task(bound_fetch(sem, session, load_url, pbar, file=output_file))
|
||||
tasks.append(task)
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# 聚合token数和响应时间
|
||||
total_input_tokens = sum(result[0] for result in results)
|
||||
total_output_tokens = sum(result[1] for result in results)
|
||||
response_times = [result[2] for result in results]
|
||||
|
||||
return total_input_tokens, total_output_tokens, response_times
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) != 3:
|
||||
print("用法: python llm_test.py <C> <N>")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
C = int(sys.argv[1]) # 最大并发数
|
||||
N = int(sys.argv[2]) # 请求总数
|
||||
except ValueError:
|
||||
print("错误: C 和 N 必须是整数。")
|
||||
sys.exit(1)
|
||||
|
||||
url = f'{LLM_ENDPOINT}/chat/completions'
|
||||
|
||||
output_file = 'A800_Qwen2.5-72Bint8_bench_.txt'
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write('') # 清空文件内容
|
||||
|
||||
start_time = time.time()
|
||||
total_input_tokens, total_output_tokens, response_times = asyncio.run(run(url, C, N, output_file))
|
||||
end_time = time.time()
|
||||
|
||||
total_time = end_time - start_time
|
||||
avg_time_per_request = sum(response_times) / len(response_times) if response_times else 0
|
||||
tokens_per_second = (total_output_tokens) / total_time if total_time > 0 else 0
|
||||
|
||||
final_output = (
|
||||
"最终表现:\n"
|
||||
f" 输入token数 : {total_input_tokens:.2f}\n"
|
||||
f" 输出token数 : {total_output_tokens:.2f}\n"
|
||||
f" 并发数 : {C}\n"
|
||||
f" 总请求数 : {N}\n"
|
||||
f" 总耗时 : {total_time:.2f} 秒\n"
|
||||
f" 平均耗时 : {avg_time_per_request:.2f} 秒\n"
|
||||
f" 吞吐(QPS) : {tokens_per_second:.2f} tokens/s"
|
||||
)
|
||||
|
||||
print(final_output)
|
||||
log_to_file(output_file, final_output) # 将最终的表现也写入文件
|
||||
138
langchain-chat/tests/test_migrate.py
Normal file
138
langchain-chat/tests/test_migrate.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
root_path = Path(__file__).parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.knowledge_base.utils import get_kb_path, get_doc_path, KnowledgeFile
|
||||
from server.knowledge_base.migrate import folder2db, prune_db_docs, prune_folder_files
|
||||
|
||||
|
||||
# setup test knowledge base
|
||||
kb_name = "test_kb_for_migrate"
|
||||
test_files = {
|
||||
"readme.md": str(root_path / "readme.md"),
|
||||
}
|
||||
|
||||
|
||||
kb_path = get_kb_path(kb_name)
|
||||
doc_path = get_doc_path(kb_name)
|
||||
|
||||
if not os.path.isdir(doc_path):
|
||||
os.makedirs(doc_path)
|
||||
|
||||
for k, v in test_files.items():
|
||||
shutil.copy(v, os.path.join(doc_path, k))
|
||||
|
||||
|
||||
def test_recreate_vs():
|
||||
folder2db([kb_name], "recreate_vs")
|
||||
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
assert kb and kb.exists()
|
||||
|
||||
files = kb.list_files()
|
||||
print(files)
|
||||
for name in test_files:
|
||||
assert name in files
|
||||
path = os.path.join(doc_path, name)
|
||||
|
||||
# list docs based on file name
|
||||
docs = kb.list_docs(file_name=name)
|
||||
assert len(docs) > 0
|
||||
pprint(docs[0])
|
||||
for doc in docs:
|
||||
assert doc.metadata["source"] == name
|
||||
|
||||
# list docs base on metadata
|
||||
docs = kb.list_docs(metadata={"source": name})
|
||||
assert len(docs) > 0
|
||||
|
||||
for doc in docs:
|
||||
assert doc.metadata["source"] == name
|
||||
|
||||
|
||||
def test_increment():
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
kb.clear_vs()
|
||||
assert kb.list_files() == []
|
||||
assert kb.list_docs() == []
|
||||
|
||||
folder2db([kb_name], "increment")
|
||||
|
||||
files = kb.list_files()
|
||||
print(files)
|
||||
for f in test_files:
|
||||
assert f in files
|
||||
|
||||
docs = kb.list_docs(file_name=f)
|
||||
assert len(docs) > 0
|
||||
pprint(docs[0])
|
||||
|
||||
for doc in docs:
|
||||
assert doc.metadata["source"] == f
|
||||
|
||||
|
||||
def test_prune_db():
|
||||
del_file, keep_file = list(test_files)[:2]
|
||||
os.remove(os.path.join(doc_path, del_file))
|
||||
|
||||
prune_db_docs([kb_name])
|
||||
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
files = kb.list_files()
|
||||
print(files)
|
||||
assert del_file not in files
|
||||
assert keep_file in files
|
||||
|
||||
docs = kb.list_docs(file_name=del_file)
|
||||
assert len(docs) == 0
|
||||
|
||||
docs = kb.list_docs(file_name=keep_file)
|
||||
assert len(docs) > 0
|
||||
pprint(docs[0])
|
||||
|
||||
shutil.copy(test_files[del_file], os.path.join(doc_path, del_file))
|
||||
|
||||
|
||||
def test_prune_folder():
|
||||
del_file, keep_file = list(test_files)[:2]
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
|
||||
# delete docs for file
|
||||
kb.delete_doc(KnowledgeFile(del_file, kb_name))
|
||||
files = kb.list_files()
|
||||
print(files)
|
||||
assert del_file not in files
|
||||
assert keep_file in files
|
||||
|
||||
docs = kb.list_docs(file_name=del_file)
|
||||
assert len(docs) == 0
|
||||
|
||||
docs = kb.list_docs(file_name=keep_file)
|
||||
assert len(docs) > 0
|
||||
|
||||
docs = kb.list_docs(file_name=del_file)
|
||||
assert len(docs) == 0
|
||||
|
||||
assert os.path.isfile(os.path.join(doc_path, del_file))
|
||||
|
||||
# prune folder
|
||||
prune_folder_files([kb_name])
|
||||
|
||||
# check result
|
||||
assert not os.path.isfile(os.path.join(doc_path, del_file))
|
||||
assert os.path.isfile(os.path.join(doc_path, keep_file))
|
||||
|
||||
|
||||
def test_drop_kb():
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
kb.drop_kb()
|
||||
assert not kb.exists()
|
||||
assert not os.path.isdir(kb_path)
|
||||
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
assert kb is None
|
||||
71
langchain-chat/tests/test_online_api.py
Normal file
71
langchain-chat/tests/test_online_api.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
root_path = Path(__file__).parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
|
||||
from configs import ONLINE_LLM_MODEL
|
||||
from server.model_workers.base import *
|
||||
from server.utils import get_model_worker_config, list_config_llm_models
|
||||
from pprint import pprint
|
||||
import pytest
|
||||
|
||||
|
||||
workers = []
|
||||
for x in list_config_llm_models()["online"]:
|
||||
if x in ONLINE_LLM_MODEL and x not in workers:
|
||||
workers.append(x)
|
||||
print(f"all workers to test: {workers}")
|
||||
|
||||
# workers = ["fangzhou-api"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("worker", workers)
|
||||
def test_chat(worker):
|
||||
params = ApiChatParams(
|
||||
messages = [
|
||||
{"role": "user", "content": "你是谁"},
|
||||
],
|
||||
)
|
||||
print(f"\nchat with {worker} \n")
|
||||
|
||||
if worker_class := get_model_worker_config(worker).get("worker_class"):
|
||||
for x in worker_class().do_chat(params):
|
||||
pprint(x)
|
||||
assert isinstance(x, dict)
|
||||
assert x["error_code"] == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("worker", workers)
|
||||
def test_embeddings(worker):
|
||||
params = ApiEmbeddingsParams(
|
||||
texts = [
|
||||
"LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。",
|
||||
"一种利用 langchain 思想实现的基于本地知识库的问答应用,目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。",
|
||||
],
|
||||
worker_name=worker,
|
||||
)
|
||||
|
||||
if worker_class := get_model_worker_config(worker).get("worker_class"):
|
||||
if worker_class.can_embedding():
|
||||
print(f"\embeddings with {worker} \n")
|
||||
resp = worker_class().do_embeddings(params)
|
||||
|
||||
pprint(resp, depth=2)
|
||||
assert resp["code"] == 200
|
||||
assert "data" in resp
|
||||
embeddings = resp["data"]
|
||||
assert isinstance(embeddings, list) and len(embeddings) > 0
|
||||
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
|
||||
assert isinstance(embeddings[0][0], float)
|
||||
print("向量长度:", len(embeddings[0]))
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("worker", workers)
|
||||
# def test_completion(worker):
|
||||
# params = ApiCompletionParams(prompt="五十六个民族")
|
||||
|
||||
# print(f"\completion with {worker} \n")
|
||||
|
||||
# worker_class = get_model_worker_config(worker)["worker_class"]
|
||||
# resp = worker_class().do_completion(params)
|
||||
# pprint(resp)
|
||||
76
langchain-chat/tests/test_textrank.py
Normal file
76
langchain-chat/tests/test_textrank.py
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user