/ tests / gateway / tools.py
tools.py
  1  import asyncio
  2  import json
  3  import os
  4  import signal
  5  import subprocess
  6  import sys
  7  import threading
  8  import time
  9  from pathlib import Path
 10  from typing import Any, NamedTuple
 11  from unittest import mock
 12  
 13  import aiohttp
 14  import requests
 15  import uvicorn
 16  import yaml
 17  from sentence_transformers import SentenceTransformer
 18  
 19  import mlflow
 20  from mlflow.gateway import app
 21  from mlflow.gateway.utils import kill_child_processes
 22  
 23  from tests.helper_functions import _get_mlflow_home, _start_scoring_proc, get_safe_port
 24  
 25  
 26  class Gateway:
 27      def __init__(self, config_path: str | Path, *args, **kwargs):
 28          self.port = get_safe_port()
 29          self.host = "localhost"
 30          self.url = f"http://{self.host}:{self.port}"
 31          self.workers = 2
 32          self.process = subprocess.Popen(
 33              [
 34                  sys.executable,
 35                  "-m",
 36                  "mlflow",
 37                  "gateway",
 38                  "start",
 39                  "--config-path",
 40                  config_path,
 41                  "--host",
 42                  self.host,
 43                  "--port",
 44                  str(self.port),
 45                  "--workers",
 46                  str(self.workers),
 47              ],
 48              *args,
 49              **kwargs,
 50          )
 51          self.wait_until_ready()
 52  
 53      def wait_until_ready(self) -> None:
 54          s = time.time()
 55          while time.time() - s < 10:
 56              try:
 57                  if self.get("health").ok:
 58                      return
 59              except requests.exceptions.ConnectionError:
 60                  time.sleep(0.5)
 61  
 62          raise Exception("Gateway failed to start")
 63  
 64      def wait_reload(self) -> None:
 65          """
 66          Should be called after we update a gateway config file in tests to ensure
 67          that the gateway service has reloaded the config.
 68          """
 69          time.sleep(self.workers)
 70  
 71      def request(self, method: str, path: str, *args: Any, **kwargs: Any) -> requests.Response:
 72          return requests.request(method, f"{self.url}/{path}", *args, **kwargs)
 73  
 74      def get(self, path: str, *args: Any, **kwargs: Any) -> requests.Response:
 75          return self.request("GET", path, *args, **kwargs)
 76  
 77      def assert_health(self):
 78          assert self.get("health").ok
 79  
 80      def post(self, path: str, *args: Any, **kwargs: Any) -> requests.Response:
 81          return self.request("POST", path, *args, **kwargs)
 82  
 83      def __enter__(self):
 84          return self
 85  
 86      def __exit__(self, exc_type, exc_val, exc_tb):
 87          kill_child_processes(self.process.pid)
 88          self.process.terminate()
 89          self.process.wait()
 90  
 91  
 92  def save_yaml(path, conf):
 93      path.write_text(yaml.safe_dump(conf))
 94  
 95  
 96  class MockAsyncResponse:
 97      def __init__(self, data: dict[str, Any], status: int = 200):
 98          # Extract status and headers from data, if present
 99          self.status = status
100          self.headers = data.pop("headers", {"Content-Type": "application/json"})
101  
102          # Save the rest of the data as content
103          self._content = data
104  
105      def raise_for_status(self) -> None:
106          if 400 <= self.status < 600:
107              raise aiohttp.ClientResponseError(None, None, status=self.status)
108  
109      async def json(self) -> dict[str, Any]:
110          return self._content
111  
112      async def text(self) -> str:
113          return json.dumps(self._content)
114  
115      async def __aenter__(self):
116          return self
117  
118      async def __aexit__(self, exc_type, exc, traceback):
119          pass
120  
121  
122  class MockAsyncStreamingResponse:
123      def __init__(self, data: list[bytes], headers: dict[str, str] | None = None, status: int = 200):
124          self.status = status
125          self.headers = headers
126          self._content = data
127  
128      def raise_for_status(self) -> None:
129          if 400 <= self.status < 600:
130              raise aiohttp.ClientResponseError(None, None, status=self.status)
131  
132      async def _async_content(self):
133          for line in self._content:
134              yield line
135  
136      @property
137      def content(self):
138          return self._async_content()
139  
140      async def __aenter__(self):
141          return self
142  
143      async def __aexit__(self, exc_type, exc, traceback):
144          pass
145  
146  
147  class MockHttpClient(mock.Mock):
148      def __init__(self, mock_response=None, *args, **kwargs):
149          super().__init__(*args, **kwargs)
150          self._mock_response = mock_response
151          # Create a mock for post that returns the response
152          self.post = mock.Mock(return_value=mock_response)
153  
154      async def __aenter__(self):
155          return self
156  
157      async def __aexit__(self, *args):
158          return
159  
160  
161  def mock_http_client(mock_response: MockAsyncResponse | MockAsyncStreamingResponse):
162      return MockHttpClient(mock_response=mock_response)
163  
164  
165  class UvicornGateway:
166      # This test utility class is used to validate the internal functionality of the
167      # AI Gateway within-process so that the provider endpoints can be mocked,
168      # allowing a nearly end-to-end validation of the entire AI Gateway stack.
169      # NB: this implementation should only be used for integration testing. Unit tests that
170      # require validation of the AI Gateway server should use the `Gateway` implementation in
171      # this module which executes the uvicorn server through gunicorn as a process manager.
172      def __init__(self, config_path: str | Path, *args, **kwargs):
173          self.port = get_safe_port()
174          self.host = "127.0.0.1"
175          self.url = f"http://{self.host}:{self.port}"
176          self.config_path = config_path
177          self.server = None
178          self.loop = None
179          self.thread = None
180          self.stop_event = threading.Event()
181  
182      def start_server(self):
183          uvicorn_app = app.create_app_from_path(self.config_path)
184  
185          self.loop = asyncio.new_event_loop()
186          asyncio.set_event_loop(self.loop)
187  
188          config = uvicorn.Config(
189              app=uvicorn_app,
190              host=self.host,
191              port=self.port,
192              lifespan="on",
193              loop="auto",
194              log_level="info",
195              ws="none",
196          )
197          self.server = uvicorn.Server(config)
198  
199          def run():
200              self.loop.run_until_complete(self.server.serve())
201  
202          self.thread = threading.Thread(name="gateway-server", target=run)
203          self.thread.start()
204  
205      def request(self, method: str, path: str, *args: Any, **kwargs: Any) -> requests.Response:
206          return requests.request(method, f"{self.url}/{path}", *args, **kwargs)
207  
208      def get(self, path: str, *args: Any, **kwargs: Any) -> requests.Response:
209          return self.request("GET", path, *args, **kwargs)
210  
211      def assert_health(self):
212          assert self.get("health").ok
213  
214      def post(self, path: str, *args: Any, **kwargs: Any) -> requests.Response:
215          return self.request("POST", path, *args, **kwargs)
216  
217      def stop(self):
218          if self.server is not None:
219              self.server.should_exit = True  # Instruct the uvicorn server to stop
220              self.stop_event.wait()  # Wait for the server to actually stop
221              self.thread.join()  # block until thread termination
222              self.server = None
223              self.loop = None
224              self.thread = None
225  
226      def __enter__(self):
227          self.start_server()
228          return self
229  
230      def __exit__(self, exc_type, exc_val, exc_tb):
231          # Stop the server and the thread
232          if self.server is not None:
233              self.server.should_exit = True
234          self.thread.join()
235  
236  
237  class ServerInfo(NamedTuple):
238      pid: int
239      url: str
240  
241  
242  def log_sentence_transformers_model():
243      model = SentenceTransformer("all-MiniLM-L6-v2")
244      artifact_path = "gen_model"
245  
246      with mlflow.start_run():
247          model_info = mlflow.sentence_transformers.log_model(
248              model,
249              name=artifact_path,
250          )
251          return model_info.model_uri
252  
253  
254  def start_mlflow_server(port, model_uri):
255      server_url = f"http://127.0.0.1:{port}"
256  
257      env = dict(os.environ)
258      env.update(LC_ALL="en_US.UTF-8", LANG="en_US.UTF-8")
259      env.update(MLFLOW_TRACKING_URI=mlflow.get_tracking_uri())
260      env.update(MLFLOW_HOME=_get_mlflow_home())
261      scoring_cmd = [
262          "mlflow",
263          "models",
264          "serve",
265          "-m",
266          model_uri,
267          "-p",
268          str(port),
269          "--install-mlflow",
270          "--no-conda",
271      ]
272  
273      server_pid = _start_scoring_proc(cmd=scoring_cmd, env=env, stdout=sys.stdout, stderr=sys.stdout)
274  
275      ping_status = None
276      for i in range(120):
277          time.sleep(1)
278          try:
279              ping_status = requests.get(url=f"{server_url}/ping")
280              if ping_status.status_code == 200:
281                  break
282          except Exception:
283              pass
284      if ping_status is None or ping_status.status_code != 200:
285          raise Exception("Could not start mlflow serving instance.")
286  
287      return ServerInfo(pid=server_pid, url=server_url)
288  
289  
290  def stop_mlflow_server(server_pid):
291      process_group = os.getpgid(server_pid.pid)
292      os.killpg(process_group, signal.SIGTERM)