/ dev / benchmarks / gateway / benchmark.py
benchmark.py
  1  # /// script
  2  # requires-python = ">=3.10"
  3  # dependencies = ["aiohttp>=3.13.3,<4", "rich>=14.3.3,<15"]
  4  # ///
  5  """Async HTTP benchmark client for the MLflow AI Gateway.
  6  
  7  Can be imported by run.py or used standalone:
  8      uv run benchmark.py --url http://127.0.0.1:5731/gateway/benchmark-chat/mlflow/invocations
  9      uv run benchmark.py --url http://... --requests 5000 --max-concurrent 100 --runs 3
 10  """
 11  
 12  import argparse
 13  import asyncio
 14  import math
 15  import statistics
 16  import time
 17  from dataclasses import dataclass, field
 18  from typing import Any
 19  
 20  import aiohttp
 21  from rich.console import Console  # type: ignore[import-not-found]
 22  from rich.progress import (  # type: ignore[import-not-found]
 23      BarColumn,
 24      MofNCompleteColumn,
 25      Progress,
 26      SpinnerColumn,
 27      TaskID,
 28      TextColumn,
 29      TimeElapsedColumn,
 30  )
 31  from rich.table import Table  # type: ignore[import-not-found]
 32  
 33  console = Console()
 34  
 35  _BODY = {
 36      "messages": [{"role": "user", "content": "benchmark request"}],
 37      "temperature": 0.0,
 38      "max_tokens": 50,
 39  }
 40  
 41  
 42  @dataclass
 43  class RunResult:
 44      latencies_ms: list[float] = field(default_factory=list)
 45      failures: dict[str, int] = field(default_factory=dict)
 46      wall_time: float = 0.0
 47  
 48      @property
 49      def n_success(self) -> int:
 50          return len(self.latencies_ms)
 51  
 52      @property
 53      def n_failures(self) -> int:
 54          return sum(self.failures.values())
 55  
 56      @property
 57      def throughput(self) -> float:
 58          return self.n_success / self.wall_time if self.wall_time > 0 else 0.0
 59  
 60      def percentile(self, p: float) -> float:
 61          if not self.latencies_ms:
 62              return 0.0
 63          s = sorted(self.latencies_ms)
 64          idx = max(0, math.ceil(p / 100 * len(s)) - 1)
 65          return s[idx]
 66  
 67  
 68  async def _send(
 69      session: aiohttp.ClientSession,
 70      url: str,
 71      sem: asyncio.Semaphore,
 72      auth: aiohttp.BasicAuth | None = None,
 73  ) -> tuple[float, str | None]:
 74      async with sem:
 75          t0 = time.perf_counter()
 76          try:
 77              async with session.post(url, json=_BODY, auth=auth) as resp:
 78                  await resp.read()
 79                  ms = (time.perf_counter() - t0) * 1000
 80                  if resp.status == 200:
 81                      return ms, None
 82                  return ms, f"HTTP {resp.status}"
 83          except Exception as e:
 84              return (time.perf_counter() - t0) * 1000, type(e).__name__
 85  
 86  
 87  async def _run_once(
 88      url: str,
 89      n: int,
 90      max_concurrent: int,
 91      progress: Progress,
 92      task_id: TaskID,
 93      auth: aiohttp.BasicAuth | None = None,
 94  ) -> RunResult:
 95      sem = asyncio.Semaphore(max_concurrent)
 96      connector = aiohttp.TCPConnector(
 97          limit=max(max_concurrent * 2, 200),
 98          limit_per_host=max(max_concurrent, 200),
 99          force_close=False,
100          enable_cleanup_closed=True,
101      )
102      result = RunResult()
103      total_time = 0.0
104      max_time = 0.0
105  
106      async with aiohttp.ClientSession(connector=connector) as session:
107          t0 = time.perf_counter()
108          for coro in asyncio.as_completed([_send(session, url, sem, auth) for _ in range(n)]):
109              ms, error = await coro
110              if error:
111                  result.failures[error] = result.failures.get(error, 0) + 1
112              else:
113                  result.latencies_ms.append(ms)
114                  total_time += ms
115                  if ms > max_time:
116                      max_time = ms
117  
118              n_ok = result.n_success
119              n_fail = result.n_failures
120              mean = total_time / n_ok if n_ok else 0.0
121              fail_part = f"[red]✗{n_fail}[/red]  " if n_fail else ""
122              live = f"{fail_part}✓{n_ok}  mean={mean:.0f}ms  max={max_time:.0f}ms"
123              progress.update(task_id, advance=1, live=live)
124  
125          result.wall_time = time.perf_counter() - t0
126      return result
127  
128  
129  async def _warmup(
130      url: str, n: int, max_concurrent: int, auth: aiohttp.BasicAuth | None = None
131  ) -> None:
132      sem = asyncio.Semaphore(max_concurrent)
133      connector = aiohttp.TCPConnector(limit=max(max_concurrent * 2, 200))
134      async with aiohttp.ClientSession(connector=connector) as session:
135          await asyncio.gather(*[_send(session, url, sem, auth) for _ in range(n)])
136  
137  
138  def run_benchmark(
139      url: str,
140      n_requests: int = 2000,
141      max_concurrent: int = 50,
142      runs: int = 3,
143      auth: aiohttp.BasicAuth | None = None,
144  ) -> list[RunResult]:
145      warmup_n = min(max(50, max_concurrent), n_requests)
146      console.print(f"  [dim]Warming up ({warmup_n} requests)...[/dim]")
147      asyncio.run(_warmup(url, warmup_n, max_concurrent, auth))
148  
149      results = []
150      with Progress(
151          SpinnerColumn(),
152          TextColumn("[progress.description]{task.description}"),
153          BarColumn(),
154          MofNCompleteColumn(),
155          TimeElapsedColumn(),
156          TextColumn("  {task.fields[live]}"),
157          console=console,
158      ) as progress:
159          for i in range(runs):
160              task_id = progress.add_task(f"  Run {i + 1}/{runs}", total=n_requests, live="")
161              results.append(
162                  asyncio.run(_run_once(url, n_requests, max_concurrent, progress, task_id, auth))
163              )
164      return results
165  
166  
167  def results_to_dict(results: list[RunResult]) -> dict[str, Any]:
168      runs = [
169          {
170              "n_success": r.n_success,
171              "n_failures": r.n_failures,
172              "failures": r.failures,
173              "wall_time_s": r.wall_time,
174              "mean_ms": statistics.mean(r.latencies_ms) if r.latencies_ms else 0.0,
175              "p50_ms": r.percentile(50),
176              "p95_ms": r.percentile(95),
177              "p99_ms": r.percentile(99),
178              "max_ms": max(r.latencies_ms) if r.latencies_ms else 0.0,
179              "rps": r.throughput,
180          }
181          for r in results
182      ]
183      summary: dict[str, Any] = (
184          {
185              "avg_mean_ms": statistics.mean(
186                  statistics.mean(r.latencies_ms) if r.latencies_ms else 0.0 for r in results
187              ),
188              "avg_p50_ms": statistics.mean(r.percentile(50) for r in results),
189              "avg_p99_ms": statistics.mean(r.percentile(99) for r in results),
190              "avg_rps": statistics.mean(r.throughput for r in results),
191          }
192          if results
193          else {}
194      )
195      return {"runs": runs, "summary": summary}
196  
197  
198  def print_results(results: list[RunResult]) -> None:
199      table = Table(show_header=True, header_style="bold cyan", box=None, padding=(0, 2))
200      table.add_column("Run", style="dim", width=5)
201      table.add_column("Mean ms", justify="right")
202      table.add_column("P50 ms", justify="right")
203      table.add_column("P95 ms", justify="right")
204      table.add_column("P99 ms", justify="right")
205      table.add_column("Max ms", justify="right")
206      table.add_column("Req/s", justify="right")
207      table.add_column("Failures", justify="right")
208  
209      means = []
210      p50s = []
211      p95s = []
212      p99s = []
213      maxes = []
214      throughputs = []
215      for i, r in enumerate(results):
216          mean = statistics.mean(r.latencies_ms) if r.latencies_ms else 0.0
217          p50 = r.percentile(50)
218          p95 = r.percentile(95)
219          p99 = r.percentile(99)
220          mx = max(r.latencies_ms) if r.latencies_ms else 0.0
221          means.append(mean)
222          p50s.append(p50)
223          p95s.append(p95)
224          p99s.append(p99)
225          maxes.append(mx)
226          throughputs.append(r.throughput)
227          fail_str = f"[red]{r.n_failures}[/red]" if r.n_failures else "0"
228          table.add_row(
229              str(i + 1),
230              f"{mean:.1f}",
231              f"{p50:.1f}",
232              f"{p95:.1f}",
233              f"{p99:.1f}",
234              f"{mx:.1f}",
235              f"{r.throughput:.0f}",
236              fail_str,
237          )
238  
239      if len(results) > 1:
240          table.add_section()
241          table.add_row(
242              "[bold]avg[/bold]",
243              f"[bold]{statistics.mean(means):.1f}[/bold]",
244              f"[bold]{statistics.mean(p50s):.1f}[/bold]",
245              f"[bold]{statistics.mean(p95s):.1f}[/bold]",
246              f"[bold]{statistics.mean(p99s):.1f}[/bold]",
247              f"[bold]{statistics.mean(maxes):.1f}[/bold]",
248              f"[bold]{statistics.mean(throughputs):.0f}[/bold]",
249              "",
250          )
251  
252      console.print()
253      console.print(table)
254  
255      combined: dict[str, int] = {}
256      for r in results:
257          for k, v in r.failures.items():
258              combined[k] = combined.get(k, 0) + v
259      if combined:
260          console.print()
261          console.print("[red]Failure breakdown:[/red]")
262          for reason, count in sorted(combined.items(), key=lambda x: -x[1]):
263              console.print(f"  {reason}: {count}")
264  
265  
266  def check_thresholds(
267      results: list[RunResult],
268      min_rps: float | None = None,
269      max_p50_ms: float | None = None,
270      max_p99_ms: float | None = None,
271  ) -> bool:
272      """Check results against performance thresholds. Returns True if all pass."""
273      avg_rps = statistics.mean(r.throughput for r in results)
274      avg_p50 = statistics.mean(r.percentile(50) for r in results)
275      avg_p99 = statistics.mean(r.percentile(99) for r in results)
276      passed = True
277  
278      if min_rps is not None and avg_rps < min_rps:
279          console.print(
280              f"\n[red]THRESHOLD FAILED:[/red] avg throughput {avg_rps:.0f} req/s"
281              f" < minimum {min_rps:.0f} req/s"
282          )
283          passed = False
284  
285      if max_p50_ms is not None and avg_p50 > max_p50_ms:
286          console.print(
287              f"\n[red]THRESHOLD FAILED:[/red] avg P50 {avg_p50:.1f} ms > maximum {max_p50_ms:.1f} ms"
288          )
289          passed = False
290  
291      if max_p99_ms is not None and avg_p99 > max_p99_ms:
292          console.print(
293              f"\n[red]THRESHOLD FAILED:[/red] avg P99 {avg_p99:.1f} ms > maximum {max_p99_ms:.1f} ms"
294          )
295          passed = False
296  
297      if passed and (min_rps is not None or max_p50_ms is not None or max_p99_ms is not None):
298          console.print("\n[green]All thresholds passed.[/green]")
299  
300      return passed
301  
302  
303  def main() -> None:
304      parser = argparse.ArgumentParser(description="Async HTTP benchmark client for MLflow Gateway")
305      parser.add_argument("--url", required=True, help="Gateway invocation URL")
306      parser.add_argument("--requests", type=int, default=2000)
307      parser.add_argument("--max-concurrent", type=int, default=50)
308      parser.add_argument("--runs", type=int, default=3)
309      parser.add_argument(
310          "--min-rps",
311          type=float,
312          default=None,
313          metavar="N",
314          help="Fail (exit 1) if average throughput falls below N req/s",
315      )
316      parser.add_argument(
317          "--max-p50-ms",
318          type=float,
319          default=None,
320          metavar="N",
321          help="Fail (exit 1) if average P50 latency exceeds N ms",
322      )
323      parser.add_argument(
324          "--max-p99-ms",
325          type=float,
326          default=None,
327          metavar="N",
328          help="Fail (exit 1) if average P99 latency exceeds N ms",
329      )
330      parser.add_argument(
331          "--auth-username",
332          default=None,
333          help="Basic auth username. If set together with --auth-password, sent on every request.",
334      )
335      parser.add_argument(
336          "--auth-password",
337          default=None,
338          help="Basic auth password. If set together with --auth-username, sent on every request.",
339      )
340      args = parser.parse_args()
341  
342      auth = (
343          aiohttp.BasicAuth(args.auth_username, args.auth_password)
344          if args.auth_username and args.auth_password
345          else None
346      )
347  
348      console.print(f"\n[bold]Benchmarking[/bold] {args.url}")
349      console.print(
350          f"  {args.requests} requests · {args.max_concurrent} concurrent · {args.runs} runs\n"
351      )
352      results = run_benchmark(args.url, args.requests, args.max_concurrent, args.runs, auth)
353      print_results(results)
354  
355      if not check_thresholds(
356          results, min_rps=args.min_rps, max_p50_ms=args.max_p50_ms, max_p99_ms=args.max_p99_ms
357      ):
358          raise SystemExit(1)
359  
360  
361  if __name__ == "__main__":
362      main()