test_utils.py
1 import io 2 import itertools 3 import os 4 import pickle 5 import uuid 6 from importlib import reload 7 from pathlib import Path 8 from unittest import mock 9 from urllib.parse import urlparse 10 from urllib.request import url2pathname 11 12 import pytest 13 14 import mlflow 15 from mlflow.environment_variables import ( 16 MLFLOW_ENABLE_WORKSPACES, 17 MLFLOW_TRACKING_INSECURE_TLS, 18 MLFLOW_TRACKING_PASSWORD, 19 MLFLOW_TRACKING_TOKEN, 20 MLFLOW_TRACKING_URI, 21 MLFLOW_TRACKING_USERNAME, 22 ) 23 from mlflow.exceptions import MlflowException 24 from mlflow.server import ARTIFACT_ROOT_ENV_VAR 25 from mlflow.store.db.db_types import DATABASE_ENGINES 26 from mlflow.store.tracking.databricks_rest_store import DatabricksTracingRestStore 27 from mlflow.store.tracking.file_store import FileStore 28 from mlflow.store.tracking.rest_store import RestStore 29 from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore 30 from mlflow.tracking._tracking_service.registry import TrackingStoreRegistry 31 from mlflow.tracking._tracking_service.utils import ( 32 _get_store, 33 _get_tracking_scheme, 34 _resolve_custom_scheme, 35 _resolve_tracking_uri, 36 _use_tracking_uri, 37 get_tracking_uri, 38 set_tracking_uri, 39 ) 40 from mlflow.tracking.registry import UnsupportedModelRegistryStoreURIException 41 from mlflow.utils.file_utils import path_to_local_file_uri 42 from mlflow.utils.os import is_windows 43 44 from tests.helpers.db_mocks import mock_get_managed_session_maker 45 from tests.tracing.helper import get_tracer_tracking_uri 46 47 # Disable mocking tracking URI here, as we want to test setting the tracking URI via 48 # environment variable. See 49 # http://doc.pytest.org/en/latest/skipping.html#skip-all-test-functions-of-a-class-or-module 50 # and https://github.com/mlflow/mlflow/blob/master/CONTRIBUTING.md#writing-python-tests 51 # for more information. 52 pytestmark = pytest.mark.notrackingurimock 53 54 55 @pytest.mark.skip(reason="FileStore is no longer supported") 56 def test_tracking_scheme_with_existing_mlruns(tmp_path, monkeypatch): 57 monkeypatch.chdir(tmp_path) 58 mlruns_dir = tmp_path / "mlruns" 59 mlruns_dir.mkdir() 60 exp_dir = mlruns_dir / "0" 61 exp_dir.mkdir() 62 (exp_dir / "meta.yaml").touch() 63 store = _get_store() 64 assert isinstance(store, FileStore) 65 66 67 def test_tracking_scheme_without_existing_mlruns(tmp_path, monkeypatch): 68 monkeypatch.chdir(tmp_path) 69 store = _get_store() 70 assert isinstance(store, SqlAlchemyStore) 71 72 73 @pytest.mark.skip(reason="FileStore is no longer supported") 74 def test_get_store_with_existing_mlruns_data(tmp_path, monkeypatch): 75 monkeypatch.chdir(tmp_path) 76 mlruns_dir = tmp_path / "mlruns" 77 mlruns_dir.mkdir() 78 exp_dir = mlruns_dir / "0" 79 exp_dir.mkdir() 80 (exp_dir / "meta.yaml").touch() 81 82 store = _get_store() 83 assert isinstance(store, FileStore) 84 assert os.path.abspath(store.root_directory) == os.path.abspath("mlruns") 85 86 87 def test_get_store_with_empty_mlruns(tmp_path, monkeypatch): 88 monkeypatch.chdir(tmp_path) 89 mlruns_dir = tmp_path / "mlruns" 90 mlruns_dir.mkdir() 91 92 store = _get_store() 93 assert isinstance(store, SqlAlchemyStore) 94 95 96 def test_get_store_with_mlruns_dir_but_no_meta_yaml(tmp_path, monkeypatch): 97 monkeypatch.chdir(tmp_path) 98 mlruns_dir = tmp_path / "mlruns" 99 mlruns_dir.mkdir() 100 (mlruns_dir / "0").mkdir() 101 102 store = _get_store() 103 assert isinstance(store, SqlAlchemyStore) 104 105 106 def test_default_sqlite_tracking_uri_respects_cwd(tmp_path, monkeypatch): 107 monkeypatch.chdir(tmp_path) 108 with _use_tracking_uri(None): 109 store = _get_store() 110 111 assert isinstance(store, SqlAlchemyStore) 112 sqlite_uri = store.db_uri 113 assert sqlite_uri.startswith("sqlite:") 114 parsed = urlparse(sqlite_uri) 115 path = parsed.path 116 if not parsed.netloc and path.startswith("//"): 117 path = path[1:] 118 if parsed.netloc: 119 path = f"//{parsed.netloc}{path}" 120 db_path = Path(url2pathname(path)) 121 assert db_path.parent == tmp_path 122 123 124 @pytest.mark.skip(reason="FileStore is no longer supported") 125 def test_get_store_file_store_from_arg(tmp_path, monkeypatch): 126 monkeypatch.chdir(tmp_path) 127 store = _get_store("other/path") 128 assert isinstance(store, FileStore) 129 assert os.path.abspath(store.root_directory) == os.path.abspath("other/path") 130 131 132 @pytest.mark.skip(reason="FileStore is no longer supported") 133 @pytest.mark.parametrize("uri", ["other/path", "file:other/path"]) 134 def test_get_store_file_store_from_env(tmp_path, monkeypatch, uri): 135 monkeypatch.chdir(tmp_path) 136 monkeypatch.setenv(MLFLOW_TRACKING_URI.name, uri) 137 store = _get_store() 138 assert isinstance(store, FileStore) 139 assert os.path.abspath(store.root_directory) == os.path.abspath("other/path") 140 141 142 def test_get_store_basic_rest_store(monkeypatch): 143 monkeypatch.setenv(MLFLOW_TRACKING_URI.name, "https://my-tracking-server:5050") 144 store = _get_store() 145 assert isinstance(store, RestStore) 146 assert store.get_host_creds().host == "https://my-tracking-server:5050" 147 assert store.get_host_creds().token is None 148 assert _get_tracking_scheme() == "https" 149 150 151 def test_get_store_rest_store_with_password(monkeypatch): 152 for k, v in { 153 MLFLOW_TRACKING_URI.name: "https://my-tracking-server:5050", 154 MLFLOW_TRACKING_USERNAME.name: "Bob", 155 MLFLOW_TRACKING_PASSWORD.name: "Ross", 156 }.items(): 157 monkeypatch.setenv(k, v) 158 159 store = _get_store() 160 assert isinstance(store, RestStore) 161 assert store.get_host_creds().host == "https://my-tracking-server:5050" 162 assert store.get_host_creds().username == "Bob" 163 assert store.get_host_creds().password == "Ross" 164 165 166 def test_get_store_rest_store_with_token(monkeypatch): 167 for k, v in { 168 MLFLOW_TRACKING_URI.name: "https://my-tracking-server:5050", 169 MLFLOW_TRACKING_TOKEN.name: "my-token", 170 }.items(): 171 monkeypatch.setenv(k, v) 172 173 store = _get_store() 174 assert isinstance(store, RestStore) 175 assert store.get_host_creds().token == "my-token" 176 177 178 def test_get_store_rest_store_with_insecure(monkeypatch): 179 for k, v in { 180 MLFLOW_TRACKING_URI.name: "https://my-tracking-server:5050", 181 MLFLOW_TRACKING_INSECURE_TLS.name: "true", 182 }.items(): 183 monkeypatch.setenv(k, v) 184 store = _get_store() 185 assert isinstance(store, RestStore) 186 assert store.get_host_creds().ignore_tls_verification 187 188 189 def test_get_store_rest_store_with_no_insecure(monkeypatch): 190 with monkeypatch.context() as m: 191 for k, v in { 192 MLFLOW_TRACKING_URI.name: "https://my-tracking-server:5050", 193 MLFLOW_TRACKING_INSECURE_TLS.name: "false", 194 }.items(): 195 m.setenv(k, v) 196 store = _get_store() 197 assert isinstance(store, RestStore) 198 assert not store.get_host_creds().ignore_tls_verification 199 200 # By default, should not ignore verification. 201 with monkeypatch.context() as m: 202 monkeypatch.setenv(MLFLOW_TRACKING_URI.name, "https://my-tracking-server:5050") 203 store = _get_store() 204 assert isinstance(store, RestStore) 205 assert not store.get_host_creds().ignore_tls_verification 206 207 208 @pytest.mark.parametrize("db_type", DATABASE_ENGINES) 209 def test_get_store_sqlalchemy_store(tmp_path, monkeypatch, db_type): 210 monkeypatch.chdir(tmp_path) 211 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "false") 212 uri = f"{db_type}://hostname/database-{uuid.uuid4().hex}" 213 monkeypatch.setenv(MLFLOW_TRACKING_URI.name, uri) 214 monkeypatch.delenv("MLFLOW_SQLALCHEMYSTORE_POOLCLASS", raising=False) 215 with ( 216 mock.patch("sqlalchemy.create_engine") as mock_create_engine, 217 mock.patch("sqlalchemy.event.listens_for"), 218 mock.patch("mlflow.store.db.utils._verify_schema"), 219 mock.patch("mlflow.store.db.utils._initialize_tables"), 220 mock.patch( 221 "mlflow.store.db.utils._get_managed_session_maker", 222 new=mock_get_managed_session_maker, 223 ), 224 mock.patch( 225 # In sqlalchemy 1.4.0, `SqlAlchemyStore.search_experiments`, which is called when 226 # fetching the store, results in an error when called with a mocked sqlalchemy engine. 227 # Accordingly, we mock `SqlAlchemyStore.search_experiments` 228 "mlflow.store.tracking.sqlalchemy_store.SqlAlchemyStore.search_experiments", 229 return_value=[], 230 ), 231 mock.patch( 232 "mlflow.store.tracking.sqlalchemy_store.SqlAlchemyStore._initialize_store_state", 233 return_value=None, 234 ), 235 ): 236 store = _get_store() 237 assert isinstance(store, SqlAlchemyStore) 238 assert store.db_uri == uri 239 # Create another store to ensure the engine is cached 240 another_store = _get_store() 241 assert store.engine is another_store.engine 242 if is_windows(): 243 assert store.artifact_root_uri == Path.cwd().joinpath("mlruns").as_uri() 244 else: 245 assert store.artifact_root_uri == Path.cwd().joinpath("mlruns").as_posix() 246 assert _get_tracking_scheme() == db_type 247 248 mock_create_engine.assert_called_once_with(uri, pool_pre_ping=True) 249 250 251 @pytest.mark.parametrize("db_type", DATABASE_ENGINES) 252 def test_get_store_sqlalchemy_store_with_artifact_uri(tmp_path, monkeypatch, db_type): 253 monkeypatch.chdir(tmp_path) 254 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "false") 255 uri = f"{db_type}://hostname/database-{uuid.uuid4().hex}" 256 artifact_uri = "file:artifact/path" 257 monkeypatch.setenv(MLFLOW_TRACKING_URI.name, uri) 258 monkeypatch.delenv("MLFLOW_SQLALCHEMYSTORE_POOLCLASS", raising=False) 259 with ( 260 mock.patch("sqlalchemy.create_engine") as mock_create_engine, 261 mock.patch("sqlalchemy.event.listens_for"), 262 mock.patch("mlflow.store.db.utils._verify_schema"), 263 mock.patch("mlflow.store.db.utils._initialize_tables"), 264 mock.patch( 265 "mlflow.store.db.utils._get_managed_session_maker", 266 new=mock_get_managed_session_maker, 267 ), 268 mock.patch( 269 "mlflow.store.tracking.sqlalchemy_store.SqlAlchemyStore.search_experiments", 270 return_value=[], 271 ), 272 mock.patch( 273 "mlflow.store.tracking.sqlalchemy_store.SqlAlchemyStore._initialize_store_state", 274 return_value=None, 275 ), 276 ): 277 store = _get_store(artifact_uri=artifact_uri) 278 assert isinstance(store, SqlAlchemyStore) 279 assert store.db_uri == uri 280 if is_windows(): 281 assert store.artifact_root_uri == Path.cwd().joinpath("artifact", "path").as_uri() 282 else: 283 assert store.artifact_root_uri == path_to_local_file_uri( 284 Path.cwd().joinpath("artifact", "path") 285 ) 286 287 mock_create_engine.assert_called_once_with(uri, pool_pre_ping=True) 288 289 290 def test_get_sqlalchemy_store_uses_server_artifact_root(tmp_path, monkeypatch): 291 store_uri = f"sqlite:///{tmp_path.joinpath('backend_store.db')}" 292 artifact_path = tmp_path / "server-artifacts" 293 artifact_uri = path_to_local_file_uri(artifact_path) 294 monkeypatch.setenv(ARTIFACT_ROOT_ENV_VAR, artifact_uri) 295 296 with mock.patch("mlflow.store.tracking.sqlalchemy_store.SqlAlchemyStore") as mock_store: 297 mlflow.tracking._tracking_service.utils._get_sqlalchemy_store( 298 store_uri=store_uri, artifact_uri=None 299 ) 300 301 mock_store.assert_called_once() 302 assert mock_store.call_args.args[1] == artifact_uri 303 monkeypatch.delenv(ARTIFACT_ROOT_ENV_VAR, raising=False) 304 305 306 def test_get_store_databricks(monkeypatch): 307 for k, v in { 308 MLFLOW_TRACKING_URI.name: "databricks", 309 "DATABRICKS_HOST": "https://my-tracking-server", 310 "DATABRICKS_TOKEN": "abcdef", 311 }.items(): 312 monkeypatch.setenv(k, v) 313 store = _get_store() 314 assert isinstance(store, DatabricksTracingRestStore) 315 assert store.get_host_creds().use_databricks_sdk 316 assert _get_tracking_scheme() == "databricks" 317 318 319 def test_get_store_databricks_profile(monkeypatch): 320 monkeypatch.setenv(MLFLOW_TRACKING_URI.name, "databricks://mycoolprofile") 321 # It's kind of annoying to setup a profile, and we're not really trying to test 322 # that anyway, so just check if we raise a relevant exception. 323 store = _get_store() 324 assert isinstance(store, DatabricksTracingRestStore) 325 with pytest.raises(MlflowException, match="mycoolprofile"): 326 store.get_host_creds() 327 328 329 def test_get_store_caches_on_store_uri_and_artifact_uri(tmp_path): 330 registry = mlflow.tracking._tracking_service.utils._tracking_store_registry 331 332 store_uri_1 = f"sqlite:///{tmp_path.joinpath('backend_store_1.db')}" 333 store_uri_2 = f"sqlite:///{tmp_path.joinpath('backend_store_2.db')}" 334 stores_uris = [store_uri_1, store_uri_2] 335 artifact_uris = [ 336 None, 337 str(tmp_path.joinpath("artifact_root_1")), 338 str(tmp_path.joinpath("artifact_root_2")), 339 ] 340 341 stores = [] 342 for args in itertools.product(stores_uris, artifact_uris): 343 store1 = registry.get_store(*args) 344 store2 = registry.get_store(*args) 345 assert store1 is store2 346 stores.append(store1) 347 348 assert all(s1 is not s2 for s1, s2 in itertools.combinations(stores, 2)) 349 350 351 def test_standard_store_registry_with_mocked_entrypoint(): 352 mock_entrypoint = mock.Mock() 353 mock_entrypoint.name = "mock-scheme" 354 355 with mock.patch("mlflow.utils.plugins._get_entry_points", return_value=[mock_entrypoint]): 356 # Entrypoints are registered at import time, so we need to reload the 357 # module to register the entrypoint given by the mocked 358 # entrypoints.get_group_all 359 reload(mlflow.tracking._tracking_service.utils) 360 361 expected_standard_registry = { 362 "", 363 "file", 364 "http", 365 "https", 366 "postgresql", 367 "mysql", 368 "sqlite", 369 "mssql", 370 "databricks", 371 "mock-scheme", 372 } 373 assert expected_standard_registry.issubset( 374 mlflow.tracking._tracking_service.utils._tracking_store_registry._registry.keys() 375 ) 376 377 378 @pytest.mark.skip(reason="FileStore is no longer supported") 379 def test_standard_store_registry_with_installed_plugin(tmp_path, monkeypatch): 380 monkeypatch.chdir(tmp_path) 381 reload(mlflow.tracking._tracking_service.utils) 382 assert ( 383 "file-plugin" in mlflow.tracking._tracking_service.utils._tracking_store_registry._registry 384 ) 385 386 from mlflow_test_plugin.file_store import PluginFileStore 387 388 monkeypatch.setenv(MLFLOW_TRACKING_URI.name, "file-plugin:test-path") 389 plugin_file_store = mlflow.tracking._tracking_service.utils._get_store() 390 assert isinstance(plugin_file_store, PluginFileStore) 391 assert plugin_file_store.is_plugin 392 assert _get_tracking_scheme() == "custom_scheme" 393 394 395 def test_plugin_registration(): 396 tracking_store = TrackingStoreRegistry() 397 398 test_uri = "mock-scheme://fake-host/fake-path" 399 test_scheme = "mock-scheme" 400 401 mock_plugin = mock.Mock() 402 tracking_store.register(test_scheme, mock_plugin) 403 assert test_scheme in tracking_store._registry 404 assert tracking_store.get_store(test_uri) == mock_plugin.return_value 405 mock_plugin.assert_called_once_with(store_uri=test_uri, artifact_uri=None) 406 407 408 def test_plugin_registration_via_entrypoints(): 409 mock_plugin_function = mock.Mock() 410 mock_entrypoint = mock.Mock(load=mock.Mock(return_value=mock_plugin_function)) 411 mock_entrypoint.name = "mock-scheme" 412 413 with mock.patch( 414 "mlflow.utils.plugins._get_entry_points", return_value=[mock_entrypoint] 415 ) as mock_get_group_all: 416 tracking_store = TrackingStoreRegistry() 417 tracking_store.register_entrypoints() 418 419 assert tracking_store.get_store("mock-scheme://") == mock_plugin_function.return_value 420 421 mock_plugin_function.assert_called_once_with(store_uri="mock-scheme://", artifact_uri=None) 422 mock_get_group_all.assert_called_once_with("mlflow.tracking_store") 423 424 425 @pytest.mark.parametrize( 426 "exception", [AttributeError("test exception"), ImportError("test exception")] 427 ) 428 def test_handle_plugin_registration_failure_via_entrypoints(exception): 429 mock_entrypoint = mock.Mock(load=mock.Mock(side_effect=exception)) 430 mock_entrypoint.name = "mock-scheme" 431 432 with mock.patch( 433 "mlflow.utils.plugins._get_entry_points", return_value=[mock_entrypoint] 434 ) as mock_get_group_all: 435 tracking_store = TrackingStoreRegistry() 436 437 # Check that the raised warning contains the message from the original exception 438 with pytest.warns(UserWarning, match="test exception"): 439 tracking_store.register_entrypoints() 440 441 mock_entrypoint.load.assert_called_once() 442 mock_get_group_all.assert_called_once_with("mlflow.tracking_store") 443 444 445 def test_get_store_for_unregistered_scheme(): 446 tracking_store = TrackingStoreRegistry() 447 448 with pytest.raises( 449 UnsupportedModelRegistryStoreURIException, 450 match="Model registry functionality is unavailable", 451 ): 452 tracking_store.get_store("unknown-scheme://") 453 454 455 def test_resolve_tracking_uri_with_param(): 456 with mock.patch( 457 "mlflow.tracking._tracking_service.utils.get_tracking_uri", 458 return_value="databricks://tracking_qoeirj", 459 ): 460 overriding_uri = "databricks://tracking_poiwerow" 461 assert _resolve_tracking_uri(overriding_uri) == overriding_uri 462 463 464 def test_resolve_tracking_uri_with_no_param(): 465 with mock.patch( 466 "mlflow.tracking._tracking_service.utils.get_tracking_uri", 467 return_value="databricks://tracking_zlkjdas", 468 ): 469 assert _resolve_tracking_uri() == "databricks://tracking_zlkjdas" 470 471 472 @pytest.mark.skip(reason="FileStore is no longer supported") 473 def test_store_object_can_be_serialized_by_pickle(tmp_path): 474 """ 475 This test ensures a store object generated by `_get_store` can be serialized by pickle 476 to prevent issues such as https://github.com/mlflow/mlflow/issues/2954 477 """ 478 pickle.dump(_get_store(f"file:///{tmp_path.joinpath('mlflow')}"), io.BytesIO()) 479 pickle.dump(_get_store("databricks"), io.BytesIO()) 480 pickle.dump(_get_store("https://example.com"), io.BytesIO()) 481 # pickle.dump(_get_store(f"sqlite:///{tmpdir.strpath}/mlflow.db"), io.BytesIO()) 482 # This throws `AttributeError: Can't pickle local object 'create_engine.<locals>.connect'` 483 484 485 @pytest.mark.parametrize("absolute", [True, False], ids=["absolute", "relative"]) 486 def test_set_tracking_uri_with_path(tmp_path, monkeypatch, absolute): 487 monkeypatch.chdir(tmp_path) 488 path = Path("foo/bar") 489 if absolute: 490 path = tmp_path / path 491 with mock.patch("mlflow.tracking._tracking_service.utils._tracking_uri", None): 492 set_tracking_uri(path) 493 assert get_tracking_uri() == path.absolute().resolve().as_uri() 494 495 496 def test_set_tracking_uri_update_trace_provider(tmp_path): 497 default_uri = mlflow.get_tracking_uri() 498 sqlite_uri = f"sqlite:///{tmp_path / 'mlflow.db'}" 499 try: 500 assert get_tracer_tracking_uri() != sqlite_uri 501 502 set_tracking_uri(sqlite_uri) 503 assert get_tracer_tracking_uri() == sqlite_uri 504 505 set_tracking_uri("https://foo") 506 assert get_tracer_tracking_uri() == "https://foo" 507 finally: 508 # clean up 509 set_tracking_uri(default_uri) 510 511 512 @pytest.mark.parametrize("store_uri", ["databricks-uc", "databricks-uc://profile"]) 513 def test_get_store_raises_on_uc_uri(store_uri): 514 set_tracking_uri(store_uri) 515 with pytest.raises( 516 MlflowException, 517 match="Setting the tracking URI to a Unity Catalog backend is not " 518 "supported in the current version of the MLflow client", 519 ): 520 mlflow.tracking.MlflowClient() 521 assert _get_tracking_scheme() == "databricks-uc" 522 523 524 @pytest.mark.parametrize("tracking_uri", ["file:///tmp/mlruns", "sqlite:///tmp/mlruns.db", ""]) 525 def test_set_get_tracking_uri_consistency(tracking_uri): 526 mlflow.set_tracking_uri(tracking_uri) 527 assert mlflow.get_tracking_uri() == tracking_uri 528 529 530 def test_get_tracking_scheme(): 531 assert _get_tracking_scheme("uc://profile@databricks") == "uc" 532 # no builder registered for custom scheme 533 assert _get_tracking_scheme("custom-scheme://") == "None" 534 535 536 @pytest.mark.parametrize( 537 ("scheme", "uri", "expected"), 538 [ 539 ("arn", "arn:aws:sagemaker:us-east-1:123456789:mlflow-tracking-server/my-server", "aws"), 540 ("arn", "arn:aws:sagemaker:eu-west-1:987654321:mlflow-tracking-server/test", "aws"), 541 ("azureml", "azureml://eastus.api.azureml.ms/mlflow/v2.0/subscriptions/123", "azure"), 542 ("azureml", "azureml://workspace", "azure"), 543 ("some-plugin", "some-plugin://host/path", "custom_scheme"), 544 ], 545 ) 546 def test_resolve_custom_scheme(scheme, uri, expected): 547 assert _resolve_custom_scheme(scheme, uri) == expected