test_model_utils.py
1 import os 2 import sys 3 from unittest import mock 4 5 import pytest 6 import sklearn.neighbors as knn 7 from sklearn import datasets 8 9 import mlflow.sklearn 10 import mlflow.utils.model_utils as mlflow_model_utils 11 from mlflow.environment_variables import MLFLOW_RECORD_ENV_VARS_IN_MODEL_LOGGING 12 from mlflow.exceptions import MlflowException 13 from mlflow.models import Model 14 from mlflow.utils.file_utils import TempDir 15 from mlflow.utils.model_utils import env_var_tracker 16 17 18 @pytest.fixture(scope="module") 19 def sklearn_knn_model(): 20 iris = datasets.load_iris() 21 X = iris.data[:, :2] # we only take the first two features. 22 y = iris.target 23 knn_model = knn.KNeighborsClassifier() 24 knn_model.fit(X, y) 25 return knn_model 26 27 28 @pytest.fixture 29 def model_path(tmp_path): 30 return os.path.join(tmp_path, "model") 31 32 33 def test_get_flavor_configuration_throws_exception_when_requested_flavor_is_missing( 34 model_path, sklearn_knn_model 35 ): 36 mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=model_path) 37 38 # The saved model contains the "sklearn" flavor, so this call should succeed 39 sklearn_flavor_config = mlflow_model_utils._get_flavor_configuration( 40 model_path=model_path, flavor_name=mlflow.sklearn.FLAVOR_NAME 41 ) 42 assert sklearn_flavor_config is not None 43 44 45 def test_get_flavor_configuration_with_present_flavor_returns_expected_configuration( 46 sklearn_knn_model, model_path 47 ): 48 mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=model_path) 49 50 sklearn_flavor_config = mlflow_model_utils._get_flavor_configuration( 51 model_path=model_path, flavor_name=mlflow.sklearn.FLAVOR_NAME 52 ) 53 model_config = Model.load(os.path.join(model_path, "MLmodel")) 54 assert sklearn_flavor_config == model_config.flavors[mlflow.sklearn.FLAVOR_NAME] 55 56 57 def test_add_code_to_system_path(sklearn_knn_model, model_path): 58 mlflow.sklearn.save_model( 59 sk_model=sklearn_knn_model, 60 path=model_path, 61 code_paths=[ 62 "tests/utils/test_resources/dummy_module.py", 63 "tests/utils/test_resources/dummy_package", 64 ], 65 ) 66 67 sklearn_flavor_config = mlflow_model_utils._get_flavor_configuration( 68 model_path=model_path, flavor_name=mlflow.sklearn.FLAVOR_NAME 69 ) 70 with TempDir(chdr=True): 71 # Load the model from a new directory that is not a parent of the source code path to 72 # verify that source code paths and their subdirectories are resolved correctly 73 with pytest.raises(ModuleNotFoundError, match="No module named 'dummy_module'"): 74 import dummy_module 75 76 mlflow_model_utils._add_code_from_conf_to_system_path(model_path, sklearn_flavor_config) 77 import dummy_module # noqa: F401 78 79 # If this raises an exception it's because dummy_package.test imported 80 # dummy_package.operator and not the built-in operator module. This only 81 # happens if MLflow is misconfiguring the system path. 82 from dummy_package import base # noqa: F401 83 84 # Ensure that the custom tests/utils/test_resources/dummy_package/pandas.py is not 85 # overwriting the 3rd party `pandas` package 86 assert "dummy_package" in sys.modules 87 assert "pandas" in sys.modules 88 assert "site-packages" in sys.modules["pandas"].__file__ 89 90 91 def test_add_code_to_system_path_not_copyable_file(sklearn_knn_model, model_path): 92 with mock.patch("builtins.open", side_effect=OSError("[Errno 95] Operation not supported")): 93 with pytest.raises(MlflowException, match=r"Failed to copy the specified code path"): 94 mlflow.sklearn.save_model( 95 sk_model=sklearn_knn_model, 96 path=model_path, 97 code_paths=["tests/utils/test_resources/dummy_module.py"], 98 ) 99 100 101 def test_env_var_tracker(monkeypatch): 102 monkeypatch.setenv("DATABRICKS_HOST", "host") 103 assert "DATABRICKS_HOST" in os.environ 104 assert "TEST_API_KEY" not in os.environ 105 106 with env_var_tracker() as tracked_env_names: 107 assert os.environ["DATABRICKS_HOST"] == "host" 108 monkeypatch.setenv("TEST_API_KEY", "key") 109 # accessed env var is tracked 110 assert os.environ.get("TEST_API_KEY") == "key" 111 # test non-existing env vars fetched by `get` are not tracked 112 os.environ.get("INVALID_API_KEY", "abc") 113 # test non-existing env vars are not tracked 114 try: 115 os.environ["ANOTHER_API_KEY"] 116 except KeyError: 117 pass 118 assert all(x in tracked_env_names for x in ["DATABRICKS_HOST", "TEST_API_KEY"]) 119 assert all(x not in tracked_env_names for x in ["INVALID_API_KEY", "ANOTHER_API_KEY"]) 120 121 assert isinstance(os.environ, os._Environ) 122 assert all(x in os.environ for x in ["DATABRICKS_HOST", "TEST_API_KEY"]) 123 assert all(x not in os.environ for x in ["INVALID_API_KEY", "ANOTHER_API_KEY"]) 124 125 monkeypatch.setenv(MLFLOW_RECORD_ENV_VARS_IN_MODEL_LOGGING.name, "false") 126 with env_var_tracker() as env: 127 os.environ.get("API_KEY") 128 assert env == set()