/ mlflow / utils / env_pack.py
env_pack.py
  1  import shutil
  2  import subprocess
  3  import sys
  4  import tarfile
  5  import tempfile
  6  from contextlib import contextmanager
  7  from dataclasses import dataclass
  8  from pathlib import Path
  9  from typing import Generator, Literal
 10  
 11  import yaml
 12  
 13  from mlflow.artifacts import download_artifacts
 14  from mlflow.exceptions import MlflowException
 15  from mlflow.models.model import MLMODEL_FILE_NAME
 16  from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
 17  from mlflow.utils.databricks_utils import DatabricksRuntimeVersion, get_databricks_runtime_version
 18  from mlflow.utils.environment import _REQUIREMENTS_FILE_NAME
 19  from mlflow.utils.logging_utils import eprint
 20  
 21  EnvPackType = Literal["databricks_model_serving"]
 22  
 23  
 24  @dataclass(kw_only=True)
 25  class EnvPackConfig:
 26      name: EnvPackType
 27      install_dependencies: bool = True
 28  
 29  
 30  _ARTIFACT_PATH = "_databricks"
 31  _MODEL_VERSION_TAR = "model_version.tar"
 32  _MODEL_ENVIRONMENT_TAR = "model_environment.tar"
 33  
 34  
 35  def _validate_env_pack(env_pack):
 36      """Checks if env_pack is a supported value
 37  
 38      Supported values are:
 39      - the string "databricks_model_serving"
 40      - an ``EnvPackConfig`` with ``name == 'databricks_model_serving'`` and a boolean
 41        ``install_dependencies`` field.
 42      - None
 43      """
 44      if env_pack is None:
 45          return None
 46  
 47      if isinstance(env_pack, str):
 48          if env_pack == "databricks_model_serving":
 49              return EnvPackConfig(name="databricks_model_serving", install_dependencies=True)
 50          raise MlflowException.invalid_parameter_value(
 51              f"Invalid env_pack value: {env_pack!r}. Expected: 'databricks_model_serving'."
 52          )
 53  
 54      if isinstance(env_pack, EnvPackConfig):
 55          if env_pack.name != "databricks_model_serving":
 56              raise MlflowException.invalid_parameter_value(
 57                  f"Invalid EnvPackConfig.name: {env_pack.name!r}. "
 58                  "Expected 'databricks_model_serving'."
 59              )
 60          if not isinstance(env_pack.install_dependencies, bool):
 61              raise MlflowException.invalid_parameter_value(
 62                  "EnvPackConfig.install_dependencies must be a bool."
 63              )
 64          return env_pack
 65  
 66      # Anything else is invalid
 67      raise MlflowException.invalid_parameter_value(
 68          "env_pack must be either None, the string 'databricks_model_serving', or an EnvPackConfig "
 69          "with a boolean 'install_dependencies' field."
 70      )
 71  
 72  
 73  def _tar(root_path: Path, tar_path: Path) -> tarfile.TarFile:
 74      """
 75      Package all files under root_path into a tar at tar_path, excluding __pycache__, *.pyc, and
 76      wheels_info.json.
 77      """
 78  
 79      def exclude(tarinfo: tarfile.TarInfo):
 80          name = tarinfo.name
 81          base = Path(name).name
 82          if "__pycache__" in name or base.endswith(".pyc") or base == "wheels_info.json":
 83              return None
 84          return tarinfo
 85  
 86      # Pull in symlinks
 87      with tarfile.open(tar_path, "w", dereference=True) as tar:
 88          tar.add(root_path, arcname=".", filter=exclude)
 89      return tar
 90  
 91  
 92  @contextmanager
 93  def _get_source_artifacts(
 94      model_uri: str, local_model_path: str | None = None
 95  ) -> Generator[Path, None, None]:
 96      """
 97      Get source artifacts and handle cleanup of downloads.
 98      Does not mutate local_model_path contents if provided.
 99  
100      Args:
101          model_uri: The URI of the model to package.
102          local_model_path: Optional local path to model artifacts.
103  
104      Yields:
105          Path: The path to the source artifacts directory.
106      """
107      source_dir = Path(local_model_path or download_artifacts(artifact_uri=model_uri))
108  
109      yield source_dir
110  
111      if not local_model_path:
112          shutil.rmtree(source_dir)
113  
114  
115  # TODO: Check pip requirements using uv instead.
116  @contextmanager
117  def pack_env_for_databricks_model_serving(
118      model_uri: str,
119      *,
120      enforce_pip_requirements: bool = False,
121      local_model_path: str | None = None,
122  ) -> Generator[str, None, None]:
123      """
124      Generate Databricks artifacts for fast deployment.
125  
126      Args:
127          model_uri: The URI of the model to package.
128          enforce_pip_requirements: Whether to enforce pip requirements installation.
129          local_model_path: Optional local path to model artifacts. If provided, pack
130              the local artifacts instead of downloading.
131  
132      Yields:
133          str: The path to the local artifacts directory containing the model artifacts and
134              environment.
135  
136      Example:
137          >>> with pack_env_for_databricks_model_serving("models:/my-model/1") as artifacts_dir:
138          ...     # Use artifacts_dir here
139          ...     pass
140      """
141      dbr_version = DatabricksRuntimeVersion.parse()
142      if not dbr_version.is_client_image:
143          raise ValueError(
144              f"Serverless environment is required when packing environment for Databricks Model "
145              f"Serving. Current version: {dbr_version}"
146          )
147  
148      with _get_source_artifacts(model_uri, local_model_path) as source_artifacts_dir:
149          # Check runtime version consistency
150          # We read the MLmodel file directly instead of using Model.to_dict() because to_dict() adds
151          # the current runtime version via get_databricks_runtime_version(), which would prevent us
152          # from detecting runtime version mismatches.
153          mlmodel_path = source_artifacts_dir / MLMODEL_FILE_NAME
154          with open(mlmodel_path) as f:
155              model_dict = yaml.safe_load(f)
156          if "databricks_runtime" not in model_dict:
157              raise ValueError(
158                  "Model must have been created in a Databricks runtime environment. "
159                  "Missing 'databricks_runtime' field in MLmodel file."
160              )
161  
162          current_runtime = DatabricksRuntimeVersion.parse()
163          model_runtime = DatabricksRuntimeVersion.parse(model_dict["databricks_runtime"])
164          if current_runtime.major != model_runtime.major:
165              raise ValueError(
166                  f"Runtime version mismatch. Model was created with runtime "
167                  f"{model_dict['databricks_runtime']} (major version {model_runtime.major}), "
168                  f"but current runtime is {get_databricks_runtime_version()} "
169                  f"(major version {current_runtime.major})"
170              )
171  
172          # Check that _databricks directory does not exist in source
173          if (source_artifacts_dir / _ARTIFACT_PATH).exists():
174              raise MlflowException(
175                  f"Source artifacts contain a '{_ARTIFACT_PATH}' directory and is not "
176                  "eligible for use with env_pack.",
177                  error_code=INVALID_PARAMETER_VALUE,
178              )
179  
180          if enforce_pip_requirements:
181              eprint("Installing model requirements...")
182              try:
183                  subprocess.run(
184                      [
185                          sys.executable,
186                          "-m",
187                          "pip",
188                          "install",
189                          "-r",
190                          str(source_artifacts_dir / _REQUIREMENTS_FILE_NAME),
191                      ],
192                      check=True,
193                      stdout=subprocess.PIPE,
194                      stderr=subprocess.STDOUT,
195                      text=True,
196                  )
197              except subprocess.CalledProcessError as e:
198                  eprint("Error installing requirements:")
199                  eprint(e.stdout)
200                  raise
201  
202          with tempfile.TemporaryDirectory() as temp_dir:
203              # Copy source artifacts to packaged_model_dir
204              packaged_model_dir = Path(temp_dir) / "model"
205              shutil.copytree(
206                  source_artifacts_dir, packaged_model_dir, dirs_exist_ok=False, symlinks=False
207              )
208  
209              # Package model artifacts and env into packaged_model_dir/_databricks
210              packaged_artifacts_dir = packaged_model_dir / _ARTIFACT_PATH
211              packaged_artifacts_dir.mkdir(exist_ok=False)
212              _tar(source_artifacts_dir, packaged_artifacts_dir / _MODEL_VERSION_TAR)
213              _tar(Path(sys.prefix), packaged_artifacts_dir / _MODEL_ENVIRONMENT_TAR)
214  
215              yield str(packaged_model_dir)