test_environment.py
1 import importlib.metadata 2 import os 3 from unittest import mock 4 5 import pytest 6 import yaml 7 8 from mlflow.exceptions import MlflowException 9 from mlflow.utils.environment import ( 10 _contains_mlflow_requirement, 11 _deduplicate_requirements, 12 _get_pip_deps, 13 _get_pip_requirement_specifier, 14 _is_mlflow_requirement, 15 _is_pip_deps, 16 _mlflow_conda_env, 17 _overwrite_pip_deps, 18 _parse_pip_requirements, 19 _process_conda_env, 20 _process_pip_requirements, 21 _remove_incompatible_requirements, 22 _validate_env_arguments, 23 infer_pip_requirements, 24 ) 25 26 from tests.helper_functions import _mlflow_major_version_string 27 28 29 @pytest.fixture 30 def conda_env_path(tmp_path): 31 return os.path.join(tmp_path, "conda_env.yaml") 32 33 34 def test_mlflow_conda_env_returns_none_when_output_path_is_specified(conda_env_path): 35 env_creation_output = _mlflow_conda_env( 36 path=conda_env_path, 37 additional_conda_deps=["conda-dep-1=0.0.1", "conda-dep-2"], 38 additional_pip_deps=["pip-dep-1", "pip-dep2==0.1.0"], 39 ) 40 41 assert env_creation_output is None 42 43 44 def test_mlflow_conda_env_returns_expected_env_dict_when_output_path_is_not_specified(): 45 conda_deps = ["conda-dep-1=0.0.1", "conda-dep-2"] 46 env = _mlflow_conda_env(path=None, additional_conda_deps=conda_deps) 47 48 for conda_dep in conda_deps: 49 assert conda_dep in env["dependencies"] 50 51 52 @pytest.mark.parametrize("conda_deps", [["conda-dep-1=0.0.1", "conda-dep-2"], None]) 53 def test_mlflow_conda_env_includes_pip_dependencies_but_pip_is_not_specified(conda_deps): 54 additional_pip_deps = ["pip-dep==0.0.1"] 55 env = _mlflow_conda_env( 56 path=None, additional_conda_deps=conda_deps, additional_pip_deps=additional_pip_deps 57 ) 58 if conda_deps is not None: 59 for conda_dep in conda_deps: 60 assert conda_dep in env["dependencies"] 61 pip_version = importlib.metadata.version("pip") 62 assert f"pip<={pip_version}" in env["dependencies"] 63 64 65 @pytest.mark.parametrize("pip_specification", ["pip", "pip==20.0.02"]) 66 def test_mlflow_conda_env_includes_pip_dependencies_and_pip_is_specified(pip_specification): 67 conda_deps = ["conda-dep-1=0.0.1", "conda-dep-2", pip_specification] 68 additional_pip_deps = ["pip-dep==0.0.1"] 69 env = _mlflow_conda_env( 70 path=None, additional_conda_deps=conda_deps, additional_pip_deps=additional_pip_deps 71 ) 72 for conda_dep in conda_deps: 73 assert conda_dep in env["dependencies"] 74 assert pip_specification in env["dependencies"] 75 76 77 def test_is_pip_deps(): 78 assert _is_pip_deps({"pip": ["a"]}) 79 assert not _is_pip_deps({"ipi": ["a"]}) 80 assert not _is_pip_deps("") 81 assert not _is_pip_deps([]) 82 83 84 def test_overwrite_pip_deps(): 85 # dependencies field doesn't exist 86 name_and_channels = {"name": "env", "channels": ["conda-forge"]} 87 expected = {**name_and_channels, "dependencies": [{"pip": ["scipy"]}]} 88 assert _overwrite_pip_deps(name_and_channels, ["scipy"]) == expected 89 90 # dependencies field doesn't contain pip dependencies 91 conda_env = {**name_and_channels, "dependencies": ["pip"]} 92 expected = {**name_and_channels, "dependencies": ["pip", {"pip": ["scipy"]}]} 93 assert _overwrite_pip_deps(conda_env, ["scipy"]) == expected 94 95 # dependencies field contains pip dependencies 96 conda_env = {**name_and_channels, "dependencies": ["pip", {"pip": ["numpy"]}, "pandas"]} 97 expected = {**name_and_channels, "dependencies": ["pip", {"pip": ["scipy"]}, "pandas"]} 98 assert _overwrite_pip_deps(conda_env, ["scipy"]) == expected 99 100 101 def test_parse_pip_requirements(tmp_path): 102 assert _parse_pip_requirements(None) == ([], []) 103 assert _parse_pip_requirements([]) == ([], []) 104 # Without version specifiers 105 assert _parse_pip_requirements(["a", "b"]) == (["a", "b"], []) 106 # With version specifiers 107 assert _parse_pip_requirements(["a==0.0", "b>1.1"]) == (["a==0.0", "b>1.1"], []) 108 # Environment marker (https://www.python.org/dev/peps/pep-0508/#environment-markers) 109 assert _parse_pip_requirements(['a; python_version < "3.8"']) == ( 110 ['a; python_version < "3.8"'], 111 [], 112 ) 113 # GitHub URI 114 mlflow_repo_uri = "git+https://github.com/mlflow/mlflow.git" 115 assert _parse_pip_requirements([mlflow_repo_uri]) == ([mlflow_repo_uri], []) 116 # Local file 117 fake_whl = tmp_path.joinpath("fake.whl") 118 fake_whl.write_text("") 119 assert _parse_pip_requirements([str(fake_whl)]) == ([str(fake_whl)], []) 120 121 122 def test_parse_pip_requirements_with_relative_requirements_files(tmp_path, monkeypatch): 123 monkeypatch.chdir(tmp_path) 124 f1 = tmp_path.joinpath("requirements1.txt") 125 f1.write_text("b") 126 assert _parse_pip_requirements(f1.name) == (["b"], []) 127 assert _parse_pip_requirements(["a", f"-r {f1.name}"]) == (["a", "b"], []) 128 129 f2 = tmp_path.joinpath("requirements2.txt") 130 f3 = tmp_path.joinpath("requirements3.txt") 131 f2.write_text(f"b\n-r {f3.name}") 132 f3.write_text("c") 133 assert _parse_pip_requirements(f2.name) == (["b", "c"], []) 134 assert _parse_pip_requirements(["a", f"-r {f2.name}"]) == (["a", "b", "c"], []) 135 136 137 def test_parse_pip_requirements_with_absolute_requirements_files(tmp_path): 138 f1 = tmp_path.joinpath("requirements1.txt") 139 f1.write_text("b") 140 assert _parse_pip_requirements(str(f1)) == (["b"], []) 141 assert _parse_pip_requirements(["a", f"-r {f1}"]) == (["a", "b"], []) 142 143 f2 = tmp_path.joinpath("requirements2.txt") 144 f3 = tmp_path.joinpath("requirements3.txt") 145 f2.write_text(f"b\n-r {f3}") 146 f3.write_text("c") 147 assert _parse_pip_requirements(str(f2)) == (["b", "c"], []) 148 assert _parse_pip_requirements(["a", f"-r {f2}"]) == (["a", "b", "c"], []) 149 150 151 def test_parse_pip_requirements_with_constraints_files(tmp_path): 152 con_file = tmp_path.joinpath("constraints.txt") 153 con_file.write_text("b") 154 assert _parse_pip_requirements(["a", f"-c {con_file}"]) == (["a"], ["b"]) 155 156 req_file = tmp_path.joinpath("requirements.txt") 157 req_file.write_text(f"-c {con_file}\n") 158 assert _parse_pip_requirements(["a", f"-r {req_file}"]) == (["a"], ["b"]) 159 160 161 def test_parse_pip_requirements_ignores_comments_and_blank_lines(tmp_path): 162 reqs = [ 163 "# comment", 164 "a # inline comment", 165 # blank lines 166 "", 167 " ", 168 ] 169 f = tmp_path.joinpath("requirements.txt") 170 f.write_text("\n".join(reqs)) 171 assert _parse_pip_requirements(reqs) == (["a"], []) 172 assert _parse_pip_requirements(str(f)) == (["a"], []) 173 174 175 def test_parse_pip_requirements_removes_temporary_requirements_file(): 176 assert _parse_pip_requirements(["a"]) == (["a"], []) 177 assert all(not x.endswith(".tmp.requirements.txt") for x in os.listdir()) 178 179 with pytest.raises(FileNotFoundError, match="No such file or directory"): 180 _parse_pip_requirements(["a", "-r does_not_exist.txt"]) 181 # Ensure the temporary requirements file has been removed even when parsing fails 182 assert all(not x.endswith(".tmp.requirements.txt") for x in os.listdir()) 183 184 185 @pytest.mark.parametrize("invalid_argument", [0, True, [0]]) 186 def test_parse_pip_requirements_with_invalid_argument_types(invalid_argument): 187 with pytest.raises(TypeError, match="`pip_requirements` must be either a string path"): 188 _parse_pip_requirements(invalid_argument) 189 190 191 def test_validate_env_arguments(): 192 _validate_env_arguments(pip_requirements=None, extra_pip_requirements=None, conda_env=None) 193 194 match = "Only one of `conda_env`, `pip_requirements`, and `extra_pip_requirements`" 195 with pytest.raises(ValueError, match=match): 196 _validate_env_arguments(conda_env={}, pip_requirements=[], extra_pip_requirements=None) 197 198 with pytest.raises(ValueError, match=match): 199 _validate_env_arguments(conda_env={}, pip_requirements=None, extra_pip_requirements=[]) 200 201 with pytest.raises(ValueError, match=match): 202 _validate_env_arguments(conda_env=None, pip_requirements=[], extra_pip_requirements=[]) 203 204 with pytest.raises(ValueError, match=match): 205 _validate_env_arguments(conda_env={}, pip_requirements=[], extra_pip_requirements=[]) 206 207 208 def test_is_mlflow_requirement(): 209 assert _is_mlflow_requirement("mlflow") 210 assert _is_mlflow_requirement("MLFLOW") 211 assert _is_mlflow_requirement("MLflow") 212 assert _is_mlflow_requirement("mlflow==1.2.3") 213 assert _is_mlflow_requirement("mlflow < 1.2.3") 214 assert _is_mlflow_requirement("mlflow; python_version < '3.8'") 215 assert _is_mlflow_requirement("mlflow @ https://github.com/mlflow/mlflow.git") 216 assert _is_mlflow_requirement("mlflow @ file:///path/to/mlflow") 217 assert _is_mlflow_requirement("mlflow-skinny==1.2.3") 218 assert not _is_mlflow_requirement("foo") 219 # Ensure packages that look like mlflow are NOT considered as mlflow. 220 assert not _is_mlflow_requirement("mlflow-foo") 221 assert not _is_mlflow_requirement("mlflow_foo") 222 223 224 def test_contains_mlflow_requirement(): 225 assert _contains_mlflow_requirement(["mlflow"]) 226 assert _contains_mlflow_requirement(["mlflow==1.2.3"]) 227 assert _contains_mlflow_requirement(["mlflow", "foo"]) 228 assert _contains_mlflow_requirement(["mlflow-skinny"]) 229 assert not _contains_mlflow_requirement([]) 230 assert not _contains_mlflow_requirement(["foo"]) 231 232 233 def test_get_pip_requirement_specifier(): 234 assert _get_pip_requirement_specifier("") == "" 235 assert _get_pip_requirement_specifier(" ") == " " 236 assert _get_pip_requirement_specifier("mlflow") == "mlflow" 237 assert _get_pip_requirement_specifier("mlflow==1.2.3") == "mlflow==1.2.3" 238 assert _get_pip_requirement_specifier("-r reqs.txt") == "" 239 assert _get_pip_requirement_specifier(" -r reqs.txt") == " " 240 assert _get_pip_requirement_specifier("mlflow==1.2.3 --hash=foo") == "mlflow==1.2.3" 241 assert _get_pip_requirement_specifier("mlflow==1.2.3 --hash=foo") == "mlflow==1.2.3 " 242 assert _get_pip_requirement_specifier("mlflow-skinny==1.2 --foo=bar") == "mlflow-skinny==1.2" 243 244 245 def test_process_pip_requirements(tmp_path): 246 expected_mlflow_ver = _mlflow_major_version_string() 247 conda_env, reqs, cons = _process_pip_requirements(["a"]) 248 assert _get_pip_deps(conda_env) == [expected_mlflow_ver, "a"] 249 assert reqs == [expected_mlflow_ver, "a"] 250 assert cons == [] 251 252 conda_env, reqs, cons = _process_pip_requirements(["a"], pip_requirements=["b"]) 253 assert _get_pip_deps(conda_env) == [expected_mlflow_ver, "b"] 254 assert reqs == [expected_mlflow_ver, "b"] 255 assert cons == [] 256 257 # Ensure a requirement for mlflow is preserved 258 conda_env, reqs, cons = _process_pip_requirements(["a"], pip_requirements=["mlflow==1.2.3"]) 259 assert _get_pip_deps(conda_env) == ["mlflow==1.2.3"] 260 assert reqs == ["mlflow==1.2.3"] 261 assert cons == [] 262 263 # Ensure a requirement for mlflow is preserved when package hashes are specified 264 hash1 = "sha256:963c22532e82a93450674ab97d62f9e528ed0906b580fadb7c003e696197557c" 265 hash2 = "sha256:b15ff0c7e5e64f864a0b40c99b9a582227315eca2065d9f831db9aeb8f24637b" 266 conda_env, reqs, cons = _process_pip_requirements( 267 ["a"], 268 pip_requirements=[f"mlflow==1.20.2 --hash={hash1} --hash={hash2}"], 269 ) 270 assert _get_pip_deps(conda_env) == [f"mlflow==1.20.2 --hash={hash1} --hash={hash2}"] 271 assert reqs == [f"mlflow==1.20.2 --hash={hash1} --hash={hash2}"] 272 assert cons == [] 273 274 conda_env, reqs, cons = _process_pip_requirements(["a"], extra_pip_requirements=["b"]) 275 assert _get_pip_deps(conda_env) == [expected_mlflow_ver, "a", "b"] 276 assert reqs == [expected_mlflow_ver, "a", "b"] 277 assert cons == [] 278 279 con_file = tmp_path.joinpath("constraints.txt") 280 con_file.write_text("c") 281 conda_env, reqs, cons = _process_pip_requirements( 282 ["a"], pip_requirements=["b", f"-c {con_file}"] 283 ) 284 assert _get_pip_deps(conda_env) == [expected_mlflow_ver, "b", "-c constraints.txt"] 285 assert reqs == [expected_mlflow_ver, "b", "-c constraints.txt"] 286 assert cons == ["c"] 287 288 conda_env, reqs, cons = _process_pip_requirements(["a"], extra_pip_requirements=["a[extras]"]) 289 assert _get_pip_deps(conda_env) == [expected_mlflow_ver, "a[extras]"] 290 assert reqs == [expected_mlflow_ver, "a[extras]"] 291 assert cons == [] 292 293 conda_env, reqs, cons = _process_pip_requirements( 294 ["mlflow==1.2.3", "b[extra1]", "a==1.2.3"], 295 extra_pip_requirements=["b[extra2]", "a[extras]"], 296 ) 297 assert _get_pip_deps(conda_env) == ["mlflow==1.2.3", "b[extra1,extra2]", "a[extras]==1.2.3"] 298 assert reqs == ["mlflow==1.2.3", "b[extra1,extra2]", "a[extras]==1.2.3"] 299 assert cons == [] 300 301 302 def test_process_conda_env(tmp_path): 303 def make_conda_env(pip_deps): 304 return { 305 "name": "mlflow-env", 306 "channels": ["conda-forge"], 307 "dependencies": ["python=3.8.15", "pip", {"pip": pip_deps}], 308 } 309 310 expected_mlflow_ver = _mlflow_major_version_string() 311 312 conda_env, reqs, cons = _process_conda_env(make_conda_env(["a"])) 313 assert _get_pip_deps(conda_env) == [expected_mlflow_ver, "a"] 314 assert reqs == [expected_mlflow_ver, "a"] 315 assert cons == [] 316 317 conda_env_file = tmp_path.joinpath("conda_env.yaml") 318 conda_env_file.write_text(yaml.dump(make_conda_env(["a"]))) 319 conda_env, reqs, cons = _process_conda_env(str(conda_env_file)) 320 assert _get_pip_deps(conda_env) == [expected_mlflow_ver, "a"] 321 assert reqs == [expected_mlflow_ver, "a"] 322 assert cons == [] 323 324 # Ensure a requirement for mlflow is preserved 325 conda_env, reqs, cons = _process_conda_env(make_conda_env(["mlflow==1.2.3"])) 326 assert _get_pip_deps(conda_env) == ["mlflow==1.2.3"] 327 assert reqs == ["mlflow==1.2.3"] 328 assert cons == [] 329 330 con_file = tmp_path.joinpath("constraints.txt") 331 con_file.write_text("c") 332 conda_env, reqs, cons = _process_conda_env(make_conda_env(["a", f"-c {con_file}"])) 333 assert _get_pip_deps(conda_env) == [expected_mlflow_ver, "a", "-c constraints.txt"] 334 assert reqs == [expected_mlflow_ver, "a", "-c constraints.txt"] 335 assert cons == ["c"] 336 337 # NB: mlflow-skinny is not automatically attached to any model. If specified, it is 338 # up to the user to pin a version. 339 conda_env, reqs, cons = _process_conda_env(make_conda_env(["mlflow-skinny", "a", "b"])) 340 assert _get_pip_deps(conda_env) == ["mlflow-skinny", "a", "b"] 341 assert reqs == ["mlflow-skinny", "a", "b"] 342 assert cons == [] 343 344 with pytest.raises(TypeError, match=r"Expected .+, but got `int`"): 345 _process_conda_env(0) 346 347 348 @pytest.mark.parametrize( 349 ("env_var", "fallbacks", "should_raise"), 350 [ 351 # 1&2. If env var is True, always throw an exception from inference error 352 (True, ["sklearn"], True), 353 (True, None, True), 354 # 3. If env var is False but fallback is provided, should not throw an exception 355 (False, ["sklearn"], False), 356 # 4. If fallback is not provided, should throw an exception 357 (False, None, True), 358 ], 359 ) 360 def test_infer_requirements_error_handling(env_var, fallbacks, should_raise, monkeypatch): 361 monkeypatch.setenv("MLFLOW_REQUIREMENTS_INFERENCE_RAISE_ERRORS", str(env_var)) 362 # Disable UV auto-detect to ensure model-based inference is used 363 monkeypatch.setenv("MLFLOW_UV_AUTO_DETECT", "false") 364 365 call_args = ("path/to/model", "sklearn", fallbacks) 366 with mock.patch( 367 "mlflow.utils.requirements_utils._capture_imported_modules", 368 side_effect=MlflowException("Failed to capture imported modules"), 369 ): 370 if should_raise: 371 with pytest.raises(MlflowException, match="Failed to capture imported module"): 372 infer_pip_requirements(*call_args) 373 else: 374 # Should just pass with warning 375 with mock.patch("mlflow.utils.environment._logger.warning") as warning_mock: 376 infer_pip_requirements(*call_args) 377 warning_mock.assert_called_once() 378 warning_text = warning_mock.call_args[0][0] 379 assert "Encountered an unexpected error while inferring" in warning_text 380 381 382 @pytest.mark.parametrize( 383 ("input_requirements", "expected"), 384 [ 385 # Simple cases 386 (["scikit-learn>1", "pandas"], ["scikit-learn>1", "pandas"]), 387 # Duplicates without extras, preserving version restrictions 388 (["packageA", "packageA==1.0"], ["packageA==1.0"]), 389 # Duplicates with extras 390 (["packageA", "packageA[extras]"], ["packageA[extras]"]), 391 # Mixed cases 392 ( 393 ["packageA", "packageB", "packageA[extras]", "packageC<=2.0"], 394 ["packageA[extras]", "packageB", "packageC<=2.0"], 395 ), 396 # Mixed versions and extras 397 (["markdown>=3.5.1", "markdown[extras]", "markdown<4"], ["markdown[extras]<4,>=3.5.1"]), 398 # Overlapping extras 399 ( 400 ["packageZ[extra1]", "packageZ[extra2]", "packageZ"], 401 ["packageZ[extra1,extra2]"], 402 ), 403 # No version on extras with final version on non-extras 404 ( 405 ["markdown[extra1]", "markdown[extra2]", "markdown>3", "markdown<4"], 406 ["markdown[extra1,extra2]<4,>3"], 407 ), 408 # Version constraints with extras 409 (["markdown>1.0", "markdown[extras]<4"], ["markdown[extras]<4,>1.0"]), 410 # Verify duplicate specifiers are not preserved 411 ( 412 ["markdown==3.5.1", "markdown[extras]==3.5.1", "markdown[extras]"], 413 ["markdown[extras]==3.5.1"], 414 ), 415 # Verify duplicate extras are not preserved 416 (["markdown[extras]", "markdown", "markdown[extras]"], ["markdown[extras]"]), 417 # Marker-differentiated versions should be kept separate 418 ( 419 [ 420 "numpy==2.2.6 ; python_full_version < '3.11'", 421 "numpy==2.4.2 ; python_full_version >= '3.11'", 422 ], 423 [ 424 'numpy==2.2.6; python_full_version < "3.11"', 425 'numpy==2.4.2; python_full_version >= "3.11"', 426 ], 427 ), 428 # Same marker should still merge 429 ( 430 [ 431 "numpy>=1.0 ; python_version >= '3.10'", 432 "numpy<2.0 ; python_version >= '3.10'", 433 ], 434 ['numpy<2.0,>=1.0; python_version >= "3.10"'], 435 ), 436 # Local version label on second entry - prefer non-local (PyPI-installable) 437 (["torch==2.7.1", "torch==2.7.1+cu128"], ["torch==2.7.1"]), 438 # Local version label on first entry - prefer non-local (PyPI-installable) 439 (["torch==2.7.1+cu128", "torch==2.7.1"], ["torch==2.7.1"]), 440 # Both have the same local label - should deduplicate normally 441 (["torch==2.7.1+cu128", "torch==2.7.1+cu128"], ["torch==2.7.1+cu128"]), 442 ], 443 ) 444 def test_deduplicate_requirements_resolve_correctly(input_requirements, expected): 445 assert _deduplicate_requirements(input_requirements) == expected 446 447 448 @pytest.mark.parametrize( 449 "input_requirements", 450 [ 451 # Non-inclusive range with precise specifier 452 ["scikit-learn==1.1", "scikit-learn<1"], 453 # Incompatible ranges with extras 454 ["markdown[extras]==3.5.1", "markdown<3.4"], 455 # Invalid ranges 456 ["markdown<3", "markdown>3"], 457 # Conflicting versions 458 ["markdown==3.0", "markdown==3.5"], 459 # Differing local labels are a real conflict and should not be silently dropped 460 ["torch==2.7.1+cu128", "torch==2.7.1+cpu"], 461 ], 462 ) 463 def test_invalid_requirements_raise(input_requirements): 464 with pytest.raises( 465 MlflowException, match="The specified requirements versions are incompatible" 466 ): 467 _deduplicate_requirements(input_requirements) 468 469 470 @pytest.mark.parametrize( 471 ("input_requirements", "expected"), 472 [ 473 (["databricks-connect", "pyspark", "pyspark-connect"], ["databricks-connect"]), 474 (["databricks-connect==1.15.0", "pyspark==3.0.0"], ["databricks-connect==1.15.0"]), 475 (["databricks-connect==1.15.0", "pyspark-connect"], ["databricks-connect==1.15.0"]), 476 (["pyspark==3.0.0", "pyspark-connect"], ["pyspark==3.0.0", "pyspark-connect"]), 477 ( 478 ["pyspark==3.0.0", "pyspark-connect==1.0.0"], 479 ["pyspark==3.0.0", "pyspark-connect==1.0.0"], 480 ), 481 ], 482 ) 483 def test_remove_incompatible_requirements(input_requirements, expected): 484 assert _remove_incompatible_requirements(input_requirements) == expected