benchmark.py
1 # /// script 2 # requires-python = ">=3.10" 3 # dependencies = ["aiohttp>=3.13.3,<4", "rich>=14.3.3,<15"] 4 # /// 5 """Async HTTP benchmark client for the MLflow AI Gateway. 6 7 Can be imported by run.py or used standalone: 8 uv run benchmark.py --url http://127.0.0.1:5731/gateway/benchmark-chat/mlflow/invocations 9 uv run benchmark.py --url http://... --requests 5000 --max-concurrent 100 --runs 3 10 """ 11 12 import argparse 13 import asyncio 14 import math 15 import statistics 16 import time 17 from dataclasses import dataclass, field 18 from typing import Any 19 20 import aiohttp 21 from rich.console import Console # type: ignore[import-not-found] 22 from rich.progress import ( # type: ignore[import-not-found] 23 BarColumn, 24 MofNCompleteColumn, 25 Progress, 26 SpinnerColumn, 27 TaskID, 28 TextColumn, 29 TimeElapsedColumn, 30 ) 31 from rich.table import Table # type: ignore[import-not-found] 32 33 console = Console() 34 35 _BODY = { 36 "messages": [{"role": "user", "content": "benchmark request"}], 37 "temperature": 0.0, 38 "max_tokens": 50, 39 } 40 41 42 @dataclass 43 class RunResult: 44 latencies_ms: list[float] = field(default_factory=list) 45 failures: dict[str, int] = field(default_factory=dict) 46 wall_time: float = 0.0 47 48 @property 49 def n_success(self) -> int: 50 return len(self.latencies_ms) 51 52 @property 53 def n_failures(self) -> int: 54 return sum(self.failures.values()) 55 56 @property 57 def throughput(self) -> float: 58 return self.n_success / self.wall_time if self.wall_time > 0 else 0.0 59 60 def percentile(self, p: float) -> float: 61 if not self.latencies_ms: 62 return 0.0 63 s = sorted(self.latencies_ms) 64 idx = max(0, math.ceil(p / 100 * len(s)) - 1) 65 return s[idx] 66 67 68 async def _send( 69 session: aiohttp.ClientSession, 70 url: str, 71 sem: asyncio.Semaphore, 72 auth: aiohttp.BasicAuth | None = None, 73 ) -> tuple[float, str | None]: 74 async with sem: 75 t0 = time.perf_counter() 76 try: 77 async with session.post(url, json=_BODY, auth=auth) as resp: 78 await resp.read() 79 ms = (time.perf_counter() - t0) * 1000 80 if resp.status == 200: 81 return ms, None 82 return ms, f"HTTP {resp.status}" 83 except Exception as e: 84 return (time.perf_counter() - t0) * 1000, type(e).__name__ 85 86 87 async def _run_once( 88 url: str, 89 n: int, 90 max_concurrent: int, 91 progress: Progress, 92 task_id: TaskID, 93 auth: aiohttp.BasicAuth | None = None, 94 ) -> RunResult: 95 sem = asyncio.Semaphore(max_concurrent) 96 connector = aiohttp.TCPConnector( 97 limit=max(max_concurrent * 2, 200), 98 limit_per_host=max(max_concurrent, 200), 99 force_close=False, 100 enable_cleanup_closed=True, 101 ) 102 result = RunResult() 103 total_time = 0.0 104 max_time = 0.0 105 106 async with aiohttp.ClientSession(connector=connector) as session: 107 t0 = time.perf_counter() 108 for coro in asyncio.as_completed([_send(session, url, sem, auth) for _ in range(n)]): 109 ms, error = await coro 110 if error: 111 result.failures[error] = result.failures.get(error, 0) + 1 112 else: 113 result.latencies_ms.append(ms) 114 total_time += ms 115 if ms > max_time: 116 max_time = ms 117 118 n_ok = result.n_success 119 n_fail = result.n_failures 120 mean = total_time / n_ok if n_ok else 0.0 121 fail_part = f"[red]✗{n_fail}[/red] " if n_fail else "" 122 live = f"{fail_part}✓{n_ok} mean={mean:.0f}ms max={max_time:.0f}ms" 123 progress.update(task_id, advance=1, live=live) 124 125 result.wall_time = time.perf_counter() - t0 126 return result 127 128 129 async def _warmup( 130 url: str, n: int, max_concurrent: int, auth: aiohttp.BasicAuth | None = None 131 ) -> None: 132 sem = asyncio.Semaphore(max_concurrent) 133 connector = aiohttp.TCPConnector(limit=max(max_concurrent * 2, 200)) 134 async with aiohttp.ClientSession(connector=connector) as session: 135 await asyncio.gather(*[_send(session, url, sem, auth) for _ in range(n)]) 136 137 138 def run_benchmark( 139 url: str, 140 n_requests: int = 2000, 141 max_concurrent: int = 50, 142 runs: int = 3, 143 auth: aiohttp.BasicAuth | None = None, 144 ) -> list[RunResult]: 145 warmup_n = min(max(50, max_concurrent), n_requests) 146 console.print(f" [dim]Warming up ({warmup_n} requests)...[/dim]") 147 asyncio.run(_warmup(url, warmup_n, max_concurrent, auth)) 148 149 results = [] 150 with Progress( 151 SpinnerColumn(), 152 TextColumn("[progress.description]{task.description}"), 153 BarColumn(), 154 MofNCompleteColumn(), 155 TimeElapsedColumn(), 156 TextColumn(" {task.fields[live]}"), 157 console=console, 158 ) as progress: 159 for i in range(runs): 160 task_id = progress.add_task(f" Run {i + 1}/{runs}", total=n_requests, live="") 161 results.append( 162 asyncio.run(_run_once(url, n_requests, max_concurrent, progress, task_id, auth)) 163 ) 164 return results 165 166 167 def results_to_dict(results: list[RunResult]) -> dict[str, Any]: 168 runs = [ 169 { 170 "n_success": r.n_success, 171 "n_failures": r.n_failures, 172 "failures": r.failures, 173 "wall_time_s": r.wall_time, 174 "mean_ms": statistics.mean(r.latencies_ms) if r.latencies_ms else 0.0, 175 "p50_ms": r.percentile(50), 176 "p95_ms": r.percentile(95), 177 "p99_ms": r.percentile(99), 178 "max_ms": max(r.latencies_ms) if r.latencies_ms else 0.0, 179 "rps": r.throughput, 180 } 181 for r in results 182 ] 183 summary: dict[str, Any] = ( 184 { 185 "avg_mean_ms": statistics.mean( 186 statistics.mean(r.latencies_ms) if r.latencies_ms else 0.0 for r in results 187 ), 188 "avg_p50_ms": statistics.mean(r.percentile(50) for r in results), 189 "avg_p99_ms": statistics.mean(r.percentile(99) for r in results), 190 "avg_rps": statistics.mean(r.throughput for r in results), 191 } 192 if results 193 else {} 194 ) 195 return {"runs": runs, "summary": summary} 196 197 198 def print_results(results: list[RunResult]) -> None: 199 table = Table(show_header=True, header_style="bold cyan", box=None, padding=(0, 2)) 200 table.add_column("Run", style="dim", width=5) 201 table.add_column("Mean ms", justify="right") 202 table.add_column("P50 ms", justify="right") 203 table.add_column("P95 ms", justify="right") 204 table.add_column("P99 ms", justify="right") 205 table.add_column("Max ms", justify="right") 206 table.add_column("Req/s", justify="right") 207 table.add_column("Failures", justify="right") 208 209 means = [] 210 p50s = [] 211 p95s = [] 212 p99s = [] 213 maxes = [] 214 throughputs = [] 215 for i, r in enumerate(results): 216 mean = statistics.mean(r.latencies_ms) if r.latencies_ms else 0.0 217 p50 = r.percentile(50) 218 p95 = r.percentile(95) 219 p99 = r.percentile(99) 220 mx = max(r.latencies_ms) if r.latencies_ms else 0.0 221 means.append(mean) 222 p50s.append(p50) 223 p95s.append(p95) 224 p99s.append(p99) 225 maxes.append(mx) 226 throughputs.append(r.throughput) 227 fail_str = f"[red]{r.n_failures}[/red]" if r.n_failures else "0" 228 table.add_row( 229 str(i + 1), 230 f"{mean:.1f}", 231 f"{p50:.1f}", 232 f"{p95:.1f}", 233 f"{p99:.1f}", 234 f"{mx:.1f}", 235 f"{r.throughput:.0f}", 236 fail_str, 237 ) 238 239 if len(results) > 1: 240 table.add_section() 241 table.add_row( 242 "[bold]avg[/bold]", 243 f"[bold]{statistics.mean(means):.1f}[/bold]", 244 f"[bold]{statistics.mean(p50s):.1f}[/bold]", 245 f"[bold]{statistics.mean(p95s):.1f}[/bold]", 246 f"[bold]{statistics.mean(p99s):.1f}[/bold]", 247 f"[bold]{statistics.mean(maxes):.1f}[/bold]", 248 f"[bold]{statistics.mean(throughputs):.0f}[/bold]", 249 "", 250 ) 251 252 console.print() 253 console.print(table) 254 255 combined: dict[str, int] = {} 256 for r in results: 257 for k, v in r.failures.items(): 258 combined[k] = combined.get(k, 0) + v 259 if combined: 260 console.print() 261 console.print("[red]Failure breakdown:[/red]") 262 for reason, count in sorted(combined.items(), key=lambda x: -x[1]): 263 console.print(f" {reason}: {count}") 264 265 266 def check_thresholds( 267 results: list[RunResult], 268 min_rps: float | None = None, 269 max_p50_ms: float | None = None, 270 max_p99_ms: float | None = None, 271 ) -> bool: 272 """Check results against performance thresholds. Returns True if all pass.""" 273 avg_rps = statistics.mean(r.throughput for r in results) 274 avg_p50 = statistics.mean(r.percentile(50) for r in results) 275 avg_p99 = statistics.mean(r.percentile(99) for r in results) 276 passed = True 277 278 if min_rps is not None and avg_rps < min_rps: 279 console.print( 280 f"\n[red]THRESHOLD FAILED:[/red] avg throughput {avg_rps:.0f} req/s" 281 f" < minimum {min_rps:.0f} req/s" 282 ) 283 passed = False 284 285 if max_p50_ms is not None and avg_p50 > max_p50_ms: 286 console.print( 287 f"\n[red]THRESHOLD FAILED:[/red] avg P50 {avg_p50:.1f} ms > maximum {max_p50_ms:.1f} ms" 288 ) 289 passed = False 290 291 if max_p99_ms is not None and avg_p99 > max_p99_ms: 292 console.print( 293 f"\n[red]THRESHOLD FAILED:[/red] avg P99 {avg_p99:.1f} ms > maximum {max_p99_ms:.1f} ms" 294 ) 295 passed = False 296 297 if passed and (min_rps is not None or max_p50_ms is not None or max_p99_ms is not None): 298 console.print("\n[green]All thresholds passed.[/green]") 299 300 return passed 301 302 303 def main() -> None: 304 parser = argparse.ArgumentParser(description="Async HTTP benchmark client for MLflow Gateway") 305 parser.add_argument("--url", required=True, help="Gateway invocation URL") 306 parser.add_argument("--requests", type=int, default=2000) 307 parser.add_argument("--max-concurrent", type=int, default=50) 308 parser.add_argument("--runs", type=int, default=3) 309 parser.add_argument( 310 "--min-rps", 311 type=float, 312 default=None, 313 metavar="N", 314 help="Fail (exit 1) if average throughput falls below N req/s", 315 ) 316 parser.add_argument( 317 "--max-p50-ms", 318 type=float, 319 default=None, 320 metavar="N", 321 help="Fail (exit 1) if average P50 latency exceeds N ms", 322 ) 323 parser.add_argument( 324 "--max-p99-ms", 325 type=float, 326 default=None, 327 metavar="N", 328 help="Fail (exit 1) if average P99 latency exceeds N ms", 329 ) 330 parser.add_argument( 331 "--auth-username", 332 default=None, 333 help="Basic auth username. If set together with --auth-password, sent on every request.", 334 ) 335 parser.add_argument( 336 "--auth-password", 337 default=None, 338 help="Basic auth password. If set together with --auth-username, sent on every request.", 339 ) 340 args = parser.parse_args() 341 342 auth = ( 343 aiohttp.BasicAuth(args.auth_username, args.auth_password) 344 if args.auth_username and args.auth_password 345 else None 346 ) 347 348 console.print(f"\n[bold]Benchmarking[/bold] {args.url}") 349 console.print( 350 f" {args.requests} requests · {args.max_concurrent} concurrent · {args.runs} runs\n" 351 ) 352 results = run_benchmark(args.url, args.requests, args.max_concurrent, args.runs, auth) 353 print_results(results) 354 355 if not check_thresholds( 356 results, min_rps=args.min_rps, max_p50_ms=args.max_p50_ms, max_p99_ms=args.max_p99_ms 357 ): 358 raise SystemExit(1) 359 360 361 if __name__ == "__main__": 362 main()