/ tests / tracking / _tracking_service / test_utils.py
test_utils.py
  1  import io
  2  import itertools
  3  import os
  4  import pickle
  5  import uuid
  6  from importlib import reload
  7  from pathlib import Path
  8  from unittest import mock
  9  from urllib.parse import urlparse
 10  from urllib.request import url2pathname
 11  
 12  import pytest
 13  
 14  import mlflow
 15  from mlflow.environment_variables import (
 16      MLFLOW_ENABLE_WORKSPACES,
 17      MLFLOW_TRACKING_INSECURE_TLS,
 18      MLFLOW_TRACKING_PASSWORD,
 19      MLFLOW_TRACKING_TOKEN,
 20      MLFLOW_TRACKING_URI,
 21      MLFLOW_TRACKING_USERNAME,
 22  )
 23  from mlflow.exceptions import MlflowException
 24  from mlflow.server import ARTIFACT_ROOT_ENV_VAR
 25  from mlflow.store.db.db_types import DATABASE_ENGINES
 26  from mlflow.store.tracking.databricks_rest_store import DatabricksTracingRestStore
 27  from mlflow.store.tracking.file_store import FileStore
 28  from mlflow.store.tracking.rest_store import RestStore
 29  from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore
 30  from mlflow.tracking._tracking_service.registry import TrackingStoreRegistry
 31  from mlflow.tracking._tracking_service.utils import (
 32      _get_store,
 33      _get_tracking_scheme,
 34      _resolve_custom_scheme,
 35      _resolve_tracking_uri,
 36      _use_tracking_uri,
 37      get_tracking_uri,
 38      set_tracking_uri,
 39  )
 40  from mlflow.tracking.registry import UnsupportedModelRegistryStoreURIException
 41  from mlflow.utils.file_utils import path_to_local_file_uri
 42  from mlflow.utils.os import is_windows
 43  
 44  from tests.helpers.db_mocks import mock_get_managed_session_maker
 45  from tests.tracing.helper import get_tracer_tracking_uri
 46  
 47  # Disable mocking tracking URI here, as we want to test setting the tracking URI via
 48  # environment variable. See
 49  # http://doc.pytest.org/en/latest/skipping.html#skip-all-test-functions-of-a-class-or-module
 50  # and https://github.com/mlflow/mlflow/blob/master/CONTRIBUTING.md#writing-python-tests
 51  # for more information.
 52  pytestmark = pytest.mark.notrackingurimock
 53  
 54  
 55  @pytest.mark.skip(reason="FileStore is no longer supported")
 56  def test_tracking_scheme_with_existing_mlruns(tmp_path, monkeypatch):
 57      monkeypatch.chdir(tmp_path)
 58      mlruns_dir = tmp_path / "mlruns"
 59      mlruns_dir.mkdir()
 60      exp_dir = mlruns_dir / "0"
 61      exp_dir.mkdir()
 62      (exp_dir / "meta.yaml").touch()
 63      store = _get_store()
 64      assert isinstance(store, FileStore)
 65  
 66  
 67  def test_tracking_scheme_without_existing_mlruns(tmp_path, monkeypatch):
 68      monkeypatch.chdir(tmp_path)
 69      store = _get_store()
 70      assert isinstance(store, SqlAlchemyStore)
 71  
 72  
 73  @pytest.mark.skip(reason="FileStore is no longer supported")
 74  def test_get_store_with_existing_mlruns_data(tmp_path, monkeypatch):
 75      monkeypatch.chdir(tmp_path)
 76      mlruns_dir = tmp_path / "mlruns"
 77      mlruns_dir.mkdir()
 78      exp_dir = mlruns_dir / "0"
 79      exp_dir.mkdir()
 80      (exp_dir / "meta.yaml").touch()
 81  
 82      store = _get_store()
 83      assert isinstance(store, FileStore)
 84      assert os.path.abspath(store.root_directory) == os.path.abspath("mlruns")
 85  
 86  
 87  def test_get_store_with_empty_mlruns(tmp_path, monkeypatch):
 88      monkeypatch.chdir(tmp_path)
 89      mlruns_dir = tmp_path / "mlruns"
 90      mlruns_dir.mkdir()
 91  
 92      store = _get_store()
 93      assert isinstance(store, SqlAlchemyStore)
 94  
 95  
 96  def test_get_store_with_mlruns_dir_but_no_meta_yaml(tmp_path, monkeypatch):
 97      monkeypatch.chdir(tmp_path)
 98      mlruns_dir = tmp_path / "mlruns"
 99      mlruns_dir.mkdir()
100      (mlruns_dir / "0").mkdir()
101  
102      store = _get_store()
103      assert isinstance(store, SqlAlchemyStore)
104  
105  
106  def test_default_sqlite_tracking_uri_respects_cwd(tmp_path, monkeypatch):
107      monkeypatch.chdir(tmp_path)
108      with _use_tracking_uri(None):
109          store = _get_store()
110  
111      assert isinstance(store, SqlAlchemyStore)
112      sqlite_uri = store.db_uri
113      assert sqlite_uri.startswith("sqlite:")
114      parsed = urlparse(sqlite_uri)
115      path = parsed.path
116      if not parsed.netloc and path.startswith("//"):
117          path = path[1:]
118      if parsed.netloc:
119          path = f"//{parsed.netloc}{path}"
120      db_path = Path(url2pathname(path))
121      assert db_path.parent == tmp_path
122  
123  
124  @pytest.mark.skip(reason="FileStore is no longer supported")
125  def test_get_store_file_store_from_arg(tmp_path, monkeypatch):
126      monkeypatch.chdir(tmp_path)
127      store = _get_store("other/path")
128      assert isinstance(store, FileStore)
129      assert os.path.abspath(store.root_directory) == os.path.abspath("other/path")
130  
131  
132  @pytest.mark.skip(reason="FileStore is no longer supported")
133  @pytest.mark.parametrize("uri", ["other/path", "file:other/path"])
134  def test_get_store_file_store_from_env(tmp_path, monkeypatch, uri):
135      monkeypatch.chdir(tmp_path)
136      monkeypatch.setenv(MLFLOW_TRACKING_URI.name, uri)
137      store = _get_store()
138      assert isinstance(store, FileStore)
139      assert os.path.abspath(store.root_directory) == os.path.abspath("other/path")
140  
141  
142  def test_get_store_basic_rest_store(monkeypatch):
143      monkeypatch.setenv(MLFLOW_TRACKING_URI.name, "https://my-tracking-server:5050")
144      store = _get_store()
145      assert isinstance(store, RestStore)
146      assert store.get_host_creds().host == "https://my-tracking-server:5050"
147      assert store.get_host_creds().token is None
148      assert _get_tracking_scheme() == "https"
149  
150  
151  def test_get_store_rest_store_with_password(monkeypatch):
152      for k, v in {
153          MLFLOW_TRACKING_URI.name: "https://my-tracking-server:5050",
154          MLFLOW_TRACKING_USERNAME.name: "Bob",
155          MLFLOW_TRACKING_PASSWORD.name: "Ross",
156      }.items():
157          monkeypatch.setenv(k, v)
158  
159      store = _get_store()
160      assert isinstance(store, RestStore)
161      assert store.get_host_creds().host == "https://my-tracking-server:5050"
162      assert store.get_host_creds().username == "Bob"
163      assert store.get_host_creds().password == "Ross"
164  
165  
166  def test_get_store_rest_store_with_token(monkeypatch):
167      for k, v in {
168          MLFLOW_TRACKING_URI.name: "https://my-tracking-server:5050",
169          MLFLOW_TRACKING_TOKEN.name: "my-token",
170      }.items():
171          monkeypatch.setenv(k, v)
172  
173      store = _get_store()
174      assert isinstance(store, RestStore)
175      assert store.get_host_creds().token == "my-token"
176  
177  
178  def test_get_store_rest_store_with_insecure(monkeypatch):
179      for k, v in {
180          MLFLOW_TRACKING_URI.name: "https://my-tracking-server:5050",
181          MLFLOW_TRACKING_INSECURE_TLS.name: "true",
182      }.items():
183          monkeypatch.setenv(k, v)
184      store = _get_store()
185      assert isinstance(store, RestStore)
186      assert store.get_host_creds().ignore_tls_verification
187  
188  
189  def test_get_store_rest_store_with_no_insecure(monkeypatch):
190      with monkeypatch.context() as m:
191          for k, v in {
192              MLFLOW_TRACKING_URI.name: "https://my-tracking-server:5050",
193              MLFLOW_TRACKING_INSECURE_TLS.name: "false",
194          }.items():
195              m.setenv(k, v)
196          store = _get_store()
197          assert isinstance(store, RestStore)
198          assert not store.get_host_creds().ignore_tls_verification
199  
200      # By default, should not ignore verification.
201      with monkeypatch.context() as m:
202          monkeypatch.setenv(MLFLOW_TRACKING_URI.name, "https://my-tracking-server:5050")
203          store = _get_store()
204          assert isinstance(store, RestStore)
205          assert not store.get_host_creds().ignore_tls_verification
206  
207  
208  @pytest.mark.parametrize("db_type", DATABASE_ENGINES)
209  def test_get_store_sqlalchemy_store(tmp_path, monkeypatch, db_type):
210      monkeypatch.chdir(tmp_path)
211      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "false")
212      uri = f"{db_type}://hostname/database-{uuid.uuid4().hex}"
213      monkeypatch.setenv(MLFLOW_TRACKING_URI.name, uri)
214      monkeypatch.delenv("MLFLOW_SQLALCHEMYSTORE_POOLCLASS", raising=False)
215      with (
216          mock.patch("sqlalchemy.create_engine") as mock_create_engine,
217          mock.patch("sqlalchemy.event.listens_for"),
218          mock.patch("mlflow.store.db.utils._verify_schema"),
219          mock.patch("mlflow.store.db.utils._initialize_tables"),
220          mock.patch(
221              "mlflow.store.db.utils._get_managed_session_maker",
222              new=mock_get_managed_session_maker,
223          ),
224          mock.patch(
225              # In sqlalchemy 1.4.0, `SqlAlchemyStore.search_experiments`, which is called when
226              # fetching the store, results in an error when called with a mocked sqlalchemy engine.
227              # Accordingly, we mock `SqlAlchemyStore.search_experiments`
228              "mlflow.store.tracking.sqlalchemy_store.SqlAlchemyStore.search_experiments",
229              return_value=[],
230          ),
231          mock.patch(
232              "mlflow.store.tracking.sqlalchemy_store.SqlAlchemyStore._initialize_store_state",
233              return_value=None,
234          ),
235      ):
236          store = _get_store()
237          assert isinstance(store, SqlAlchemyStore)
238          assert store.db_uri == uri
239          # Create another store to ensure the engine is cached
240          another_store = _get_store()
241          assert store.engine is another_store.engine
242          if is_windows():
243              assert store.artifact_root_uri == Path.cwd().joinpath("mlruns").as_uri()
244          else:
245              assert store.artifact_root_uri == Path.cwd().joinpath("mlruns").as_posix()
246          assert _get_tracking_scheme() == db_type
247  
248      mock_create_engine.assert_called_once_with(uri, pool_pre_ping=True)
249  
250  
251  @pytest.mark.parametrize("db_type", DATABASE_ENGINES)
252  def test_get_store_sqlalchemy_store_with_artifact_uri(tmp_path, monkeypatch, db_type):
253      monkeypatch.chdir(tmp_path)
254      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "false")
255      uri = f"{db_type}://hostname/database-{uuid.uuid4().hex}"
256      artifact_uri = "file:artifact/path"
257      monkeypatch.setenv(MLFLOW_TRACKING_URI.name, uri)
258      monkeypatch.delenv("MLFLOW_SQLALCHEMYSTORE_POOLCLASS", raising=False)
259      with (
260          mock.patch("sqlalchemy.create_engine") as mock_create_engine,
261          mock.patch("sqlalchemy.event.listens_for"),
262          mock.patch("mlflow.store.db.utils._verify_schema"),
263          mock.patch("mlflow.store.db.utils._initialize_tables"),
264          mock.patch(
265              "mlflow.store.db.utils._get_managed_session_maker",
266              new=mock_get_managed_session_maker,
267          ),
268          mock.patch(
269              "mlflow.store.tracking.sqlalchemy_store.SqlAlchemyStore.search_experiments",
270              return_value=[],
271          ),
272          mock.patch(
273              "mlflow.store.tracking.sqlalchemy_store.SqlAlchemyStore._initialize_store_state",
274              return_value=None,
275          ),
276      ):
277          store = _get_store(artifact_uri=artifact_uri)
278          assert isinstance(store, SqlAlchemyStore)
279          assert store.db_uri == uri
280          if is_windows():
281              assert store.artifact_root_uri == Path.cwd().joinpath("artifact", "path").as_uri()
282          else:
283              assert store.artifact_root_uri == path_to_local_file_uri(
284                  Path.cwd().joinpath("artifact", "path")
285              )
286  
287      mock_create_engine.assert_called_once_with(uri, pool_pre_ping=True)
288  
289  
290  def test_get_sqlalchemy_store_uses_server_artifact_root(tmp_path, monkeypatch):
291      store_uri = f"sqlite:///{tmp_path.joinpath('backend_store.db')}"
292      artifact_path = tmp_path / "server-artifacts"
293      artifact_uri = path_to_local_file_uri(artifact_path)
294      monkeypatch.setenv(ARTIFACT_ROOT_ENV_VAR, artifact_uri)
295  
296      with mock.patch("mlflow.store.tracking.sqlalchemy_store.SqlAlchemyStore") as mock_store:
297          mlflow.tracking._tracking_service.utils._get_sqlalchemy_store(
298              store_uri=store_uri, artifact_uri=None
299          )
300  
301      mock_store.assert_called_once()
302      assert mock_store.call_args.args[1] == artifact_uri
303      monkeypatch.delenv(ARTIFACT_ROOT_ENV_VAR, raising=False)
304  
305  
306  def test_get_store_databricks(monkeypatch):
307      for k, v in {
308          MLFLOW_TRACKING_URI.name: "databricks",
309          "DATABRICKS_HOST": "https://my-tracking-server",
310          "DATABRICKS_TOKEN": "abcdef",
311      }.items():
312          monkeypatch.setenv(k, v)
313      store = _get_store()
314      assert isinstance(store, DatabricksTracingRestStore)
315      assert store.get_host_creds().use_databricks_sdk
316      assert _get_tracking_scheme() == "databricks"
317  
318  
319  def test_get_store_databricks_profile(monkeypatch):
320      monkeypatch.setenv(MLFLOW_TRACKING_URI.name, "databricks://mycoolprofile")
321      # It's kind of annoying to setup a profile, and we're not really trying to test
322      # that anyway, so just check if we raise a relevant exception.
323      store = _get_store()
324      assert isinstance(store, DatabricksTracingRestStore)
325      with pytest.raises(MlflowException, match="mycoolprofile"):
326          store.get_host_creds()
327  
328  
329  def test_get_store_caches_on_store_uri_and_artifact_uri(tmp_path):
330      registry = mlflow.tracking._tracking_service.utils._tracking_store_registry
331  
332      store_uri_1 = f"sqlite:///{tmp_path.joinpath('backend_store_1.db')}"
333      store_uri_2 = f"sqlite:///{tmp_path.joinpath('backend_store_2.db')}"
334      stores_uris = [store_uri_1, store_uri_2]
335      artifact_uris = [
336          None,
337          str(tmp_path.joinpath("artifact_root_1")),
338          str(tmp_path.joinpath("artifact_root_2")),
339      ]
340  
341      stores = []
342      for args in itertools.product(stores_uris, artifact_uris):
343          store1 = registry.get_store(*args)
344          store2 = registry.get_store(*args)
345          assert store1 is store2
346          stores.append(store1)
347  
348      assert all(s1 is not s2 for s1, s2 in itertools.combinations(stores, 2))
349  
350  
351  def test_standard_store_registry_with_mocked_entrypoint():
352      mock_entrypoint = mock.Mock()
353      mock_entrypoint.name = "mock-scheme"
354  
355      with mock.patch("mlflow.utils.plugins._get_entry_points", return_value=[mock_entrypoint]):
356          # Entrypoints are registered at import time, so we need to reload the
357          # module to register the entrypoint given by the mocked
358          # entrypoints.get_group_all
359          reload(mlflow.tracking._tracking_service.utils)
360  
361          expected_standard_registry = {
362              "",
363              "file",
364              "http",
365              "https",
366              "postgresql",
367              "mysql",
368              "sqlite",
369              "mssql",
370              "databricks",
371              "mock-scheme",
372          }
373          assert expected_standard_registry.issubset(
374              mlflow.tracking._tracking_service.utils._tracking_store_registry._registry.keys()
375          )
376  
377  
378  @pytest.mark.skip(reason="FileStore is no longer supported")
379  def test_standard_store_registry_with_installed_plugin(tmp_path, monkeypatch):
380      monkeypatch.chdir(tmp_path)
381      reload(mlflow.tracking._tracking_service.utils)
382      assert (
383          "file-plugin" in mlflow.tracking._tracking_service.utils._tracking_store_registry._registry
384      )
385  
386      from mlflow_test_plugin.file_store import PluginFileStore
387  
388      monkeypatch.setenv(MLFLOW_TRACKING_URI.name, "file-plugin:test-path")
389      plugin_file_store = mlflow.tracking._tracking_service.utils._get_store()
390      assert isinstance(plugin_file_store, PluginFileStore)
391      assert plugin_file_store.is_plugin
392      assert _get_tracking_scheme() == "custom_scheme"
393  
394  
395  def test_plugin_registration():
396      tracking_store = TrackingStoreRegistry()
397  
398      test_uri = "mock-scheme://fake-host/fake-path"
399      test_scheme = "mock-scheme"
400  
401      mock_plugin = mock.Mock()
402      tracking_store.register(test_scheme, mock_plugin)
403      assert test_scheme in tracking_store._registry
404      assert tracking_store.get_store(test_uri) == mock_plugin.return_value
405      mock_plugin.assert_called_once_with(store_uri=test_uri, artifact_uri=None)
406  
407  
408  def test_plugin_registration_via_entrypoints():
409      mock_plugin_function = mock.Mock()
410      mock_entrypoint = mock.Mock(load=mock.Mock(return_value=mock_plugin_function))
411      mock_entrypoint.name = "mock-scheme"
412  
413      with mock.patch(
414          "mlflow.utils.plugins._get_entry_points", return_value=[mock_entrypoint]
415      ) as mock_get_group_all:
416          tracking_store = TrackingStoreRegistry()
417          tracking_store.register_entrypoints()
418  
419      assert tracking_store.get_store("mock-scheme://") == mock_plugin_function.return_value
420  
421      mock_plugin_function.assert_called_once_with(store_uri="mock-scheme://", artifact_uri=None)
422      mock_get_group_all.assert_called_once_with("mlflow.tracking_store")
423  
424  
425  @pytest.mark.parametrize(
426      "exception", [AttributeError("test exception"), ImportError("test exception")]
427  )
428  def test_handle_plugin_registration_failure_via_entrypoints(exception):
429      mock_entrypoint = mock.Mock(load=mock.Mock(side_effect=exception))
430      mock_entrypoint.name = "mock-scheme"
431  
432      with mock.patch(
433          "mlflow.utils.plugins._get_entry_points", return_value=[mock_entrypoint]
434      ) as mock_get_group_all:
435          tracking_store = TrackingStoreRegistry()
436  
437          # Check that the raised warning contains the message from the original exception
438          with pytest.warns(UserWarning, match="test exception"):
439              tracking_store.register_entrypoints()
440  
441      mock_entrypoint.load.assert_called_once()
442      mock_get_group_all.assert_called_once_with("mlflow.tracking_store")
443  
444  
445  def test_get_store_for_unregistered_scheme():
446      tracking_store = TrackingStoreRegistry()
447  
448      with pytest.raises(
449          UnsupportedModelRegistryStoreURIException,
450          match="Model registry functionality is unavailable",
451      ):
452          tracking_store.get_store("unknown-scheme://")
453  
454  
455  def test_resolve_tracking_uri_with_param():
456      with mock.patch(
457          "mlflow.tracking._tracking_service.utils.get_tracking_uri",
458          return_value="databricks://tracking_qoeirj",
459      ):
460          overriding_uri = "databricks://tracking_poiwerow"
461          assert _resolve_tracking_uri(overriding_uri) == overriding_uri
462  
463  
464  def test_resolve_tracking_uri_with_no_param():
465      with mock.patch(
466          "mlflow.tracking._tracking_service.utils.get_tracking_uri",
467          return_value="databricks://tracking_zlkjdas",
468      ):
469          assert _resolve_tracking_uri() == "databricks://tracking_zlkjdas"
470  
471  
472  @pytest.mark.skip(reason="FileStore is no longer supported")
473  def test_store_object_can_be_serialized_by_pickle(tmp_path):
474      """
475      This test ensures a store object generated by `_get_store` can be serialized by pickle
476      to prevent issues such as https://github.com/mlflow/mlflow/issues/2954
477      """
478      pickle.dump(_get_store(f"file:///{tmp_path.joinpath('mlflow')}"), io.BytesIO())
479      pickle.dump(_get_store("databricks"), io.BytesIO())
480      pickle.dump(_get_store("https://example.com"), io.BytesIO())
481      # pickle.dump(_get_store(f"sqlite:///{tmpdir.strpath}/mlflow.db"), io.BytesIO())
482      # This throws `AttributeError: Can't pickle local object 'create_engine.<locals>.connect'`
483  
484  
485  @pytest.mark.parametrize("absolute", [True, False], ids=["absolute", "relative"])
486  def test_set_tracking_uri_with_path(tmp_path, monkeypatch, absolute):
487      monkeypatch.chdir(tmp_path)
488      path = Path("foo/bar")
489      if absolute:
490          path = tmp_path / path
491      with mock.patch("mlflow.tracking._tracking_service.utils._tracking_uri", None):
492          set_tracking_uri(path)
493          assert get_tracking_uri() == path.absolute().resolve().as_uri()
494  
495  
496  def test_set_tracking_uri_update_trace_provider(tmp_path):
497      default_uri = mlflow.get_tracking_uri()
498      sqlite_uri = f"sqlite:///{tmp_path / 'mlflow.db'}"
499      try:
500          assert get_tracer_tracking_uri() != sqlite_uri
501  
502          set_tracking_uri(sqlite_uri)
503          assert get_tracer_tracking_uri() == sqlite_uri
504  
505          set_tracking_uri("https://foo")
506          assert get_tracer_tracking_uri() == "https://foo"
507      finally:
508          # clean up
509          set_tracking_uri(default_uri)
510  
511  
512  @pytest.mark.parametrize("store_uri", ["databricks-uc", "databricks-uc://profile"])
513  def test_get_store_raises_on_uc_uri(store_uri):
514      set_tracking_uri(store_uri)
515      with pytest.raises(
516          MlflowException,
517          match="Setting the tracking URI to a Unity Catalog backend is not "
518          "supported in the current version of the MLflow client",
519      ):
520          mlflow.tracking.MlflowClient()
521      assert _get_tracking_scheme() == "databricks-uc"
522  
523  
524  @pytest.mark.parametrize("tracking_uri", ["file:///tmp/mlruns", "sqlite:///tmp/mlruns.db", ""])
525  def test_set_get_tracking_uri_consistency(tracking_uri):
526      mlflow.set_tracking_uri(tracking_uri)
527      assert mlflow.get_tracking_uri() == tracking_uri
528  
529  
530  def test_get_tracking_scheme():
531      assert _get_tracking_scheme("uc://profile@databricks") == "uc"
532      # no builder registered for custom scheme
533      assert _get_tracking_scheme("custom-scheme://") == "None"
534  
535  
536  @pytest.mark.parametrize(
537      ("scheme", "uri", "expected"),
538      [
539          ("arn", "arn:aws:sagemaker:us-east-1:123456789:mlflow-tracking-server/my-server", "aws"),
540          ("arn", "arn:aws:sagemaker:eu-west-1:987654321:mlflow-tracking-server/test", "aws"),
541          ("azureml", "azureml://eastus.api.azureml.ms/mlflow/v2.0/subscriptions/123", "azure"),
542          ("azureml", "azureml://workspace", "azure"),
543          ("some-plugin", "some-plugin://host/path", "custom_scheme"),
544      ],
545  )
546  def test_resolve_custom_scheme(scheme, uri, expected):
547      assert _resolve_custom_scheme(scheme, uri) == expected