/ dev / benchmarks / gateway / run.py
run.py
  1  # /// script
  2  # requires-python = ">=3.10"
  3  # dependencies = ["aiohttp>=3.13.3,<4", "psycopg2-binary>=2.9,<3", "rich>=14.3.3,<15"]
  4  # ///
  5  """MLflow AI Gateway benchmark runner.
  6  
  7  Orchestrates fake OpenAI server, MLflow server(s), optional PostgreSQL and
  8  nginx (via Docker), then runs the async benchmark client.
  9  
 10  Usage:
 11      uv run run.py                              # 4 instances, PostgreSQL, nginx (Docker)
 12      uv run run.py --instances 1               # single instance, SQLite, no Docker
 13      uv run run.py --instances 1 --database postgres
 14      uv run run.py --instances 8 --workers 8
 15      uv run run.py --url http://...            # benchmark an existing endpoint directly
 16  """
 17  
 18  import argparse
 19  import base64
 20  import contextlib
 21  import json
 22  import os
 23  import shutil
 24  import subprocess
 25  import sys
 26  import tempfile
 27  import time
 28  import urllib.error
 29  import urllib.request
 30  from collections.abc import Generator
 31  from pathlib import Path
 32  from typing import Any
 33  
 34  sys.path.insert(0, str(Path(__file__).parent))
 35  import aiohttp  # type: ignore[import-not-found]
 36  import benchmark as bm  # local module; path inserted above
 37  from rich.console import Console  # type: ignore[import-not-found]
 38  from rich.panel import Panel  # type: ignore[import-not-found]
 39  from rich.progress import (  # type: ignore[import-not-found]
 40      Progress,
 41      SpinnerColumn,
 42      TextColumn,
 43      TimeElapsedColumn,
 44  )
 45  
 46  SCRIPT_DIR = Path(__file__).parent
 47  
 48  FAKE_SERVER_PORT = 9137
 49  FAKE_SERVER_WORKERS = 8
 50  MLFLOW_PORT = 5731
 51  INSTANCE_BASE_PORT = 5800
 52  POSTGRES_PORT = int(os.environ.get("GATEWAY_BENCH_POSTGRES_PORT", "5432"))
 53  POSTGRES_PASSWORD = "benchmarkpass"
 54  ENDPOINT_NAME = "benchmark-chat"
 55  
 56  _API_SECRET_CREATE = "gateway/secrets/create"
 57  _API_MODEL_DEF_CREATE = "gateway/model-definitions/create"
 58  _API_ENDPOINT_CREATE = "gateway/endpoints/create"
 59  
 60  console = Console()
 61  
 62  
 63  def _uv_prefix() -> list[str]:
 64      """Return uv run prefix when inside the mlflow repo, else empty list."""
 65      in_repo = (
 66          shutil.which("uv")
 67          and subprocess.run(
 68              ["git", "rev-parse", "HEAD"], cwd=SCRIPT_DIR, capture_output=True
 69          ).returncode
 70          == 0
 71      )
 72      return ["uv", "run", "--no-build-isolation", "--extra", "gateway"] if in_repo else []
 73  
 74  
 75  def _subprocess_env() -> dict[str, str]:
 76      return os.environ | {"OBJC_DISABLE_INITIALIZE_FORK_SAFETY": "YES"}
 77  
 78  
 79  def _wait_for_port(port: int, label: str, log_file: Path | None = None, timeout: int = 30) -> None:
 80      url = f"http://127.0.0.1:{port}/health"
 81      with Progress(
 82          SpinnerColumn(),
 83          TextColumn("[progress.description]{task.description}"),
 84          TimeElapsedColumn(),
 85          console=console,
 86          transient=True,
 87      ) as progress:
 88          progress.add_task(f"  Waiting for {label}...", total=None)
 89          deadline = time.monotonic() + timeout
 90          while time.monotonic() < deadline:
 91              try:
 92                  with urllib.request.urlopen(url, timeout=1):
 93                      break
 94              except Exception:
 95                  time.sleep(0.5)
 96          else:
 97              console.print(f"  [red]✗ {label} failed to start within {timeout}s[/red]")
 98              if log_file and log_file.exists():
 99                  console.print("  [yellow]Last 20 lines of log:[/yellow]")
100                  for line in log_file.read_text().splitlines()[-20:]:
101                      console.print(f"    [dim]{line}[/dim]")
102              sys.exit(1)
103      console.print(f"  [green]✓[/green] {label} ready")
104  
105  
106  @contextlib.contextmanager
107  def _start_fake_server(
108      work_dir: str, port: int = FAKE_SERVER_PORT, workers: int = FAKE_SERVER_WORKERS
109  ) -> Generator[None, None, None]:
110      prefix = _uv_prefix()
111      log_file = Path(work_dir) / "fake_server.log"
112      with (
113          log_file.open("w") as f,
114          subprocess.Popen(
115              [
116                  *prefix,
117                  "uvicorn",
118                  "fake_server:app",
119                  "--workers",
120                  str(workers),
121                  "--host",
122                  "127.0.0.1",
123                  "--port",
124                  str(port),
125                  "--log-level",
126                  "warning",
127              ],
128              cwd=SCRIPT_DIR,
129              stdout=f,
130              stderr=f,
131              env=_subprocess_env(),
132          ) as proc,
133      ):
134          _wait_for_port(port, "fake OpenAI server", log_file)
135          try:
136              yield
137          finally:
138              proc.terminate()
139  
140  
141  @contextlib.contextmanager
142  def _start_mlflow(
143      work_dir: str,
144      port: int,
145      workers: int,
146      backend_uri: str,
147      label: str = "MLflow server",
148      host: str = "127.0.0.1",
149      auth: bool = False,
150  ) -> Generator[None, None, None]:
151      prefix = _uv_prefix()
152      # basic-auth requires the `auth` extra (Flask-WTF) at runtime.
153      if auth and prefix:
154          prefix = [*prefix, "--extra", "auth"]
155      # psycopg2-binary lives in the `db` extra.
156      if backend_uri.startswith("postgresql") and prefix:
157          prefix = [*prefix, "--extra", "db"]
158      log_file = Path(work_dir) / f"mlflow-{port}.log"
159      cmd = [
160          *prefix,
161          "mlflow",
162          "server",
163          "--backend-store-uri",
164          backend_uri,
165          "--host",
166          host,
167          "--port",
168          str(port),
169          "--workers",
170          str(workers),
171          "--disable-security-middleware",
172      ]
173      if auth:
174          cmd += ["--app-name", "basic-auth"]
175      with (
176          log_file.open("w") as f,
177          subprocess.Popen(cmd, cwd=SCRIPT_DIR, stdout=f, stderr=f, env=_subprocess_env()) as proc,
178      ):
179          _wait_for_port(port, label, log_file)
180          try:
181              yield
182          finally:
183              proc.terminate()
184  
185  
186  def _check_docker() -> None:
187      try:
188          result = subprocess.run(["docker", "info"], capture_output=True)
189      except FileNotFoundError:
190          console.print(
191              "[red]Docker is not installed. Install it at https://docs.docker.com/get-docker/[/red]"
192          )
193          sys.exit(1)
194      if result.returncode != 0:
195          console.print("[red]Docker daemon is not running. Please start Docker and try again.[/red]")
196          sys.exit(1)
197  
198  
199  @contextlib.contextmanager
200  def _start_postgres(container_name: str = "benchmark-postgres") -> Generator[str, None, None]:
201      """Start a PostgreSQL Docker container. Yields the connection URI."""
202      subprocess.run(["docker", "rm", "-f", container_name], capture_output=True)
203  
204      with subprocess.Popen(
205          [
206              "docker",
207              "run",
208              "--rm",
209              "--name",
210              container_name,
211              "-e",
212              f"POSTGRES_PASSWORD={POSTGRES_PASSWORD}",
213              "-e",
214              "POSTGRES_DB=mlflow",
215              "-p",
216              f"127.0.0.1:{POSTGRES_PORT}:5432",
217              "postgres:16-alpine",
218              "-c",
219              "max_connections=500",
220          ],
221          stdout=subprocess.DEVNULL,
222          stderr=subprocess.DEVNULL,
223      ):
224          with Progress(
225              SpinnerColumn(),
226              TextColumn("[progress.description]{task.description}"),
227              TimeElapsedColumn(),
228              console=console,
229              transient=True,
230          ) as progress:
231              progress.add_task("  Starting PostgreSQL...", total=None)
232              deadline = time.monotonic() + 30
233              while time.monotonic() < deadline:
234                  if (
235                      subprocess.run(
236                          ["docker", "exec", container_name, "pg_isready", "-U", "postgres"],
237                          capture_output=True,
238                      ).returncode
239                      == 0
240                  ):
241                      break
242                  time.sleep(0.5)
243              else:
244                  console.print("  [red]✗ PostgreSQL failed to start within 30s[/red]")
245                  sys.exit(1)
246  
247          console.print("  [green]✓[/green] PostgreSQL ready")
248          try:
249              yield f"postgresql://postgres:{POSTGRES_PASSWORD}@127.0.0.1:{POSTGRES_PORT}/mlflow"
250          finally:
251              subprocess.run(["docker", "kill", container_name], capture_output=True)
252  
253  
254  def _basic_auth_header(creds: tuple[str, str] | None) -> dict[str, str]:
255      if creds is None:
256          return {}
257      token = base64.b64encode(f"{creds[0]}:{creds[1]}".encode()).decode()
258      return {"Authorization": f"Basic {token}"}
259  
260  
261  def _api_post(
262      tracking_uri: str,
263      path: str,
264      body: dict[str, Any],
265      creds: tuple[str, str] | None = None,
266  ) -> Any:
267      url = f"{tracking_uri.rstrip('/')}/api/3.0/mlflow/{path}"
268      headers = {"Content-Type": "application/json", **_basic_auth_header(creds)}
269      req = urllib.request.Request(url, data=json.dumps(body).encode(), headers=headers)
270      try:
271          with urllib.request.urlopen(req, timeout=10) as resp:
272              return json.loads(resp.read())
273      except urllib.error.HTTPError as e:
274          console.print(f"  [red]API error {e.code} at {url}: {e.read().decode()}[/red]")
275          sys.exit(1)
276      except urllib.error.URLError as e:
277          console.print(f"  [red]API error at {url}: {e.reason}[/red]")
278          sys.exit(1)
279  
280  
281  def _setup_endpoint(
282      tracking_uri: str,
283      fake_server_url: str,
284      endpoint_name: str,
285      usage_tracking: bool,
286      creds: tuple[str, str] | None = None,
287  ) -> str:
288      """Create secret → model definition → endpoint. Returns the invocation URL."""
289      console.print("  Creating secret...")
290      secret_id = _api_post(
291          tracking_uri,
292          _API_SECRET_CREATE,
293          {
294              "secret_name": "benchmark-secret",
295              "secret_value": {"api_key": "fake-benchmark-key"},
296              "provider": "openai",
297              "auth_config": {"api_base": fake_server_url},
298          },
299          creds,
300      )["secret"]["secret_id"]
301  
302      console.print("  Creating model definition...")
303      model_def_id = _api_post(
304          tracking_uri,
305          _API_MODEL_DEF_CREATE,
306          {
307              "name": "benchmark-model",
308              "secret_id": secret_id,
309              "provider": "openai",
310              "model_name": "gpt-4o-mini",
311          },
312          creds,
313      )["model_definition"]["model_definition_id"]
314  
315      console.print(f"  Creating endpoint '{endpoint_name}' (usage_tracking={usage_tracking})...")
316      _api_post(
317          tracking_uri,
318          _API_ENDPOINT_CREATE,
319          {
320              "name": endpoint_name,
321              "model_configs": [
322                  {"model_definition_id": model_def_id, "linkage_type": "PRIMARY", "weight": 1.0}
323              ],
324              "usage_tracking": usage_tracking,
325          },
326          creds,
327      )
328  
329      invoke_url = f"{tracking_uri.rstrip('/')}/gateway/{endpoint_name}/mlflow/invocations"
330      console.print(f"  [green]✓[/green] Endpoint ready: [cyan]{invoke_url}[/cyan]")
331      return invoke_url
332  
333  
334  def _sanity_check(url: str, creds: tuple[str, str] | None = None) -> None:
335      console.print("  Sending sanity-check request...")
336      body = json.dumps({"messages": [{"role": "user", "content": "test"}]}).encode()
337      headers = {"Content-Type": "application/json", **_basic_auth_header(creds)}
338      req = urllib.request.Request(url, data=body, headers=headers)
339      try:
340          with urllib.request.urlopen(req, timeout=10) as resp:
341              if resp.status != 200:
342                  console.print(f"  [red]✗ Sanity check failed: HTTP {resp.status}[/red]")
343                  sys.exit(1)
344      except Exception as e:
345          console.print(f"  [red]✗ Sanity check failed: {e}[/red]")
346          sys.exit(1)
347      console.print("  [green]✓[/green] Sanity check passed")
348  
349  
350  def _run_benchmark(
351      url: str,
352      n_requests: int,
353      max_concurrent: int,
354      runs: int,
355      min_rps: float | None = None,
356      max_p50_ms: float | None = None,
357      max_p99_ms: float | None = None,
358      output: Path | None = None,
359      creds: tuple[str, str] | None = None,
360  ) -> None:
361      auth = aiohttp.BasicAuth(*creds) if creds else None
362      results = bm.run_benchmark(url, n_requests, max_concurrent, runs, auth)
363      bm.print_results(results)
364      if output is not None:
365          output.write_text(json.dumps(bm.results_to_dict(results), indent=2))
366          console.print(f"  Results saved to [cyan]{output}[/cyan]")
367      if not bm.check_thresholds(
368          results, min_rps=min_rps, max_p50_ms=max_p50_ms, max_p99_ms=max_p99_ms
369      ):
370          raise SystemExit(1)
371  
372  
373  @contextlib.contextmanager
374  def _start_nginx(
375      work_dir: str, instance_ports: list[int], port: int, container_name: str = "benchmark-nginx"
376  ) -> Generator[None, None, None]:
377      nginx_dir = Path(work_dir) / "nginx"
378      conf_d = nginx_dir / "conf.d"
379      conf_d.mkdir(parents=True)
380  
381      upstream_lines = "\n".join(f"    server host.docker.internal:{p};" for p in instance_ports)
382      (conf_d / "mlflow.conf").write_text(
383          f"upstream mlflow_backends {{\n"
384          f"{upstream_lines}\n"
385          f"    keepalive 512;\n"
386          f"    keepalive_requests 100000;\n"
387          f"    keepalive_timeout 60s;\n"
388          f"}}\n"
389          f"server {{\n"
390          f"    listen {port} reuseport backlog=65535;\n"
391          f"    location / {{\n"
392          f"        proxy_pass http://mlflow_backends;\n"
393          f"        proxy_http_version 1.1;\n"
394          f'        proxy_set_header Connection "";\n'
395          f"        proxy_set_header Host $host;\n"
396          f"        proxy_set_header X-Real-IP $remote_addr;\n"
397          f"        proxy_connect_timeout 5s;\n"
398          f"        proxy_send_timeout 60s;\n"
399          f"        proxy_read_timeout 60s;\n"
400          f"    }}\n"
401          f"}}\n"
402      )
403      (nginx_dir / "nginx.conf").write_text(
404          "worker_processes auto;\n"
405          "worker_rlimit_nofile 65535;\n"
406          "events {\n"
407          "    worker_connections 16384;\n"
408          "    use epoll;\n"
409          "    multi_accept on;\n"
410          "}\n"
411          "http {\n"
412          "    access_log off;\n"
413          "    tcp_nodelay on;\n"
414          "    keepalive_timeout 65;\n"
415          "    keepalive_requests 100000;\n"
416          "    reset_timedout_connection on;\n"
417          "    include /etc/nginx/conf.d/*.conf;\n"
418          "}\n"
419      )
420  
421      subprocess.run(["docker", "rm", "-f", container_name], capture_output=True)
422      with Progress(
423          SpinnerColumn(),
424          TextColumn("[progress.description]{task.description}"),
425          TimeElapsedColumn(),
426          console=console,
427          transient=True,
428      ) as progress:
429          progress.add_task("  Starting nginx...", total=None)
430          subprocess.run(
431              [
432                  "docker",
433                  "run",
434                  "--rm",
435                  "-d",
436                  "--name",
437                  container_name,
438                  "--add-host=host.docker.internal:host-gateway",
439                  "--ulimit",
440                  "nofile=65535:65535",
441                  "-v",
442                  f"{nginx_dir / 'nginx.conf'}:/etc/nginx/nginx.conf:ro",
443                  "-v",
444                  f"{conf_d}:/etc/nginx/conf.d:ro",
445                  "-p",
446                  f"127.0.0.1:{port}:{port}",
447                  "nginx:alpine",
448              ],
449              check=True,
450              capture_output=True,
451          )
452  
453          deadline = time.monotonic() + 15
454          while time.monotonic() < deadline:
455              if (
456                  subprocess.run(
457                      ["docker", "exec", container_name, "nginx", "-t"], capture_output=True
458                  ).returncode
459                  == 0
460              ):
461                  break
462              time.sleep(0.5)
463          else:
464              console.print("  [red]✗ nginx failed to start[/red]")
465              sys.exit(1)
466  
467      console.print("  [green]✓[/green] nginx ready")
468      try:
469          yield
470      finally:
471          subprocess.run(["docker", "kill", container_name], capture_output=True)
472  
473  
474  def cmd_bench(args: argparse.Namespace) -> None:
475      instances = args.instances
476      mode = "1 instance" if instances == 1 else f"{instances} instances, nginx LB"
477      creds = (args.auth_username, args.auth_password) if args.auth else None
478  
479      if args.url:
480          console.print(
481              Panel.fit(
482                  f"[bold]Gateway Benchmark[/bold] ({mode})\n"
483                  f"URL: [cyan]{args.url}[/cyan]\n"
484                  f"Auth: {'basic-auth as ' + args.auth_username if creds else 'disabled'}\n"
485                  f"Requests: {args.requests}  ·  Concurrency: {args.max_concurrent}"
486                  f"  ·  Runs: {args.runs}",
487                  border_style="cyan",
488              )
489          )
490          console.print("\n[bold]Running benchmark[/bold]")
491          _run_benchmark(
492              args.url,
493              args.requests,
494              args.max_concurrent,
495              args.runs,
496              args.min_rps,
497              args.max_p50_ms,
498              args.max_p99_ms,
499              args.output,
500              creds,
501          )
502          return
503  
504      needs_docker = instances > 1 or args.database == "postgres"
505      if needs_docker:
506          _check_docker()
507  
508      with tempfile.TemporaryDirectory(prefix="mlflow-bench-") as work_dir:
509          port = args.port
510          fake_port = args.fake_server_port
511          instance_ports = [args.base_port + i for i in range(instances)]
512  
513          auth_line = f"basic-auth as {args.auth_username}" if creds else "disabled"
514          if instances == 1:
515              panel = (
516                  f"[bold]Gateway Benchmark[/bold] ({mode})\n"
517                  f"Workers: {args.workers}  ·  DB: {args.database.upper()}  ·  "
518                  f"Usage tracking: {args.usage_tracking}  ·  Auth: {auth_line}\n"
519                  f"Requests: {args.requests}  ·  Concurrency: {args.max_concurrent}  ·  "
520                  f"Runs: {args.runs}  ·  Fake delay: {args.fake_delay_ms}ms\n"
521                  f"Ports: MLflow :{port}  ·  Fake server :{fake_port}"
522              )
523          else:
524              panel = (
525                  f"[bold]Gateway Benchmark[/bold] ({mode})\n"
526                  f"Workers/instance: {args.workers}  ·  "
527                  f"Total workers: {instances * args.workers}  ·  "
528                  f"Usage tracking: {args.usage_tracking}  ·  Auth: {auth_line}\n"
529                  f"Requests: {args.requests}  ·  Concurrency: {args.max_concurrent}  ·  "
530                  f"Runs: {args.runs}  ·  Fake delay: {args.fake_delay_ms}ms\n"
531                  f"Ports: instances {instance_ports[0]}–{instance_ports[-1]}"
532                  f"  ·  LB :{port}  ·  Fake server :{fake_port}"
533              )
534          console.print(Panel.fit(panel, border_style="cyan"))
535  
536          with contextlib.ExitStack() as stack:
537              stack.callback(lambda: console.print("\n[dim]Cleaning up...[/dim]"))
538  
539              # Backend
540              if instances > 1 or args.database == "postgres":
541                  console.print("\n[bold]PostgreSQL[/bold]")
542                  backend_uri = stack.enter_context(_start_postgres())
543              else:
544                  db_path = Path(work_dir) / "mlflow.db"
545                  backend_uri = f"sqlite:///{db_path}"
546                  console.print(f"\n[dim]Using SQLite: {db_path}[/dim]")
547  
548              # Servers
549              console.print("\n[bold]Starting servers[/bold]")
550              stack.enter_context(
551                  _start_fake_server(work_dir, port=fake_port, workers=FAKE_SERVER_WORKERS)
552              )
553  
554              if instances == 1:
555                  stack.enter_context(
556                      _start_mlflow(work_dir, port, args.workers, backend_uri, auth=args.auth)
557                  )
558  
559                  console.print("\n[bold]Setting up gateway endpoint[/bold]")
560                  invoke_url = _setup_endpoint(
561                      f"http://127.0.0.1:{port}",
562                      f"http://127.0.0.1:{fake_port}/v1",
563                      ENDPOINT_NAME,
564                      usage_tracking=args.usage_tracking,
565                      creds=creds,
566                  )
567                  _sanity_check(invoke_url, creds)
568              else:
569                  # Start instance 0 first — it initializes the DB schema.
570                  # All instances share the same PostgreSQL DB, so starting concurrently
571                  # can cause CREATE TABLE race conditions.
572                  stack.enter_context(
573                      _start_mlflow(
574                          work_dir,
575                          instance_ports[0],
576                          args.workers,
577                          backend_uri,
578                          "MLflow instance 0",
579                          host="0.0.0.0",
580                          auth=args.auth,
581                      )
582                  )
583                  for i, p in enumerate(instance_ports[1:], start=1):
584                      stack.enter_context(
585                          _start_mlflow(
586                              work_dir,
587                              p,
588                              args.workers,
589                              backend_uri,
590                              f"MLflow instance {i}",
591                              host="0.0.0.0",
592                              auth=args.auth,
593                          )
594                      )
595  
596                  console.print("\n[bold]Setting up gateway endpoint[/bold]")
597                  _setup_endpoint(
598                      f"http://127.0.0.1:{instance_ports[0]}",
599                      f"http://127.0.0.1:{fake_port}/v1",
600                      ENDPOINT_NAME,
601                      usage_tracking=args.usage_tracking,
602                      creds=creds,
603                  )
604  
605                  console.print("\n[bold]Starting nginx load balancer[/bold]")
606                  nginx_container = "benchmark-nginx"
607                  stack.enter_context(
608                      _start_nginx(
609                          work_dir, instance_ports, port=port, container_name=nginx_container
610                      )
611                  )
612                  subprocess.run(
613                      ["docker", "exec", nginx_container, "nginx", "-s", "reload"],
614                      capture_output=True,
615                  )
616                  time.sleep(1)
617  
618                  invoke_url = f"http://127.0.0.1:{port}/gateway/{ENDPOINT_NAME}/mlflow/invocations"
619                  _sanity_check(invoke_url, creds)
620  
621              console.print("\n[bold]Running benchmark[/bold]")
622              _run_benchmark(
623                  invoke_url,
624                  args.requests,
625                  args.max_concurrent,
626                  args.runs,
627                  args.min_rps,
628                  args.max_p50_ms,
629                  args.max_p99_ms,
630                  args.output,
631                  creds,
632              )
633  
634  
635  def main() -> None:
636      parser = argparse.ArgumentParser(
637          description="MLflow AI Gateway benchmark",
638          formatter_class=argparse.RawDescriptionHelpFormatter,
639          epilog=__doc__,
640      )
641      parser.add_argument(
642          "--url",
643          metavar="URL",
644          help="Benchmark this endpoint URL directly, skipping server setup entirely",
645      )
646      parser.add_argument(
647          "--instances",
648          type=int,
649          default=int(os.environ.get("INSTANCES", "4")),
650          metavar="N",
651          help=(
652              "Number of MLflow instances to run (default: 4). "
653              "Values >1 require Docker (postgres + nginx). "
654              "Use --instances 1 for a single instance with optional SQLite."
655          ),
656      )
657      parser.add_argument(
658          "--workers",
659          type=int,
660          default=int(os.environ.get("WORKERS_PER_INSTANCE", "4")),
661          metavar="N",
662          help="Gunicorn/uvicorn worker processes per MLflow instance (default: 4)",
663      )
664      parser.add_argument(
665          "--database",
666          choices=["sqlite", "postgres"],
667          default="sqlite",
668          help=(
669              "Database to use — only applies when --instances 1. "
670              "'postgres' auto-starts a Docker container. (default: sqlite)"
671          ),
672      )
673      parser.add_argument(
674          "--no-usage-tracking",
675          dest="usage_tracking",
676          action="store_false",
677          default=True,
678          help="Disable usage tracking (tracing) on the benchmark endpoint",
679      )
680      parser.add_argument(
681          "--port",
682          type=int,
683          default=int(os.environ.get("MLFLOW_PORT", str(MLFLOW_PORT))),
684          metavar="N",
685          help=(
686              "Port the benchmark client sends requests to. "
687              "For --instances 1 this is the MLflow port; "
688              "for --instances >1 this is the nginx load balancer port. (default: 5731)"
689          ),
690      )
691      parser.add_argument(
692          "--base-port",
693          type=int,
694          default=int(os.environ.get("BASE_PORT", str(INSTANCE_BASE_PORT))),
695          metavar="N",
696          help=(
697              "Starting port for MLflow instances in multi mode. "
698              "Instances listen on base-port, base-port+1, … (default: 5800)"
699          ),
700      )
701      parser.add_argument(
702          "--fake-server-port",
703          type=int,
704          metavar="N",
705          default=int(os.environ.get("FAKE_SERVER_PORT", str(FAKE_SERVER_PORT))),
706          help="Port for the fake OpenAI server that simulates provider latency (default: 9137)",
707      )
708      parser.add_argument(
709          "--requests",
710          type=int,
711          default=int(os.environ.get("REQUESTS", "2000")),
712          metavar="N",
713          help="Total requests to send per benchmark run (default: 2000)",
714      )
715      parser.add_argument(
716          "--max-concurrent",
717          type=int,
718          default=int(os.environ.get("MAX_CONCURRENT", "50")),
719          metavar="N",
720          help="Maximum number of in-flight requests at any time (default: 50)",
721      )
722      parser.add_argument(
723          "--runs",
724          type=int,
725          default=int(os.environ.get("RUNS", "3")),
726          metavar="N",
727          help="Number of timed runs; results are reported per-run and averaged (default: 3)",
728      )
729      parser.add_argument(
730          "--fake-delay-ms",
731          type=int,
732          default=int(os.environ.get("FAKE_RESPONSE_DELAY_MS", "50")),
733          metavar="N",
734          help=(
735              "Simulated provider latency in ms. Set to 0 to measure pure MLflow overhead "
736              "with no provider delay. (default: 50)"
737          ),
738      )
739      parser.add_argument(
740          "--output",
741          type=Path,
742          default=None,
743          metavar="FILE",
744          help="Write benchmark results as JSON to FILE (useful for CI artifact upload)",
745      )
746      parser.add_argument(
747          "--min-rps",
748          type=float,
749          default=None,
750          metavar="N",
751          help="Exit 1 if average throughput across runs falls below N req/s (CI threshold)",
752      )
753      parser.add_argument(
754          "--max-p50-ms",
755          type=float,
756          default=None,
757          metavar="N",
758          help="Exit 1 if average P50 latency across runs exceeds N ms (CI threshold)",
759      )
760      parser.add_argument(
761          "--max-p99-ms",
762          type=float,
763          default=None,
764          metavar="N",
765          help="Exit 1 if average P99 latency across runs exceeds N ms (CI threshold)",
766      )
767      parser.add_argument(
768          "--auth",
769          action="store_true",
770          default=os.environ.get("AUTH", "").lower() in ("1", "true"),
771          help=(
772              "Start MLflow with --app-name=basic-auth and authenticate every setup + "
773              "benchmark request using --auth-username/--auth-password."
774          ),
775      )
776      parser.add_argument(
777          "--auth-username",
778          default=os.environ.get("AUTH_USERNAME", "admin"),
779          help="Basic auth username (default: admin, from basic_auth.ini)",
780      )
781      parser.add_argument(
782          "--auth-password",
783          default=os.environ.get("AUTH_PASSWORD", "password1234"),
784          help="Basic auth password (default: password1234, from basic_auth.ini)",
785      )
786  
787      args = parser.parse_args()
788      os.environ["FAKE_RESPONSE_DELAY_MS"] = str(args.fake_delay_ms)
789      cmd_bench(args)
790  
791  
792  if __name__ == "__main__":
793      main()