test_requirements_utils.py
1 import importlib 2 import os 3 import sys 4 from importlib.metadata import version 5 from unittest import mock 6 7 import cloudpickle 8 import importlib_metadata 9 import pytest 10 11 import mlflow 12 import mlflow.utils.requirements_utils 13 from mlflow.exceptions import MlflowException 14 from mlflow.utils.environment import infer_pip_requirements 15 from mlflow.utils.os import is_windows 16 from mlflow.utils.requirements_utils import ( 17 _capture_imported_modules, 18 _check_requirement_satisfied, 19 _get_installed_version, 20 _get_pinned_requirement, 21 _infer_requirements, 22 _is_comment, 23 _is_empty, 24 _is_requirements_file, 25 _join_continued_lines, 26 _normalize_package_name, 27 _parse_requirements, 28 _prune_packages, 29 _strip_inline_comment, 30 _strip_local_version_label, 31 warn_dependency_requirement_mismatches, 32 ) 33 34 from tests.helper_functions import AnyStringWith 35 36 37 def test_is_comment(): 38 assert _is_comment("# comment") 39 assert _is_comment("#") 40 assert _is_comment("### comment ###") 41 assert not _is_comment("comment") 42 assert not _is_comment("") 43 44 45 def test_is_empty(): 46 assert _is_empty("") 47 assert not _is_empty(" ") 48 assert not _is_empty("a") 49 50 51 def test_is_requirements_file(): 52 assert _is_requirements_file("-r req.txt") 53 assert _is_requirements_file("-r req.txt") 54 assert _is_requirements_file("--requirement req.txt") 55 assert _is_requirements_file("--requirement req.txt") 56 assert not _is_requirements_file("req") 57 58 59 def test_strip_inline_comment(): 60 assert _strip_inline_comment("aaa # comment") == "aaa" 61 assert _strip_inline_comment("aaa # comment") == "aaa" 62 assert _strip_inline_comment("aaa # comment") == "aaa" 63 assert _strip_inline_comment("aaa # com1 # com2") == "aaa" 64 # Ensure a URI fragment is not stripped 65 assert ( 66 _strip_inline_comment("git+https://git/repo.git#subdirectory=subdir") 67 == "git+https://git/repo.git#subdirectory=subdir" 68 ) 69 70 71 def test_join_continued_lines(): 72 assert list(_join_continued_lines(["a"])) == ["a"] 73 assert list(_join_continued_lines(["a\\", "b"])) == ["ab"] 74 assert list(_join_continued_lines(["a\\", "b\\", "c"])) == ["abc"] 75 assert list(_join_continued_lines(["a\\", " b"])) == ["a b"] 76 assert list(_join_continued_lines(["a\\", " b\\", " c"])) == ["a b c"] 77 assert list(_join_continued_lines(["a\\", "\\", "b"])) == ["ab"] 78 assert list(_join_continued_lines(["a\\", "b", "c\\", "d"])) == ["ab", "cd"] 79 assert list(_join_continued_lines(["a\\", "", "b"])) == ["a", "b"] 80 assert list(_join_continued_lines(["a\\"])) == ["a"] 81 assert list(_join_continued_lines(["\\", "a"])) == ["a"] 82 83 84 def test_parse_requirements(tmp_path, monkeypatch): 85 root_req_src = """ 86 # No version specifier 87 noverspec 88 no-ver-spec 89 90 # Version specifiers 91 verspec<1.0 92 ver-spec == 2.0 93 94 # Environment marker 95 env-marker; python_version < "3.8" 96 97 inline-comm # Inline comment 98 inlinecomm # Inline comment 99 100 # Git URIs 101 git+https://github.com/git/uri 102 git+https://github.com/sub/dir#subdirectory=subdir 103 104 # Requirements files 105 -r {relative_req} 106 --requirement {absolute_req} 107 108 # Constraints files 109 -c {relative_con} 110 --constraint {absolute_con} 111 112 # Line continuation 113 line-cont\ 114 ==\ 115 1.0 116 117 # Line continuation with spaces 118 line-cont-space \ 119 == \ 120 1.0 121 122 # Line continuation with a blank line 123 line-cont-blank\ 124 125 # Line continuation at EOF 126 line-cont-eof\ 127 """.strip() 128 129 monkeypatch.chdir(tmp_path) 130 root_req = tmp_path.joinpath("requirements.txt") 131 # Requirements files 132 rel_req = tmp_path.joinpath("relative_req.txt") 133 abs_req = tmp_path.joinpath("absolute_req.txt") 134 # Constraints files 135 rel_con = tmp_path.joinpath("relative_con.txt") 136 abs_con = tmp_path.joinpath("absolute_con.txt") 137 138 # pip's requirements parser collapses an absolute requirements file path: 139 # https://github.com/pypa/pip/issues/10121 140 # As a workaround, use a relative path on Windows. 141 absolute_req = abs_req.name if is_windows() else str(abs_req) 142 absolute_con = abs_con.name if is_windows() else str(abs_con) 143 root_req.write_text( 144 root_req_src.format( 145 relative_req=rel_req.name, 146 absolute_req=absolute_req, 147 relative_con=rel_con.name, 148 absolute_con=absolute_con, 149 ) 150 ) 151 rel_req.write_text("rel-req-xxx\nrel-req-yyy") 152 abs_req.write_text("abs-req-zzz") 153 rel_con.write_text("rel-con-xxx\nrel-con-yyy") 154 abs_con.write_text("abs-con-zzz") 155 156 # Uncomment this to get the expected output from pip's internal parser 157 # from pip._internal.network.session import PipSession 158 # from pip._internal.req import parse_requirements as pip_parse_requirements 159 # 160 # pip_reqs = list(pip_parse_requirements(root_req.name, session=PipSession())) 161 # print(f"expected_reqs = {[r.requirement for r in pip_reqs if not r.constraint]}") 162 # print(f"expected_cons = {[r.requirement for r in pip_reqs if r.constraint]}") 163 164 expected_reqs = [ 165 "noverspec", 166 "no-ver-spec", 167 "verspec<1.0", 168 "ver-spec == 2.0", 169 'env-marker; python_version < "3.8"', 170 "inline-comm", 171 "inlinecomm", 172 "git+https://github.com/git/uri", 173 "git+https://github.com/sub/dir#subdirectory=subdir", 174 "rel-req-xxx", 175 "rel-req-yyy", 176 "abs-req-zzz", 177 "line-cont==1.0", 178 "line-cont-space == 1.0", 179 "line-cont-blank", 180 "line-cont-eof", 181 ] 182 expected_cons = [ 183 "rel-con-xxx", 184 "rel-con-yyy", 185 "abs-con-zzz", 186 ] 187 188 parsed_reqs = list(_parse_requirements(root_req.name, is_constraint=False)) 189 assert [r.req_str for r in parsed_reqs if not r.is_constraint] == expected_reqs 190 assert [r.req_str for r in parsed_reqs if r.is_constraint] == expected_cons 191 192 193 def test_normalize_package_name(): 194 assert _normalize_package_name("abc") == "abc" 195 assert _normalize_package_name("ABC") == "abc" 196 assert _normalize_package_name("a-b-c") == "a-b-c" 197 assert _normalize_package_name("a.b.c") == "a-b-c" 198 assert _normalize_package_name("a_b_c") == "a-b-c" 199 assert _normalize_package_name("a--b--c") == "a-b-c" 200 assert _normalize_package_name("a-._b-._c") == "a-b-c" 201 202 203 def test_prune_packages(): 204 assert _prune_packages(["mlflow"]) == {"mlflow"} 205 assert _prune_packages(["mlflow", "scikit-learn"]) == {"mlflow", "scikit-learn"} 206 207 208 def test_capture_imported_modules(): 209 from mlflow.utils._capture_modules import _CaptureImportedModules 210 211 with _CaptureImportedModules() as cap: 212 import math # clint: disable=lazy-import # noqa: F401 213 214 __import__("pandas") 215 importlib.import_module("numpy") 216 217 assert "math" in cap.imported_modules 218 assert "pandas" in cap.imported_modules 219 assert "numpy" in cap.imported_modules 220 221 222 def test_strip_local_version_label(): 223 assert _strip_local_version_label("1.2.3") == "1.2.3" 224 assert _strip_local_version_label("1.2.3+ab") == "1.2.3" 225 assert _strip_local_version_label("1.2.3rc0+ab") == "1.2.3rc0" 226 assert _strip_local_version_label("1.2.3.dev0+ab") == "1.2.3.dev0" 227 assert _strip_local_version_label("1.2.3.post0+ab") == "1.2.3.post0" 228 assert _strip_local_version_label("invalid") == "invalid" 229 230 231 def test_get_installed_version(tmp_path, monkeypatch): 232 assert _get_installed_version("mlflow") == mlflow.__version__ 233 assert _get_installed_version("numpy") == version("numpy") 234 assert _get_installed_version("pandas") == version("pandas") 235 assert _get_installed_version("scikit-learn", module="sklearn") == version("scikit-learn") 236 237 not_found_package = tmp_path.joinpath("not_found.py") 238 not_found_package.write_text("__version__ = '1.2.3'") 239 monkeypatch.syspath_prepend(str(tmp_path)) 240 with pytest.raises(importlib_metadata.PackageNotFoundError, match=r".+"): 241 importlib_metadata.version("not_found") 242 assert _get_installed_version("not_found") == "1.2.3" 243 244 245 def test_package_with_mismatched_pypi_and_import_name(): 246 try: 247 import dspy # noqa: F401 248 249 assert _get_installed_version("dspy") == version("dspy-ai") 250 except ImportError: 251 pytest.skip("Skipping test because 'dspy' package is not installed") 252 253 254 def test_get_pinned_requirement(tmp_path, monkeypatch): 255 assert _get_pinned_requirement("mlflow") == f"mlflow=={mlflow.__version__}" 256 assert _get_pinned_requirement("mlflow", version="1.2.3") == "mlflow==1.2.3" 257 258 not_found_package = tmp_path.joinpath("not_found.py") 259 not_found_package.write_text("__version__ = '1.2.3'") 260 monkeypatch.syspath_prepend(str(tmp_path)) 261 with pytest.raises(importlib_metadata.PackageNotFoundError, match=r".+"): 262 importlib_metadata.version("not_found") 263 assert _get_pinned_requirement("not_found") == "not_found==1.2.3" 264 265 266 def test_get_pinned_requirement_local_version_label(tmp_path, monkeypatch): 267 package = tmp_path.joinpath("my_package.py") 268 lvl = "abc.def.ghi" # Local version label 269 package.write_text(f"__version__ = '1.2.3+{lvl}'") 270 monkeypatch.syspath_prepend(str(tmp_path)) 271 272 with mock.patch("mlflow.utils.requirements_utils._logger.warning") as mock_warning: 273 req = _get_pinned_requirement("my_package") 274 mock_warning.assert_called_once() 275 (first_pos_arg,) = mock_warning.call_args[0] 276 assert first_pos_arg.startswith( 277 f"Found my_package version (1.2.3+{lvl}) contains a local version label (+{lvl})." 278 ) 279 assert req == "my_package==1.2.3" 280 281 282 def test_infer_requirements_excludes_mlflow(): 283 with mock.patch( 284 "mlflow.utils.requirements_utils._capture_imported_modules", 285 return_value=["mlflow", "pytest"], 286 ): 287 mlflow_package = "mlflow-skinny" if "MLFLOW_SKINNY" in os.environ else "mlflow" 288 assert mlflow_package in importlib_metadata.packages_distributions()["mlflow"] 289 assert _infer_requirements("path/to/model", "sklearn") == [f"pytest=={pytest.__version__}"] 290 291 292 def test_capture_imported_modules_scopes_databricks_imports(monkeypatch, tmp_path): 293 from mlflow.utils._capture_modules import _CaptureImportedModules 294 295 monkeypatch.chdir(tmp_path) 296 monkeypatch.syspath_prepend(str(tmp_path)) 297 298 databricks_dir = os.path.join(tmp_path, "databricks") 299 os.makedirs(databricks_dir) 300 for file_name in [ 301 "__init__.py", 302 "automl.py", 303 "automl_runtime.py", 304 "automl_foo.py", 305 "model_monitoring.py", 306 "other.py", 307 ]: 308 with open(os.path.join(databricks_dir, file_name), "w"): 309 pass 310 311 with _CaptureImportedModules() as cap: 312 # Delete `databricks` from the cache to ensure we load from the dummy module created above. 313 if "databricks" in sys.modules: 314 del sys.modules["databricks"] 315 import databricks 316 import databricks.automl 317 import databricks.automl_foo 318 import databricks.automl_runtime 319 import databricks.model_monitoring 320 321 assert "databricks.automl" in cap.imported_modules 322 assert "databricks.model_monitoring" in cap.imported_modules 323 assert "databricks" not in cap.imported_modules 324 assert "databricks.automl_foo" not in cap.imported_modules 325 326 with _CaptureImportedModules() as cap: 327 import databricks.automl 328 import databricks.automl_foo 329 import databricks.automl_runtime 330 import databricks.model_monitoring 331 import databricks.other # noqa: F401 332 333 assert "databricks.automl" in cap.imported_modules 334 assert "databricks.model_monitoring" in cap.imported_modules 335 assert "databricks" in cap.imported_modules 336 assert "databricks.automl_foo" not in cap.imported_modules 337 338 339 def test_infer_pip_requirements_scopes_databricks_imports(): 340 mlflow.utils.requirements_utils._MODULES_TO_PACKAGES = None 341 mlflow.utils.requirements_utils._PACKAGES_TO_MODULES = None 342 343 with ( 344 mock.patch( 345 "mlflow.utils.requirements_utils._capture_imported_modules", 346 return_value=[ 347 "databricks.automl", 348 "databricks.model_monitoring", 349 "databricks.automl_runtime", 350 ], 351 ), 352 mock.patch( 353 "mlflow.utils.requirements_utils._get_installed_version", 354 return_value="1.0", 355 ), 356 mock.patch( 357 "importlib_metadata.packages_distributions", 358 return_value={ 359 "databricks": [ 360 "databricks-automl-runtime", 361 "databricks-model-monitoring", 362 "koalas", 363 ], 364 }, 365 ), 366 ): 367 assert _infer_requirements("path/to/model", "sklearn") == [ 368 "databricks-automl-runtime==1.0", 369 "databricks-model-monitoring==1.0", 370 ] 371 assert mlflow.utils.requirements_utils._MODULES_TO_PACKAGES["databricks"] == ["koalas"] 372 373 374 def test_capture_imported_modules_include_deps_by_params(): 375 class MyModel(mlflow.pyfunc.PythonModel): 376 def predict(self, context, model_input, params=None): 377 if params is not None: 378 import pandas as pd 379 import sklearn # noqa: F401 380 381 return pd.DataFrame([params]) 382 return model_input 383 384 params = {"a": 1, "b": "string", "c": True} 385 386 with mlflow.start_run(): 387 model_info = mlflow.pyfunc.log_model( 388 name="test_model", 389 python_model=MyModel(), 390 input_example=(["input1"], params), 391 ) 392 393 captured_modules = _capture_imported_modules(model_info.model_uri, "pyfunc") 394 assert "pandas" in captured_modules 395 assert "sklearn" in captured_modules 396 397 398 @pytest.mark.parametrize( 399 ("module_to_import", "should_capture_extra"), 400 [ 401 ("mlflow.gateway", True), 402 ("mlflow.deployments.server.config", True), 403 # The `mlflow[gateway]`` extra includes requirements for starting the deployment server, 404 # but it is not required when the model only uses the deployment client. These test 405 # cases validate that importing the deployment client alone does not add the extra. 406 ("mlflow.deployments", False), 407 ], 408 ) 409 def test_capture_imported_modules_includes_gateway_extra( 410 module_to_import, should_capture_extra, monkeypatch 411 ): 412 # Disable UV auto-detect to ensure model-based inference is used 413 monkeypatch.setenv("MLFLOW_UV_AUTO_DETECT", "false") 414 415 class MyModel(mlflow.pyfunc.PythonModel): 416 def predict(self, context, inputs, params=None): 417 importlib.import_module(module_to_import) 418 419 return inputs 420 421 with mlflow.start_run(): 422 model_info = mlflow.pyfunc.log_model( 423 name="test_model", 424 python_model=MyModel(), 425 input_example=([1, 2, 3]), 426 ) 427 428 captured_modules = _capture_imported_modules(model_info.model_uri, "pyfunc") 429 assert ("mlflow.gateway" in captured_modules) == should_capture_extra 430 431 pip_requirements = infer_pip_requirements(model_info.model_uri, "pyfunc") 432 assert (f"mlflow[gateway]=={mlflow.__version__}" in pip_requirements) == should_capture_extra 433 434 435 def test_gateway_extra_not_captured_when_importing_deployment_client_only(monkeypatch): 436 # Disable UV auto-detect to ensure model-based inference is used 437 monkeypatch.setenv("MLFLOW_UV_AUTO_DETECT", "false") 438 439 class MyModel(mlflow.pyfunc.PythonModel): 440 def predict(self, context, model_input, params=None): 441 from mlflow.deployments import get_deploy_client # noqa: F401 442 443 return model_input 444 445 with mlflow.start_run(): 446 model_info = mlflow.pyfunc.log_model( 447 name="test_model", 448 python_model=MyModel(), 449 input_example=([1, 2, 3]), 450 ) 451 452 captured_modules = _capture_imported_modules(model_info.model_uri, "pyfunc") 453 assert "mlflow.gateway" not in captured_modules 454 455 pip_requirements = infer_pip_requirements(model_info.model_uri, "pyfunc") 456 assert f"mlflow[gateway]=={mlflow.__version__}" not in pip_requirements 457 458 459 def test_warn_dependency_requirement_mismatches(): 460 import sklearn 461 462 with mock.patch("mlflow.utils.requirements_utils._logger.warning") as mock_warning: 463 # Test case: all packages satisfy requirements. 464 warn_dependency_requirement_mismatches( 465 model_requirements=[ 466 f"cloudpickle=={cloudpickle.__version__}", 467 f"scikit-learn=={sklearn.__version__}", 468 ] 469 ) 470 mock_warning.assert_not_called() 471 mock_warning.reset_mock() 472 473 original_get_installed_version_fn = mlflow.utils.requirements_utils._get_installed_version 474 475 def gen_mock_get_installed_version_fn(mock_versions): 476 def mock_get_installed_version_fn(package, module=None): 477 if package in mock_versions: 478 return mock_versions[package] 479 else: 480 return original_get_installed_version_fn(package, module) 481 482 return mock_get_installed_version_fn 483 484 # Test case: multiple mismatched packages 485 with mock.patch( 486 "mlflow.utils.requirements_utils._get_installed_version", 487 gen_mock_get_installed_version_fn({ 488 "scikit-learn": "999.99.11", 489 "cloudpickle": "999.99.22", 490 }), 491 ): 492 warn_dependency_requirement_mismatches( 493 model_requirements=[ 494 f"cloudpickle=={cloudpickle.__version__}", 495 f"scikit-learn=={sklearn.__version__}", 496 ] 497 ) 498 mock_warning.assert_called_once_with( 499 f""" 500 Detected one or more mismatches between the model's dependencies and the current Python environment: 501 - cloudpickle (current: 999.99.22, required: cloudpickle=={cloudpickle.__version__}) 502 - scikit-learn (current: 999.99.11, required: scikit-learn=={sklearn.__version__}) 503 To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the \ 504 model's environment and install dependencies using the resulting environment file. 505 """.strip() 506 ) 507 mock_warning.reset_mock() 508 509 # Test case: requirement with multiple version specifiers is satisfied 510 with mock.patch( 511 "mlflow.utils.requirements_utils._get_installed_version", 512 gen_mock_get_installed_version_fn({"scikit-learn": "0.8.1"}), 513 ): 514 warn_dependency_requirement_mismatches(model_requirements=["scikit-learn>=0.8,<=0.9"]) 515 mock_warning.assert_not_called() 516 mock_warning.reset_mock() 517 518 # Test case: requirement with multiple version specifiers is not satisfied 519 with mock.patch( 520 "mlflow.utils.requirements_utils._get_installed_version", 521 gen_mock_get_installed_version_fn({"scikit-learn": "0.7.1"}), 522 ): 523 warn_dependency_requirement_mismatches(model_requirements=["scikit-learn>=0.8,<=0.9"]) 524 mock_warning.assert_called_once_with( 525 AnyStringWith(" - scikit-learn (current: 0.7.1, required: scikit-learn>=0.8,<=0.9)") 526 ) 527 mock_warning.reset_mock() 528 529 # Test case: required package is uninstalled. 530 warn_dependency_requirement_mismatches(model_requirements=["uninstalled-pkg==1.2.3"]) 531 mock_warning.assert_called_once_with( 532 AnyStringWith( 533 " - uninstalled-pkg (current: uninstalled, required: uninstalled-pkg==1.2.3)" 534 ) 535 ) 536 mock_warning.reset_mock() 537 538 # Test case: requirement without version specifiers 539 warn_dependency_requirement_mismatches(model_requirements=["mlflow"]) 540 mock_warning.assert_not_called() 541 mock_warning.reset_mock() 542 543 # Test case: an unexpected error happens while detecting mismatched packages. 544 with mock.patch( 545 "mlflow.utils.requirements_utils._check_requirement_satisfied", 546 side_effect=RuntimeError("check_requirement_satisfied_fn_failed"), 547 ): 548 warn_dependency_requirement_mismatches(model_requirements=["mlflow"]) 549 mock_warning.assert_called_once_with( 550 AnyStringWith( 551 "Encountered an unexpected error " 552 "(RuntimeError('check_requirement_satisfied_fn_failed')) while " 553 "detecting model dependency mismatches" 554 ) 555 ) 556 mock_warning.reset_mock() 557 558 # Test case: ignore file path 559 warn_dependency_requirement_mismatches(model_requirements=["/path/to/my.whl"]) 560 mock_warning.assert_not_called() 561 562 563 def test_check_requirement_satisfied_skips_non_matching_marker(): 564 result = _check_requirement_satisfied("numpy==999.0.0 ; python_full_version < '3.0'") 565 assert result is None 566 567 568 def test_check_requirement_satisfied_checks_matching_marker(): 569 result = _check_requirement_satisfied("numpy==999.0.0 ; python_full_version >= '3.0'") 570 assert result is not None 571 572 573 @pytest.mark.parametrize( 574 "ignore_package_name", 575 [ 576 "databricks-feature-lookup", 577 "databricks-agents", 578 "databricks_agents", 579 "databricks.agents", 580 ], 581 ) 582 def test_suppress_warn_dependency_requirement_mismatches_ignore_some_packages(ignore_package_name): 583 with mock.patch("mlflow.utils.requirements_utils._logger.warning") as mock_warning: 584 original_get_installed_version_fn = mlflow.utils.requirements_utils._get_installed_version 585 586 def gen_mock_get_installed_version_fn(mock_versions): 587 def mock_get_installed_version_fn(package, module=None): 588 if package in mock_versions: 589 return mock_versions[package] 590 else: 591 return original_get_installed_version_fn(package, module) 592 593 return mock_get_installed_version_fn 594 595 # Test case: multiple mismatched packages 596 with mock.patch( 597 "mlflow.utils.requirements_utils._get_installed_version", 598 gen_mock_get_installed_version_fn({ 599 ignore_package_name: "9.99.11", 600 "cloudpickle": "999.99.22", 601 }), 602 ): 603 warn_dependency_requirement_mismatches( 604 model_requirements=[ 605 f"cloudpickle=={cloudpickle.__version__}", 606 f"{ignore_package_name}==999.1.1", 607 ] 608 ) 609 mock_warning.assert_called_once_with( 610 """ 611 Detected one or more mismatches between the model's dependencies and the current Python environment: 612 - cloudpickle (current: 999.99.22, required: cloudpickle=={cloudpickle_version}) 613 To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the \ 614 model's environment and install dependencies using the resulting environment file. 615 """.strip().format(cloudpickle_version=cloudpickle.__version__) 616 ) 617 618 619 def test_capture_imported_modules_with_exception(): 620 class TestModel(mlflow.pyfunc.PythonModel): 621 def predict(self, context, model_input, params=None): 622 import pandas # noqa: F401 623 624 raise Exception("Test exception") 625 import sklearn # noqa: F401 626 627 with mlflow.start_run(): 628 model_info = mlflow.pyfunc.log_model( 629 name="model", 630 python_model=TestModel(), 631 input_example="test", 632 ) 633 634 with mock.patch("mlflow.utils.requirements_utils._logger.warning") as mock_warning: 635 modules = _capture_imported_modules(model_info.model_uri, mlflow.pyfunc.FLAVOR_NAME) 636 assert "pandas" in modules 637 assert ( 638 "Failed to run predict on input_example, dependencies " 639 "introduced in predict are not captured.\n" in mock_warning.call_args[0][0] 640 ) 641 assert "sklearn" not in modules 642 643 644 def test_capture_imported_modules_raises_when_env_var_set(monkeypatch): 645 monkeypatch.setenv("MLFLOW_REQUIREMENTS_INFERENCE_RAISE_ERRORS", "True") 646 647 class BadModel(mlflow.pyfunc.PythonModel): 648 def predict(self, context, model_input, params=None): 649 raise Exception("Intentional") 650 651 with pytest.raises( 652 MlflowException, match="Encountered an error while capturing imported modules" 653 ): 654 with mlflow.start_run(): 655 mlflow.pyfunc.log_model( 656 name="model", 657 python_model=BadModel(), 658 input_example="test", 659 ) 660 661 662 def test_capture_imported_modules_correct(monkeypatch): 663 monkeypatch.setenv("MLFLOW_REQUIREMENTS_INFERENCE_RAISE_ERRORS", "true") 664 665 class TestModel(mlflow.pyfunc.PythonModel): 666 def predict(self, context, model_input, params=None): 667 import pandas # noqa: F401 668 import sklearn # noqa: F401 669 670 return model_input 671 672 with mlflow.start_run(): 673 model_info = mlflow.pyfunc.log_model( 674 name="model", 675 python_model=TestModel(), 676 input_example="test", 677 ) 678 679 modules = _capture_imported_modules(model_info.model_uri, mlflow.pyfunc.FLAVOR_NAME) 680 assert "pandas" in modules 681 assert "sklearn" in modules 682 683 684 def test_capture_imported_modules_extra_env_vars(monkeypatch): 685 monkeypatch.setenv("MLFLOW_REQUIREMENTS_INFERENCE_RAISE_ERRORS", "true") 686 687 class TestModel(mlflow.pyfunc.PythonModel): 688 def predict(self, context, model_input, params=None): 689 assert os.environ["TEST"] == "test" 690 return model_input 691 692 with mlflow.start_run(): 693 model_info = mlflow.pyfunc.log_model( 694 name="model", 695 python_model=TestModel(), 696 input_example="test", 697 pip_requirements=[], 698 ) 699 700 _capture_imported_modules( 701 model_info.model_uri, mlflow.pyfunc.FLAVOR_NAME, extra_env_vars={"TEST": "test"} 702 ) 703 704 705 @pytest.mark.skipif( 706 importlib.util.find_spec("databricks.agents") is None, 707 reason="Requires databricks.agents", 708 ) 709 def test_infer_pip_requirements_on_databricks_agents(tmp_path): 710 # import here to avoid breaking this test suite on mlflow-skinny 711 from mlflow.pyfunc import _get_pip_requirements_from_model_path 712 713 class TestModel(mlflow.pyfunc.PythonModel): 714 def predict(self, context, model_input, params=None): 715 import databricks.agents # noqa: F401 716 import pyspark # noqa: F401 717 718 return model_input 719 720 mlflow.pyfunc.save_model( 721 tmp_path, 722 python_model=TestModel(), 723 input_example="test", 724 ) 725 726 requirements = _get_pip_requirements_from_model_path(tmp_path) 727 packages = [req.split("==")[0] for req in requirements] 728 assert "databricks-agents" in packages 729 # databricks-connect should not be pruned even it's a dependency of databricks-agents 730 assert "databricks-connect" in packages 731 # pyspark should not exist because it conflicts with databricks-connect 732 assert "pyspark" not in packages 733 734 735 def test_capture_imported_modules_excludes_pyspark_gateway_env_vars(monkeypatch, tmp_path): 736 """ 737 Test that PYSPARK_GATEWAY_PORT and PYSPARK_GATEWAY_SECRET are excluded from the 738 subprocess environment when capturing imported modules. 739 740 These env vars, if inherited by a subprocess, can cause the subprocess to connect 741 to the parent's py4j gateway. Libraries like databricks-sdk may then corrupt the 742 parent's gateway state, causing delayed py4j errors like 743 "Error while obtaining a new communication channel". 744 """ 745 monkeypatch.setenv("PYSPARK_GATEWAY_PORT", "12345") 746 monkeypatch.setenv("PYSPARK_GATEWAY_SECRET", "secret123") 747 748 captured_env = {} 749 750 def mock_run_command(cmd, timeout_seconds, env): 751 captured_env.update(env) 752 raise MlflowException("Mocked - stopping before actual subprocess execution") 753 754 with ( 755 mock.patch( 756 "mlflow.utils.requirements_utils._run_command", 757 side_effect=mock_run_command, 758 ) as mock_run, 759 mock.patch( 760 "mlflow.utils.requirements_utils._download_artifact_from_uri", 761 return_value=str(tmp_path), 762 ) as mock_download, 763 ): 764 with pytest.raises(MlflowException, match="Mocked"): 765 _capture_imported_modules("fake/model/path", "pyfunc") 766 767 mock_download.assert_called_once() 768 mock_run.assert_called_once() 769 assert "PYSPARK_GATEWAY_PORT" not in captured_env 770 assert "PYSPARK_GATEWAY_SECRET" not in captured_env