/ tests / tracking / integration_test_utils.py
integration_test_utils.py
  1  import contextlib
  2  import logging
  3  import os
  4  import socket
  5  import sys
  6  import time
  7  from subprocess import Popen
  8  from threading import Thread
  9  from typing import Any, Generator, Literal
 10  
 11  import requests
 12  import uvicorn
 13  from fastapi import FastAPI
 14  
 15  import mlflow
 16  from mlflow.server import ARTIFACT_ROOT_ENV_VAR, BACKEND_STORE_URI_ENV_VAR
 17  
 18  from tests.helper_functions import LOCALHOST, get_safe_port
 19  
 20  _logger = logging.getLogger(__name__)
 21  
 22  
 23  def _await_server_up_or_die(port: int, timeout: int = 30) -> None:
 24      """Waits until the local flask server is listening on the given port."""
 25      _logger.info(f"Awaiting server to be up on {LOCALHOST}:{port}")
 26      start_time = time.time()
 27      while time.time() - start_time < timeout:
 28          with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
 29              sock.settimeout(2)
 30              if sock.connect_ex((LOCALHOST, port)) == 0:
 31                  _logger.info(f"Server is up on {LOCALHOST}:{port}!")
 32                  break
 33          _logger.info("Server not yet up, waiting...")
 34          time.sleep(0.5)
 35      else:
 36          raise Exception(f"Failed to connect on {LOCALHOST}:{port} within {timeout} seconds")
 37  
 38  
 39  @contextlib.contextmanager
 40  def _init_server(
 41      backend_uri: str,
 42      root_artifact_uri: str,
 43      extra_env: dict[str, Any] | None = None,
 44      app: str | None = None,
 45      server_type: Literal["flask", "fastapi"] = "fastapi",
 46  ) -> Generator[str, None, None]:
 47      """
 48      Launch a new REST server using the tracking store specified by backend_uri and root artifact
 49      directory specified by root_artifact_uri.
 50  
 51      Args:
 52          backend_uri: Backend store URI for the server
 53          root_artifact_uri: Root artifact URI for the server
 54          extra_env: Additional environment variables
 55          app: Application module path (defaults based on server_type if None)
 56          server_type: Server type to use - "fastapi" (default) or "flask"
 57  
 58      Yields:
 59          The string URL of the server.
 60      """
 61      mlflow.set_tracking_uri(None)
 62      server_port = get_safe_port()
 63  
 64      if server_type == "fastapi":
 65          # Use uvicorn for FastAPI
 66          cmd = [
 67              sys.executable,
 68              "-m",
 69              "uvicorn",
 70              app or "mlflow.server.fastapi_app:app",
 71              "--host",
 72              LOCALHOST,
 73              "--port",
 74              str(server_port),
 75          ]
 76      else:
 77          # Default to Flask
 78          cmd = [
 79              sys.executable,
 80              "-m",
 81              "flask",
 82              "--app",
 83              app or "mlflow.server:app",
 84              "run",
 85              "--host",
 86              LOCALHOST,
 87              "--port",
 88              str(server_port),
 89          ]
 90  
 91      with Popen(
 92          cmd,
 93          env={
 94              **os.environ,
 95              BACKEND_STORE_URI_ENV_VAR: backend_uri,
 96              ARTIFACT_ROOT_ENV_VAR: root_artifact_uri,
 97              **(extra_env or {}),
 98          },
 99      ) as proc:
100          try:
101              _await_server_up_or_die(server_port)
102              url = f"http://{LOCALHOST}:{server_port}"
103              _logger.info(
104                  f"Launching tracking server on {url} with backend URI {backend_uri} and "
105                  f"artifact root {root_artifact_uri}"
106              )
107              yield url
108          finally:
109              proc.terminate()
110  
111  
112  def _send_rest_tracking_post_request(tracking_server_uri, api_path, json_payload, auth=None):
113      """
114      Make a POST request to the specified MLflow Tracking API and retrieve the
115      corresponding `requests.Response` object
116      """
117      import requests
118  
119      url = tracking_server_uri + api_path
120      return requests.post(url, json=json_payload, auth=auth)
121  
122  
123  class ServerThread(Thread):
124      """Run a FastAPI/uvicorn app in a background thread, usable as a context manager."""
125  
126      def __init__(self, app: FastAPI, port: int):
127          super().__init__(name="mlflow-tracking-server", daemon=True)
128          self.host = "127.0.0.1"
129          self.port = port
130          self.url = f"http://{self.host}:{port}"
131          self.health_url = f"{self.url}/health"
132          config = uvicorn.Config(app, host=self.host, port=self.port, log_level="error", ws="none")
133          self.server = uvicorn.Server(config)
134  
135      def run(self) -> None:
136          """Thread target: let Uvicorn manage its own event loop."""
137          self.server.run()
138  
139      def shutdown(self) -> None:
140          """Ask Uvicorn to exit; the serving loop checks this flag."""
141          self.server.should_exit = True
142  
143      def __enter__(self) -> str:
144          """Use as a context manager for tests or short-lived runs."""
145          self.start()
146  
147          # Quick readiness wait (poll the health endpoint if available)
148          deadline = time.time() + 5.0
149          while time.time() < deadline:
150              try:
151                  r = requests.get(self.health_url, timeout=0.2)
152                  if r.ok:
153                      break
154              except (requests.ConnectionError, requests.Timeout):
155                  pass
156              time.sleep(0.1)
157          return self.url
158  
159      def __exit__(self, exc_type, exc, tb) -> bool | None:
160          """Clean up resources when exiting context."""
161          self.shutdown()
162          # Give the server a moment to wind down
163          self.join(timeout=5.0)
164          return None