/ tests / utils / test_requirements_utils.py
test_requirements_utils.py
  1  import importlib
  2  import os
  3  import sys
  4  from importlib.metadata import version
  5  from unittest import mock
  6  
  7  import cloudpickle
  8  import importlib_metadata
  9  import pytest
 10  
 11  import mlflow
 12  import mlflow.utils.requirements_utils
 13  from mlflow.exceptions import MlflowException
 14  from mlflow.utils.environment import infer_pip_requirements
 15  from mlflow.utils.os import is_windows
 16  from mlflow.utils.requirements_utils import (
 17      _capture_imported_modules,
 18      _check_requirement_satisfied,
 19      _get_installed_version,
 20      _get_pinned_requirement,
 21      _infer_requirements,
 22      _is_comment,
 23      _is_empty,
 24      _is_requirements_file,
 25      _join_continued_lines,
 26      _normalize_package_name,
 27      _parse_requirements,
 28      _prune_packages,
 29      _strip_inline_comment,
 30      _strip_local_version_label,
 31      warn_dependency_requirement_mismatches,
 32  )
 33  
 34  from tests.helper_functions import AnyStringWith
 35  
 36  
 37  def test_is_comment():
 38      assert _is_comment("# comment")
 39      assert _is_comment("#")
 40      assert _is_comment("### comment ###")
 41      assert not _is_comment("comment")
 42      assert not _is_comment("")
 43  
 44  
 45  def test_is_empty():
 46      assert _is_empty("")
 47      assert not _is_empty(" ")
 48      assert not _is_empty("a")
 49  
 50  
 51  def test_is_requirements_file():
 52      assert _is_requirements_file("-r req.txt")
 53      assert _is_requirements_file("-r  req.txt")
 54      assert _is_requirements_file("--requirement req.txt")
 55      assert _is_requirements_file("--requirement  req.txt")
 56      assert not _is_requirements_file("req")
 57  
 58  
 59  def test_strip_inline_comment():
 60      assert _strip_inline_comment("aaa # comment") == "aaa"
 61      assert _strip_inline_comment("aaa   # comment") == "aaa"
 62      assert _strip_inline_comment("aaa #   comment") == "aaa"
 63      assert _strip_inline_comment("aaa # com1 # com2") == "aaa"
 64      # Ensure a URI fragment is not stripped
 65      assert (
 66          _strip_inline_comment("git+https://git/repo.git#subdirectory=subdir")
 67          == "git+https://git/repo.git#subdirectory=subdir"
 68      )
 69  
 70  
 71  def test_join_continued_lines():
 72      assert list(_join_continued_lines(["a"])) == ["a"]
 73      assert list(_join_continued_lines(["a\\", "b"])) == ["ab"]
 74      assert list(_join_continued_lines(["a\\", "b\\", "c"])) == ["abc"]
 75      assert list(_join_continued_lines(["a\\", " b"])) == ["a b"]
 76      assert list(_join_continued_lines(["a\\", " b\\", " c"])) == ["a b c"]
 77      assert list(_join_continued_lines(["a\\", "\\", "b"])) == ["ab"]
 78      assert list(_join_continued_lines(["a\\", "b", "c\\", "d"])) == ["ab", "cd"]
 79      assert list(_join_continued_lines(["a\\", "", "b"])) == ["a", "b"]
 80      assert list(_join_continued_lines(["a\\"])) == ["a"]
 81      assert list(_join_continued_lines(["\\", "a"])) == ["a"]
 82  
 83  
 84  def test_parse_requirements(tmp_path, monkeypatch):
 85      root_req_src = """
 86  # No version specifier
 87  noverspec
 88  no-ver-spec
 89  
 90  # Version specifiers
 91  verspec<1.0
 92  ver-spec == 2.0
 93  
 94  # Environment marker
 95  env-marker; python_version < "3.8"
 96  
 97  inline-comm # Inline comment
 98  inlinecomm                        # Inline comment
 99  
100  # Git URIs
101  git+https://github.com/git/uri
102  git+https://github.com/sub/dir#subdirectory=subdir
103  
104  # Requirements files
105  -r {relative_req}
106  --requirement {absolute_req}
107  
108  # Constraints files
109  -c {relative_con}
110  --constraint {absolute_con}
111  
112  # Line continuation
113  line-cont\
114  ==\
115  1.0
116  
117  # Line continuation with spaces
118  line-cont-space \
119  == \
120  1.0
121  
122  # Line continuation with a blank line
123  line-cont-blank\
124  
125  # Line continuation at EOF
126  line-cont-eof\
127  """.strip()
128  
129      monkeypatch.chdir(tmp_path)
130      root_req = tmp_path.joinpath("requirements.txt")
131      # Requirements files
132      rel_req = tmp_path.joinpath("relative_req.txt")
133      abs_req = tmp_path.joinpath("absolute_req.txt")
134      # Constraints files
135      rel_con = tmp_path.joinpath("relative_con.txt")
136      abs_con = tmp_path.joinpath("absolute_con.txt")
137  
138      # pip's requirements parser collapses an absolute requirements file path:
139      # https://github.com/pypa/pip/issues/10121
140      # As a workaround, use a relative path on Windows.
141      absolute_req = abs_req.name if is_windows() else str(abs_req)
142      absolute_con = abs_con.name if is_windows() else str(abs_con)
143      root_req.write_text(
144          root_req_src.format(
145              relative_req=rel_req.name,
146              absolute_req=absolute_req,
147              relative_con=rel_con.name,
148              absolute_con=absolute_con,
149          )
150      )
151      rel_req.write_text("rel-req-xxx\nrel-req-yyy")
152      abs_req.write_text("abs-req-zzz")
153      rel_con.write_text("rel-con-xxx\nrel-con-yyy")
154      abs_con.write_text("abs-con-zzz")
155  
156      # Uncomment this to get the expected output from pip's internal parser
157      # from pip._internal.network.session import PipSession
158      # from pip._internal.req import parse_requirements as pip_parse_requirements
159      #
160      # pip_reqs = list(pip_parse_requirements(root_req.name, session=PipSession()))
161      # print(f"expected_reqs = {[r.requirement for r in pip_reqs if not r.constraint]}")
162      # print(f"expected_cons = {[r.requirement for r in pip_reqs if r.constraint]}")
163  
164      expected_reqs = [
165          "noverspec",
166          "no-ver-spec",
167          "verspec<1.0",
168          "ver-spec == 2.0",
169          'env-marker; python_version < "3.8"',
170          "inline-comm",
171          "inlinecomm",
172          "git+https://github.com/git/uri",
173          "git+https://github.com/sub/dir#subdirectory=subdir",
174          "rel-req-xxx",
175          "rel-req-yyy",
176          "abs-req-zzz",
177          "line-cont==1.0",
178          "line-cont-space == 1.0",
179          "line-cont-blank",
180          "line-cont-eof",
181      ]
182      expected_cons = [
183          "rel-con-xxx",
184          "rel-con-yyy",
185          "abs-con-zzz",
186      ]
187  
188      parsed_reqs = list(_parse_requirements(root_req.name, is_constraint=False))
189      assert [r.req_str for r in parsed_reqs if not r.is_constraint] == expected_reqs
190      assert [r.req_str for r in parsed_reqs if r.is_constraint] == expected_cons
191  
192  
193  def test_normalize_package_name():
194      assert _normalize_package_name("abc") == "abc"
195      assert _normalize_package_name("ABC") == "abc"
196      assert _normalize_package_name("a-b-c") == "a-b-c"
197      assert _normalize_package_name("a.b.c") == "a-b-c"
198      assert _normalize_package_name("a_b_c") == "a-b-c"
199      assert _normalize_package_name("a--b--c") == "a-b-c"
200      assert _normalize_package_name("a-._b-._c") == "a-b-c"
201  
202  
203  def test_prune_packages():
204      assert _prune_packages(["mlflow"]) == {"mlflow"}
205      assert _prune_packages(["mlflow", "scikit-learn"]) == {"mlflow", "scikit-learn"}
206  
207  
208  def test_capture_imported_modules():
209      from mlflow.utils._capture_modules import _CaptureImportedModules
210  
211      with _CaptureImportedModules() as cap:
212          import math  # clint: disable=lazy-import  # noqa: F401
213  
214          __import__("pandas")
215          importlib.import_module("numpy")
216  
217      assert "math" in cap.imported_modules
218      assert "pandas" in cap.imported_modules
219      assert "numpy" in cap.imported_modules
220  
221  
222  def test_strip_local_version_label():
223      assert _strip_local_version_label("1.2.3") == "1.2.3"
224      assert _strip_local_version_label("1.2.3+ab") == "1.2.3"
225      assert _strip_local_version_label("1.2.3rc0+ab") == "1.2.3rc0"
226      assert _strip_local_version_label("1.2.3.dev0+ab") == "1.2.3.dev0"
227      assert _strip_local_version_label("1.2.3.post0+ab") == "1.2.3.post0"
228      assert _strip_local_version_label("invalid") == "invalid"
229  
230  
231  def test_get_installed_version(tmp_path, monkeypatch):
232      assert _get_installed_version("mlflow") == mlflow.__version__
233      assert _get_installed_version("numpy") == version("numpy")
234      assert _get_installed_version("pandas") == version("pandas")
235      assert _get_installed_version("scikit-learn", module="sklearn") == version("scikit-learn")
236  
237      not_found_package = tmp_path.joinpath("not_found.py")
238      not_found_package.write_text("__version__ = '1.2.3'")
239      monkeypatch.syspath_prepend(str(tmp_path))
240      with pytest.raises(importlib_metadata.PackageNotFoundError, match=r".+"):
241          importlib_metadata.version("not_found")
242      assert _get_installed_version("not_found") == "1.2.3"
243  
244  
245  def test_package_with_mismatched_pypi_and_import_name():
246      try:
247          import dspy  # noqa: F401
248  
249          assert _get_installed_version("dspy") == version("dspy-ai")
250      except ImportError:
251          pytest.skip("Skipping test because 'dspy' package is not installed")
252  
253  
254  def test_get_pinned_requirement(tmp_path, monkeypatch):
255      assert _get_pinned_requirement("mlflow") == f"mlflow=={mlflow.__version__}"
256      assert _get_pinned_requirement("mlflow", version="1.2.3") == "mlflow==1.2.3"
257  
258      not_found_package = tmp_path.joinpath("not_found.py")
259      not_found_package.write_text("__version__ = '1.2.3'")
260      monkeypatch.syspath_prepend(str(tmp_path))
261      with pytest.raises(importlib_metadata.PackageNotFoundError, match=r".+"):
262          importlib_metadata.version("not_found")
263      assert _get_pinned_requirement("not_found") == "not_found==1.2.3"
264  
265  
266  def test_get_pinned_requirement_local_version_label(tmp_path, monkeypatch):
267      package = tmp_path.joinpath("my_package.py")
268      lvl = "abc.def.ghi"  # Local version label
269      package.write_text(f"__version__ = '1.2.3+{lvl}'")
270      monkeypatch.syspath_prepend(str(tmp_path))
271  
272      with mock.patch("mlflow.utils.requirements_utils._logger.warning") as mock_warning:
273          req = _get_pinned_requirement("my_package")
274          mock_warning.assert_called_once()
275          (first_pos_arg,) = mock_warning.call_args[0]
276          assert first_pos_arg.startswith(
277              f"Found my_package version (1.2.3+{lvl}) contains a local version label (+{lvl})."
278          )
279      assert req == "my_package==1.2.3"
280  
281  
282  def test_infer_requirements_excludes_mlflow():
283      with mock.patch(
284          "mlflow.utils.requirements_utils._capture_imported_modules",
285          return_value=["mlflow", "pytest"],
286      ):
287          mlflow_package = "mlflow-skinny" if "MLFLOW_SKINNY" in os.environ else "mlflow"
288          assert mlflow_package in importlib_metadata.packages_distributions()["mlflow"]
289          assert _infer_requirements("path/to/model", "sklearn") == [f"pytest=={pytest.__version__}"]
290  
291  
292  def test_capture_imported_modules_scopes_databricks_imports(monkeypatch, tmp_path):
293      from mlflow.utils._capture_modules import _CaptureImportedModules
294  
295      monkeypatch.chdir(tmp_path)
296      monkeypatch.syspath_prepend(str(tmp_path))
297  
298      databricks_dir = os.path.join(tmp_path, "databricks")
299      os.makedirs(databricks_dir)
300      for file_name in [
301          "__init__.py",
302          "automl.py",
303          "automl_runtime.py",
304          "automl_foo.py",
305          "model_monitoring.py",
306          "other.py",
307      ]:
308          with open(os.path.join(databricks_dir, file_name), "w"):
309              pass
310  
311      with _CaptureImportedModules() as cap:
312          # Delete `databricks` from the cache to ensure we load from the dummy module created above.
313          if "databricks" in sys.modules:
314              del sys.modules["databricks"]
315          import databricks
316          import databricks.automl
317          import databricks.automl_foo
318          import databricks.automl_runtime
319          import databricks.model_monitoring
320  
321      assert "databricks.automl" in cap.imported_modules
322      assert "databricks.model_monitoring" in cap.imported_modules
323      assert "databricks" not in cap.imported_modules
324      assert "databricks.automl_foo" not in cap.imported_modules
325  
326      with _CaptureImportedModules() as cap:
327          import databricks.automl
328          import databricks.automl_foo
329          import databricks.automl_runtime
330          import databricks.model_monitoring
331          import databricks.other  # noqa: F401
332  
333      assert "databricks.automl" in cap.imported_modules
334      assert "databricks.model_monitoring" in cap.imported_modules
335      assert "databricks" in cap.imported_modules
336      assert "databricks.automl_foo" not in cap.imported_modules
337  
338  
339  def test_infer_pip_requirements_scopes_databricks_imports():
340      mlflow.utils.requirements_utils._MODULES_TO_PACKAGES = None
341      mlflow.utils.requirements_utils._PACKAGES_TO_MODULES = None
342  
343      with (
344          mock.patch(
345              "mlflow.utils.requirements_utils._capture_imported_modules",
346              return_value=[
347                  "databricks.automl",
348                  "databricks.model_monitoring",
349                  "databricks.automl_runtime",
350              ],
351          ),
352          mock.patch(
353              "mlflow.utils.requirements_utils._get_installed_version",
354              return_value="1.0",
355          ),
356          mock.patch(
357              "importlib_metadata.packages_distributions",
358              return_value={
359                  "databricks": [
360                      "databricks-automl-runtime",
361                      "databricks-model-monitoring",
362                      "koalas",
363                  ],
364              },
365          ),
366      ):
367          assert _infer_requirements("path/to/model", "sklearn") == [
368              "databricks-automl-runtime==1.0",
369              "databricks-model-monitoring==1.0",
370          ]
371          assert mlflow.utils.requirements_utils._MODULES_TO_PACKAGES["databricks"] == ["koalas"]
372  
373  
374  def test_capture_imported_modules_include_deps_by_params():
375      class MyModel(mlflow.pyfunc.PythonModel):
376          def predict(self, context, model_input, params=None):
377              if params is not None:
378                  import pandas as pd
379                  import sklearn  # noqa: F401
380  
381                  return pd.DataFrame([params])
382              return model_input
383  
384      params = {"a": 1, "b": "string", "c": True}
385  
386      with mlflow.start_run():
387          model_info = mlflow.pyfunc.log_model(
388              name="test_model",
389              python_model=MyModel(),
390              input_example=(["input1"], params),
391          )
392  
393      captured_modules = _capture_imported_modules(model_info.model_uri, "pyfunc")
394      assert "pandas" in captured_modules
395      assert "sklearn" in captured_modules
396  
397  
398  @pytest.mark.parametrize(
399      ("module_to_import", "should_capture_extra"),
400      [
401          ("mlflow.gateway", True),
402          ("mlflow.deployments.server.config", True),
403          # The `mlflow[gateway]`` extra includes requirements for starting the deployment server,
404          # but it is not required when the model only uses the deployment client. These test
405          # cases validate that importing the deployment client alone does not add the extra.
406          ("mlflow.deployments", False),
407      ],
408  )
409  def test_capture_imported_modules_includes_gateway_extra(
410      module_to_import, should_capture_extra, monkeypatch
411  ):
412      # Disable UV auto-detect to ensure model-based inference is used
413      monkeypatch.setenv("MLFLOW_UV_AUTO_DETECT", "false")
414  
415      class MyModel(mlflow.pyfunc.PythonModel):
416          def predict(self, context, inputs, params=None):
417              importlib.import_module(module_to_import)
418  
419              return inputs
420  
421      with mlflow.start_run():
422          model_info = mlflow.pyfunc.log_model(
423              name="test_model",
424              python_model=MyModel(),
425              input_example=([1, 2, 3]),
426          )
427  
428      captured_modules = _capture_imported_modules(model_info.model_uri, "pyfunc")
429      assert ("mlflow.gateway" in captured_modules) == should_capture_extra
430  
431      pip_requirements = infer_pip_requirements(model_info.model_uri, "pyfunc")
432      assert (f"mlflow[gateway]=={mlflow.__version__}" in pip_requirements) == should_capture_extra
433  
434  
435  def test_gateway_extra_not_captured_when_importing_deployment_client_only(monkeypatch):
436      # Disable UV auto-detect to ensure model-based inference is used
437      monkeypatch.setenv("MLFLOW_UV_AUTO_DETECT", "false")
438  
439      class MyModel(mlflow.pyfunc.PythonModel):
440          def predict(self, context, model_input, params=None):
441              from mlflow.deployments import get_deploy_client  # noqa: F401
442  
443              return model_input
444  
445      with mlflow.start_run():
446          model_info = mlflow.pyfunc.log_model(
447              name="test_model",
448              python_model=MyModel(),
449              input_example=([1, 2, 3]),
450          )
451  
452      captured_modules = _capture_imported_modules(model_info.model_uri, "pyfunc")
453      assert "mlflow.gateway" not in captured_modules
454  
455      pip_requirements = infer_pip_requirements(model_info.model_uri, "pyfunc")
456      assert f"mlflow[gateway]=={mlflow.__version__}" not in pip_requirements
457  
458  
459  def test_warn_dependency_requirement_mismatches():
460      import sklearn
461  
462      with mock.patch("mlflow.utils.requirements_utils._logger.warning") as mock_warning:
463          # Test case: all packages satisfy requirements.
464          warn_dependency_requirement_mismatches(
465              model_requirements=[
466                  f"cloudpickle=={cloudpickle.__version__}",
467                  f"scikit-learn=={sklearn.__version__}",
468              ]
469          )
470          mock_warning.assert_not_called()
471          mock_warning.reset_mock()
472  
473          original_get_installed_version_fn = mlflow.utils.requirements_utils._get_installed_version
474  
475          def gen_mock_get_installed_version_fn(mock_versions):
476              def mock_get_installed_version_fn(package, module=None):
477                  if package in mock_versions:
478                      return mock_versions[package]
479                  else:
480                      return original_get_installed_version_fn(package, module)
481  
482              return mock_get_installed_version_fn
483  
484          # Test case: multiple mismatched packages
485          with mock.patch(
486              "mlflow.utils.requirements_utils._get_installed_version",
487              gen_mock_get_installed_version_fn({
488                  "scikit-learn": "999.99.11",
489                  "cloudpickle": "999.99.22",
490              }),
491          ):
492              warn_dependency_requirement_mismatches(
493                  model_requirements=[
494                      f"cloudpickle=={cloudpickle.__version__}",
495                      f"scikit-learn=={sklearn.__version__}",
496                  ]
497              )
498          mock_warning.assert_called_once_with(
499              f"""
500  Detected one or more mismatches between the model's dependencies and the current Python environment:
501   - cloudpickle (current: 999.99.22, required: cloudpickle=={cloudpickle.__version__})
502   - scikit-learn (current: 999.99.11, required: scikit-learn=={sklearn.__version__})
503  To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the \
504  model's environment and install dependencies using the resulting environment file.
505          """.strip()
506          )
507          mock_warning.reset_mock()
508  
509          # Test case: requirement with multiple version specifiers is satisfied
510          with mock.patch(
511              "mlflow.utils.requirements_utils._get_installed_version",
512              gen_mock_get_installed_version_fn({"scikit-learn": "0.8.1"}),
513          ):
514              warn_dependency_requirement_mismatches(model_requirements=["scikit-learn>=0.8,<=0.9"])
515          mock_warning.assert_not_called()
516          mock_warning.reset_mock()
517  
518          # Test case: requirement with multiple version specifiers is not satisfied
519          with mock.patch(
520              "mlflow.utils.requirements_utils._get_installed_version",
521              gen_mock_get_installed_version_fn({"scikit-learn": "0.7.1"}),
522          ):
523              warn_dependency_requirement_mismatches(model_requirements=["scikit-learn>=0.8,<=0.9"])
524          mock_warning.assert_called_once_with(
525              AnyStringWith(" - scikit-learn (current: 0.7.1, required: scikit-learn>=0.8,<=0.9)")
526          )
527          mock_warning.reset_mock()
528  
529          # Test case: required package is uninstalled.
530          warn_dependency_requirement_mismatches(model_requirements=["uninstalled-pkg==1.2.3"])
531          mock_warning.assert_called_once_with(
532              AnyStringWith(
533                  " - uninstalled-pkg (current: uninstalled, required: uninstalled-pkg==1.2.3)"
534              )
535          )
536          mock_warning.reset_mock()
537  
538          # Test case: requirement without version specifiers
539          warn_dependency_requirement_mismatches(model_requirements=["mlflow"])
540          mock_warning.assert_not_called()
541          mock_warning.reset_mock()
542  
543          # Test case: an unexpected error happens while detecting mismatched packages.
544          with mock.patch(
545              "mlflow.utils.requirements_utils._check_requirement_satisfied",
546              side_effect=RuntimeError("check_requirement_satisfied_fn_failed"),
547          ):
548              warn_dependency_requirement_mismatches(model_requirements=["mlflow"])
549          mock_warning.assert_called_once_with(
550              AnyStringWith(
551                  "Encountered an unexpected error "
552                  "(RuntimeError('check_requirement_satisfied_fn_failed')) while "
553                  "detecting model dependency mismatches"
554              )
555          )
556          mock_warning.reset_mock()
557  
558          # Test case: ignore file path
559          warn_dependency_requirement_mismatches(model_requirements=["/path/to/my.whl"])
560          mock_warning.assert_not_called()
561  
562  
563  def test_check_requirement_satisfied_skips_non_matching_marker():
564      result = _check_requirement_satisfied("numpy==999.0.0 ; python_full_version < '3.0'")
565      assert result is None
566  
567  
568  def test_check_requirement_satisfied_checks_matching_marker():
569      result = _check_requirement_satisfied("numpy==999.0.0 ; python_full_version >= '3.0'")
570      assert result is not None
571  
572  
573  @pytest.mark.parametrize(
574      "ignore_package_name",
575      [
576          "databricks-feature-lookup",
577          "databricks-agents",
578          "databricks_agents",
579          "databricks.agents",
580      ],
581  )
582  def test_suppress_warn_dependency_requirement_mismatches_ignore_some_packages(ignore_package_name):
583      with mock.patch("mlflow.utils.requirements_utils._logger.warning") as mock_warning:
584          original_get_installed_version_fn = mlflow.utils.requirements_utils._get_installed_version
585  
586          def gen_mock_get_installed_version_fn(mock_versions):
587              def mock_get_installed_version_fn(package, module=None):
588                  if package in mock_versions:
589                      return mock_versions[package]
590                  else:
591                      return original_get_installed_version_fn(package, module)
592  
593              return mock_get_installed_version_fn
594  
595          # Test case: multiple mismatched packages
596          with mock.patch(
597              "mlflow.utils.requirements_utils._get_installed_version",
598              gen_mock_get_installed_version_fn({
599                  ignore_package_name: "9.99.11",
600                  "cloudpickle": "999.99.22",
601              }),
602          ):
603              warn_dependency_requirement_mismatches(
604                  model_requirements=[
605                      f"cloudpickle=={cloudpickle.__version__}",
606                      f"{ignore_package_name}==999.1.1",
607                  ]
608              )
609              mock_warning.assert_called_once_with(
610                  """
611  Detected one or more mismatches between the model's dependencies and the current Python environment:
612   - cloudpickle (current: 999.99.22, required: cloudpickle=={cloudpickle_version})
613  To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the \
614  model's environment and install dependencies using the resulting environment file.
615  """.strip().format(cloudpickle_version=cloudpickle.__version__)
616              )
617  
618  
619  def test_capture_imported_modules_with_exception():
620      class TestModel(mlflow.pyfunc.PythonModel):
621          def predict(self, context, model_input, params=None):
622              import pandas  # noqa: F401
623  
624              raise Exception("Test exception")
625              import sklearn  # noqa: F401
626  
627      with mlflow.start_run():
628          model_info = mlflow.pyfunc.log_model(
629              name="model",
630              python_model=TestModel(),
631              input_example="test",
632          )
633  
634      with mock.patch("mlflow.utils.requirements_utils._logger.warning") as mock_warning:
635          modules = _capture_imported_modules(model_info.model_uri, mlflow.pyfunc.FLAVOR_NAME)
636          assert "pandas" in modules
637          assert (
638              "Failed to run predict on input_example, dependencies "
639              "introduced in predict are not captured.\n" in mock_warning.call_args[0][0]
640          )
641          assert "sklearn" not in modules
642  
643  
644  def test_capture_imported_modules_raises_when_env_var_set(monkeypatch):
645      monkeypatch.setenv("MLFLOW_REQUIREMENTS_INFERENCE_RAISE_ERRORS", "True")
646  
647      class BadModel(mlflow.pyfunc.PythonModel):
648          def predict(self, context, model_input, params=None):
649              raise Exception("Intentional")
650  
651      with pytest.raises(
652          MlflowException, match="Encountered an error while capturing imported modules"
653      ):
654          with mlflow.start_run():
655              mlflow.pyfunc.log_model(
656                  name="model",
657                  python_model=BadModel(),
658                  input_example="test",
659              )
660  
661  
662  def test_capture_imported_modules_correct(monkeypatch):
663      monkeypatch.setenv("MLFLOW_REQUIREMENTS_INFERENCE_RAISE_ERRORS", "true")
664  
665      class TestModel(mlflow.pyfunc.PythonModel):
666          def predict(self, context, model_input, params=None):
667              import pandas  # noqa: F401
668              import sklearn  # noqa: F401
669  
670              return model_input
671  
672      with mlflow.start_run():
673          model_info = mlflow.pyfunc.log_model(
674              name="model",
675              python_model=TestModel(),
676              input_example="test",
677          )
678  
679      modules = _capture_imported_modules(model_info.model_uri, mlflow.pyfunc.FLAVOR_NAME)
680      assert "pandas" in modules
681      assert "sklearn" in modules
682  
683  
684  def test_capture_imported_modules_extra_env_vars(monkeypatch):
685      monkeypatch.setenv("MLFLOW_REQUIREMENTS_INFERENCE_RAISE_ERRORS", "true")
686  
687      class TestModel(mlflow.pyfunc.PythonModel):
688          def predict(self, context, model_input, params=None):
689              assert os.environ["TEST"] == "test"
690              return model_input
691  
692      with mlflow.start_run():
693          model_info = mlflow.pyfunc.log_model(
694              name="model",
695              python_model=TestModel(),
696              input_example="test",
697              pip_requirements=[],
698          )
699  
700      _capture_imported_modules(
701          model_info.model_uri, mlflow.pyfunc.FLAVOR_NAME, extra_env_vars={"TEST": "test"}
702      )
703  
704  
705  @pytest.mark.skipif(
706      importlib.util.find_spec("databricks.agents") is None,
707      reason="Requires databricks.agents",
708  )
709  def test_infer_pip_requirements_on_databricks_agents(tmp_path):
710      # import here to avoid breaking this test suite on mlflow-skinny
711      from mlflow.pyfunc import _get_pip_requirements_from_model_path
712  
713      class TestModel(mlflow.pyfunc.PythonModel):
714          def predict(self, context, model_input, params=None):
715              import databricks.agents  # noqa: F401
716              import pyspark  # noqa: F401
717  
718              return model_input
719  
720      mlflow.pyfunc.save_model(
721          tmp_path,
722          python_model=TestModel(),
723          input_example="test",
724      )
725  
726      requirements = _get_pip_requirements_from_model_path(tmp_path)
727      packages = [req.split("==")[0] for req in requirements]
728      assert "databricks-agents" in packages
729      # databricks-connect should not be pruned even it's a dependency of databricks-agents
730      assert "databricks-connect" in packages
731      # pyspark should not exist because it conflicts with databricks-connect
732      assert "pyspark" not in packages
733  
734  
735  def test_capture_imported_modules_excludes_pyspark_gateway_env_vars(monkeypatch, tmp_path):
736      """
737      Test that PYSPARK_GATEWAY_PORT and PYSPARK_GATEWAY_SECRET are excluded from the
738      subprocess environment when capturing imported modules.
739  
740      These env vars, if inherited by a subprocess, can cause the subprocess to connect
741      to the parent's py4j gateway. Libraries like databricks-sdk may then corrupt the
742      parent's gateway state, causing delayed py4j errors like
743      "Error while obtaining a new communication channel".
744      """
745      monkeypatch.setenv("PYSPARK_GATEWAY_PORT", "12345")
746      monkeypatch.setenv("PYSPARK_GATEWAY_SECRET", "secret123")
747  
748      captured_env = {}
749  
750      def mock_run_command(cmd, timeout_seconds, env):
751          captured_env.update(env)
752          raise MlflowException("Mocked - stopping before actual subprocess execution")
753  
754      with (
755          mock.patch(
756              "mlflow.utils.requirements_utils._run_command",
757              side_effect=mock_run_command,
758          ) as mock_run,
759          mock.patch(
760              "mlflow.utils.requirements_utils._download_artifact_from_uri",
761              return_value=str(tmp_path),
762          ) as mock_download,
763      ):
764          with pytest.raises(MlflowException, match="Mocked"):
765              _capture_imported_modules("fake/model/path", "pyfunc")
766  
767      mock_download.assert_called_once()
768      mock_run.assert_called_once()
769      assert "PYSPARK_GATEWAY_PORT" not in captured_env
770      assert "PYSPARK_GATEWAY_SECRET" not in captured_env