#!/usr/bin/env python3 """ LLM Performance Test Tool 支持本地和云端大模型性能测试,兼容 OpenAI API """ import os import json import time import uuid import statistics from datetime import datetime from concurrent.futures import ThreadPoolExecutor, as_completed from threading import Lock from flask import Flask, render_template, request, jsonify, send_from_directory import requests app = Flask(__name__) app.config['SECRET_KEY'] = 'llm-perf-test-secret-key' # 数据存储目录 DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') os.makedirs(DATA_DIR, exist_ok=True) # 配置文件路径 CONFIG_FILE = os.path.join(DATA_DIR, 'config.json') TEST_CASES_FILE = os.path.join(DATA_DIR, 'test_cases.json') RESULTS_FILE = os.path.join(DATA_DIR, 'results.json') # 默认配置 DEFAULT_CONFIG = { "api_base": "http://localhost:11434/v1", "api_key": "", "model": "qwen2.5:latest", "timeout": 60, "max_tokens": 512, "temperature": 0.7 } # 默认测试用例 DEFAULT_TEST_CASES = [ { "id": "tc_001", "name": "简单问答", "prompt": "你好,请介绍一下自己。", "expected_length": 100 }, { "id": "tc_002", "name": "代码生成", "prompt": "写一个Python函数,计算斐波那契数列的前n项。", "expected_length": 200 }, { "id": "tc_003", "name": "长文本理解", "prompt": """请总结以下段落的主要观点:\n\n人工智能(AI)是计算机科学的一个分支,致力于创造能够执行通常需要人类智能的任务的系统。这些任务包括视觉感知、语音识别、决策制定和语言翻译等。机器学习是AI的一个子集,它使计算机能够从数据中学习并改进,而无需明确编程。深度学习是机器学习的一种特定方法,使用人工神经网络来模拟人脑的工作方式。近年来,随着计算能力的提升和大数据的可用性,AI技术取得了显著进展,在医疗诊断、自动驾驶汽车、自然语言处理等领域展现出巨大潜力。然而,AI的发展也引发了关于隐私、就业和伦理等方面的担忧,需要社会各界共同探讨和制定相应的规范。""", "expected_length": 150 }, { "id": "tc_004", "name": "创意写作", "prompt": "写一个关于未来城市的短篇科幻故事,约300字。", "expected_length": 400 } ] # 全局锁 results_lock = Lock() def load_config(): """加载配置""" if os.path.exists(CONFIG_FILE): with open(CONFIG_FILE, 'r', encoding='utf-8') as f: return {**DEFAULT_CONFIG, **json.load(f)} return DEFAULT_CONFIG.copy() def save_config(config): """保存配置""" with open(CONFIG_FILE, 'w', encoding='utf-8') as f: json.dump(config, f, ensure_ascii=False, indent=2) def load_test_cases(): """加载测试用例""" if os.path.exists(TEST_CASES_FILE): with open(TEST_CASES_FILE, 'r', encoding='utf-8') as f: return json.load(f) return DEFAULT_TEST_CASES.copy() def save_test_cases(test_cases): """保存测试用例""" with open(TEST_CASES_FILE, 'w', encoding='utf-8') as f: json.dump(test_cases, f, ensure_ascii=False, indent=2) def load_results(): """加载历史测试结果""" if os.path.exists(RESULTS_FILE): with open(RESULTS_FILE, 'r', encoding='utf-8') as f: return json.load(f) return [] def save_result(result): """保存测试结果""" with results_lock: results = load_results() results.append(result) with open(RESULTS_FILE, 'w', encoding='utf-8') as f: json.dump(results, f, ensure_ascii=False, indent=2) def stream_chat_completion(api_base, api_key, model, messages, max_tokens, temperature, timeout): """ 流式调用 LLM API,实时计算 TTFT 和 TPS """ headers = { "Content-Type": "application/json" } if api_key: headers["Authorization"] = f"Bearer {api_key}" payload = { "model": model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, "stream": True } url = f"{api_base}/chat/completions" first_token_time = None start_time = time.time() total_tokens = 0 content_chunks = [] try: response = requests.post(url, headers=headers, json=payload, timeout=timeout, stream=True) response.raise_for_status() for line in response.iter_lines(): if not line: continue line_str = line.decode('utf-8') if line_str.startswith('data: '): data_str = line_str[6:] if data_str == '[DONE]': break try: data = json.loads(data_str) delta = data.get('choices', [{}])[0].get('delta', {}) content = delta.get('content', '') if content: if first_token_time is None: first_token_time = time.time() content_chunks.append(content) total_tokens += len(content) # 近似token数 except json.JSONDecodeError: continue end_time = time.time() # 计算指标 ttft = (first_token_time - start_time) * 1000 if first_token_time else 0 # ms total_time = (end_time - start_time) * 1000 # ms tps = total_tokens / (total_time / 1000) if total_time > 0 else 0 return { "success": True, "ttft_ms": round(ttft, 2), "total_time_ms": round(total_time, 2), "tps": round(tps, 2), "total_chars": sum(len(c) for c in content_chunks), "content": ''.join(content_chunks) } except Exception as e: return { "success": False, "error": str(e) } def run_single_test(api_config, test_case, run_index=0): """运行单个测试""" messages = [{"role": "user", "content": test_case["prompt"]}] result = stream_chat_completion( api_base=api_config["api_base"], api_key=api_config["api_key"], model=api_config["model"], messages=messages, max_tokens=api_config.get("max_tokens", 512), temperature=api_config.get("temperature", 0.7), timeout=api_config.get("timeout", 60) ) result["test_case_id"] = test_case["id"] result["test_case_name"] = test_case["name"] result["run_index"] = run_index result["timestamp"] = datetime.now().isoformat() return result def run_batch_tests(api_config, test_cases, runs_per_case=1, concurrency=1): """批量运行测试""" all_tasks = [] for test_case in test_cases: for i in range(runs_per_case): all_tasks.append((api_config, test_case, i)) results = [] completed = 0 total = len(all_tasks) with ThreadPoolExecutor(max_workers=concurrency) as executor: futures = {executor.submit(run_single_test, *task): task for task in all_tasks} for future in as_completed(futures): try: result = future.result() results.append(result) completed += 1 print(f"Progress: {completed}/{total}") except Exception as e: print(f"Test failed: {e}") return results def calculate_statistics(results): """计算统计数据""" successful = [r for r in results if r.get("success")] failed = [r for r in results if not r.get("success")] if not successful: return {"error": "No successful tests"} ttfts = [r["ttft_ms"] for r in successful] tpss = [r["tps"] for r in successful] times = [r["total_time_ms"] for r in successful] stats = { "total_tests": len(results), "successful": len(successful), "failed": len(failed), "success_rate": round(len(successful) / len(results) * 100, 2), "ttft": { "avg": round(statistics.mean(ttfts), 2), "min": round(min(ttfts), 2), "max": round(max(ttfts), 2), "median": round(statistics.median(ttfts), 2) }, "tps": { "avg": round(statistics.mean(tpss), 2), "min": round(min(tpss), 2), "max": round(max(tpss), 2), "median": round(statistics.median(tpss), 2) }, "total_time": { "avg": round(statistics.mean(times), 2), "min": round(min(times), 2), "max": round(max(times), 2) } } return stats # ==================== Flask Routes ==================== @app.route('/') def index(): """主页""" return render_template('index.html') @app.route('/api/config', methods=['GET', 'POST']) def config_api(): """配置管理 API""" if request.method == 'GET': return jsonify(load_config()) else: new_config = request.json save_config(new_config) return jsonify({"status": "success"}) @app.route('/api/test-cases', methods=['GET', 'POST', 'PUT', 'DELETE']) def test_cases_api(): """测试用例管理 API""" if request.method == 'GET': return jsonify(load_test_cases()) elif request.method == 'POST': test_cases = load_test_cases() new_case = request.json new_case['id'] = f"tc_{uuid.uuid4().hex[:6]}" test_cases.append(new_case) save_test_cases(test_cases) return jsonify({"status": "success", "id": new_case['id']}) elif request.method == 'PUT': updated_case = request.json test_cases = load_test_cases() for i, tc in enumerate(test_cases): if tc['id'] == updated_case['id']: test_cases[i] = updated_case break save_test_cases(test_cases) return jsonify({"status": "success"}) elif request.method == 'DELETE': case_id = request.args.get('id') test_cases = load_test_cases() test_cases = [tc for tc in test_cases if tc['id'] != case_id] save_test_cases(test_cases) return jsonify({"status": "success"}) @app.route('/api/run-test', methods=['POST']) def run_test_api(): """运行测试 API""" data = request.json api_config = data.get('config', load_config()) test_case_ids = data.get('test_case_ids', []) runs_per_case = data.get('runs_per_case', 1) concurrency = data.get('concurrency', 1) # 获取要运行的测试用例 all_test_cases = load_test_cases() if test_case_ids: test_cases = [tc for tc in all_test_cases if tc['id'] in test_case_ids] else: test_cases = all_test_cases if not test_cases: return jsonify({"error": "No test cases selected"}), 400 # 运行测试 results = run_batch_tests(api_config, test_cases, runs_per_case, concurrency) # 计算统计 stats = calculate_statistics(results) # 保存结果 test_run = { "id": f"run_{uuid.uuid4().hex[:8]}", "timestamp": datetime.now().isoformat(), "config": api_config, "stats": stats, "results": results } save_result(test_run) return jsonify(test_run) @app.route('/api/results', methods=['GET']) def get_results_api(): """获取历史测试结果""" return jsonify(load_results()) @app.route('/api/results/', methods=['GET']) def get_result_detail_api(result_id): """获取单个测试结果详情""" results = load_results() for result in results: if result.get('id') == result_id: return jsonify(result) return jsonify({"error": "Result not found"}), 404 @app.route('/api/results/', methods=['DELETE']) def delete_result_api(result_id): """删除测试结果""" with results_lock: results = load_results() results = [r for r in results if r.get('id') != result_id] with open(RESULTS_FILE, 'w', encoding='utf-8') as f: json.dump(results, f, ensure_ascii=False, indent=2) return jsonify({"status": "success"}) if __name__ == '__main__': # 初始化默认配置和测试用例 if not os.path.exists(CONFIG_FILE): save_config(DEFAULT_CONFIG) if not os.path.exists(TEST_CASES_FILE): save_test_cases(DEFAULT_TEST_CASES) app.run(host='0.0.0.0', port=8001, debug=True)