test_dependencies_functions.py
1 from pathlib import Path 2 from unittest import mock 3 4 import pytest 5 import sklearn 6 from sklearn.linear_model import LinearRegression 7 8 import mlflow.utils.requirements_utils 9 from mlflow.exceptions import MlflowException 10 from mlflow.pyfunc import get_model_dependencies 11 from mlflow.utils import PYTHON_VERSION 12 13 14 def test_get_model_dependencies_read_req_file(tmp_path): 15 req_file = tmp_path / "requirements.txt" 16 req_file_content = """ 17 mlflow 18 cloudpickle==2.0.0 19 scikit-learn==1.0.2""" 20 req_file.write_text(req_file_content) 21 22 model_path = str(tmp_path) 23 24 # Test getting pip dependencies 25 assert Path(get_model_dependencies(model_path, format="pip")).read_text() == req_file_content 26 27 # Test getting pip dependencies will print instructions 28 with mock.patch("mlflow.pyfunc._logger.info") as mock_log_info: 29 get_model_dependencies(model_path, format="pip") 30 mock_log_info.assert_called_once_with( 31 "To install the dependencies that were used to train the model, run the " 32 f"following command: 'pip install -r {req_file}'." 33 ) 34 35 mock_log_info.reset_mock() 36 with mock.patch("mlflow.pyfunc._is_in_ipython_notebook", return_value=True): 37 get_model_dependencies(model_path, format="pip") 38 mock_log_info.assert_called_once_with( 39 "To install the dependencies that were used to train the model, run the " 40 f"following command: '%pip install -r {req_file}'." 41 ) 42 43 with pytest.raises(MlflowException, match="Illegal format argument 'abc'"): 44 get_model_dependencies(model_path, format="abc") 45 46 47 @pytest.mark.parametrize( 48 "ml_model_file_content", 49 [ 50 """ 51 artifact_path: model 52 flavors: 53 python_function: 54 env: conda.yaml 55 loader_module: mlflow.sklearn 56 model_path: model.pkl 57 python_version: {PYTHON_VERSION} 58 model_uuid: 722a374a432f48f09ee85da92df13bca 59 run_id: 765e66a5ba404650be51cb02cda66f35 60 """, 61 f""" 62 artifact_path: model 63 flavors: 64 python_function: 65 env: 66 conda: conda.yaml 67 virtualenv: python_env.yaml 68 loader_module: mlflow.sklearn 69 model_path: model.pkl 70 python_version: {PYTHON_VERSION} 71 model_uuid: 722a374a432f48f09ee85da92df13bca 72 run_id: 765e66a5ba404650be51cb02cda66f35 73 """, 74 ], 75 ids=["old_env", "new_env"], 76 ) 77 def test_get_model_dependencies_read_conda_file(ml_model_file_content, tmp_path): 78 MLmodel_file = tmp_path / "MLmodel" 79 MLmodel_file.write_text(ml_model_file_content) 80 conda_yml_file = tmp_path / "conda.yaml" 81 conda_yml_file_content = f""" 82 channels: 83 - conda-forge 84 dependencies: 85 - python={PYTHON_VERSION} 86 - pip=22.0.3 87 - scikit-learn=0.22.0 88 - tensorflow=2.0.0 89 - pip: 90 - mlflow 91 - cloudpickle==2.0.0 92 - scikit-learn==1.0.1 93 name: mlflow-env 94 """ 95 96 conda_yml_file.write_text(conda_yml_file_content) 97 98 model_path = str(tmp_path) 99 100 # Test getting conda environment 101 assert ( 102 Path(get_model_dependencies(model_path, format="conda")).read_text() 103 == conda_yml_file_content 104 ) 105 106 # Test getting pip requirement file failed and fallback to extract pip section from conda.yaml 107 with mock.patch("mlflow.pyfunc._logger.warning") as mock_warning: 108 pip_file_path = get_model_dependencies(model_path, format="pip") 109 assert ( 110 Path(pip_file_path).read_text().strip() 111 == "mlflow\ncloudpickle==2.0.0\nscikit-learn==1.0.1" 112 ) 113 mock_warning.assert_called_once_with( 114 "The following conda dependencies have been excluded from the environment file: " 115 f"python={PYTHON_VERSION}, pip=22.0.3, scikit-learn=0.22.0, tensorflow=2.0.0." 116 ) 117 118 conda_yml_file.write_text( 119 f""" 120 channels: 121 - conda-forge 122 dependencies: 123 - python={PYTHON_VERSION} 124 - pip=22.0.3 125 - scikit-learn=0.22.0 126 - tensorflow=2.0.0 127 """ 128 ) 129 130 with pytest.raises(MlflowException, match="No pip section found in conda.yaml file"): 131 get_model_dependencies(model_path, format="pip") 132 133 134 def test_get_model_dependencies_with_model_version_uri(): 135 with mlflow.start_run(): 136 mlflow.sklearn.log_model(LinearRegression(), name="model", registered_model_name="linear") 137 138 deps = get_model_dependencies("models:/linear/1", format="pip") 139 assert f"scikit-learn=={sklearn.__version__}" in Path(deps).read_text()