environment.py
1 import hashlib 2 import importlib.metadata 3 import logging 4 import os 5 import pathlib 6 import re 7 import shutil 8 import subprocess 9 import sys 10 import tempfile 11 from copy import deepcopy 12 13 import yaml 14 from packaging.requirements import InvalidRequirement, Requirement 15 from packaging.specifiers import SpecifierSet 16 from packaging.version import Version 17 18 from mlflow.environment_variables import ( 19 _MLFLOW_ACTIVE_MODEL_ID, 20 _MLFLOW_TESTING, 21 MLFLOW_EXPERIMENT_ID, 22 MLFLOW_INPUT_EXAMPLE_INFERENCE_TIMEOUT, 23 MLFLOW_LOCK_MODEL_DEPENDENCIES, 24 MLFLOW_REQUIREMENTS_INFERENCE_RAISE_ERRORS, 25 MLFLOW_UV_AUTO_DETECT, 26 ) 27 from mlflow.exceptions import MlflowException 28 from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE 29 from mlflow.tracking import get_tracking_uri 30 from mlflow.tracking.fluent import _get_experiment_id, get_active_model_id 31 from mlflow.utils import PYTHON_VERSION 32 from mlflow.utils.databricks_utils import ( 33 _get_databricks_serverless_env_vars, 34 get_databricks_env_vars, 35 is_databricks_connect, 36 is_in_databricks_runtime, 37 ) 38 from mlflow.utils.os import is_windows 39 from mlflow.utils.process import _exec_cmd 40 from mlflow.utils.requirements_utils import ( 41 _get_local_version_label, 42 _infer_requirements, 43 _parse_requirements, 44 _strip_local_version_label, 45 warn_dependency_requirement_mismatches, 46 ) 47 from mlflow.utils.timeout import MlflowTimeoutError, run_with_timeout 48 from mlflow.utils.uv_utils import ( 49 detect_uv_project, 50 export_uv_requirements, 51 ) 52 from mlflow.version import VERSION 53 54 _logger = logging.getLogger(__name__) 55 56 _conda_header = """\ 57 name: mlflow-env 58 channels: 59 - conda-forge 60 """ 61 62 _CONDA_ENV_FILE_NAME = "conda.yaml" 63 _REQUIREMENTS_FILE_NAME = "requirements.txt" 64 _CONSTRAINTS_FILE_NAME = "constraints.txt" 65 _PYTHON_ENV_FILE_NAME = "python_env.yaml" 66 67 68 # Note this regular expression does not cover all possible patterns 69 _CONDA_DEPENDENCY_REGEX = re.compile( 70 r"^(?P<package>python|pip|setuptools|wheel)" 71 r"(?P<operator><|>|<=|>=|=|==|!=)?" 72 r"(?P<version>[\d.]+)?$" 73 ) 74 75 76 class _PythonEnv: 77 BUILD_PACKAGES = ("pip", "setuptools", "wheel") 78 79 def __init__(self, python=None, build_dependencies=None, dependencies=None): 80 """ 81 Represents environment information for MLflow Models and Projects. 82 83 Args: 84 python: Python version for the environment. If unspecified, defaults to the current 85 Python version. 86 build_dependencies: List of build dependencies for the environment that must 87 be installed before installing ``dependencies``. If unspecified, 88 defaults to an empty list. 89 dependencies: List of dependencies for the environment. If unspecified, defaults to 90 an empty list. 91 """ 92 if python is not None and not isinstance(python, str): 93 raise TypeError(f"`python` must be a string but got {type(python)}") 94 if build_dependencies is not None and not isinstance(build_dependencies, list): 95 raise TypeError( 96 f"`build_dependencies` must be a list but got {type(build_dependencies)}" 97 ) 98 if dependencies is not None and not isinstance(dependencies, list): 99 raise TypeError(f"`dependencies` must be a list but got {type(dependencies)}") 100 101 self.python = python or PYTHON_VERSION 102 self.build_dependencies = build_dependencies or [] 103 self.dependencies = dependencies or [] 104 105 def __str__(self): 106 return str(self.to_dict()) 107 108 @classmethod 109 def current(cls): 110 return cls( 111 python=PYTHON_VERSION, 112 build_dependencies=cls.get_current_build_dependencies(), 113 dependencies=[f"-r {_REQUIREMENTS_FILE_NAME}"], 114 ) 115 116 @staticmethod 117 def get_current_build_dependencies(): 118 build_dependencies = [] 119 for package in _PythonEnv.BUILD_PACKAGES: 120 version = _get_package_version(package) 121 dep = (package + "==" + version) if version else package 122 build_dependencies.append(dep) 123 return build_dependencies 124 125 def to_dict(self): 126 return self.__dict__.copy() 127 128 @classmethod 129 def from_dict(cls, dct): 130 return cls(**dct) 131 132 def to_yaml(self, path): 133 with open(path, "w") as f: 134 # Exclude None and empty lists 135 data = {k: v for k, v in self.to_dict().items() if v} 136 yaml.safe_dump(data, f, sort_keys=False, default_flow_style=False) 137 138 @classmethod 139 def from_yaml(cls, path): 140 with open(path) as f: 141 return cls.from_dict(yaml.safe_load(f)) 142 143 @staticmethod 144 def get_dependencies_from_conda_yaml(path): 145 with open(path) as f: 146 conda_env = yaml.safe_load(f) 147 148 python = None 149 build_dependencies = None 150 unmatched_dependencies = [] 151 dependencies = None 152 for dep in conda_env.get("dependencies", []): 153 if isinstance(dep, str): 154 match = _CONDA_DEPENDENCY_REGEX.match(dep) 155 if not match: 156 unmatched_dependencies.append(dep) 157 continue 158 package = match.group("package") 159 operator = match.group("operator") 160 version = match.group("version") 161 162 # Python 163 if not python and package == "python": 164 if operator is None: 165 raise MlflowException.invalid_parameter_value( 166 f"Invalid dependency for python: {dep}. " 167 "It must be pinned (e.g. python=3.8.13)." 168 ) 169 170 if operator in ("<", ">", "!="): 171 raise MlflowException( 172 f"Invalid version comparator for python: '{operator}'. " 173 "Must be one of ['<=', '>=', '=', '=='].", 174 error_code=INVALID_PARAMETER_VALUE, 175 ) 176 python = version 177 continue 178 179 # Build packages 180 if build_dependencies is None: 181 build_dependencies = [] 182 # "=" is an invalid operator for pip 183 operator = "==" if operator == "=" else operator 184 build_dependencies.append(package + (operator or "") + (version or "")) 185 elif _is_pip_deps(dep): 186 dependencies = dep["pip"] 187 else: 188 raise MlflowException( 189 f"Invalid conda dependency: {dep}. Must be str or dict in the form of " 190 '{"pip": [...]}', 191 error_code=INVALID_PARAMETER_VALUE, 192 ) 193 194 if python is None: 195 _logger.warning( 196 f"{path} does not include a python version specification. " 197 f"Using the current python version {PYTHON_VERSION}." 198 ) 199 python = PYTHON_VERSION 200 201 if unmatched_dependencies: 202 _logger.warning( 203 "The following conda dependencies will not be installed in the resulting " 204 "environment: %s", 205 unmatched_dependencies, 206 ) 207 208 return { 209 "python": python, 210 "build_dependencies": build_dependencies, 211 "dependencies": dependencies, 212 } 213 214 @classmethod 215 def from_conda_yaml(cls, path): 216 return cls.from_dict(cls.get_dependencies_from_conda_yaml(path)) 217 218 219 def _mlflow_conda_env( 220 path=None, 221 additional_conda_deps=None, 222 additional_pip_deps=None, 223 additional_conda_channels=None, 224 install_mlflow=True, 225 ): 226 """Creates a Conda environment with the specified package channels and dependencies. If there 227 are any pip dependencies, including from the install_mlflow parameter, then pip will be added to 228 the conda dependencies. This is done to ensure that the pip inside the conda environment is 229 used to install the pip dependencies. 230 231 Args: 232 path: Local filesystem path where the conda env file is to be written. If unspecified, 233 the conda env will not be written to the filesystem; it will still be returned 234 in dictionary format. 235 additional_conda_deps: List of additional conda dependencies passed as strings. 236 additional_pip_deps: List of additional pip dependencies passed as strings. 237 additional_conda_channels: List of additional conda channels to search when resolving 238 packages. 239 240 Returns: 241 None if path is specified. Otherwise, the a dictionary representation of the 242 Conda environment. 243 244 """ 245 additional_pip_deps = additional_pip_deps or [] 246 mlflow_deps = ( 247 [f"mlflow=={VERSION}"] 248 if install_mlflow and not _contains_mlflow_requirement(additional_pip_deps) 249 else [] 250 ) 251 pip_deps = mlflow_deps + additional_pip_deps 252 conda_deps = additional_conda_deps or [] 253 if pip_deps: 254 pip_version = _get_package_version("pip") 255 if pip_version is not None: 256 # When a new version of pip is released on PyPI, it takes a while until that version is 257 # uploaded to conda-forge. This time lag causes `conda create` to fail with 258 # a `ResolvePackageNotFound` error. As a workaround for this issue, use `<=` instead 259 # of `==` so conda installs `pip_version - 1` when `pip_version` is unavailable. 260 conda_deps.append(f"pip<={pip_version}") 261 else: 262 _logger.warning( 263 "Failed to resolve installed pip version. ``pip`` will be added to conda.yaml" 264 " environment spec without a version specifier." 265 ) 266 conda_deps.append("pip") 267 268 env = yaml.safe_load(_conda_header) 269 env["dependencies"] = [f"python={PYTHON_VERSION}"] 270 env["dependencies"] += conda_deps 271 env["dependencies"].append({"pip": pip_deps}) 272 if additional_conda_channels is not None: 273 env["channels"] += additional_conda_channels 274 275 if path is not None: 276 with open(path, "w") as out: 277 yaml.safe_dump(env, stream=out, default_flow_style=False) 278 return None 279 else: 280 return env 281 282 283 def _get_package_version(package_name: str) -> str | None: 284 try: 285 return importlib.metadata.version(package_name) 286 except importlib.metadata.PackageNotFoundError: 287 return None 288 289 290 def _mlflow_additional_pip_env(pip_deps, path=None): 291 requirements = "\n".join(pip_deps) 292 if path is not None: 293 with open(path, "w") as out: 294 out.write(requirements) 295 return None 296 else: 297 return requirements 298 299 300 def _is_pip_deps(dep): 301 """ 302 Returns True if `dep` is a dict representing pip dependencies 303 """ 304 return isinstance(dep, dict) and "pip" in dep 305 306 307 def _get_pip_deps(conda_env): 308 """ 309 Returns: 310 The pip dependencies from the conda env. 311 """ 312 if conda_env is not None: 313 for dep in conda_env["dependencies"]: 314 if _is_pip_deps(dep): 315 return dep["pip"] 316 return [] 317 318 319 def _overwrite_pip_deps(conda_env, new_pip_deps): 320 """ 321 Overwrites the pip dependencies section in the given conda env dictionary. 322 323 { 324 "name": "env", 325 "channels": [...], 326 "dependencies": [ 327 ..., 328 "pip", 329 {"pip": [...]}, <- Overwrite this 330 ], 331 } 332 """ 333 deps = conda_env.get("dependencies", []) 334 new_deps = [] 335 contains_pip_deps = False 336 for dep in deps: 337 if _is_pip_deps(dep): 338 contains_pip_deps = True 339 new_deps.append({"pip": new_pip_deps}) 340 else: 341 new_deps.append(dep) 342 343 if not contains_pip_deps: 344 new_deps.append({"pip": new_pip_deps}) 345 346 return {**conda_env, "dependencies": new_deps} 347 348 349 def _log_pip_requirements(conda_env, path, requirements_file=_REQUIREMENTS_FILE_NAME): 350 pip_deps = _get_pip_deps(conda_env) 351 _mlflow_additional_pip_env(pip_deps, path=os.path.join(path, requirements_file)) 352 353 354 def _parse_pip_requirements(pip_requirements): 355 """Parses an iterable of pip requirement strings or a pip requirements file. 356 357 Args: 358 pip_requirements: Either an iterable of pip requirement strings 359 (e.g. ``["scikit-learn", "-r requirements.txt"]``) or the string path to a pip 360 requirements file on the local filesystem (e.g. ``"requirements.txt"``). If ``None``, 361 an empty list will be returned. 362 363 Returns: 364 A tuple of parsed requirements and constraints. 365 """ 366 if pip_requirements is None: 367 return [], [] 368 369 def _is_string(x): 370 return isinstance(x, str) 371 372 def _is_iterable(x): 373 try: 374 iter(x) 375 return True 376 except Exception: 377 return False 378 379 if _is_string(pip_requirements): 380 with open(pip_requirements) as f: 381 return _parse_pip_requirements(f.read().splitlines()) 382 elif _is_iterable(pip_requirements) and all(map(_is_string, pip_requirements)): 383 requirements = [] 384 constraints = [] 385 for req_or_con in _parse_requirements(pip_requirements, is_constraint=False): 386 if req_or_con.is_constraint: 387 constraints.append(req_or_con.req_str) 388 else: 389 requirements.append(req_or_con.req_str) 390 391 return requirements, constraints 392 else: 393 raise TypeError( 394 "`pip_requirements` must be either a string path to a pip requirements file on the " 395 "local filesystem or an iterable of pip requirement strings, but got `{}`".format( 396 type(pip_requirements) 397 ) 398 ) 399 400 401 _INFER_PIP_REQUIREMENTS_GENERAL_ERROR_MESSAGE = ( 402 "Encountered an unexpected error while inferring pip requirements " 403 "(model URI: {model_uri}, flavor: {flavor}). Fall back to return {fallback}. " 404 "Set logging level to DEBUG to see the full traceback. " 405 ) 406 407 408 def infer_pip_requirements( 409 model_uri, 410 flavor, 411 fallback=None, 412 timeout=None, 413 extra_env_vars=None, 414 uv_project_dir=None, 415 uv_groups=None, 416 uv_extras=None, 417 ): 418 """Infers the pip requirements of the specified model by creating a subprocess and loading 419 the model in it to determine which packages are imported. 420 421 If a uv project is detected (contains both uv.lock and pyproject.toml), this function 422 will first attempt to export dependencies via ``uv export``. If that succeeds, those 423 requirements are returned. Otherwise, falls back to inferring dependencies by capturing 424 imported packages during model inference. 425 426 Args: 427 model_uri: The URI of the model. 428 flavor: The flavor name of the model. 429 fallback: If provided, an unexpected error during the inference procedure is swallowed 430 and the value of ``fallback`` is returned. Otherwise, the error is raised. 431 timeout: If specified, the inference operation is bound by the timeout (in seconds). 432 extra_env_vars: A dictionary of extra environment variables to pass to the subprocess. 433 Default to None. 434 uv_project_dir: Explicit path to a uv project directory. When provided, overrides 435 the ``MLFLOW_UV_AUTO_DETECT`` environment variable and searches the specified 436 directory instead of cwd. Default to None (auto-detect from cwd). 437 uv_groups: Optional list of uv dependency groups to include when exporting 438 requirements. Maps to ``uv export --group <name>``. 439 uv_extras: Optional list of uv extras (optional dependency sets) to include 440 when exporting requirements. Maps to ``uv export --extra <name>``. 441 442 Returns: 443 A list of inferred pip requirements (e.g. ``["scikit-learn==0.24.2", ...]``). 444 445 """ 446 # Check for uv project first - if detected, use uv export instead of 447 # inferring model dependencies by capturing imported packages during model inference. 448 # An explicit uv_project_dir overrides the MLFLOW_UV_AUTO_DETECT env var. 449 if uv_project_dir is not None or MLFLOW_UV_AUTO_DETECT.get(): 450 if uv_project := detect_uv_project(uv_project_dir): 451 _logger.info( 452 f"Detected uv project at {uv_project.uv_lock.parent}. " 453 "Attempting to export requirements via 'uv export'." 454 ) 455 if uv_requirements := export_uv_requirements( 456 uv_project.uv_lock.parent, 457 groups=uv_groups, 458 extras=uv_extras, 459 ): 460 _logger.info( 461 f"Successfully exported {len(uv_requirements)} requirements from uv project. " 462 "Skipping package capture based inference." 463 ) 464 return uv_requirements 465 else: 466 _logger.warning( 467 "uv export failed or returned no requirements. " 468 "Falling back to package capture based inference." 469 ) 470 elif uv_groups or uv_extras: 471 _logger.warning( 472 "uv_groups and/or uv_extras were specified but no uv project was detected. " 473 "These parameters will be ignored. Falling back to package capture based inference." 474 ) 475 476 raise_on_error = MLFLOW_REQUIREMENTS_INFERENCE_RAISE_ERRORS.get() 477 478 if timeout and is_windows(): 479 timeout = None 480 _logger.warning( 481 "On Windows, timeout is not supported for model requirement inference. Therefore, " 482 "the operation is not bound by a timeout and may hang indefinitely. If it hangs, " 483 "please consider specifying the signature manually." 484 ) 485 486 try: 487 if timeout: 488 with run_with_timeout(timeout): 489 return _infer_requirements( 490 model_uri, flavor, raise_on_error=raise_on_error, extra_env_vars=extra_env_vars 491 ) 492 else: 493 return _infer_requirements( 494 model_uri, flavor, raise_on_error=raise_on_error, extra_env_vars=extra_env_vars 495 ) 496 except Exception as e: 497 if raise_on_error or (fallback is None): 498 raise 499 500 if isinstance(e, MlflowTimeoutError): 501 msg = ( 502 "Attempted to infer pip requirements for the saved model or pipeline but the " 503 f"operation timed out in {timeout} seconds. Fall back to return {fallback}. " 504 "You can specify a different timeout by setting the environment variable " 505 f"{MLFLOW_INPUT_EXAMPLE_INFERENCE_TIMEOUT}." 506 ) 507 else: 508 msg = _INFER_PIP_REQUIREMENTS_GENERAL_ERROR_MESSAGE.format( 509 model_uri=model_uri, flavor=flavor, fallback=fallback 510 ) 511 _logger.warning(msg) 512 _logger.debug("", exc_info=True) 513 return fallback 514 515 516 def _get_uv_options_for_databricks() -> tuple[list[str], dict[str, str]] | None: 517 """ 518 Retrieves the predefined secrets to configure `pip` for Databricks, and converts them into 519 command-line arguments and environment variables for `uv`. 520 521 References: 522 - https://docs.databricks.com/aws/en/compute/serverless/dependencies#predefined-secret-scope-name 523 - https://docs.astral.sh/uv/configuration/environment/#environment-variables 524 """ 525 from databricks.sdk import WorkspaceClient 526 527 from mlflow.utils.databricks_utils import ( 528 _get_dbutils, 529 _NoDbutilsError, 530 is_in_databricks_runtime, 531 ) 532 533 if not is_in_databricks_runtime(): 534 return None 535 536 workspace_client = WorkspaceClient() 537 secret_scopes = workspace_client.secrets.list_scopes() 538 if not any(s.name == "databricks-package-management" for s in secret_scopes): 539 return None 540 541 try: 542 dbutils = _get_dbutils() 543 except _NoDbutilsError: 544 return None 545 546 def get_secret(key: str) -> str | None: 547 """ 548 Retrieves a secret from the Databricks secrets scope. 549 """ 550 try: 551 return dbutils.secrets.get(scope="databricks-package-management", key=key) 552 except Exception as e: 553 _logger.debug(f"Failed to fetch secret '{key}': {e}") 554 return None 555 556 args: list[str] = [] 557 if url := get_secret("pip-index-url"): 558 args.append(f"--index-url={url}") 559 560 if urls := get_secret("pip-extra-index-urls"): 561 args.append(f"--extra-index-url={urls}") 562 563 # There is no command-line option for SSL_CERT_FILE in `uv`. 564 envs: dict[str, str] = {} 565 if cert := get_secret("pip-cert"): 566 envs["SSL_CERT_FILE"] = cert 567 568 _logger.debug(f"uv arguments and environment variables: {args}, {envs}") 569 return args, envs 570 571 572 def _lock_requirements( 573 requirements: list[str], constraints: list[str] | None = None 574 ) -> list[str] | None: 575 """ 576 Locks the given requirements using `uv`. Returns the locked requirements when the locking is 577 performed successfully, otherwise returns None. 578 """ 579 if not MLFLOW_LOCK_MODEL_DEPENDENCIES.get(): 580 return None 581 582 uv_bin = shutil.which("uv") 583 if uv_bin is None: 584 _logger.debug("`uv` binary not found. Skipping locking requirements.") 585 return None 586 587 _logger.info("Locking requirements...") 588 with tempfile.TemporaryDirectory() as tmp_dir: 589 tmp_dir_path = pathlib.Path(tmp_dir) 590 in_file = tmp_dir_path / "requirements.in" 591 in_file.write_text("\n".join(requirements)) 592 out_file = tmp_dir_path / "requirements.out" 593 constraints_opt: list[str] = [] 594 if constraints: 595 constraints_file = tmp_dir_path / "constraints.txt" 596 constraints_file.write_text("\n".join(constraints)) 597 constraints_opt = [f"--constraints={constraints_file}"] 598 elif pip_constraint := os.environ.get("PIP_CONSTRAINT"): 599 # If PIP_CONSTRAINT is set, use it as a constraint file 600 constraints_opt = [f"--constraints={pip_constraint}"] 601 602 try: 603 if res := _get_uv_options_for_databricks(): 604 uv_options, uv_envs = res 605 else: 606 uv_options = [] 607 uv_envs = {} 608 out = subprocess.check_output( 609 [ 610 uv_bin, 611 "pip", 612 "compile", 613 "--color=never", 614 "--universal", 615 "--no-annotate", 616 "--no-header", 617 f"--python-version={PYTHON_VERSION}", 618 f"--output-file={out_file}", 619 *uv_options, 620 *constraints_opt, 621 in_file, 622 ], 623 stderr=subprocess.STDOUT, 624 env=os.environ.copy() | uv_envs, 625 text=True, 626 ) 627 _logger.debug(f"Successfully compiled requirements with `uv`:\n{out}") 628 except subprocess.CalledProcessError as e: 629 _logger.warning(f"Failed to lock requirements:\n{e.output}") 630 return None 631 632 return [ 633 "# Original requirements", 634 *(f"# {l}" for l in requirements), # Preserve original requirements as comments 635 "#", 636 "# Locked requirements", 637 *out_file.read_text().splitlines(), 638 ] 639 640 641 def _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements): 642 """ 643 Validates that only one or none of `conda_env`, `pip_requirements`, and 644 `extra_pip_requirements` is specified. 645 """ 646 args = [ 647 conda_env, 648 pip_requirements, 649 extra_pip_requirements, 650 ] 651 specified = [arg for arg in args if arg is not None] 652 if len(specified) > 1: 653 raise ValueError( 654 "Only one of `conda_env`, `pip_requirements`, and " 655 "`extra_pip_requirements` can be specified" 656 ) 657 658 659 # PIP requirement parser inspired from https://github.com/pypa/pip/blob/b392833a0f1cff1bbee1ac6dbe0270cccdd0c11f/src/pip/_internal/req/req_file.py#L400 660 def _get_pip_requirement_specifier(requirement_string): 661 tokens = requirement_string.split(" ") 662 for idx, token in enumerate(tokens): 663 if token.startswith("-"): 664 return " ".join(tokens[:idx]) 665 return requirement_string 666 667 668 def _is_mlflow_requirement(requirement_string): 669 """ 670 Returns True if `requirement_string` represents a requirement for mlflow (e.g. 'mlflow==1.2.3'). 671 """ 672 # "/opt/mlflow" is the path where we mount the mlflow source code in the Docker container 673 # when running tests. 674 if _MLFLOW_TESTING.get() and requirement_string == "/opt/mlflow": 675 return True 676 677 try: 678 # `Requirement` throws an `InvalidRequirement` exception if `requirement_string` doesn't 679 # conform to PEP 508 (https://www.python.org/dev/peps/pep-0508). 680 return Requirement(requirement_string).name.lower() in [ 681 "mlflow", 682 "mlflow-skinny", 683 "mlflow-tracing", 684 ] 685 except InvalidRequirement: 686 # A local file path or URL falls into this branch. 687 688 # `Requirement` throws an `InvalidRequirement` exception if `requirement_string` contains 689 # per-requirement options (ex: package hashes) 690 # GitHub issue: https://github.com/pypa/packaging/issues/488 691 # Per-requirement-option spec: https://pip.pypa.io/en/stable/reference/requirements-file-format/#per-requirement-options 692 requirement_specifier = _get_pip_requirement_specifier(requirement_string) 693 try: 694 # Try again with the per-requirement options removed 695 return Requirement(requirement_specifier).name.lower() == "mlflow" 696 except InvalidRequirement: 697 # Support defining branch dependencies for local builds or direct GitHub builds 698 # from source. 699 # Example: mlflow @ git+https://github.com/mlflow/mlflow@branch_2.0 700 repository_matches = ["/mlflow", "mlflow@git"] 701 702 return any( 703 match in requirement_string.replace(" ", "").lower() for match in repository_matches 704 ) 705 706 707 def _generate_mlflow_version_pinning() -> str: 708 """Returns a pinned requirement for the current MLflow version (e.g., "mlflow==3.2.1"). 709 710 Returns: 711 A pinned requirement for the current MLflow version. 712 713 """ 714 if _MLFLOW_TESTING.get(): 715 # The local PyPI server should be running. It serves a wheel for the current MLflow version. 716 return f"mlflow=={VERSION}" 717 718 version = Version(VERSION) 719 if not version.is_devrelease: 720 # mlflow is installed from PyPI. 721 return f"mlflow=={VERSION}" 722 723 # We reach here when mlflow is installed from the source outside of the MLflow CI environment 724 # (e.g., Databricks notebook). 725 726 # mlflow installed from the source for development purposes. A dev version (e.g., 2.8.1.dev0) 727 # is always a micro-version ahead of the latest release (unless it's manually modified) 728 # and can't be installed from PyPI. We therefore subtract 1 from the micro version when running 729 # tests. 730 return f"mlflow=={version.major}.{version.minor}.{version.micro - 1}" 731 732 733 def _contains_mlflow_requirement(requirements): 734 """ 735 Returns True if `requirements` contains a requirement for mlflow (e.g. 'mlflow==1.2.3'). 736 """ 737 return any(map(_is_mlflow_requirement, requirements)) 738 739 740 def _process_pip_requirements( 741 default_pip_requirements, pip_requirements=None, extra_pip_requirements=None 742 ): 743 """ 744 Processes `pip_requirements` and `extra_pip_requirements` passed to `mlflow.*.save_model` or 745 `mlflow.*.log_model`, and returns a tuple of (conda_env, pip_requirements, pip_constraints). 746 """ 747 constraints = [] 748 if pip_requirements is not None: 749 pip_reqs, constraints = _parse_pip_requirements(pip_requirements) 750 elif extra_pip_requirements is not None: 751 extra_pip_requirements, constraints = _parse_pip_requirements(extra_pip_requirements) 752 pip_reqs = default_pip_requirements + extra_pip_requirements 753 else: 754 pip_reqs = default_pip_requirements 755 756 if not _contains_mlflow_requirement(pip_reqs): 757 pip_reqs.insert(0, _generate_mlflow_version_pinning()) 758 759 sanitized_pip_reqs = _deduplicate_requirements(pip_reqs) 760 sanitized_pip_reqs = _remove_incompatible_requirements(sanitized_pip_reqs) 761 762 # Check if pip requirements contain incompatible version with the current environment 763 warn_dependency_requirement_mismatches(sanitized_pip_reqs) 764 765 if locked_requirements := _lock_requirements(sanitized_pip_reqs, constraints): 766 # Locking requirements was performed successfully 767 sanitized_pip_reqs = locked_requirements 768 else: 769 # Locking requirements was skipped or failed 770 if constraints: 771 sanitized_pip_reqs.append(f"-c {_CONSTRAINTS_FILE_NAME}") 772 773 # Set `install_mlflow` to False because `pip_reqs` already contains `mlflow` 774 conda_env = _mlflow_conda_env(additional_pip_deps=sanitized_pip_reqs, install_mlflow=False) 775 return conda_env, sanitized_pip_reqs, constraints 776 777 778 def _deduplicate_requirements(requirements): 779 """ 780 De-duplicates a list of pip package requirements, handling complex scenarios such as merging 781 extras and combining version constraints. 782 783 This function processes a list of pip package requirements and de-duplicates them. It handles 784 standard PyPI packages and non-standard requirements (like URLs or local paths). The function 785 merges extras and combines version constraints for duplicate packages. The most restrictive 786 version specifications or the ones with extras are prioritized. If incompatible version 787 constraints are detected, it raises an MlflowException. 788 789 Args: 790 requirements (list of str): A list of pip package requirement strings. 791 792 Returns: 793 list of str: A deduplicated list of pip package requirements. 794 795 Raises: 796 MlflowException: If incompatible version constraints are detected among the provided 797 requirements. 798 799 Examples: 800 - Input: ["packageA", "packageA==1.0"] 801 Output: ["packageA==1.0"] 802 803 - Input: ["packageX>1.0", "packageX[extras]", "packageX<2.0"] 804 Output: ["packageX[extras]<2.0,>1.0"] 805 806 - Input: ["markdown[extra1]>=3.5.1", "markdown[extra2]<4", "markdown"] 807 Output: ["markdown[extra1,extra2]<4,>=3.5.1"] 808 809 - Input: ["scikit-learn==1.1", "scikit-learn<1"] 810 Raises MlflowException indicating incompatible versions. 811 812 Note: 813 - Non-standard requirements (like URLs or file paths) are included as-is. 814 - If a requirement appears multiple times with different sets of extras, they are merged. 815 - The function uses `_validate_version_constraints` to check for incompatible version 816 constraints by doing a dry-run pip install of a requirements collection. 817 """ 818 deduped_reqs = {} 819 820 for req in requirements: 821 try: 822 parsed_req = Requirement(req) 823 base_pkg = parsed_req.name 824 key = (base_pkg, str(parsed_req.marker) if parsed_req.marker else "") 825 826 existing_req = deduped_reqs.get(key) 827 828 if not existing_req: 829 deduped_reqs[key] = parsed_req 830 else: 831 # Verify that there are not unresolvable constraints applied if set and combine 832 # if possible 833 if ( 834 existing_req.specifier 835 and parsed_req.specifier 836 and existing_req.specifier != parsed_req.specifier 837 ): 838 existing_specs = list(existing_req.specifier) 839 new_specs = list(parsed_req.specifier) 840 # When uv export preserves local version labels (e.g. torch==2.7.1+cu128) 841 # but _get_pinned_requirement strips them (e.g. torch==2.7.1), both end up 842 # in the merged list. Detect this case and prefer the non-local version 843 # (PyPI-installable) rather than failing validation. 844 if ( 845 len(existing_specs) == 1 846 and len(new_specs) == 1 847 and existing_specs[0].operator == "==" 848 and new_specs[0].operator == "==" 849 and _strip_local_version_label(existing_specs[0].version) 850 == _strip_local_version_label(new_specs[0].version) 851 and bool(_get_local_version_label(existing_specs[0].version)) 852 != bool(_get_local_version_label(new_specs[0].version)) 853 ): 854 # Keep whichever specifier has no local label (PyPI-installable) 855 if local_label := _get_local_version_label(new_specs[0].version): 856 _logger.debug( 857 f"Dropping local version label (+{local_label}) from " 858 f"'{parsed_req.name}=={new_specs[0].version}' to keep the " 859 f"PyPI-installable version " 860 f"'{parsed_req.name}=={existing_specs[0].version}'." 861 ) 862 parsed_req.specifier = existing_req.specifier 863 else: 864 _validate_version_constraints([str(existing_req), req]) 865 parsed_req.specifier = SpecifierSet( 866 ",".join([ 867 str(existing_req.specifier), 868 str(parsed_req.specifier), 869 ]) 870 ) 871 872 # Preserve existing specifiers 873 if existing_req.specifier and not parsed_req.specifier: 874 parsed_req.specifier = existing_req.specifier 875 876 # Combine and apply extras if specified 877 if ( 878 existing_req.extras 879 and parsed_req.extras 880 and existing_req.extras != parsed_req.extras 881 ): 882 parsed_req.extras = sorted(set(existing_req.extras).union(parsed_req.extras)) 883 elif existing_req.extras and not parsed_req.extras: 884 parsed_req.extras = existing_req.extras 885 886 deduped_reqs[key] = parsed_req 887 888 except InvalidRequirement: 889 # Include non-standard package strings as-is 890 if req not in deduped_reqs: 891 deduped_reqs[req] = req 892 return [str(req) for req in deduped_reqs.values()] 893 894 895 def _parse_requirement_name(req: str) -> str: 896 try: 897 return Requirement(req).name 898 except InvalidRequirement: 899 return req 900 901 902 def _remove_incompatible_requirements(requirements: list[str]) -> list[str]: 903 req_names = {_parse_requirement_name(req) for req in requirements} 904 if "databricks-connect" in req_names and req_names.intersection({"pyspark", "pyspark-connect"}): 905 _logger.debug( 906 "Found incompatible requirements: 'databricks-connect' with 'pyspark' or " 907 "'pyspark-connect'. Removing 'pyspark' or 'pyspark-connect' from the requirements." 908 ) 909 requirements = [ 910 req 911 for req in requirements 912 if _parse_requirement_name(req) not in ["pyspark", "pyspark-connect"] 913 ] 914 return requirements 915 916 917 def _validate_version_constraints(requirements): 918 """ 919 Validates the version constraints of given Python package requirements using pip's resolver with 920 the `--dry-run` option enabled that performs validation only (will not install packages). 921 922 This function writes the requirements to a temporary file and then attempts to resolve 923 them using pip's `--dry-run` install option. If any version conflicts are detected, it 924 raises an MlflowException with details of the conflict. 925 926 Args: 927 requirements (list of str): A list of package requirements (e.g., `["pandas>=1.15", 928 "pandas<2"]`). 929 930 Raises: 931 MlflowException: If any version conflicts are detected among the provided requirements. 932 933 Returns: 934 None: This function does not return anything. It either completes successfully or raises 935 an MlflowException. 936 937 Example: 938 _validate_version_constraints(["tensorflow<2.0", "tensorflow>2.3"]) 939 # This will raise an exception due to boundary validity. 940 """ 941 with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp_file: 942 tmp_file.write("\n".join(requirements)) 943 tmp_file_name = tmp_file.name 944 945 try: 946 subprocess.run( 947 [sys.executable, "-m", "pip", "install", "--dry-run", "-r", tmp_file_name], 948 check=True, 949 capture_output=True, 950 ) 951 except subprocess.CalledProcessError as e: 952 raise MlflowException.invalid_parameter_value( 953 "The specified requirements versions are incompatible. Detected " 954 f"conflicts: \n{e.stderr.decode()}" 955 ) 956 finally: 957 os.remove(tmp_file_name) 958 959 960 def _process_conda_env(conda_env): 961 """ 962 Processes `conda_env` passed to `mlflow.*.save_model` or `mlflow.*.log_model`, and returns 963 a tuple of (conda_env, pip_requirements, pip_constraints). 964 """ 965 if isinstance(conda_env, str): 966 with open(conda_env) as f: 967 conda_env = yaml.safe_load(f) 968 elif not isinstance(conda_env, dict): 969 raise TypeError( 970 "Expected a string path to a conda env yaml file or a `dict` representing a conda env, " 971 f"but got `{type(conda_env).__name__}`" 972 ) 973 974 # User-specified `conda_env` may contain requirements/constraints file references 975 pip_reqs = _get_pip_deps(conda_env) 976 pip_reqs, constraints = _parse_pip_requirements(pip_reqs) 977 if not _contains_mlflow_requirement(pip_reqs): 978 pip_reqs.insert(0, _generate_mlflow_version_pinning()) 979 980 # Check if pip requirements contain incompatible version with the current environment 981 warn_dependency_requirement_mismatches(pip_reqs) 982 983 if constraints: 984 pip_reqs.append(f"-c {_CONSTRAINTS_FILE_NAME}") 985 986 conda_env = _overwrite_pip_deps(conda_env, pip_reqs) 987 return conda_env, pip_reqs, constraints 988 989 990 def _get_mlflow_env_name(s): 991 """Creates an environment name for an MLflow model by hashing the given string. 992 993 Args: 994 s: String to hash (e.g. the content of `conda.yaml`). 995 996 Returns: 997 String in the form of "mlflow-{hash}" 998 (e.g. "mlflow-da39a3ee5e6b4b0d3255bfef95601890afd80709") 999 1000 """ 1001 return "mlflow-" + hashlib.sha1(s.encode("utf-8"), usedforsecurity=False).hexdigest() 1002 1003 1004 def _get_pip_install_mlflow(): 1005 """ 1006 Returns a command to pip-install mlflow. If the MLFLOW_HOME environment variable exists, 1007 returns "pip install -e {MLFLOW_HOME} 1>&2", otherwise 1008 "pip install mlflow=={mlflow.__version__} 1>&2". 1009 """ 1010 if mlflow_home := os.environ.get("MLFLOW_HOME"): # dev version 1011 return f"pip install -e {mlflow_home} 1>&2" 1012 else: 1013 return f"pip install mlflow=={VERSION} 1>&2" 1014 1015 1016 def _get_requirements_from_file( 1017 file_path: pathlib.Path, 1018 ) -> list[Requirement]: 1019 data = file_path.read_text() 1020 if file_path.name == _CONDA_ENV_FILE_NAME: 1021 conda_env = yaml.safe_load(data) 1022 reqs = _get_pip_deps(conda_env) 1023 else: 1024 reqs = data.splitlines() 1025 return [Requirement(req) for req in reqs if req] 1026 1027 1028 def _write_requirements_to_file( 1029 file_path: pathlib.Path, 1030 new_reqs: list[str], 1031 ) -> None: 1032 if file_path.name == _CONDA_ENV_FILE_NAME: 1033 conda_env = yaml.safe_load(file_path.read_text()) 1034 conda_env = _overwrite_pip_deps(conda_env, new_reqs) 1035 with file_path.open("w") as file: 1036 yaml.dump(conda_env, file) 1037 else: 1038 file_path.write_text("\n".join(new_reqs)) 1039 1040 1041 def _add_or_overwrite_requirements( 1042 new_reqs: list[Requirement], 1043 old_reqs: list[Requirement], 1044 ) -> list[str]: 1045 deduped_new_reqs = _deduplicate_requirements([str(req) for req in new_reqs]) 1046 deduped_new_reqs = [Requirement(req) for req in deduped_new_reqs] 1047 1048 old_reqs_dict = {req.name: str(req) for req in old_reqs} 1049 new_reqs_dict = {req.name: str(req) for req in deduped_new_reqs} 1050 old_reqs_dict.update(new_reqs_dict) 1051 return list(old_reqs_dict.values()) 1052 1053 1054 def _remove_requirements( 1055 reqs_to_remove: list[Requirement], 1056 old_reqs: list[Requirement], 1057 ) -> list[str]: 1058 old_reqs_dict = {req.name: str(req) for req in old_reqs} 1059 for req in reqs_to_remove: 1060 if req.name not in old_reqs_dict: 1061 _logger.warning(f'"{req.name}" not found in requirements, ignoring') 1062 old_reqs_dict.pop(req.name, None) 1063 return list(old_reqs_dict.values()) 1064 1065 1066 class Environment: 1067 def __init__(self, activate_cmd, extra_env=None): 1068 if not isinstance(activate_cmd, list): 1069 activate_cmd = [activate_cmd] 1070 self._activate_cmd = activate_cmd 1071 self._extra_env = extra_env or {} 1072 1073 def get_activate_command(self): 1074 return self._activate_cmd 1075 1076 def execute( 1077 self, 1078 command, 1079 command_env=None, 1080 preexec_fn=None, 1081 capture_output=False, 1082 stdout=None, 1083 stderr=None, 1084 stdin=None, 1085 synchronous=True, 1086 ): 1087 command_env = os.environ.copy() if command_env is None else deepcopy(command_env) 1088 if is_in_databricks_runtime(): 1089 command_env.update(get_databricks_env_vars(get_tracking_uri())) 1090 if is_databricks_connect(): 1091 command_env.update(_get_databricks_serverless_env_vars()) 1092 if exp_id := _get_experiment_id(): 1093 command_env[MLFLOW_EXPERIMENT_ID.name] = exp_id 1094 if active_model_id := get_active_model_id(): 1095 command_env[_MLFLOW_ACTIVE_MODEL_ID.name] = active_model_id 1096 command_env.update(self._extra_env) 1097 if not isinstance(command, list): 1098 command = [command] 1099 1100 separator = " && " if not is_windows() else " & " 1101 1102 command = separator.join(map(str, self._activate_cmd + command)) 1103 command = ["bash", "-c", command] if not is_windows() else ["cmd", "/c", command] 1104 _logger.info("=== Running command '%s'", command) 1105 return _exec_cmd( 1106 command, 1107 env=command_env, 1108 capture_output=capture_output, 1109 synchronous=synchronous, 1110 preexec_fn=preexec_fn, 1111 close_fds=True, 1112 stdout=stdout, 1113 stderr=stderr, 1114 stdin=stdin, 1115 )