/ tests / utils / test_environment.py
test_environment.py
  1  import importlib.metadata
  2  import os
  3  from unittest import mock
  4  
  5  import pytest
  6  import yaml
  7  
  8  from mlflow.exceptions import MlflowException
  9  from mlflow.utils.environment import (
 10      _contains_mlflow_requirement,
 11      _deduplicate_requirements,
 12      _get_pip_deps,
 13      _get_pip_requirement_specifier,
 14      _is_mlflow_requirement,
 15      _is_pip_deps,
 16      _mlflow_conda_env,
 17      _overwrite_pip_deps,
 18      _parse_pip_requirements,
 19      _process_conda_env,
 20      _process_pip_requirements,
 21      _remove_incompatible_requirements,
 22      _validate_env_arguments,
 23      infer_pip_requirements,
 24  )
 25  
 26  from tests.helper_functions import _mlflow_major_version_string
 27  
 28  
 29  @pytest.fixture
 30  def conda_env_path(tmp_path):
 31      return os.path.join(tmp_path, "conda_env.yaml")
 32  
 33  
 34  def test_mlflow_conda_env_returns_none_when_output_path_is_specified(conda_env_path):
 35      env_creation_output = _mlflow_conda_env(
 36          path=conda_env_path,
 37          additional_conda_deps=["conda-dep-1=0.0.1", "conda-dep-2"],
 38          additional_pip_deps=["pip-dep-1", "pip-dep2==0.1.0"],
 39      )
 40  
 41      assert env_creation_output is None
 42  
 43  
 44  def test_mlflow_conda_env_returns_expected_env_dict_when_output_path_is_not_specified():
 45      conda_deps = ["conda-dep-1=0.0.1", "conda-dep-2"]
 46      env = _mlflow_conda_env(path=None, additional_conda_deps=conda_deps)
 47  
 48      for conda_dep in conda_deps:
 49          assert conda_dep in env["dependencies"]
 50  
 51  
 52  @pytest.mark.parametrize("conda_deps", [["conda-dep-1=0.0.1", "conda-dep-2"], None])
 53  def test_mlflow_conda_env_includes_pip_dependencies_but_pip_is_not_specified(conda_deps):
 54      additional_pip_deps = ["pip-dep==0.0.1"]
 55      env = _mlflow_conda_env(
 56          path=None, additional_conda_deps=conda_deps, additional_pip_deps=additional_pip_deps
 57      )
 58      if conda_deps is not None:
 59          for conda_dep in conda_deps:
 60              assert conda_dep in env["dependencies"]
 61      pip_version = importlib.metadata.version("pip")
 62      assert f"pip<={pip_version}" in env["dependencies"]
 63  
 64  
 65  @pytest.mark.parametrize("pip_specification", ["pip", "pip==20.0.02"])
 66  def test_mlflow_conda_env_includes_pip_dependencies_and_pip_is_specified(pip_specification):
 67      conda_deps = ["conda-dep-1=0.0.1", "conda-dep-2", pip_specification]
 68      additional_pip_deps = ["pip-dep==0.0.1"]
 69      env = _mlflow_conda_env(
 70          path=None, additional_conda_deps=conda_deps, additional_pip_deps=additional_pip_deps
 71      )
 72      for conda_dep in conda_deps:
 73          assert conda_dep in env["dependencies"]
 74      assert pip_specification in env["dependencies"]
 75  
 76  
 77  def test_is_pip_deps():
 78      assert _is_pip_deps({"pip": ["a"]})
 79      assert not _is_pip_deps({"ipi": ["a"]})
 80      assert not _is_pip_deps("")
 81      assert not _is_pip_deps([])
 82  
 83  
 84  def test_overwrite_pip_deps():
 85      # dependencies field doesn't exist
 86      name_and_channels = {"name": "env", "channels": ["conda-forge"]}
 87      expected = {**name_and_channels, "dependencies": [{"pip": ["scipy"]}]}
 88      assert _overwrite_pip_deps(name_and_channels, ["scipy"]) == expected
 89  
 90      # dependencies field doesn't contain pip dependencies
 91      conda_env = {**name_and_channels, "dependencies": ["pip"]}
 92      expected = {**name_and_channels, "dependencies": ["pip", {"pip": ["scipy"]}]}
 93      assert _overwrite_pip_deps(conda_env, ["scipy"]) == expected
 94  
 95      # dependencies field contains pip dependencies
 96      conda_env = {**name_and_channels, "dependencies": ["pip", {"pip": ["numpy"]}, "pandas"]}
 97      expected = {**name_and_channels, "dependencies": ["pip", {"pip": ["scipy"]}, "pandas"]}
 98      assert _overwrite_pip_deps(conda_env, ["scipy"]) == expected
 99  
100  
101  def test_parse_pip_requirements(tmp_path):
102      assert _parse_pip_requirements(None) == ([], [])
103      assert _parse_pip_requirements([]) == ([], [])
104      # Without version specifiers
105      assert _parse_pip_requirements(["a", "b"]) == (["a", "b"], [])
106      # With version specifiers
107      assert _parse_pip_requirements(["a==0.0", "b>1.1"]) == (["a==0.0", "b>1.1"], [])
108      # Environment marker (https://www.python.org/dev/peps/pep-0508/#environment-markers)
109      assert _parse_pip_requirements(['a; python_version < "3.8"']) == (
110          ['a; python_version < "3.8"'],
111          [],
112      )
113      # GitHub URI
114      mlflow_repo_uri = "git+https://github.com/mlflow/mlflow.git"
115      assert _parse_pip_requirements([mlflow_repo_uri]) == ([mlflow_repo_uri], [])
116      # Local file
117      fake_whl = tmp_path.joinpath("fake.whl")
118      fake_whl.write_text("")
119      assert _parse_pip_requirements([str(fake_whl)]) == ([str(fake_whl)], [])
120  
121  
122  def test_parse_pip_requirements_with_relative_requirements_files(tmp_path, monkeypatch):
123      monkeypatch.chdir(tmp_path)
124      f1 = tmp_path.joinpath("requirements1.txt")
125      f1.write_text("b")
126      assert _parse_pip_requirements(f1.name) == (["b"], [])
127      assert _parse_pip_requirements(["a", f"-r {f1.name}"]) == (["a", "b"], [])
128  
129      f2 = tmp_path.joinpath("requirements2.txt")
130      f3 = tmp_path.joinpath("requirements3.txt")
131      f2.write_text(f"b\n-r {f3.name}")
132      f3.write_text("c")
133      assert _parse_pip_requirements(f2.name) == (["b", "c"], [])
134      assert _parse_pip_requirements(["a", f"-r {f2.name}"]) == (["a", "b", "c"], [])
135  
136  
137  def test_parse_pip_requirements_with_absolute_requirements_files(tmp_path):
138      f1 = tmp_path.joinpath("requirements1.txt")
139      f1.write_text("b")
140      assert _parse_pip_requirements(str(f1)) == (["b"], [])
141      assert _parse_pip_requirements(["a", f"-r {f1}"]) == (["a", "b"], [])
142  
143      f2 = tmp_path.joinpath("requirements2.txt")
144      f3 = tmp_path.joinpath("requirements3.txt")
145      f2.write_text(f"b\n-r {f3}")
146      f3.write_text("c")
147      assert _parse_pip_requirements(str(f2)) == (["b", "c"], [])
148      assert _parse_pip_requirements(["a", f"-r {f2}"]) == (["a", "b", "c"], [])
149  
150  
151  def test_parse_pip_requirements_with_constraints_files(tmp_path):
152      con_file = tmp_path.joinpath("constraints.txt")
153      con_file.write_text("b")
154      assert _parse_pip_requirements(["a", f"-c {con_file}"]) == (["a"], ["b"])
155  
156      req_file = tmp_path.joinpath("requirements.txt")
157      req_file.write_text(f"-c {con_file}\n")
158      assert _parse_pip_requirements(["a", f"-r {req_file}"]) == (["a"], ["b"])
159  
160  
161  def test_parse_pip_requirements_ignores_comments_and_blank_lines(tmp_path):
162      reqs = [
163          "# comment",
164          "a # inline comment",
165          # blank lines
166          "",
167          " ",
168      ]
169      f = tmp_path.joinpath("requirements.txt")
170      f.write_text("\n".join(reqs))
171      assert _parse_pip_requirements(reqs) == (["a"], [])
172      assert _parse_pip_requirements(str(f)) == (["a"], [])
173  
174  
175  def test_parse_pip_requirements_removes_temporary_requirements_file():
176      assert _parse_pip_requirements(["a"]) == (["a"], [])
177      assert all(not x.endswith(".tmp.requirements.txt") for x in os.listdir())
178  
179      with pytest.raises(FileNotFoundError, match="No such file or directory"):
180          _parse_pip_requirements(["a", "-r does_not_exist.txt"])
181      # Ensure the temporary requirements file has been removed even when parsing fails
182      assert all(not x.endswith(".tmp.requirements.txt") for x in os.listdir())
183  
184  
185  @pytest.mark.parametrize("invalid_argument", [0, True, [0]])
186  def test_parse_pip_requirements_with_invalid_argument_types(invalid_argument):
187      with pytest.raises(TypeError, match="`pip_requirements` must be either a string path"):
188          _parse_pip_requirements(invalid_argument)
189  
190  
191  def test_validate_env_arguments():
192      _validate_env_arguments(pip_requirements=None, extra_pip_requirements=None, conda_env=None)
193  
194      match = "Only one of `conda_env`, `pip_requirements`, and `extra_pip_requirements`"
195      with pytest.raises(ValueError, match=match):
196          _validate_env_arguments(conda_env={}, pip_requirements=[], extra_pip_requirements=None)
197  
198      with pytest.raises(ValueError, match=match):
199          _validate_env_arguments(conda_env={}, pip_requirements=None, extra_pip_requirements=[])
200  
201      with pytest.raises(ValueError, match=match):
202          _validate_env_arguments(conda_env=None, pip_requirements=[], extra_pip_requirements=[])
203  
204      with pytest.raises(ValueError, match=match):
205          _validate_env_arguments(conda_env={}, pip_requirements=[], extra_pip_requirements=[])
206  
207  
208  def test_is_mlflow_requirement():
209      assert _is_mlflow_requirement("mlflow")
210      assert _is_mlflow_requirement("MLFLOW")
211      assert _is_mlflow_requirement("MLflow")
212      assert _is_mlflow_requirement("mlflow==1.2.3")
213      assert _is_mlflow_requirement("mlflow < 1.2.3")
214      assert _is_mlflow_requirement("mlflow; python_version < '3.8'")
215      assert _is_mlflow_requirement("mlflow @ https://github.com/mlflow/mlflow.git")
216      assert _is_mlflow_requirement("mlflow @ file:///path/to/mlflow")
217      assert _is_mlflow_requirement("mlflow-skinny==1.2.3")
218      assert not _is_mlflow_requirement("foo")
219      # Ensure packages that look like mlflow are NOT considered as mlflow.
220      assert not _is_mlflow_requirement("mlflow-foo")
221      assert not _is_mlflow_requirement("mlflow_foo")
222  
223  
224  def test_contains_mlflow_requirement():
225      assert _contains_mlflow_requirement(["mlflow"])
226      assert _contains_mlflow_requirement(["mlflow==1.2.3"])
227      assert _contains_mlflow_requirement(["mlflow", "foo"])
228      assert _contains_mlflow_requirement(["mlflow-skinny"])
229      assert not _contains_mlflow_requirement([])
230      assert not _contains_mlflow_requirement(["foo"])
231  
232  
233  def test_get_pip_requirement_specifier():
234      assert _get_pip_requirement_specifier("") == ""
235      assert _get_pip_requirement_specifier(" ") == " "
236      assert _get_pip_requirement_specifier("mlflow") == "mlflow"
237      assert _get_pip_requirement_specifier("mlflow==1.2.3") == "mlflow==1.2.3"
238      assert _get_pip_requirement_specifier("-r reqs.txt") == ""
239      assert _get_pip_requirement_specifier("  -r reqs.txt") == " "
240      assert _get_pip_requirement_specifier("mlflow==1.2.3 --hash=foo") == "mlflow==1.2.3"
241      assert _get_pip_requirement_specifier("mlflow==1.2.3       --hash=foo") == "mlflow==1.2.3      "
242      assert _get_pip_requirement_specifier("mlflow-skinny==1.2 --foo=bar") == "mlflow-skinny==1.2"
243  
244  
245  def test_process_pip_requirements(tmp_path):
246      expected_mlflow_ver = _mlflow_major_version_string()
247      conda_env, reqs, cons = _process_pip_requirements(["a"])
248      assert _get_pip_deps(conda_env) == [expected_mlflow_ver, "a"]
249      assert reqs == [expected_mlflow_ver, "a"]
250      assert cons == []
251  
252      conda_env, reqs, cons = _process_pip_requirements(["a"], pip_requirements=["b"])
253      assert _get_pip_deps(conda_env) == [expected_mlflow_ver, "b"]
254      assert reqs == [expected_mlflow_ver, "b"]
255      assert cons == []
256  
257      # Ensure a requirement for mlflow is preserved
258      conda_env, reqs, cons = _process_pip_requirements(["a"], pip_requirements=["mlflow==1.2.3"])
259      assert _get_pip_deps(conda_env) == ["mlflow==1.2.3"]
260      assert reqs == ["mlflow==1.2.3"]
261      assert cons == []
262  
263      # Ensure a requirement for mlflow is preserved when package hashes are specified
264      hash1 = "sha256:963c22532e82a93450674ab97d62f9e528ed0906b580fadb7c003e696197557c"
265      hash2 = "sha256:b15ff0c7e5e64f864a0b40c99b9a582227315eca2065d9f831db9aeb8f24637b"
266      conda_env, reqs, cons = _process_pip_requirements(
267          ["a"],
268          pip_requirements=[f"mlflow==1.20.2 --hash={hash1} --hash={hash2}"],
269      )
270      assert _get_pip_deps(conda_env) == [f"mlflow==1.20.2 --hash={hash1} --hash={hash2}"]
271      assert reqs == [f"mlflow==1.20.2 --hash={hash1} --hash={hash2}"]
272      assert cons == []
273  
274      conda_env, reqs, cons = _process_pip_requirements(["a"], extra_pip_requirements=["b"])
275      assert _get_pip_deps(conda_env) == [expected_mlflow_ver, "a", "b"]
276      assert reqs == [expected_mlflow_ver, "a", "b"]
277      assert cons == []
278  
279      con_file = tmp_path.joinpath("constraints.txt")
280      con_file.write_text("c")
281      conda_env, reqs, cons = _process_pip_requirements(
282          ["a"], pip_requirements=["b", f"-c {con_file}"]
283      )
284      assert _get_pip_deps(conda_env) == [expected_mlflow_ver, "b", "-c constraints.txt"]
285      assert reqs == [expected_mlflow_ver, "b", "-c constraints.txt"]
286      assert cons == ["c"]
287  
288      conda_env, reqs, cons = _process_pip_requirements(["a"], extra_pip_requirements=["a[extras]"])
289      assert _get_pip_deps(conda_env) == [expected_mlflow_ver, "a[extras]"]
290      assert reqs == [expected_mlflow_ver, "a[extras]"]
291      assert cons == []
292  
293      conda_env, reqs, cons = _process_pip_requirements(
294          ["mlflow==1.2.3", "b[extra1]", "a==1.2.3"],
295          extra_pip_requirements=["b[extra2]", "a[extras]"],
296      )
297      assert _get_pip_deps(conda_env) == ["mlflow==1.2.3", "b[extra1,extra2]", "a[extras]==1.2.3"]
298      assert reqs == ["mlflow==1.2.3", "b[extra1,extra2]", "a[extras]==1.2.3"]
299      assert cons == []
300  
301  
302  def test_process_conda_env(tmp_path):
303      def make_conda_env(pip_deps):
304          return {
305              "name": "mlflow-env",
306              "channels": ["conda-forge"],
307              "dependencies": ["python=3.8.15", "pip", {"pip": pip_deps}],
308          }
309  
310      expected_mlflow_ver = _mlflow_major_version_string()
311  
312      conda_env, reqs, cons = _process_conda_env(make_conda_env(["a"]))
313      assert _get_pip_deps(conda_env) == [expected_mlflow_ver, "a"]
314      assert reqs == [expected_mlflow_ver, "a"]
315      assert cons == []
316  
317      conda_env_file = tmp_path.joinpath("conda_env.yaml")
318      conda_env_file.write_text(yaml.dump(make_conda_env(["a"])))
319      conda_env, reqs, cons = _process_conda_env(str(conda_env_file))
320      assert _get_pip_deps(conda_env) == [expected_mlflow_ver, "a"]
321      assert reqs == [expected_mlflow_ver, "a"]
322      assert cons == []
323  
324      # Ensure a requirement for mlflow is preserved
325      conda_env, reqs, cons = _process_conda_env(make_conda_env(["mlflow==1.2.3"]))
326      assert _get_pip_deps(conda_env) == ["mlflow==1.2.3"]
327      assert reqs == ["mlflow==1.2.3"]
328      assert cons == []
329  
330      con_file = tmp_path.joinpath("constraints.txt")
331      con_file.write_text("c")
332      conda_env, reqs, cons = _process_conda_env(make_conda_env(["a", f"-c {con_file}"]))
333      assert _get_pip_deps(conda_env) == [expected_mlflow_ver, "a", "-c constraints.txt"]
334      assert reqs == [expected_mlflow_ver, "a", "-c constraints.txt"]
335      assert cons == ["c"]
336  
337      # NB: mlflow-skinny is not automatically attached to any model. If specified, it is
338      # up to the user to pin a version.
339      conda_env, reqs, cons = _process_conda_env(make_conda_env(["mlflow-skinny", "a", "b"]))
340      assert _get_pip_deps(conda_env) == ["mlflow-skinny", "a", "b"]
341      assert reqs == ["mlflow-skinny", "a", "b"]
342      assert cons == []
343  
344      with pytest.raises(TypeError, match=r"Expected .+, but got `int`"):
345          _process_conda_env(0)
346  
347  
348  @pytest.mark.parametrize(
349      ("env_var", "fallbacks", "should_raise"),
350      [
351          # 1&2. If env var is True, always throw an exception from inference error
352          (True, ["sklearn"], True),
353          (True, None, True),
354          # 3. If env var is False but fallback is provided, should not throw an exception
355          (False, ["sklearn"], False),
356          # 4. If fallback is not provided, should throw an exception
357          (False, None, True),
358      ],
359  )
360  def test_infer_requirements_error_handling(env_var, fallbacks, should_raise, monkeypatch):
361      monkeypatch.setenv("MLFLOW_REQUIREMENTS_INFERENCE_RAISE_ERRORS", str(env_var))
362      # Disable UV auto-detect to ensure model-based inference is used
363      monkeypatch.setenv("MLFLOW_UV_AUTO_DETECT", "false")
364  
365      call_args = ("path/to/model", "sklearn", fallbacks)
366      with mock.patch(
367          "mlflow.utils.requirements_utils._capture_imported_modules",
368          side_effect=MlflowException("Failed to capture imported modules"),
369      ):
370          if should_raise:
371              with pytest.raises(MlflowException, match="Failed to capture imported module"):
372                  infer_pip_requirements(*call_args)
373          else:
374              # Should just pass with warning
375              with mock.patch("mlflow.utils.environment._logger.warning") as warning_mock:
376                  infer_pip_requirements(*call_args)
377              warning_mock.assert_called_once()
378              warning_text = warning_mock.call_args[0][0]
379              assert "Encountered an unexpected error while inferring" in warning_text
380  
381  
382  @pytest.mark.parametrize(
383      ("input_requirements", "expected"),
384      [
385          # Simple cases
386          (["scikit-learn>1", "pandas"], ["scikit-learn>1", "pandas"]),
387          # Duplicates without extras, preserving version restrictions
388          (["packageA", "packageA==1.0"], ["packageA==1.0"]),
389          # Duplicates with extras
390          (["packageA", "packageA[extras]"], ["packageA[extras]"]),
391          # Mixed cases
392          (
393              ["packageA", "packageB", "packageA[extras]", "packageC<=2.0"],
394              ["packageA[extras]", "packageB", "packageC<=2.0"],
395          ),
396          # Mixed versions and extras
397          (["markdown>=3.5.1", "markdown[extras]", "markdown<4"], ["markdown[extras]<4,>=3.5.1"]),
398          # Overlapping extras
399          (
400              ["packageZ[extra1]", "packageZ[extra2]", "packageZ"],
401              ["packageZ[extra1,extra2]"],
402          ),
403          # No version on extras with final version on non-extras
404          (
405              ["markdown[extra1]", "markdown[extra2]", "markdown>3", "markdown<4"],
406              ["markdown[extra1,extra2]<4,>3"],
407          ),
408          # Version constraints with extras
409          (["markdown>1.0", "markdown[extras]<4"], ["markdown[extras]<4,>1.0"]),
410          # Verify duplicate specifiers are not preserved
411          (
412              ["markdown==3.5.1", "markdown[extras]==3.5.1", "markdown[extras]"],
413              ["markdown[extras]==3.5.1"],
414          ),
415          # Verify duplicate extras are not preserved
416          (["markdown[extras]", "markdown", "markdown[extras]"], ["markdown[extras]"]),
417          # Marker-differentiated versions should be kept separate
418          (
419              [
420                  "numpy==2.2.6 ; python_full_version < '3.11'",
421                  "numpy==2.4.2 ; python_full_version >= '3.11'",
422              ],
423              [
424                  'numpy==2.2.6; python_full_version < "3.11"',
425                  'numpy==2.4.2; python_full_version >= "3.11"',
426              ],
427          ),
428          # Same marker should still merge
429          (
430              [
431                  "numpy>=1.0 ; python_version >= '3.10'",
432                  "numpy<2.0 ; python_version >= '3.10'",
433              ],
434              ['numpy<2.0,>=1.0; python_version >= "3.10"'],
435          ),
436          # Local version label on second entry - prefer non-local (PyPI-installable)
437          (["torch==2.7.1", "torch==2.7.1+cu128"], ["torch==2.7.1"]),
438          # Local version label on first entry - prefer non-local (PyPI-installable)
439          (["torch==2.7.1+cu128", "torch==2.7.1"], ["torch==2.7.1"]),
440          # Both have the same local label - should deduplicate normally
441          (["torch==2.7.1+cu128", "torch==2.7.1+cu128"], ["torch==2.7.1+cu128"]),
442      ],
443  )
444  def test_deduplicate_requirements_resolve_correctly(input_requirements, expected):
445      assert _deduplicate_requirements(input_requirements) == expected
446  
447  
448  @pytest.mark.parametrize(
449      "input_requirements",
450      [
451          # Non-inclusive range with precise specifier
452          ["scikit-learn==1.1", "scikit-learn<1"],
453          # Incompatible ranges with extras
454          ["markdown[extras]==3.5.1", "markdown<3.4"],
455          # Invalid ranges
456          ["markdown<3", "markdown>3"],
457          # Conflicting versions
458          ["markdown==3.0", "markdown==3.5"],
459          # Differing local labels are a real conflict and should not be silently dropped
460          ["torch==2.7.1+cu128", "torch==2.7.1+cpu"],
461      ],
462  )
463  def test_invalid_requirements_raise(input_requirements):
464      with pytest.raises(
465          MlflowException, match="The specified requirements versions are incompatible"
466      ):
467          _deduplicate_requirements(input_requirements)
468  
469  
470  @pytest.mark.parametrize(
471      ("input_requirements", "expected"),
472      [
473          (["databricks-connect", "pyspark", "pyspark-connect"], ["databricks-connect"]),
474          (["databricks-connect==1.15.0", "pyspark==3.0.0"], ["databricks-connect==1.15.0"]),
475          (["databricks-connect==1.15.0", "pyspark-connect"], ["databricks-connect==1.15.0"]),
476          (["pyspark==3.0.0", "pyspark-connect"], ["pyspark==3.0.0", "pyspark-connect"]),
477          (
478              ["pyspark==3.0.0", "pyspark-connect==1.0.0"],
479              ["pyspark==3.0.0", "pyspark-connect==1.0.0"],
480          ),
481      ],
482  )
483  def test_remove_incompatible_requirements(input_requirements, expected):
484      assert _remove_incompatible_requirements(input_requirements) == expected