[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
214
langchain-chat/tests/api/test_kb_api.py
Normal file
214
langchain-chat/tests/api/test_kb_api.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from server.utils import api_address
|
||||
from configs import VECTOR_SEARCH_TOP_K
|
||||
from server.knowledge_base.utils import get_kb_path, get_file_path
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
|
||||
|
||||
kb = "kb_for_api_test"
|
||||
test_files = {
|
||||
"wiki/Home.MD": get_file_path("samples", "wiki/Home.md"),
|
||||
"wiki/开发环境部署.MD": get_file_path("samples", "wiki/开发环境部署.md"),
|
||||
"test_files/test.txt": get_file_path("samples", "test_files/test.txt"),
|
||||
}
|
||||
|
||||
print("\n\n直接url访问\n")
|
||||
|
||||
|
||||
def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"):
|
||||
if not Path(get_kb_path(kb)).exists():
|
||||
return
|
||||
|
||||
url = api_base_url + api
|
||||
print("\n测试知识库存在,需要删除")
|
||||
r = requests.post(url, json=kb)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
|
||||
# check kb not exists anymore
|
||||
url = api_base_url + "/knowledge_base/list_knowledge_bases"
|
||||
print("\n获取知识库列表:")
|
||||
r = requests.get(url)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert isinstance(data["data"], list) and len(data["data"]) > 0
|
||||
assert kb not in data["data"]
|
||||
|
||||
|
||||
def test_create_kb(api="/knowledge_base/create_knowledge_base"):
|
||||
url = api_base_url + api
|
||||
|
||||
print(f"\n尝试用空名称创建知识库:")
|
||||
r = requests.post(url, json={"knowledge_base_name": " "})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
|
||||
|
||||
print(f"\n创建新知识库: {kb}")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"已新增知识库 {kb}"
|
||||
|
||||
print(f"\n尝试创建同名知识库: {kb}")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == f"已存在同名知识库 {kb}"
|
||||
|
||||
|
||||
def test_list_kbs(api="/knowledge_base/list_knowledge_bases"):
|
||||
url = api_base_url + api
|
||||
print("\n获取知识库列表:")
|
||||
r = requests.get(url)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert isinstance(data["data"], list) and len(data["data"]) > 0
|
||||
assert kb in data["data"]
|
||||
|
||||
|
||||
def test_upload_docs(api="/knowledge_base/upload_docs"):
|
||||
url = api_base_url + api
|
||||
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
|
||||
|
||||
print(f"\n上传知识文件")
|
||||
data = {"knowledge_base_name": kb, "override": True}
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
print(f"\n尝试重新上传知识文件, 不覆盖")
|
||||
data = {"knowledge_base_name": kb, "override": False}
|
||||
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == len(test_files)
|
||||
|
||||
print(f"\n尝试重新上传知识文件, 覆盖,自定义docs")
|
||||
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
|
||||
data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)}
|
||||
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
|
||||
def test_list_files(api="/knowledge_base/list_files"):
|
||||
url = api_base_url + api
|
||||
print("\n获取知识库中文件列表:")
|
||||
r = requests.get(url, params={"knowledge_base_name": kb})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert isinstance(data["data"], list)
|
||||
for name in test_files:
|
||||
assert name in data["data"]
|
||||
|
||||
|
||||
def test_search_docs(api="/knowledge_base/search_docs"):
|
||||
url = api_base_url + api
|
||||
query = "介绍一下langchain-chatchat项目"
|
||||
print("\n检索知识库:")
|
||||
print(query)
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||
|
||||
|
||||
def test_update_info(api="/knowledge_base/update_info"):
|
||||
url = api_base_url + api
|
||||
print("\n更新知识库介绍")
|
||||
r = requests.post(url, json={"knowledge_base_name": "samples", "kb_info": "你好"})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
|
||||
def test_update_docs(api="/knowledge_base/update_docs"):
|
||||
url = api_base_url + api
|
||||
|
||||
print(f"\n更新知识文件")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
|
||||
def test_delete_docs(api="/knowledge_base/delete_docs"):
|
||||
url = api_base_url + api
|
||||
|
||||
print(f"\n删除知识文件")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
url = api_base_url + "/knowledge_base/search_docs"
|
||||
query = "介绍一下langchain-chatchat项目"
|
||||
print("\n尝试检索删除后的检索知识库:")
|
||||
print(query)
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == 0
|
||||
|
||||
|
||||
def test_recreate_vs(api="/knowledge_base/recreate_vector_store"):
|
||||
url = api_base_url + api
|
||||
print("\n重建知识库:")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb}, stream=True)
|
||||
for chunk in r.iter_content(None):
|
||||
data = json.loads(chunk[6:])
|
||||
assert isinstance(data, dict)
|
||||
assert data["code"] == 200
|
||||
print(data["msg"])
|
||||
|
||||
url = api_base_url + "/knowledge_base/search_docs"
|
||||
query = "本项目支持哪些文件格式?"
|
||||
print("\n尝试检索重建后的检索知识库:")
|
||||
print(query)
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||
|
||||
|
||||
def test_delete_kb_after(api="/knowledge_base/delete_knowledge_base"):
|
||||
url = api_base_url + api
|
||||
print("\n删除知识库")
|
||||
r = requests.post(url, json=kb)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
|
||||
# check kb not exists anymore
|
||||
url = api_base_url + "/knowledge_base/list_knowledge_bases"
|
||||
print("\n获取知识库列表:")
|
||||
r = requests.get(url)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert isinstance(data["data"], list) and len(data["data"]) > 0
|
||||
assert kb not in data["data"]
|
||||
161
langchain-chat/tests/api/test_kb_api_request.py
Normal file
161
langchain-chat/tests/api/test_kb_api_request.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from server.utils import api_address
|
||||
from configs import VECTOR_SEARCH_TOP_K
|
||||
from server.knowledge_base.utils import get_kb_path, get_file_path
|
||||
from webui_pages.utils import ApiRequest
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
api: ApiRequest = ApiRequest(api_base_url)
|
||||
|
||||
|
||||
kb = "kb_for_api_test"
|
||||
test_files = {
|
||||
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"),
|
||||
"README.MD": str(root_path / "README.MD"),
|
||||
"test.txt": get_file_path("samples", "test.txt"),
|
||||
}
|
||||
|
||||
print("\n\nApiRquest调用\n")
|
||||
|
||||
|
||||
def test_delete_kb_before():
|
||||
if not Path(get_kb_path(kb)).exists():
|
||||
return
|
||||
|
||||
data = api.delete_knowledge_base(kb)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert isinstance(data["data"], list) and len(data["data"]) > 0
|
||||
assert kb not in data["data"]
|
||||
|
||||
|
||||
def test_create_kb():
|
||||
print(f"\n尝试用空名称创建知识库:")
|
||||
data = api.create_knowledge_base(" ")
|
||||
pprint(data)
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
|
||||
|
||||
print(f"\n创建新知识库: {kb}")
|
||||
data = api.create_knowledge_base(kb)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"已新增知识库 {kb}"
|
||||
|
||||
print(f"\n尝试创建同名知识库: {kb}")
|
||||
data = api.create_knowledge_base(kb)
|
||||
pprint(data)
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == f"已存在同名知识库 {kb}"
|
||||
|
||||
|
||||
def test_list_kbs():
|
||||
data = api.list_knowledge_bases()
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) > 0
|
||||
assert kb in data
|
||||
|
||||
|
||||
def test_upload_docs():
|
||||
files = list(test_files.values())
|
||||
|
||||
print(f"\n上传知识文件")
|
||||
data = {"knowledge_base_name": kb, "override": True}
|
||||
data = api.upload_kb_docs(files, **data)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
print(f"\n尝试重新上传知识文件, 不覆盖")
|
||||
data = {"knowledge_base_name": kb, "override": False}
|
||||
data = api.upload_kb_docs(files, **data)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == len(test_files)
|
||||
|
||||
print(f"\n尝试重新上传知识文件, 覆盖,自定义docs")
|
||||
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
|
||||
data = {"knowledge_base_name": kb, "override": True, "docs": docs}
|
||||
data = api.upload_kb_docs(files, **data)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
|
||||
def test_list_files():
|
||||
print("\n获取知识库中文件列表:")
|
||||
data = api.list_kb_docs(knowledge_base_name=kb)
|
||||
pprint(data)
|
||||
assert isinstance(data, list)
|
||||
for name in test_files:
|
||||
assert name in data
|
||||
|
||||
|
||||
def test_search_docs():
|
||||
query = "介绍一下langchain-chatchat项目"
|
||||
print("\n检索知识库:")
|
||||
print(query)
|
||||
data = api.search_kb_docs(query, kb)
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||
|
||||
|
||||
def test_update_docs():
|
||||
print(f"\n更新知识文件")
|
||||
data = api.update_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
|
||||
def test_delete_docs():
|
||||
print(f"\n删除知识文件")
|
||||
data = api.delete_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
query = "介绍一下langchain-chatchat项目"
|
||||
print("\n尝试检索删除后的检索知识库:")
|
||||
print(query)
|
||||
data = api.search_kb_docs(query, kb)
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == 0
|
||||
|
||||
|
||||
def test_recreate_vs():
|
||||
print("\n重建知识库:")
|
||||
r = api.recreate_vector_store(kb)
|
||||
for data in r:
|
||||
assert isinstance(data, dict)
|
||||
assert data["code"] == 200
|
||||
print(data["msg"])
|
||||
|
||||
query = "本项目支持哪些文件格式?"
|
||||
print("\n尝试检索重建后的检索知识库:")
|
||||
print(query)
|
||||
data = api.search_kb_docs(query, kb)
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||
|
||||
|
||||
def test_delete_kb_after():
|
||||
print("\n删除知识库")
|
||||
data = api.delete_knowledge_base(kb)
|
||||
pprint(data)
|
||||
|
||||
# check kb not exists anymore
|
||||
print("\n获取知识库列表:")
|
||||
data = api.list_knowledge_bases()
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) > 0
|
||||
assert kb not in data
|
||||
44
langchain-chat/tests/api/test_kb_summary_api.py
Normal file
44
langchain-chat/tests/api/test_kb_summary_api.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from server.utils import api_address
|
||||
|
||||
api_base_url = api_address()
|
||||
|
||||
kb = "samples"
|
||||
file_name = "/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/knowledge_base/samples/content/llm/大模型技术栈-实战与应用.md"
|
||||
doc_ids = [
|
||||
"357d580f-fdf7-495c-b58b-595a398284e8",
|
||||
"c7338773-2e83-4671-b237-1ad20335b0f0",
|
||||
"6da613d1-327d-466f-8c1a-b32e6f461f47"
|
||||
]
|
||||
|
||||
|
||||
def test_summary_file_to_vector_store(api="/knowledge_base/kb_summary_api/summary_file_to_vector_store"):
|
||||
url = api_base_url + api
|
||||
print("\n文件摘要:")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb,
|
||||
"file_name": file_name
|
||||
}, stream=True)
|
||||
for chunk in r.iter_content(None):
|
||||
data = json.loads(chunk[6:])
|
||||
assert isinstance(data, dict)
|
||||
assert data["code"] == 200
|
||||
print(data["msg"])
|
||||
|
||||
|
||||
def test_summary_doc_ids_to_vector_store(api="/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store"):
|
||||
url = api_base_url + api
|
||||
print("\n文件摘要:")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb,
|
||||
"doc_ids": doc_ids
|
||||
}, stream=True)
|
||||
for chunk in r.iter_content(None):
|
||||
data = json.loads(chunk[6:])
|
||||
assert isinstance(data, dict)
|
||||
assert data["code"] == 200
|
||||
print(data)
|
||||
70
langchain-chat/tests/api/test_llm_api.py
Normal file
70
langchain-chat/tests/api/test_llm_api.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
from server.utils import api_address, get_model_worker_config
|
||||
|
||||
from pprint import pprint
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
|
||||
def get_configured_models() -> List[str]:
|
||||
model_workers = list(FSCHAT_MODEL_WORKERS)
|
||||
if "default" in model_workers:
|
||||
model_workers.remove("default")
|
||||
return model_workers
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
|
||||
|
||||
def get_running_models(api="/llm_model/list_models"):
|
||||
url = api_base_url + api
|
||||
r = requests.post(url)
|
||||
if r.status_code == 200:
|
||||
return r.json()["data"]
|
||||
return []
|
||||
|
||||
|
||||
def test_running_models(api="/llm_model/list_running_models"):
|
||||
url = api_base_url + api
|
||||
r = requests.post(url)
|
||||
assert r.status_code == 200
|
||||
print("\n获取当前正在运行的模型列表:")
|
||||
pprint(r.json())
|
||||
assert isinstance(r.json()["data"], list)
|
||||
assert len(r.json()["data"]) > 0
|
||||
|
||||
|
||||
# 不建议使用stop_model功能。按现在的实现,停止了就只能手动再启动
|
||||
# def test_stop_model(api="/llm_model/stop"):
|
||||
# url = api_base_url + api
|
||||
# r = requests.post(url, json={""})
|
||||
|
||||
|
||||
def test_change_model(api="/llm_model/change_model"):
|
||||
url = api_base_url + api
|
||||
|
||||
running_models = get_running_models()
|
||||
assert len(running_models) > 0
|
||||
|
||||
model_workers = get_configured_models()
|
||||
|
||||
availabel_new_models = list(set(model_workers) - set(running_models))
|
||||
assert len(availabel_new_models) > 0
|
||||
print(availabel_new_models)
|
||||
|
||||
local_models = [x for x in running_models if not get_model_worker_config(x).get("online_api")]
|
||||
model_name = random.choice(local_models)
|
||||
new_model_name = random.choice(availabel_new_models)
|
||||
print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}")
|
||||
r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name})
|
||||
assert r.status_code == 200
|
||||
|
||||
running_models = get_running_models()
|
||||
assert new_model_name in running_models
|
||||
47
langchain-chat/tests/api/test_server_state_api.py
Normal file
47
langchain-chat/tests/api/test_server_state_api.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
|
||||
from webui_pages.utils import ApiRequest
|
||||
|
||||
import pytest
|
||||
from pprint import pprint
|
||||
from typing import List
|
||||
|
||||
|
||||
api = ApiRequest()
|
||||
|
||||
|
||||
def test_get_default_llm():
|
||||
llm = api.get_default_llm_model()
|
||||
|
||||
print(llm)
|
||||
assert isinstance(llm, tuple)
|
||||
assert isinstance(llm[0], str) and isinstance(llm[1], bool)
|
||||
|
||||
|
||||
def test_server_configs():
|
||||
configs = api.get_server_configs()
|
||||
pprint(configs, depth=2)
|
||||
|
||||
assert isinstance(configs, dict)
|
||||
assert len(configs) > 0
|
||||
|
||||
|
||||
def test_list_search_engines():
|
||||
engines = api.list_search_engines()
|
||||
pprint(engines)
|
||||
|
||||
assert isinstance(engines, list)
|
||||
assert len(engines) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type", ["llm_chat", "agent_chat"])
|
||||
def test_get_prompt_template(type):
|
||||
print(f"prompt template for: {type}")
|
||||
template = api.get_prompt_template(type=type)
|
||||
|
||||
print(template)
|
||||
assert isinstance(template, str)
|
||||
assert len(template) > 0
|
||||
113
langchain-chat/tests/api/test_stream_chat_api.py
Normal file
113
langchain-chat/tests/api/test_stream_chat_api.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
from configs import BING_SUBSCRIPTION_KEY
|
||||
from server.utils import api_address
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
|
||||
|
||||
def dump_input(d, title):
|
||||
print("\n")
|
||||
print("=" * 30 + title + " input " + "="*30)
|
||||
pprint(d)
|
||||
|
||||
|
||||
def dump_output(r, title):
|
||||
print("\n")
|
||||
print("=" * 30 + title + " output" + "="*30)
|
||||
for line in r.iter_content(None, decode_unicode=True):
|
||||
print(line, end="", flush=True)
|
||||
|
||||
|
||||
headers = {
|
||||
'accept': 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
data = {
|
||||
"query": "请用100字左右的文字介绍自己",
|
||||
"history": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "你好,我是人工智能大模型"
|
||||
}
|
||||
],
|
||||
"stream": True,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
|
||||
|
||||
def test_chat_chat(api="/chat/chat"):
|
||||
url = f"{api_base_url}{api}"
|
||||
dump_input(data, api)
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
dump_output(response, api)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_knowledge_chat(api="/chat/knowledge_base_chat"):
|
||||
url = f"{api_base_url}{api}"
|
||||
data = {
|
||||
"query": "如何提问以获得高质量答案",
|
||||
"knowledge_base_name": "samples",
|
||||
"history": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "你好,我是 ChatGLM"
|
||||
}
|
||||
],
|
||||
"stream": True
|
||||
}
|
||||
dump_input(data, api)
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
print("\n")
|
||||
print("=" * 30 + api + " output" + "="*30)
|
||||
for line in response.iter_content(None, decode_unicode=True):
|
||||
data = json.loads(line[6:])
|
||||
if "answer" in data:
|
||||
print(data["answer"], end="", flush=True)
|
||||
pprint(data)
|
||||
assert "docs" in data and len(data["docs"]) > 0
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_search_engine_chat(api="/chat/search_engine_chat"):
|
||||
global data
|
||||
|
||||
data["query"] = "室温超导最新进展是什么样?"
|
||||
|
||||
url = f"{api_base_url}{api}"
|
||||
for se in ["bing", "duckduckgo"]:
|
||||
data["search_engine_name"] = se
|
||||
dump_input(data, api + f" by {se}")
|
||||
response = requests.post(url, json=data, stream=True)
|
||||
if se == "bing" and not BING_SUBSCRIPTION_KEY:
|
||||
data = response.json()
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`"
|
||||
|
||||
print("\n")
|
||||
print("=" * 30 + api + f" by {se} output" + "="*30)
|
||||
for line in response.iter_content(None, decode_unicode=True):
|
||||
data = json.loads(line[6:])
|
||||
if "answer" in data:
|
||||
print(data["answer"], end="", flush=True)
|
||||
assert "docs" in data and len(data["docs"]) > 0
|
||||
pprint(data["docs"])
|
||||
assert response.status_code == 200
|
||||
|
||||
81
langchain-chat/tests/api/test_stream_chat_api_thread.py
Normal file
81
langchain-chat/tests/api/test_stream_chat_api_thread.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
from configs import BING_SUBSCRIPTION_KEY
|
||||
from server.utils import api_address
|
||||
|
||||
from pprint import pprint
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import time
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
|
||||
|
||||
def dump_input(d, title):
|
||||
print("\n")
|
||||
print("=" * 30 + title + " input " + "="*30)
|
||||
pprint(d)
|
||||
|
||||
|
||||
def dump_output(r, title):
|
||||
print("\n")
|
||||
print("=" * 30 + title + " output" + "="*30)
|
||||
for line in r.iter_content(None, decode_unicode=True):
|
||||
print(line, end="", flush=True)
|
||||
|
||||
|
||||
headers = {
|
||||
'accept': 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
|
||||
def knowledge_chat(api="/chat/knowledge_base_chat"):
|
||||
url = f"{api_base_url}{api}"
|
||||
data = {
|
||||
"query": "如何提问以获得高质量答案",
|
||||
"knowledge_base_name": "samples",
|
||||
"history": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "你好,我是 ChatGLM"
|
||||
}
|
||||
],
|
||||
"stream": True
|
||||
}
|
||||
result = []
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
|
||||
for line in response.iter_content(None, decode_unicode=True):
|
||||
data = json.loads(line[6:])
|
||||
result.append(data)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def test_thread():
|
||||
threads = []
|
||||
times = []
|
||||
pool = ThreadPoolExecutor()
|
||||
start = time.time()
|
||||
for i in range(10):
|
||||
t = pool.submit(knowledge_chat)
|
||||
threads.append(t)
|
||||
|
||||
for r in as_completed(threads):
|
||||
end = time.time()
|
||||
times.append(end - start)
|
||||
print("\nResult:\n")
|
||||
pprint(r.result())
|
||||
|
||||
print("\nTime used:\n")
|
||||
for x in times:
|
||||
print(f"{x}")
|
||||
Reference in New Issue
Block a user