/ tests / server / test_init.py
test_init.py
  1  import signal
  2  import socket
  3  import subprocess
  4  import sys
  5  import time
  6  from unittest import mock
  7  
  8  import pytest
  9  
 10  from mlflow import server
 11  from mlflow.environment_variables import _MLFLOW_SGI_NAME
 12  from mlflow.exceptions import MlflowException
 13  from mlflow.utils import find_free_port
 14  from mlflow.utils.os import is_windows
 15  
 16  
 17  @pytest.fixture
 18  def mock_exec_cmd():
 19      with mock.patch("mlflow.server._exec_cmd") as m:
 20          yield m
 21  
 22  
 23  def _wait_for_port(host: str, port: int, proc: subprocess.Popen, timeout: int = 15) -> None:
 24      deadline = time.time() + timeout
 25      while time.time() < deadline:
 26          if proc.poll() is not None:
 27              stdout, stderr = proc.communicate()
 28              raise AssertionError(
 29                  "MLflow server exited before accepting connections.\n"
 30                  f"stdout:\n{stdout}\n"
 31                  f"stderr:\n{stderr}"
 32              )
 33          try:
 34              with socket.create_connection((host, port), timeout=1):
 35                  return
 36          except OSError:
 37              time.sleep(0.1)
 38      raise AssertionError(f"Timed out waiting for {host}:{port} to accept connections")
 39  
 40  
 41  def _wait_for_port_closed(host: str, port: int, timeout: int = 15) -> None:
 42      deadline = time.time() + timeout
 43      while time.time() < deadline:
 44          try:
 45              with socket.create_connection((host, port), timeout=1):
 46                  time.sleep(0.1)
 47          except OSError:
 48              return
 49      raise AssertionError(f"Timed out waiting for {host}:{port} to close")
 50  
 51  
 52  def test_find_app_custom_app_plugin():
 53      assert server._find_app("custom_app") == "mlflow_test_plugin.app:custom_app"
 54  
 55  
 56  def test_find_app_non_existing_app():
 57      with pytest.raises(MlflowException, match=r"Failed to find app 'does_not_exist'"):
 58          server._find_app("does_not_exist")
 59  
 60  
 61  def test_build_waitress_command():
 62      assert server._build_waitress_command(
 63          "", "localhost", "5000", f"{server.__name__}:app", is_factory=True
 64      ) == [
 65          sys.executable,
 66          "-m",
 67          "waitress",
 68          "--host=localhost",
 69          "--port=5000",
 70          "--ident=mlflow",
 71          "--call",
 72          "mlflow.server:app",
 73      ]
 74      assert server._build_waitress_command(
 75          "", "localhost", "5000", f"{server.__name__}:app", is_factory=False
 76      ) == [
 77          sys.executable,
 78          "-m",
 79          "waitress",
 80          "--host=localhost",
 81          "--port=5000",
 82          "--ident=mlflow",
 83          "mlflow.server:app",
 84      ]
 85  
 86  
 87  def test_build_gunicorn_command():
 88      assert server._build_gunicorn_command(
 89          "", "localhost", "5000", "4", f"{server.__name__}:app"
 90      ) == [
 91          sys.executable,
 92          "-m",
 93          "gunicorn",
 94          "-b",
 95          "localhost:5000",
 96          "-w",
 97          "4",
 98          "mlflow.server:app",
 99      ]
100  
101  
102  def test_build_uvicorn_command():
103      assert server._build_uvicorn_command(
104          "", "localhost", "5000", "4", "mlflow.server.fastapi_app:app"
105      ) == [
106          sys.executable,
107          "-m",
108          "uvicorn",
109          "--log-config",
110          str(server._UVICORN_LOG_CONFIG),
111          "--host",
112          "localhost",
113          "--port",
114          "5000",
115          "--workers",
116          "4",
117          "mlflow.server.fastapi_app:app",
118      ]
119  
120      # Test with custom uvicorn options
121      assert server._build_uvicorn_command(
122          "--reload --log-level debug", "localhost", "5000", "4", "mlflow.server.fastapi_app:app"
123      ) == [
124          sys.executable,
125          "-m",
126          "uvicorn",
127          "--reload",
128          "--log-level",
129          "debug",
130          "--log-config",
131          str(server._UVICORN_LOG_CONFIG),
132          "--host",
133          "localhost",
134          "--port",
135          "5000",
136          "--workers",
137          "4",
138          "mlflow.server.fastapi_app:app",
139      ]
140  
141      assert server._build_uvicorn_command(
142          "", "localhost", "5000", "4", "mlflow.server.fastapi_app:app", None, is_factory=True
143      ) == [
144          sys.executable,
145          "-m",
146          "uvicorn",
147          "--log-config",
148          str(server._UVICORN_LOG_CONFIG),
149          "--host",
150          "localhost",
151          "--port",
152          "5000",
153          "--workers",
154          "4",
155          "--factory",
156          "mlflow.server.fastapi_app:app",
157      ]
158  
159  
160  def test_build_uvicorn_command_with_env_file():
161      cmd = server._build_uvicorn_command(
162          uvicorn_opts=None,
163          host="localhost",
164          port=5000,
165          workers=4,
166          app_name="app:app",
167          env_file="/path/to/.env",
168      )
169  
170      assert "--env-file" in cmd
171      assert "/path/to/.env" in cmd
172      assert "--log-config" in cmd
173      # Verify the order - env-file should come before the app name
174      env_file_idx = cmd.index("--env-file")
175      env_file_path_idx = cmd.index("/path/to/.env")
176      app_name_idx = cmd.index("app:app")
177      assert env_file_idx < app_name_idx
178      assert env_file_path_idx == env_file_idx + 1
179      assert env_file_path_idx < app_name_idx
180  
181  
182  def test_run_server(mock_exec_cmd, monkeypatch):
183      monkeypatch.setenv("MLFLOW_SERVER_ENABLE_JOB_EXECUTION", "false")
184      with mock.patch("sys.platform", return_value="linux"):
185          server._run_server(
186              file_store_path="",
187              registry_store_uri="",
188              default_artifact_root="",
189              serve_artifacts="",
190              artifacts_only="",
191              artifacts_destination="",
192              host="",
193              port="",
194          )
195      mock_exec_cmd.assert_called_once()
196  
197  
198  def test_run_server_win32(mock_exec_cmd, monkeypatch):
199      monkeypatch.setenv("MLFLOW_SERVER_ENABLE_JOB_EXECUTION", "false")
200      with mock.patch("sys.platform", return_value="win32"):
201          server._run_server(
202              file_store_path="",
203              registry_store_uri="",
204              default_artifact_root="",
205              serve_artifacts="",
206              artifacts_only="",
207              artifacts_destination="",
208              host="",
209              port="",
210          )
211      mock_exec_cmd.assert_called_once()
212  
213  
214  def test_run_server_with_uvicorn(mock_exec_cmd, monkeypatch):
215      monkeypatch.setenv("MLFLOW_SERVER_ENABLE_JOB_EXECUTION", "false")
216      with mock.patch("sys.platform", return_value="linux"):
217          server._run_server(
218              file_store_path="",
219              registry_store_uri="",
220              default_artifact_root="",
221              serve_artifacts="",
222              artifacts_only="",
223              artifacts_destination="",
224              host="localhost",
225              port="5000",
226              uvicorn_opts="--reload",
227          )
228      expected_command = [
229          sys.executable,
230          "-m",
231          "uvicorn",
232          "--reload",
233          "--log-config",
234          str(server._UVICORN_LOG_CONFIG),
235          "--host",
236          "localhost",
237          "--port",
238          "5000",
239          "--workers",
240          "4",
241          "mlflow.server.fastapi_app:app",
242      ]
243      mock_exec_cmd.assert_called_once_with(
244          expected_command,
245          extra_env={
246              _MLFLOW_SGI_NAME.name: "uvicorn",
247          },
248          capture_output=False,
249          synchronous=False,
250      )
251  
252  
253  @pytest.mark.parametrize(
254      "uvicorn_opts",
255      [
256          "--log-config /custom/path.yaml",
257          "--log-config=/custom/path.yaml",
258      ],
259  )
260  def test_build_uvicorn_command_user_log_config_takes_precedence(uvicorn_opts):
261      cmd = server._build_uvicorn_command(
262          uvicorn_opts, "localhost", "5000", "4", "mlflow.server.fastapi_app:app"
263      )
264      assert not any("uvicorn_log_config.yaml" in o for o in cmd)
265  
266  
267  @pytest.mark.parametrize(
268      "sig",
269      [
270          pytest.param(
271              signal.SIGTERM,
272              marks=pytest.mark.skipif(is_windows(), reason="SIGTERM is a hard kill on Windows"),
273          ),
274          signal.SIGINT,
275      ],
276  )
277  def test_mlflow_server_shuts_down_on_signal(sig: signal.Signals, tmp_path):
278      port = find_free_port()
279      db_path = tmp_path / "mlflow.db"
280      cmd = [
281          sys.executable,
282          "-m",
283          "mlflow",
284          "server",
285          "--host",
286          "127.0.0.1",
287          "--port",
288          str(port),
289          "--workers",
290          "1",
291          "--backend-store-uri",
292          f"sqlite:///{db_path}",
293      ]
294      if is_windows():
295          proc = subprocess.Popen(cmd, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP)
296      else:
297          proc = subprocess.Popen(cmd)
298      try:
299          _wait_for_port("127.0.0.1", port, proc, timeout=60 if is_windows() else 15)
300          if is_windows():
301              proc.send_signal(signal.CTRL_BREAK_EVENT)
302          else:
303              proc.send_signal(sig)
304          proc.wait(timeout=30 if is_windows() else 15)
305          _wait_for_port_closed("127.0.0.1", port)
306          # Exit code 0 means graceful shutdown (signal was caught and handled)
307          # -sig or 128+sig means the process was killed by the signal
308          # On Windows, CTRL_BREAK_EVENT maps to 0xC000013A.
309          if is_windows():
310              assert proc.returncode in (0, 0xC000013A)
311          else:
312              assert proc.returncode in (0, -sig, 128 + sig)
313      finally:
314          if proc.poll() is None:
315              proc.terminate()
316              try:
317                  proc.wait(timeout=5)
318              except subprocess.TimeoutExpired:
319                  proc.kill()