integration_test_utils.py
1 import contextlib 2 import logging 3 import os 4 import socket 5 import sys 6 import time 7 from subprocess import Popen 8 from threading import Thread 9 from typing import Any, Generator, Literal 10 11 import requests 12 import uvicorn 13 from fastapi import FastAPI 14 15 import mlflow 16 from mlflow.server import ARTIFACT_ROOT_ENV_VAR, BACKEND_STORE_URI_ENV_VAR 17 18 from tests.helper_functions import LOCALHOST, get_safe_port 19 20 _logger = logging.getLogger(__name__) 21 22 23 def _await_server_up_or_die(port: int, timeout: int = 30) -> None: 24 """Waits until the local flask server is listening on the given port.""" 25 _logger.info(f"Awaiting server to be up on {LOCALHOST}:{port}") 26 start_time = time.time() 27 while time.time() - start_time < timeout: 28 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 29 sock.settimeout(2) 30 if sock.connect_ex((LOCALHOST, port)) == 0: 31 _logger.info(f"Server is up on {LOCALHOST}:{port}!") 32 break 33 _logger.info("Server not yet up, waiting...") 34 time.sleep(0.5) 35 else: 36 raise Exception(f"Failed to connect on {LOCALHOST}:{port} within {timeout} seconds") 37 38 39 @contextlib.contextmanager 40 def _init_server( 41 backend_uri: str, 42 root_artifact_uri: str, 43 extra_env: dict[str, Any] | None = None, 44 app: str | None = None, 45 server_type: Literal["flask", "fastapi"] = "fastapi", 46 ) -> Generator[str, None, None]: 47 """ 48 Launch a new REST server using the tracking store specified by backend_uri and root artifact 49 directory specified by root_artifact_uri. 50 51 Args: 52 backend_uri: Backend store URI for the server 53 root_artifact_uri: Root artifact URI for the server 54 extra_env: Additional environment variables 55 app: Application module path (defaults based on server_type if None) 56 server_type: Server type to use - "fastapi" (default) or "flask" 57 58 Yields: 59 The string URL of the server. 60 """ 61 mlflow.set_tracking_uri(None) 62 server_port = get_safe_port() 63 64 if server_type == "fastapi": 65 # Use uvicorn for FastAPI 66 cmd = [ 67 sys.executable, 68 "-m", 69 "uvicorn", 70 app or "mlflow.server.fastapi_app:app", 71 "--host", 72 LOCALHOST, 73 "--port", 74 str(server_port), 75 ] 76 else: 77 # Default to Flask 78 cmd = [ 79 sys.executable, 80 "-m", 81 "flask", 82 "--app", 83 app or "mlflow.server:app", 84 "run", 85 "--host", 86 LOCALHOST, 87 "--port", 88 str(server_port), 89 ] 90 91 with Popen( 92 cmd, 93 env={ 94 **os.environ, 95 BACKEND_STORE_URI_ENV_VAR: backend_uri, 96 ARTIFACT_ROOT_ENV_VAR: root_artifact_uri, 97 **(extra_env or {}), 98 }, 99 ) as proc: 100 try: 101 _await_server_up_or_die(server_port) 102 url = f"http://{LOCALHOST}:{server_port}" 103 _logger.info( 104 f"Launching tracking server on {url} with backend URI {backend_uri} and " 105 f"artifact root {root_artifact_uri}" 106 ) 107 yield url 108 finally: 109 proc.terminate() 110 111 112 def _send_rest_tracking_post_request(tracking_server_uri, api_path, json_payload, auth=None): 113 """ 114 Make a POST request to the specified MLflow Tracking API and retrieve the 115 corresponding `requests.Response` object 116 """ 117 import requests 118 119 url = tracking_server_uri + api_path 120 return requests.post(url, json=json_payload, auth=auth) 121 122 123 class ServerThread(Thread): 124 """Run a FastAPI/uvicorn app in a background thread, usable as a context manager.""" 125 126 def __init__(self, app: FastAPI, port: int): 127 super().__init__(name="mlflow-tracking-server", daemon=True) 128 self.host = "127.0.0.1" 129 self.port = port 130 self.url = f"http://{self.host}:{port}" 131 self.health_url = f"{self.url}/health" 132 config = uvicorn.Config(app, host=self.host, port=self.port, log_level="error", ws="none") 133 self.server = uvicorn.Server(config) 134 135 def run(self) -> None: 136 """Thread target: let Uvicorn manage its own event loop.""" 137 self.server.run() 138 139 def shutdown(self) -> None: 140 """Ask Uvicorn to exit; the serving loop checks this flag.""" 141 self.server.should_exit = True 142 143 def __enter__(self) -> str: 144 """Use as a context manager for tests or short-lived runs.""" 145 self.start() 146 147 # Quick readiness wait (poll the health endpoint if available) 148 deadline = time.time() + 5.0 149 while time.time() < deadline: 150 try: 151 r = requests.get(self.health_url, timeout=0.2) 152 if r.ok: 153 break 154 except (requests.ConnectionError, requests.Timeout): 155 pass 156 time.sleep(0.1) 157 return self.url 158 159 def __exit__(self, exc_type, exc, tb) -> bool | None: 160 """Clean up resources when exiting context.""" 161 self.shutdown() 162 # Give the server a moment to wind down 163 self.join(timeout=5.0) 164 return None