tools.py
1 import asyncio 2 import json 3 import os 4 import signal 5 import subprocess 6 import sys 7 import threading 8 import time 9 from pathlib import Path 10 from typing import Any, NamedTuple 11 from unittest import mock 12 13 import aiohttp 14 import requests 15 import uvicorn 16 import yaml 17 from sentence_transformers import SentenceTransformer 18 19 import mlflow 20 from mlflow.gateway import app 21 from mlflow.gateway.utils import kill_child_processes 22 23 from tests.helper_functions import _get_mlflow_home, _start_scoring_proc, get_safe_port 24 25 26 class Gateway: 27 def __init__(self, config_path: str | Path, *args, **kwargs): 28 self.port = get_safe_port() 29 self.host = "localhost" 30 self.url = f"http://{self.host}:{self.port}" 31 self.workers = 2 32 self.process = subprocess.Popen( 33 [ 34 sys.executable, 35 "-m", 36 "mlflow", 37 "gateway", 38 "start", 39 "--config-path", 40 config_path, 41 "--host", 42 self.host, 43 "--port", 44 str(self.port), 45 "--workers", 46 str(self.workers), 47 ], 48 *args, 49 **kwargs, 50 ) 51 self.wait_until_ready() 52 53 def wait_until_ready(self) -> None: 54 s = time.time() 55 while time.time() - s < 10: 56 try: 57 if self.get("health").ok: 58 return 59 except requests.exceptions.ConnectionError: 60 time.sleep(0.5) 61 62 raise Exception("Gateway failed to start") 63 64 def wait_reload(self) -> None: 65 """ 66 Should be called after we update a gateway config file in tests to ensure 67 that the gateway service has reloaded the config. 68 """ 69 time.sleep(self.workers) 70 71 def request(self, method: str, path: str, *args: Any, **kwargs: Any) -> requests.Response: 72 return requests.request(method, f"{self.url}/{path}", *args, **kwargs) 73 74 def get(self, path: str, *args: Any, **kwargs: Any) -> requests.Response: 75 return self.request("GET", path, *args, **kwargs) 76 77 def assert_health(self): 78 assert self.get("health").ok 79 80 def post(self, path: str, *args: Any, **kwargs: Any) -> requests.Response: 81 return self.request("POST", path, *args, **kwargs) 82 83 def __enter__(self): 84 return self 85 86 def __exit__(self, exc_type, exc_val, exc_tb): 87 kill_child_processes(self.process.pid) 88 self.process.terminate() 89 self.process.wait() 90 91 92 def save_yaml(path, conf): 93 path.write_text(yaml.safe_dump(conf)) 94 95 96 class MockAsyncResponse: 97 def __init__(self, data: dict[str, Any], status: int = 200): 98 # Extract status and headers from data, if present 99 self.status = status 100 self.headers = data.pop("headers", {"Content-Type": "application/json"}) 101 102 # Save the rest of the data as content 103 self._content = data 104 105 def raise_for_status(self) -> None: 106 if 400 <= self.status < 600: 107 raise aiohttp.ClientResponseError(None, None, status=self.status) 108 109 async def json(self) -> dict[str, Any]: 110 return self._content 111 112 async def text(self) -> str: 113 return json.dumps(self._content) 114 115 async def __aenter__(self): 116 return self 117 118 async def __aexit__(self, exc_type, exc, traceback): 119 pass 120 121 122 class MockAsyncStreamingResponse: 123 def __init__(self, data: list[bytes], headers: dict[str, str] | None = None, status: int = 200): 124 self.status = status 125 self.headers = headers 126 self._content = data 127 128 def raise_for_status(self) -> None: 129 if 400 <= self.status < 600: 130 raise aiohttp.ClientResponseError(None, None, status=self.status) 131 132 async def _async_content(self): 133 for line in self._content: 134 yield line 135 136 @property 137 def content(self): 138 return self._async_content() 139 140 async def __aenter__(self): 141 return self 142 143 async def __aexit__(self, exc_type, exc, traceback): 144 pass 145 146 147 class MockHttpClient(mock.Mock): 148 def __init__(self, mock_response=None, *args, **kwargs): 149 super().__init__(*args, **kwargs) 150 self._mock_response = mock_response 151 # Create a mock for post that returns the response 152 self.post = mock.Mock(return_value=mock_response) 153 154 async def __aenter__(self): 155 return self 156 157 async def __aexit__(self, *args): 158 return 159 160 161 def mock_http_client(mock_response: MockAsyncResponse | MockAsyncStreamingResponse): 162 return MockHttpClient(mock_response=mock_response) 163 164 165 class UvicornGateway: 166 # This test utility class is used to validate the internal functionality of the 167 # AI Gateway within-process so that the provider endpoints can be mocked, 168 # allowing a nearly end-to-end validation of the entire AI Gateway stack. 169 # NB: this implementation should only be used for integration testing. Unit tests that 170 # require validation of the AI Gateway server should use the `Gateway` implementation in 171 # this module which executes the uvicorn server through gunicorn as a process manager. 172 def __init__(self, config_path: str | Path, *args, **kwargs): 173 self.port = get_safe_port() 174 self.host = "127.0.0.1" 175 self.url = f"http://{self.host}:{self.port}" 176 self.config_path = config_path 177 self.server = None 178 self.loop = None 179 self.thread = None 180 self.stop_event = threading.Event() 181 182 def start_server(self): 183 uvicorn_app = app.create_app_from_path(self.config_path) 184 185 self.loop = asyncio.new_event_loop() 186 asyncio.set_event_loop(self.loop) 187 188 config = uvicorn.Config( 189 app=uvicorn_app, 190 host=self.host, 191 port=self.port, 192 lifespan="on", 193 loop="auto", 194 log_level="info", 195 ws="none", 196 ) 197 self.server = uvicorn.Server(config) 198 199 def run(): 200 self.loop.run_until_complete(self.server.serve()) 201 202 self.thread = threading.Thread(name="gateway-server", target=run) 203 self.thread.start() 204 205 def request(self, method: str, path: str, *args: Any, **kwargs: Any) -> requests.Response: 206 return requests.request(method, f"{self.url}/{path}", *args, **kwargs) 207 208 def get(self, path: str, *args: Any, **kwargs: Any) -> requests.Response: 209 return self.request("GET", path, *args, **kwargs) 210 211 def assert_health(self): 212 assert self.get("health").ok 213 214 def post(self, path: str, *args: Any, **kwargs: Any) -> requests.Response: 215 return self.request("POST", path, *args, **kwargs) 216 217 def stop(self): 218 if self.server is not None: 219 self.server.should_exit = True # Instruct the uvicorn server to stop 220 self.stop_event.wait() # Wait for the server to actually stop 221 self.thread.join() # block until thread termination 222 self.server = None 223 self.loop = None 224 self.thread = None 225 226 def __enter__(self): 227 self.start_server() 228 return self 229 230 def __exit__(self, exc_type, exc_val, exc_tb): 231 # Stop the server and the thread 232 if self.server is not None: 233 self.server.should_exit = True 234 self.thread.join() 235 236 237 class ServerInfo(NamedTuple): 238 pid: int 239 url: str 240 241 242 def log_sentence_transformers_model(): 243 model = SentenceTransformer("all-MiniLM-L6-v2") 244 artifact_path = "gen_model" 245 246 with mlflow.start_run(): 247 model_info = mlflow.sentence_transformers.log_model( 248 model, 249 name=artifact_path, 250 ) 251 return model_info.model_uri 252 253 254 def start_mlflow_server(port, model_uri): 255 server_url = f"http://127.0.0.1:{port}" 256 257 env = dict(os.environ) 258 env.update(LC_ALL="en_US.UTF-8", LANG="en_US.UTF-8") 259 env.update(MLFLOW_TRACKING_URI=mlflow.get_tracking_uri()) 260 env.update(MLFLOW_HOME=_get_mlflow_home()) 261 scoring_cmd = [ 262 "mlflow", 263 "models", 264 "serve", 265 "-m", 266 model_uri, 267 "-p", 268 str(port), 269 "--install-mlflow", 270 "--no-conda", 271 ] 272 273 server_pid = _start_scoring_proc(cmd=scoring_cmd, env=env, stdout=sys.stdout, stderr=sys.stdout) 274 275 ping_status = None 276 for i in range(120): 277 time.sleep(1) 278 try: 279 ping_status = requests.get(url=f"{server_url}/ping") 280 if ping_status.status_code == 200: 281 break 282 except Exception: 283 pass 284 if ping_status is None or ping_status.status_code != 200: 285 raise Exception("Could not start mlflow serving instance.") 286 287 return ServerInfo(pid=server_pid, url=server_url) 288 289 290 def stop_mlflow_server(server_pid): 291 process_group = os.getpgid(server_pid.pid) 292 os.killpg(process_group, signal.SIGTERM)