Files
gangyan/langchain-chat/tests/test_concurrency_csv.py

139 lines
4.9 KiB
Python
Raw Normal View History

import requests
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from tabulate import tabulate
# 模型配置
model_config = {
"Qwen1.5-32B-Chat": {
"model_name": "Qwen1.5-32B-Chat",
"api_base_url": "http://192.168.56.123:8821/v1",
"api_key": "fake",
}
}
# 选择要测试的模型
selected_model = "Qwen1.5-32B-Chat"
API_URL = model_config[selected_model]["api_base_url"] + "/completions"
API_KEY = model_config[selected_model]["api_key"]
MODEL_NAME = model_config[selected_model]["model_name"]
# 测试输入
TEST_PROMPT = "从前有座山"
# 测试每秒输出字符数
def test_output_speed(prompt, api_url, api_key, model_name, max_tokens):
headers = {"Authorization": f"Bearer {api_key}"}
data = {
"model": model_name,
"prompt": prompt,
"max_tokens": max_tokens
}
start_time = time.time()
response = requests.post(api_url, headers=headers, json=data)
end_time = time.time()
if response.status_code == 200:
output_text = response.json().get("choices", [{}])[0].get("text", "")
char_count = len(output_text)
elapsed_time = end_time - start_time
chars_per_second = char_count / elapsed_time if elapsed_time > 0 else 0
result = [
max_tokens,
char_count,
f"{elapsed_time:.2f} ",
f"{chars_per_second:.2f} "
]
return result
else:
print(f"错误: {response.status_code}, {response.text}")
return max_tokens, 0, 0, 0
# 测试最大并发请求数
def make_request(api_url, api_key, prompt, model_name, max_tokens):
headers = {"Authorization": f"Bearer {api_key}"}
data = {
"model": model_name,
"prompt": prompt,
"max_tokens": max_tokens
}
start_time = time.time()
response = requests.post(api_url, headers=headers, json=data)
end_time = time.time()
response_time = end_time - start_time
return response.status_code, response.text, response_time
def test_concurrent_requests(api_url, api_key, prompt, model_name, max_workers, max_tokens):
total_requests = max_workers
success_count = 0
total_response_time = 0
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(make_request, api_url, api_key, prompt, model_name, max_tokens) for _ in range(max_workers)]
for future in as_completed(futures):
status_code, _, response_time = future.result()
total_response_time += response_time
if status_code == 200:
success_count += 1
average_response_time = total_response_time / total_requests if total_requests > 0 else 0
throughput = success_count / total_response_time if total_response_time > 0 else 0
success_rate = (success_count / total_requests) * 100 if total_requests > 0 else 0
result = [
max_tokens,
max_workers,
total_requests,
f"{average_response_time:.2f}",
f"{throughput:.2f}",
success_count,
f"{success_rate:.2f}"
]
return result
# 动态测试最大并发请求数
def find_max_concurrent_requests(api_url, api_key, prompt, model_name, max_tokens, max_workers_list):
results = []
for max_workers in max_workers_list:
result = test_concurrent_requests(api_url, api_key, prompt, model_name, max_workers, max_tokens)
results.append(result)
return results
# 输出为制表符分隔的文本格式
def print_csv(headers, rows):
header_line = ",".join(headers)
print(header_line)
for row in rows:
row_line = ",".join(map(str, row))
print(row_line)
# 主程序
if __name__ == "__main__":
# 定义要测试的 max_tokens 列表
max_tokens_list = [512, 1024, 2048, 4096, 8192]
max_workers_list = [10, 20, 50, 100]
# max_tokens_list = [10, 20, 30]
# max_workers_list = [2, 4, 8, 16]
output_speed_results = []
for max_tokens in max_tokens_list:
result = test_output_speed(TEST_PROMPT, API_URL, API_KEY, MODEL_NAME, max_tokens)
output_speed_results.append(result)
# 输出每秒输出字符数的测试结果表格
output_speed_headers = ["max_tokens", "输出字符数", "耗时 (秒)", "每秒字符数"]
print("每秒输出字符数测试结果:")
print_csv(output_speed_headers, output_speed_results)
# 动态测试最大并发请求数
all_concurrency_results = []
for max_tokens in max_tokens_list:
results = find_max_concurrent_requests(API_URL, API_KEY, TEST_PROMPT, MODEL_NAME, max_tokens, max_workers_list)
all_concurrency_results.extend(results)
# 输出并发请求测试结果
concurrency_headers = ["max_tokens", "并发请求数", "请求个数", "平均响应时间(RT,单位秒)", "吞吐量 (QPS)", "请求成功个数", "请求成功率 (%)"]
print("\n测试并发结果:")
print_csv(concurrency_headers, all_concurrency_results)