/ mlflow / utils / environment.py
environment.py
   1  import hashlib
   2  import importlib.metadata
   3  import logging
   4  import os
   5  import pathlib
   6  import re
   7  import shutil
   8  import subprocess
   9  import sys
  10  import tempfile
  11  from copy import deepcopy
  12  
  13  import yaml
  14  from packaging.requirements import InvalidRequirement, Requirement
  15  from packaging.specifiers import SpecifierSet
  16  from packaging.version import Version
  17  
  18  from mlflow.environment_variables import (
  19      _MLFLOW_ACTIVE_MODEL_ID,
  20      _MLFLOW_TESTING,
  21      MLFLOW_EXPERIMENT_ID,
  22      MLFLOW_INPUT_EXAMPLE_INFERENCE_TIMEOUT,
  23      MLFLOW_LOCK_MODEL_DEPENDENCIES,
  24      MLFLOW_REQUIREMENTS_INFERENCE_RAISE_ERRORS,
  25      MLFLOW_UV_AUTO_DETECT,
  26  )
  27  from mlflow.exceptions import MlflowException
  28  from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
  29  from mlflow.tracking import get_tracking_uri
  30  from mlflow.tracking.fluent import _get_experiment_id, get_active_model_id
  31  from mlflow.utils import PYTHON_VERSION
  32  from mlflow.utils.databricks_utils import (
  33      _get_databricks_serverless_env_vars,
  34      get_databricks_env_vars,
  35      is_databricks_connect,
  36      is_in_databricks_runtime,
  37  )
  38  from mlflow.utils.os import is_windows
  39  from mlflow.utils.process import _exec_cmd
  40  from mlflow.utils.requirements_utils import (
  41      _get_local_version_label,
  42      _infer_requirements,
  43      _parse_requirements,
  44      _strip_local_version_label,
  45      warn_dependency_requirement_mismatches,
  46  )
  47  from mlflow.utils.timeout import MlflowTimeoutError, run_with_timeout
  48  from mlflow.utils.uv_utils import (
  49      detect_uv_project,
  50      export_uv_requirements,
  51  )
  52  from mlflow.version import VERSION
  53  
  54  _logger = logging.getLogger(__name__)
  55  
  56  _conda_header = """\
  57  name: mlflow-env
  58  channels:
  59    - conda-forge
  60  """
  61  
  62  _CONDA_ENV_FILE_NAME = "conda.yaml"
  63  _REQUIREMENTS_FILE_NAME = "requirements.txt"
  64  _CONSTRAINTS_FILE_NAME = "constraints.txt"
  65  _PYTHON_ENV_FILE_NAME = "python_env.yaml"
  66  
  67  
  68  # Note this regular expression does not cover all possible patterns
  69  _CONDA_DEPENDENCY_REGEX = re.compile(
  70      r"^(?P<package>python|pip|setuptools|wheel)"
  71      r"(?P<operator><|>|<=|>=|=|==|!=)?"
  72      r"(?P<version>[\d.]+)?$"
  73  )
  74  
  75  
  76  class _PythonEnv:
  77      BUILD_PACKAGES = ("pip", "setuptools", "wheel")
  78  
  79      def __init__(self, python=None, build_dependencies=None, dependencies=None):
  80          """
  81          Represents environment information for MLflow Models and Projects.
  82  
  83          Args:
  84              python: Python version for the environment. If unspecified, defaults to the current
  85                  Python version.
  86              build_dependencies: List of build dependencies for the environment that must
  87                  be installed before installing ``dependencies``. If unspecified,
  88                  defaults to an empty list.
  89              dependencies: List of dependencies for the environment. If unspecified, defaults to
  90                  an empty list.
  91          """
  92          if python is not None and not isinstance(python, str):
  93              raise TypeError(f"`python` must be a string but got {type(python)}")
  94          if build_dependencies is not None and not isinstance(build_dependencies, list):
  95              raise TypeError(
  96                  f"`build_dependencies` must be a list but got {type(build_dependencies)}"
  97              )
  98          if dependencies is not None and not isinstance(dependencies, list):
  99              raise TypeError(f"`dependencies` must be a list but got {type(dependencies)}")
 100  
 101          self.python = python or PYTHON_VERSION
 102          self.build_dependencies = build_dependencies or []
 103          self.dependencies = dependencies or []
 104  
 105      def __str__(self):
 106          return str(self.to_dict())
 107  
 108      @classmethod
 109      def current(cls):
 110          return cls(
 111              python=PYTHON_VERSION,
 112              build_dependencies=cls.get_current_build_dependencies(),
 113              dependencies=[f"-r {_REQUIREMENTS_FILE_NAME}"],
 114          )
 115  
 116      @staticmethod
 117      def get_current_build_dependencies():
 118          build_dependencies = []
 119          for package in _PythonEnv.BUILD_PACKAGES:
 120              version = _get_package_version(package)
 121              dep = (package + "==" + version) if version else package
 122              build_dependencies.append(dep)
 123          return build_dependencies
 124  
 125      def to_dict(self):
 126          return self.__dict__.copy()
 127  
 128      @classmethod
 129      def from_dict(cls, dct):
 130          return cls(**dct)
 131  
 132      def to_yaml(self, path):
 133          with open(path, "w") as f:
 134              # Exclude None and empty lists
 135              data = {k: v for k, v in self.to_dict().items() if v}
 136              yaml.safe_dump(data, f, sort_keys=False, default_flow_style=False)
 137  
 138      @classmethod
 139      def from_yaml(cls, path):
 140          with open(path) as f:
 141              return cls.from_dict(yaml.safe_load(f))
 142  
 143      @staticmethod
 144      def get_dependencies_from_conda_yaml(path):
 145          with open(path) as f:
 146              conda_env = yaml.safe_load(f)
 147  
 148          python = None
 149          build_dependencies = None
 150          unmatched_dependencies = []
 151          dependencies = None
 152          for dep in conda_env.get("dependencies", []):
 153              if isinstance(dep, str):
 154                  match = _CONDA_DEPENDENCY_REGEX.match(dep)
 155                  if not match:
 156                      unmatched_dependencies.append(dep)
 157                      continue
 158                  package = match.group("package")
 159                  operator = match.group("operator")
 160                  version = match.group("version")
 161  
 162                  # Python
 163                  if not python and package == "python":
 164                      if operator is None:
 165                          raise MlflowException.invalid_parameter_value(
 166                              f"Invalid dependency for python: {dep}. "
 167                              "It must be pinned (e.g. python=3.8.13)."
 168                          )
 169  
 170                      if operator in ("<", ">", "!="):
 171                          raise MlflowException(
 172                              f"Invalid version comparator for python: '{operator}'. "
 173                              "Must be one of ['<=', '>=', '=', '=='].",
 174                              error_code=INVALID_PARAMETER_VALUE,
 175                          )
 176                      python = version
 177                      continue
 178  
 179                  # Build packages
 180                  if build_dependencies is None:
 181                      build_dependencies = []
 182                  # "=" is an invalid operator for pip
 183                  operator = "==" if operator == "=" else operator
 184                  build_dependencies.append(package + (operator or "") + (version or ""))
 185              elif _is_pip_deps(dep):
 186                  dependencies = dep["pip"]
 187              else:
 188                  raise MlflowException(
 189                      f"Invalid conda dependency: {dep}. Must be str or dict in the form of "
 190                      '{"pip": [...]}',
 191                      error_code=INVALID_PARAMETER_VALUE,
 192                  )
 193  
 194          if python is None:
 195              _logger.warning(
 196                  f"{path} does not include a python version specification. "
 197                  f"Using the current python version {PYTHON_VERSION}."
 198              )
 199              python = PYTHON_VERSION
 200  
 201          if unmatched_dependencies:
 202              _logger.warning(
 203                  "The following conda dependencies will not be installed in the resulting "
 204                  "environment: %s",
 205                  unmatched_dependencies,
 206              )
 207  
 208          return {
 209              "python": python,
 210              "build_dependencies": build_dependencies,
 211              "dependencies": dependencies,
 212          }
 213  
 214      @classmethod
 215      def from_conda_yaml(cls, path):
 216          return cls.from_dict(cls.get_dependencies_from_conda_yaml(path))
 217  
 218  
 219  def _mlflow_conda_env(
 220      path=None,
 221      additional_conda_deps=None,
 222      additional_pip_deps=None,
 223      additional_conda_channels=None,
 224      install_mlflow=True,
 225  ):
 226      """Creates a Conda environment with the specified package channels and dependencies. If there
 227      are any pip dependencies, including from the install_mlflow parameter, then pip will be added to
 228      the conda dependencies. This is done to ensure that the pip inside the conda environment is
 229      used to install the pip dependencies.
 230  
 231      Args:
 232          path: Local filesystem path where the conda env file is to be written. If unspecified,
 233              the conda env will not be written to the filesystem; it will still be returned
 234              in dictionary format.
 235          additional_conda_deps: List of additional conda dependencies passed as strings.
 236          additional_pip_deps: List of additional pip dependencies passed as strings.
 237          additional_conda_channels: List of additional conda channels to search when resolving
 238              packages.
 239  
 240      Returns:
 241          None if path is specified. Otherwise, the a dictionary representation of the
 242          Conda environment.
 243  
 244      """
 245      additional_pip_deps = additional_pip_deps or []
 246      mlflow_deps = (
 247          [f"mlflow=={VERSION}"]
 248          if install_mlflow and not _contains_mlflow_requirement(additional_pip_deps)
 249          else []
 250      )
 251      pip_deps = mlflow_deps + additional_pip_deps
 252      conda_deps = additional_conda_deps or []
 253      if pip_deps:
 254          pip_version = _get_package_version("pip")
 255          if pip_version is not None:
 256              # When a new version of pip is released on PyPI, it takes a while until that version is
 257              # uploaded to conda-forge. This time lag causes `conda create` to fail with
 258              # a `ResolvePackageNotFound` error. As a workaround for this issue, use `<=` instead
 259              # of `==` so conda installs `pip_version - 1` when `pip_version` is unavailable.
 260              conda_deps.append(f"pip<={pip_version}")
 261          else:
 262              _logger.warning(
 263                  "Failed to resolve installed pip version. ``pip`` will be added to conda.yaml"
 264                  " environment spec without a version specifier."
 265              )
 266              conda_deps.append("pip")
 267  
 268      env = yaml.safe_load(_conda_header)
 269      env["dependencies"] = [f"python={PYTHON_VERSION}"]
 270      env["dependencies"] += conda_deps
 271      env["dependencies"].append({"pip": pip_deps})
 272      if additional_conda_channels is not None:
 273          env["channels"] += additional_conda_channels
 274  
 275      if path is not None:
 276          with open(path, "w") as out:
 277              yaml.safe_dump(env, stream=out, default_flow_style=False)
 278          return None
 279      else:
 280          return env
 281  
 282  
 283  def _get_package_version(package_name: str) -> str | None:
 284      try:
 285          return importlib.metadata.version(package_name)
 286      except importlib.metadata.PackageNotFoundError:
 287          return None
 288  
 289  
 290  def _mlflow_additional_pip_env(pip_deps, path=None):
 291      requirements = "\n".join(pip_deps)
 292      if path is not None:
 293          with open(path, "w") as out:
 294              out.write(requirements)
 295          return None
 296      else:
 297          return requirements
 298  
 299  
 300  def _is_pip_deps(dep):
 301      """
 302      Returns True if `dep` is a dict representing pip dependencies
 303      """
 304      return isinstance(dep, dict) and "pip" in dep
 305  
 306  
 307  def _get_pip_deps(conda_env):
 308      """
 309      Returns:
 310          The pip dependencies from the conda env.
 311      """
 312      if conda_env is not None:
 313          for dep in conda_env["dependencies"]:
 314              if _is_pip_deps(dep):
 315                  return dep["pip"]
 316      return []
 317  
 318  
 319  def _overwrite_pip_deps(conda_env, new_pip_deps):
 320      """
 321      Overwrites the pip dependencies section in the given conda env dictionary.
 322  
 323      {
 324          "name": "env",
 325          "channels": [...],
 326          "dependencies": [
 327              ...,
 328              "pip",
 329              {"pip": [...]},  <- Overwrite this
 330          ],
 331      }
 332      """
 333      deps = conda_env.get("dependencies", [])
 334      new_deps = []
 335      contains_pip_deps = False
 336      for dep in deps:
 337          if _is_pip_deps(dep):
 338              contains_pip_deps = True
 339              new_deps.append({"pip": new_pip_deps})
 340          else:
 341              new_deps.append(dep)
 342  
 343      if not contains_pip_deps:
 344          new_deps.append({"pip": new_pip_deps})
 345  
 346      return {**conda_env, "dependencies": new_deps}
 347  
 348  
 349  def _log_pip_requirements(conda_env, path, requirements_file=_REQUIREMENTS_FILE_NAME):
 350      pip_deps = _get_pip_deps(conda_env)
 351      _mlflow_additional_pip_env(pip_deps, path=os.path.join(path, requirements_file))
 352  
 353  
 354  def _parse_pip_requirements(pip_requirements):
 355      """Parses an iterable of pip requirement strings or a pip requirements file.
 356  
 357      Args:
 358          pip_requirements: Either an iterable of pip requirement strings
 359              (e.g. ``["scikit-learn", "-r requirements.txt"]``) or the string path to a pip
 360              requirements file on the local filesystem (e.g. ``"requirements.txt"``). If ``None``,
 361              an empty list will be returned.
 362  
 363      Returns:
 364          A tuple of parsed requirements and constraints.
 365      """
 366      if pip_requirements is None:
 367          return [], []
 368  
 369      def _is_string(x):
 370          return isinstance(x, str)
 371  
 372      def _is_iterable(x):
 373          try:
 374              iter(x)
 375              return True
 376          except Exception:
 377              return False
 378  
 379      if _is_string(pip_requirements):
 380          with open(pip_requirements) as f:
 381              return _parse_pip_requirements(f.read().splitlines())
 382      elif _is_iterable(pip_requirements) and all(map(_is_string, pip_requirements)):
 383          requirements = []
 384          constraints = []
 385          for req_or_con in _parse_requirements(pip_requirements, is_constraint=False):
 386              if req_or_con.is_constraint:
 387                  constraints.append(req_or_con.req_str)
 388              else:
 389                  requirements.append(req_or_con.req_str)
 390  
 391          return requirements, constraints
 392      else:
 393          raise TypeError(
 394              "`pip_requirements` must be either a string path to a pip requirements file on the "
 395              "local filesystem or an iterable of pip requirement strings, but got `{}`".format(
 396                  type(pip_requirements)
 397              )
 398          )
 399  
 400  
 401  _INFER_PIP_REQUIREMENTS_GENERAL_ERROR_MESSAGE = (
 402      "Encountered an unexpected error while inferring pip requirements "
 403      "(model URI: {model_uri}, flavor: {flavor}). Fall back to return {fallback}. "
 404      "Set logging level to DEBUG to see the full traceback. "
 405  )
 406  
 407  
 408  def infer_pip_requirements(
 409      model_uri,
 410      flavor,
 411      fallback=None,
 412      timeout=None,
 413      extra_env_vars=None,
 414      uv_project_dir=None,
 415      uv_groups=None,
 416      uv_extras=None,
 417  ):
 418      """Infers the pip requirements of the specified model by creating a subprocess and loading
 419      the model in it to determine which packages are imported.
 420  
 421      If a uv project is detected (contains both uv.lock and pyproject.toml), this function
 422      will first attempt to export dependencies via ``uv export``. If that succeeds, those
 423      requirements are returned. Otherwise, falls back to inferring dependencies by capturing
 424      imported packages during model inference.
 425  
 426      Args:
 427          model_uri: The URI of the model.
 428          flavor: The flavor name of the model.
 429          fallback: If provided, an unexpected error during the inference procedure is swallowed
 430              and the value of ``fallback`` is returned. Otherwise, the error is raised.
 431          timeout: If specified, the inference operation is bound by the timeout (in seconds).
 432          extra_env_vars: A dictionary of extra environment variables to pass to the subprocess.
 433              Default to None.
 434          uv_project_dir: Explicit path to a uv project directory. When provided, overrides
 435              the ``MLFLOW_UV_AUTO_DETECT`` environment variable and searches the specified
 436              directory instead of cwd. Default to None (auto-detect from cwd).
 437          uv_groups: Optional list of uv dependency groups to include when exporting
 438              requirements. Maps to ``uv export --group <name>``.
 439          uv_extras: Optional list of uv extras (optional dependency sets) to include
 440              when exporting requirements. Maps to ``uv export --extra <name>``.
 441  
 442      Returns:
 443          A list of inferred pip requirements (e.g. ``["scikit-learn==0.24.2", ...]``).
 444  
 445      """
 446      # Check for uv project first - if detected, use uv export instead of
 447      # inferring model dependencies by capturing imported packages during model inference.
 448      # An explicit uv_project_dir overrides the MLFLOW_UV_AUTO_DETECT env var.
 449      if uv_project_dir is not None or MLFLOW_UV_AUTO_DETECT.get():
 450          if uv_project := detect_uv_project(uv_project_dir):
 451              _logger.info(
 452                  f"Detected uv project at {uv_project.uv_lock.parent}. "
 453                  "Attempting to export requirements via 'uv export'."
 454              )
 455              if uv_requirements := export_uv_requirements(
 456                  uv_project.uv_lock.parent,
 457                  groups=uv_groups,
 458                  extras=uv_extras,
 459              ):
 460                  _logger.info(
 461                      f"Successfully exported {len(uv_requirements)} requirements from uv project. "
 462                      "Skipping package capture based inference."
 463                  )
 464                  return uv_requirements
 465              else:
 466                  _logger.warning(
 467                      "uv export failed or returned no requirements. "
 468                      "Falling back to package capture based inference."
 469                  )
 470          elif uv_groups or uv_extras:
 471              _logger.warning(
 472                  "uv_groups and/or uv_extras were specified but no uv project was detected. "
 473                  "These parameters will be ignored. Falling back to package capture based inference."
 474              )
 475  
 476      raise_on_error = MLFLOW_REQUIREMENTS_INFERENCE_RAISE_ERRORS.get()
 477  
 478      if timeout and is_windows():
 479          timeout = None
 480          _logger.warning(
 481              "On Windows, timeout is not supported for model requirement inference. Therefore, "
 482              "the operation is not bound by a timeout and may hang indefinitely. If it hangs, "
 483              "please consider specifying the signature manually."
 484          )
 485  
 486      try:
 487          if timeout:
 488              with run_with_timeout(timeout):
 489                  return _infer_requirements(
 490                      model_uri, flavor, raise_on_error=raise_on_error, extra_env_vars=extra_env_vars
 491                  )
 492          else:
 493              return _infer_requirements(
 494                  model_uri, flavor, raise_on_error=raise_on_error, extra_env_vars=extra_env_vars
 495              )
 496      except Exception as e:
 497          if raise_on_error or (fallback is None):
 498              raise
 499  
 500          if isinstance(e, MlflowTimeoutError):
 501              msg = (
 502                  "Attempted to infer pip requirements for the saved model or pipeline but the "
 503                  f"operation timed out in {timeout} seconds. Fall back to return {fallback}. "
 504                  "You can specify a different timeout by setting the environment variable "
 505                  f"{MLFLOW_INPUT_EXAMPLE_INFERENCE_TIMEOUT}."
 506              )
 507          else:
 508              msg = _INFER_PIP_REQUIREMENTS_GENERAL_ERROR_MESSAGE.format(
 509                  model_uri=model_uri, flavor=flavor, fallback=fallback
 510              )
 511          _logger.warning(msg)
 512          _logger.debug("", exc_info=True)
 513          return fallback
 514  
 515  
 516  def _get_uv_options_for_databricks() -> tuple[list[str], dict[str, str]] | None:
 517      """
 518      Retrieves the predefined secrets to configure `pip` for Databricks, and converts them into
 519      command-line arguments and environment variables for `uv`.
 520  
 521      References:
 522      - https://docs.databricks.com/aws/en/compute/serverless/dependencies#predefined-secret-scope-name
 523      - https://docs.astral.sh/uv/configuration/environment/#environment-variables
 524      """
 525      from databricks.sdk import WorkspaceClient
 526  
 527      from mlflow.utils.databricks_utils import (
 528          _get_dbutils,
 529          _NoDbutilsError,
 530          is_in_databricks_runtime,
 531      )
 532  
 533      if not is_in_databricks_runtime():
 534          return None
 535  
 536      workspace_client = WorkspaceClient()
 537      secret_scopes = workspace_client.secrets.list_scopes()
 538      if not any(s.name == "databricks-package-management" for s in secret_scopes):
 539          return None
 540  
 541      try:
 542          dbutils = _get_dbutils()
 543      except _NoDbutilsError:
 544          return None
 545  
 546      def get_secret(key: str) -> str | None:
 547          """
 548          Retrieves a secret from the Databricks secrets scope.
 549          """
 550          try:
 551              return dbutils.secrets.get(scope="databricks-package-management", key=key)
 552          except Exception as e:
 553              _logger.debug(f"Failed to fetch secret '{key}': {e}")
 554              return None
 555  
 556      args: list[str] = []
 557      if url := get_secret("pip-index-url"):
 558          args.append(f"--index-url={url}")
 559  
 560      if urls := get_secret("pip-extra-index-urls"):
 561          args.append(f"--extra-index-url={urls}")
 562  
 563      # There is no command-line option for SSL_CERT_FILE in `uv`.
 564      envs: dict[str, str] = {}
 565      if cert := get_secret("pip-cert"):
 566          envs["SSL_CERT_FILE"] = cert
 567  
 568      _logger.debug(f"uv arguments and environment variables: {args}, {envs}")
 569      return args, envs
 570  
 571  
 572  def _lock_requirements(
 573      requirements: list[str], constraints: list[str] | None = None
 574  ) -> list[str] | None:
 575      """
 576      Locks the given requirements using `uv`. Returns the locked requirements when the locking is
 577      performed successfully, otherwise returns None.
 578      """
 579      if not MLFLOW_LOCK_MODEL_DEPENDENCIES.get():
 580          return None
 581  
 582      uv_bin = shutil.which("uv")
 583      if uv_bin is None:
 584          _logger.debug("`uv` binary not found. Skipping locking requirements.")
 585          return None
 586  
 587      _logger.info("Locking requirements...")
 588      with tempfile.TemporaryDirectory() as tmp_dir:
 589          tmp_dir_path = pathlib.Path(tmp_dir)
 590          in_file = tmp_dir_path / "requirements.in"
 591          in_file.write_text("\n".join(requirements))
 592          out_file = tmp_dir_path / "requirements.out"
 593          constraints_opt: list[str] = []
 594          if constraints:
 595              constraints_file = tmp_dir_path / "constraints.txt"
 596              constraints_file.write_text("\n".join(constraints))
 597              constraints_opt = [f"--constraints={constraints_file}"]
 598          elif pip_constraint := os.environ.get("PIP_CONSTRAINT"):
 599              # If PIP_CONSTRAINT is set, use it as a constraint file
 600              constraints_opt = [f"--constraints={pip_constraint}"]
 601  
 602          try:
 603              if res := _get_uv_options_for_databricks():
 604                  uv_options, uv_envs = res
 605              else:
 606                  uv_options = []
 607                  uv_envs = {}
 608              out = subprocess.check_output(
 609                  [
 610                      uv_bin,
 611                      "pip",
 612                      "compile",
 613                      "--color=never",
 614                      "--universal",
 615                      "--no-annotate",
 616                      "--no-header",
 617                      f"--python-version={PYTHON_VERSION}",
 618                      f"--output-file={out_file}",
 619                      *uv_options,
 620                      *constraints_opt,
 621                      in_file,
 622                  ],
 623                  stderr=subprocess.STDOUT,
 624                  env=os.environ.copy() | uv_envs,
 625                  text=True,
 626              )
 627              _logger.debug(f"Successfully compiled requirements with `uv`:\n{out}")
 628          except subprocess.CalledProcessError as e:
 629              _logger.warning(f"Failed to lock requirements:\n{e.output}")
 630              return None
 631  
 632          return [
 633              "# Original requirements",
 634              *(f"# {l}" for l in requirements),  # Preserve original requirements as comments
 635              "#",
 636              "# Locked requirements",
 637              *out_file.read_text().splitlines(),
 638          ]
 639  
 640  
 641  def _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements):
 642      """
 643      Validates that only one or none of `conda_env`, `pip_requirements`, and
 644      `extra_pip_requirements` is specified.
 645      """
 646      args = [
 647          conda_env,
 648          pip_requirements,
 649          extra_pip_requirements,
 650      ]
 651      specified = [arg for arg in args if arg is not None]
 652      if len(specified) > 1:
 653          raise ValueError(
 654              "Only one of `conda_env`, `pip_requirements`, and "
 655              "`extra_pip_requirements` can be specified"
 656          )
 657  
 658  
 659  # PIP requirement parser inspired from https://github.com/pypa/pip/blob/b392833a0f1cff1bbee1ac6dbe0270cccdd0c11f/src/pip/_internal/req/req_file.py#L400
 660  def _get_pip_requirement_specifier(requirement_string):
 661      tokens = requirement_string.split(" ")
 662      for idx, token in enumerate(tokens):
 663          if token.startswith("-"):
 664              return " ".join(tokens[:idx])
 665      return requirement_string
 666  
 667  
 668  def _is_mlflow_requirement(requirement_string):
 669      """
 670      Returns True if `requirement_string` represents a requirement for mlflow (e.g. 'mlflow==1.2.3').
 671      """
 672      # "/opt/mlflow" is the path where we mount the mlflow source code in the Docker container
 673      # when running tests.
 674      if _MLFLOW_TESTING.get() and requirement_string == "/opt/mlflow":
 675          return True
 676  
 677      try:
 678          # `Requirement` throws an `InvalidRequirement` exception if `requirement_string` doesn't
 679          # conform to PEP 508 (https://www.python.org/dev/peps/pep-0508).
 680          return Requirement(requirement_string).name.lower() in [
 681              "mlflow",
 682              "mlflow-skinny",
 683              "mlflow-tracing",
 684          ]
 685      except InvalidRequirement:
 686          # A local file path or URL falls into this branch.
 687  
 688          # `Requirement` throws an `InvalidRequirement` exception if `requirement_string` contains
 689          # per-requirement options (ex: package hashes)
 690          # GitHub issue: https://github.com/pypa/packaging/issues/488
 691          # Per-requirement-option spec: https://pip.pypa.io/en/stable/reference/requirements-file-format/#per-requirement-options
 692          requirement_specifier = _get_pip_requirement_specifier(requirement_string)
 693          try:
 694              # Try again with the per-requirement options removed
 695              return Requirement(requirement_specifier).name.lower() == "mlflow"
 696          except InvalidRequirement:
 697              # Support defining branch dependencies for local builds or direct GitHub builds
 698              # from source.
 699              # Example: mlflow @ git+https://github.com/mlflow/mlflow@branch_2.0
 700              repository_matches = ["/mlflow", "mlflow@git"]
 701  
 702              return any(
 703                  match in requirement_string.replace(" ", "").lower() for match in repository_matches
 704              )
 705  
 706  
 707  def _generate_mlflow_version_pinning() -> str:
 708      """Returns a pinned requirement for the current MLflow version (e.g., "mlflow==3.2.1").
 709  
 710      Returns:
 711          A pinned requirement for the current MLflow version.
 712  
 713      """
 714      if _MLFLOW_TESTING.get():
 715          # The local PyPI server should be running. It serves a wheel for the current MLflow version.
 716          return f"mlflow=={VERSION}"
 717  
 718      version = Version(VERSION)
 719      if not version.is_devrelease:
 720          # mlflow is installed from PyPI.
 721          return f"mlflow=={VERSION}"
 722  
 723      # We reach here when mlflow is installed from the source outside of the MLflow CI environment
 724      # (e.g., Databricks notebook).
 725  
 726      # mlflow installed from the source for development purposes. A dev version (e.g., 2.8.1.dev0)
 727      # is always a micro-version ahead of the latest release (unless it's manually modified)
 728      # and can't be installed from PyPI. We therefore subtract 1 from the micro version when running
 729      # tests.
 730      return f"mlflow=={version.major}.{version.minor}.{version.micro - 1}"
 731  
 732  
 733  def _contains_mlflow_requirement(requirements):
 734      """
 735      Returns True if `requirements` contains a requirement for mlflow (e.g. 'mlflow==1.2.3').
 736      """
 737      return any(map(_is_mlflow_requirement, requirements))
 738  
 739  
 740  def _process_pip_requirements(
 741      default_pip_requirements, pip_requirements=None, extra_pip_requirements=None
 742  ):
 743      """
 744      Processes `pip_requirements` and `extra_pip_requirements` passed to `mlflow.*.save_model` or
 745      `mlflow.*.log_model`, and returns a tuple of (conda_env, pip_requirements, pip_constraints).
 746      """
 747      constraints = []
 748      if pip_requirements is not None:
 749          pip_reqs, constraints = _parse_pip_requirements(pip_requirements)
 750      elif extra_pip_requirements is not None:
 751          extra_pip_requirements, constraints = _parse_pip_requirements(extra_pip_requirements)
 752          pip_reqs = default_pip_requirements + extra_pip_requirements
 753      else:
 754          pip_reqs = default_pip_requirements
 755  
 756      if not _contains_mlflow_requirement(pip_reqs):
 757          pip_reqs.insert(0, _generate_mlflow_version_pinning())
 758  
 759      sanitized_pip_reqs = _deduplicate_requirements(pip_reqs)
 760      sanitized_pip_reqs = _remove_incompatible_requirements(sanitized_pip_reqs)
 761  
 762      # Check if pip requirements contain incompatible version with the current environment
 763      warn_dependency_requirement_mismatches(sanitized_pip_reqs)
 764  
 765      if locked_requirements := _lock_requirements(sanitized_pip_reqs, constraints):
 766          # Locking requirements was performed successfully
 767          sanitized_pip_reqs = locked_requirements
 768      else:
 769          # Locking requirements was skipped or failed
 770          if constraints:
 771              sanitized_pip_reqs.append(f"-c {_CONSTRAINTS_FILE_NAME}")
 772  
 773      # Set `install_mlflow` to False because `pip_reqs` already contains `mlflow`
 774      conda_env = _mlflow_conda_env(additional_pip_deps=sanitized_pip_reqs, install_mlflow=False)
 775      return conda_env, sanitized_pip_reqs, constraints
 776  
 777  
 778  def _deduplicate_requirements(requirements):
 779      """
 780      De-duplicates a list of pip package requirements, handling complex scenarios such as merging
 781      extras and combining version constraints.
 782  
 783      This function processes a list of pip package requirements and de-duplicates them. It handles
 784      standard PyPI packages and non-standard requirements (like URLs or local paths). The function
 785      merges extras and combines version constraints for duplicate packages. The most restrictive
 786      version specifications or the ones with extras are prioritized. If incompatible version
 787      constraints are detected, it raises an MlflowException.
 788  
 789      Args:
 790          requirements (list of str): A list of pip package requirement strings.
 791  
 792      Returns:
 793          list of str: A deduplicated list of pip package requirements.
 794  
 795      Raises:
 796          MlflowException: If incompatible version constraints are detected among the provided
 797                           requirements.
 798  
 799      Examples:
 800          - Input: ["packageA", "packageA==1.0"]
 801            Output: ["packageA==1.0"]
 802  
 803          - Input: ["packageX>1.0", "packageX[extras]", "packageX<2.0"]
 804            Output: ["packageX[extras]<2.0,>1.0"]
 805  
 806          - Input: ["markdown[extra1]>=3.5.1", "markdown[extra2]<4", "markdown"]
 807            Output: ["markdown[extra1,extra2]<4,>=3.5.1"]
 808  
 809          - Input: ["scikit-learn==1.1", "scikit-learn<1"]
 810            Raises MlflowException indicating incompatible versions.
 811  
 812      Note:
 813          - Non-standard requirements (like URLs or file paths) are included as-is.
 814          - If a requirement appears multiple times with different sets of extras, they are merged.
 815          - The function uses `_validate_version_constraints` to check for incompatible version
 816            constraints by doing a dry-run pip install of a requirements collection.
 817      """
 818      deduped_reqs = {}
 819  
 820      for req in requirements:
 821          try:
 822              parsed_req = Requirement(req)
 823              base_pkg = parsed_req.name
 824              key = (base_pkg, str(parsed_req.marker) if parsed_req.marker else "")
 825  
 826              existing_req = deduped_reqs.get(key)
 827  
 828              if not existing_req:
 829                  deduped_reqs[key] = parsed_req
 830              else:
 831                  # Verify that there are not unresolvable constraints applied if set and combine
 832                  # if possible
 833                  if (
 834                      existing_req.specifier
 835                      and parsed_req.specifier
 836                      and existing_req.specifier != parsed_req.specifier
 837                  ):
 838                      existing_specs = list(existing_req.specifier)
 839                      new_specs = list(parsed_req.specifier)
 840                      # When uv export preserves local version labels (e.g. torch==2.7.1+cu128)
 841                      # but _get_pinned_requirement strips them (e.g. torch==2.7.1), both end up
 842                      # in the merged list. Detect this case and prefer the non-local version
 843                      # (PyPI-installable) rather than failing validation.
 844                      if (
 845                          len(existing_specs) == 1
 846                          and len(new_specs) == 1
 847                          and existing_specs[0].operator == "=="
 848                          and new_specs[0].operator == "=="
 849                          and _strip_local_version_label(existing_specs[0].version)
 850                          == _strip_local_version_label(new_specs[0].version)
 851                          and bool(_get_local_version_label(existing_specs[0].version))
 852                          != bool(_get_local_version_label(new_specs[0].version))
 853                      ):
 854                          # Keep whichever specifier has no local label (PyPI-installable)
 855                          if local_label := _get_local_version_label(new_specs[0].version):
 856                              _logger.debug(
 857                                  f"Dropping local version label (+{local_label}) from "
 858                                  f"'{parsed_req.name}=={new_specs[0].version}' to keep the "
 859                                  f"PyPI-installable version "
 860                                  f"'{parsed_req.name}=={existing_specs[0].version}'."
 861                              )
 862                              parsed_req.specifier = existing_req.specifier
 863                      else:
 864                          _validate_version_constraints([str(existing_req), req])
 865                          parsed_req.specifier = SpecifierSet(
 866                              ",".join([
 867                                  str(existing_req.specifier),
 868                                  str(parsed_req.specifier),
 869                              ])
 870                          )
 871  
 872                  # Preserve existing specifiers
 873                  if existing_req.specifier and not parsed_req.specifier:
 874                      parsed_req.specifier = existing_req.specifier
 875  
 876                  # Combine and apply extras if specified
 877                  if (
 878                      existing_req.extras
 879                      and parsed_req.extras
 880                      and existing_req.extras != parsed_req.extras
 881                  ):
 882                      parsed_req.extras = sorted(set(existing_req.extras).union(parsed_req.extras))
 883                  elif existing_req.extras and not parsed_req.extras:
 884                      parsed_req.extras = existing_req.extras
 885  
 886                  deduped_reqs[key] = parsed_req
 887  
 888          except InvalidRequirement:
 889              # Include non-standard package strings as-is
 890              if req not in deduped_reqs:
 891                  deduped_reqs[req] = req
 892      return [str(req) for req in deduped_reqs.values()]
 893  
 894  
 895  def _parse_requirement_name(req: str) -> str:
 896      try:
 897          return Requirement(req).name
 898      except InvalidRequirement:
 899          return req
 900  
 901  
 902  def _remove_incompatible_requirements(requirements: list[str]) -> list[str]:
 903      req_names = {_parse_requirement_name(req) for req in requirements}
 904      if "databricks-connect" in req_names and req_names.intersection({"pyspark", "pyspark-connect"}):
 905          _logger.debug(
 906              "Found incompatible requirements: 'databricks-connect' with 'pyspark' or "
 907              "'pyspark-connect'. Removing 'pyspark' or 'pyspark-connect' from the requirements."
 908          )
 909          requirements = [
 910              req
 911              for req in requirements
 912              if _parse_requirement_name(req) not in ["pyspark", "pyspark-connect"]
 913          ]
 914      return requirements
 915  
 916  
 917  def _validate_version_constraints(requirements):
 918      """
 919      Validates the version constraints of given Python package requirements using pip's resolver with
 920      the `--dry-run` option enabled that performs validation only (will not install packages).
 921  
 922      This function writes the requirements to a temporary file and then attempts to resolve
 923      them using pip's `--dry-run` install option. If any version conflicts are detected, it
 924      raises an MlflowException with details of the conflict.
 925  
 926      Args:
 927          requirements (list of str): A list of package requirements (e.g., `["pandas>=1.15",
 928          "pandas<2"]`).
 929  
 930      Raises:
 931          MlflowException: If any version conflicts are detected among the provided requirements.
 932  
 933      Returns:
 934          None: This function does not return anything. It either completes successfully or raises
 935          an MlflowException.
 936  
 937      Example:
 938          _validate_version_constraints(["tensorflow<2.0", "tensorflow>2.3"])
 939          # This will raise an exception due to boundary validity.
 940      """
 941      with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp_file:
 942          tmp_file.write("\n".join(requirements))
 943          tmp_file_name = tmp_file.name
 944  
 945      try:
 946          subprocess.run(
 947              [sys.executable, "-m", "pip", "install", "--dry-run", "-r", tmp_file_name],
 948              check=True,
 949              capture_output=True,
 950          )
 951      except subprocess.CalledProcessError as e:
 952          raise MlflowException.invalid_parameter_value(
 953              "The specified requirements versions are incompatible. Detected "
 954              f"conflicts: \n{e.stderr.decode()}"
 955          )
 956      finally:
 957          os.remove(tmp_file_name)
 958  
 959  
 960  def _process_conda_env(conda_env):
 961      """
 962      Processes `conda_env` passed to `mlflow.*.save_model` or `mlflow.*.log_model`, and returns
 963      a tuple of (conda_env, pip_requirements, pip_constraints).
 964      """
 965      if isinstance(conda_env, str):
 966          with open(conda_env) as f:
 967              conda_env = yaml.safe_load(f)
 968      elif not isinstance(conda_env, dict):
 969          raise TypeError(
 970              "Expected a string path to a conda env yaml file or a `dict` representing a conda env, "
 971              f"but got `{type(conda_env).__name__}`"
 972          )
 973  
 974      # User-specified `conda_env` may contain requirements/constraints file references
 975      pip_reqs = _get_pip_deps(conda_env)
 976      pip_reqs, constraints = _parse_pip_requirements(pip_reqs)
 977      if not _contains_mlflow_requirement(pip_reqs):
 978          pip_reqs.insert(0, _generate_mlflow_version_pinning())
 979  
 980      # Check if pip requirements contain incompatible version with the current environment
 981      warn_dependency_requirement_mismatches(pip_reqs)
 982  
 983      if constraints:
 984          pip_reqs.append(f"-c {_CONSTRAINTS_FILE_NAME}")
 985  
 986      conda_env = _overwrite_pip_deps(conda_env, pip_reqs)
 987      return conda_env, pip_reqs, constraints
 988  
 989  
 990  def _get_mlflow_env_name(s):
 991      """Creates an environment name for an MLflow model by hashing the given string.
 992  
 993      Args:
 994          s: String to hash (e.g. the content of `conda.yaml`).
 995  
 996      Returns:
 997          String in the form of "mlflow-{hash}"
 998          (e.g. "mlflow-da39a3ee5e6b4b0d3255bfef95601890afd80709")
 999  
1000      """
1001      return "mlflow-" + hashlib.sha1(s.encode("utf-8"), usedforsecurity=False).hexdigest()
1002  
1003  
1004  def _get_pip_install_mlflow():
1005      """
1006      Returns a command to pip-install mlflow. If the MLFLOW_HOME environment variable exists,
1007      returns "pip install -e {MLFLOW_HOME} 1>&2", otherwise
1008      "pip install mlflow=={mlflow.__version__} 1>&2".
1009      """
1010      if mlflow_home := os.environ.get("MLFLOW_HOME"):  # dev version
1011          return f"pip install -e {mlflow_home} 1>&2"
1012      else:
1013          return f"pip install mlflow=={VERSION} 1>&2"
1014  
1015  
1016  def _get_requirements_from_file(
1017      file_path: pathlib.Path,
1018  ) -> list[Requirement]:
1019      data = file_path.read_text()
1020      if file_path.name == _CONDA_ENV_FILE_NAME:
1021          conda_env = yaml.safe_load(data)
1022          reqs = _get_pip_deps(conda_env)
1023      else:
1024          reqs = data.splitlines()
1025      return [Requirement(req) for req in reqs if req]
1026  
1027  
1028  def _write_requirements_to_file(
1029      file_path: pathlib.Path,
1030      new_reqs: list[str],
1031  ) -> None:
1032      if file_path.name == _CONDA_ENV_FILE_NAME:
1033          conda_env = yaml.safe_load(file_path.read_text())
1034          conda_env = _overwrite_pip_deps(conda_env, new_reqs)
1035          with file_path.open("w") as file:
1036              yaml.dump(conda_env, file)
1037      else:
1038          file_path.write_text("\n".join(new_reqs))
1039  
1040  
1041  def _add_or_overwrite_requirements(
1042      new_reqs: list[Requirement],
1043      old_reqs: list[Requirement],
1044  ) -> list[str]:
1045      deduped_new_reqs = _deduplicate_requirements([str(req) for req in new_reqs])
1046      deduped_new_reqs = [Requirement(req) for req in deduped_new_reqs]
1047  
1048      old_reqs_dict = {req.name: str(req) for req in old_reqs}
1049      new_reqs_dict = {req.name: str(req) for req in deduped_new_reqs}
1050      old_reqs_dict.update(new_reqs_dict)
1051      return list(old_reqs_dict.values())
1052  
1053  
1054  def _remove_requirements(
1055      reqs_to_remove: list[Requirement],
1056      old_reqs: list[Requirement],
1057  ) -> list[str]:
1058      old_reqs_dict = {req.name: str(req) for req in old_reqs}
1059      for req in reqs_to_remove:
1060          if req.name not in old_reqs_dict:
1061              _logger.warning(f'"{req.name}" not found in requirements, ignoring')
1062          old_reqs_dict.pop(req.name, None)
1063      return list(old_reqs_dict.values())
1064  
1065  
1066  class Environment:
1067      def __init__(self, activate_cmd, extra_env=None):
1068          if not isinstance(activate_cmd, list):
1069              activate_cmd = [activate_cmd]
1070          self._activate_cmd = activate_cmd
1071          self._extra_env = extra_env or {}
1072  
1073      def get_activate_command(self):
1074          return self._activate_cmd
1075  
1076      def execute(
1077          self,
1078          command,
1079          command_env=None,
1080          preexec_fn=None,
1081          capture_output=False,
1082          stdout=None,
1083          stderr=None,
1084          stdin=None,
1085          synchronous=True,
1086      ):
1087          command_env = os.environ.copy() if command_env is None else deepcopy(command_env)
1088          if is_in_databricks_runtime():
1089              command_env.update(get_databricks_env_vars(get_tracking_uri()))
1090          if is_databricks_connect():
1091              command_env.update(_get_databricks_serverless_env_vars())
1092          if exp_id := _get_experiment_id():
1093              command_env[MLFLOW_EXPERIMENT_ID.name] = exp_id
1094          if active_model_id := get_active_model_id():
1095              command_env[_MLFLOW_ACTIVE_MODEL_ID.name] = active_model_id
1096          command_env.update(self._extra_env)
1097          if not isinstance(command, list):
1098              command = [command]
1099  
1100          separator = " && " if not is_windows() else " & "
1101  
1102          command = separator.join(map(str, self._activate_cmd + command))
1103          command = ["bash", "-c", command] if not is_windows() else ["cmd", "/c", command]
1104          _logger.info("=== Running command '%s'", command)
1105          return _exec_cmd(
1106              command,
1107              env=command_env,
1108              capture_output=capture_output,
1109              synchronous=synchronous,
1110              preexec_fn=preexec_fn,
1111              close_fds=True,
1112              stdout=stdout,
1113              stderr=stderr,
1114              stdin=stdin,
1115          )