env_pack.py
1 import shutil 2 import subprocess 3 import sys 4 import tarfile 5 import tempfile 6 from contextlib import contextmanager 7 from dataclasses import dataclass 8 from pathlib import Path 9 from typing import Generator, Literal 10 11 import yaml 12 13 from mlflow.artifacts import download_artifacts 14 from mlflow.exceptions import MlflowException 15 from mlflow.models.model import MLMODEL_FILE_NAME 16 from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE 17 from mlflow.utils.databricks_utils import DatabricksRuntimeVersion, get_databricks_runtime_version 18 from mlflow.utils.environment import _REQUIREMENTS_FILE_NAME 19 from mlflow.utils.logging_utils import eprint 20 21 EnvPackType = Literal["databricks_model_serving"] 22 23 24 @dataclass(kw_only=True) 25 class EnvPackConfig: 26 name: EnvPackType 27 install_dependencies: bool = True 28 29 30 _ARTIFACT_PATH = "_databricks" 31 _MODEL_VERSION_TAR = "model_version.tar" 32 _MODEL_ENVIRONMENT_TAR = "model_environment.tar" 33 34 35 def _validate_env_pack(env_pack): 36 """Checks if env_pack is a supported value 37 38 Supported values are: 39 - the string "databricks_model_serving" 40 - an ``EnvPackConfig`` with ``name == 'databricks_model_serving'`` and a boolean 41 ``install_dependencies`` field. 42 - None 43 """ 44 if env_pack is None: 45 return None 46 47 if isinstance(env_pack, str): 48 if env_pack == "databricks_model_serving": 49 return EnvPackConfig(name="databricks_model_serving", install_dependencies=True) 50 raise MlflowException.invalid_parameter_value( 51 f"Invalid env_pack value: {env_pack!r}. Expected: 'databricks_model_serving'." 52 ) 53 54 if isinstance(env_pack, EnvPackConfig): 55 if env_pack.name != "databricks_model_serving": 56 raise MlflowException.invalid_parameter_value( 57 f"Invalid EnvPackConfig.name: {env_pack.name!r}. " 58 "Expected 'databricks_model_serving'." 59 ) 60 if not isinstance(env_pack.install_dependencies, bool): 61 raise MlflowException.invalid_parameter_value( 62 "EnvPackConfig.install_dependencies must be a bool." 63 ) 64 return env_pack 65 66 # Anything else is invalid 67 raise MlflowException.invalid_parameter_value( 68 "env_pack must be either None, the string 'databricks_model_serving', or an EnvPackConfig " 69 "with a boolean 'install_dependencies' field." 70 ) 71 72 73 def _tar(root_path: Path, tar_path: Path) -> tarfile.TarFile: 74 """ 75 Package all files under root_path into a tar at tar_path, excluding __pycache__, *.pyc, and 76 wheels_info.json. 77 """ 78 79 def exclude(tarinfo: tarfile.TarInfo): 80 name = tarinfo.name 81 base = Path(name).name 82 if "__pycache__" in name or base.endswith(".pyc") or base == "wheels_info.json": 83 return None 84 return tarinfo 85 86 # Pull in symlinks 87 with tarfile.open(tar_path, "w", dereference=True) as tar: 88 tar.add(root_path, arcname=".", filter=exclude) 89 return tar 90 91 92 @contextmanager 93 def _get_source_artifacts( 94 model_uri: str, local_model_path: str | None = None 95 ) -> Generator[Path, None, None]: 96 """ 97 Get source artifacts and handle cleanup of downloads. 98 Does not mutate local_model_path contents if provided. 99 100 Args: 101 model_uri: The URI of the model to package. 102 local_model_path: Optional local path to model artifacts. 103 104 Yields: 105 Path: The path to the source artifacts directory. 106 """ 107 source_dir = Path(local_model_path or download_artifacts(artifact_uri=model_uri)) 108 109 yield source_dir 110 111 if not local_model_path: 112 shutil.rmtree(source_dir) 113 114 115 # TODO: Check pip requirements using uv instead. 116 @contextmanager 117 def pack_env_for_databricks_model_serving( 118 model_uri: str, 119 *, 120 enforce_pip_requirements: bool = False, 121 local_model_path: str | None = None, 122 ) -> Generator[str, None, None]: 123 """ 124 Generate Databricks artifacts for fast deployment. 125 126 Args: 127 model_uri: The URI of the model to package. 128 enforce_pip_requirements: Whether to enforce pip requirements installation. 129 local_model_path: Optional local path to model artifacts. If provided, pack 130 the local artifacts instead of downloading. 131 132 Yields: 133 str: The path to the local artifacts directory containing the model artifacts and 134 environment. 135 136 Example: 137 >>> with pack_env_for_databricks_model_serving("models:/my-model/1") as artifacts_dir: 138 ... # Use artifacts_dir here 139 ... pass 140 """ 141 dbr_version = DatabricksRuntimeVersion.parse() 142 if not dbr_version.is_client_image: 143 raise ValueError( 144 f"Serverless environment is required when packing environment for Databricks Model " 145 f"Serving. Current version: {dbr_version}" 146 ) 147 148 with _get_source_artifacts(model_uri, local_model_path) as source_artifacts_dir: 149 # Check runtime version consistency 150 # We read the MLmodel file directly instead of using Model.to_dict() because to_dict() adds 151 # the current runtime version via get_databricks_runtime_version(), which would prevent us 152 # from detecting runtime version mismatches. 153 mlmodel_path = source_artifacts_dir / MLMODEL_FILE_NAME 154 with open(mlmodel_path) as f: 155 model_dict = yaml.safe_load(f) 156 if "databricks_runtime" not in model_dict: 157 raise ValueError( 158 "Model must have been created in a Databricks runtime environment. " 159 "Missing 'databricks_runtime' field in MLmodel file." 160 ) 161 162 current_runtime = DatabricksRuntimeVersion.parse() 163 model_runtime = DatabricksRuntimeVersion.parse(model_dict["databricks_runtime"]) 164 if current_runtime.major != model_runtime.major: 165 raise ValueError( 166 f"Runtime version mismatch. Model was created with runtime " 167 f"{model_dict['databricks_runtime']} (major version {model_runtime.major}), " 168 f"but current runtime is {get_databricks_runtime_version()} " 169 f"(major version {current_runtime.major})" 170 ) 171 172 # Check that _databricks directory does not exist in source 173 if (source_artifacts_dir / _ARTIFACT_PATH).exists(): 174 raise MlflowException( 175 f"Source artifacts contain a '{_ARTIFACT_PATH}' directory and is not " 176 "eligible for use with env_pack.", 177 error_code=INVALID_PARAMETER_VALUE, 178 ) 179 180 if enforce_pip_requirements: 181 eprint("Installing model requirements...") 182 try: 183 subprocess.run( 184 [ 185 sys.executable, 186 "-m", 187 "pip", 188 "install", 189 "-r", 190 str(source_artifacts_dir / _REQUIREMENTS_FILE_NAME), 191 ], 192 check=True, 193 stdout=subprocess.PIPE, 194 stderr=subprocess.STDOUT, 195 text=True, 196 ) 197 except subprocess.CalledProcessError as e: 198 eprint("Error installing requirements:") 199 eprint(e.stdout) 200 raise 201 202 with tempfile.TemporaryDirectory() as temp_dir: 203 # Copy source artifacts to packaged_model_dir 204 packaged_model_dir = Path(temp_dir) / "model" 205 shutil.copytree( 206 source_artifacts_dir, packaged_model_dir, dirs_exist_ok=False, symlinks=False 207 ) 208 209 # Package model artifacts and env into packaged_model_dir/_databricks 210 packaged_artifacts_dir = packaged_model_dir / _ARTIFACT_PATH 211 packaged_artifacts_dir.mkdir(exist_ok=False) 212 _tar(source_artifacts_dir, packaged_artifacts_dir / _MODEL_VERSION_TAR) 213 _tar(Path(sys.prefix), packaged_artifacts_dir / _MODEL_ENVIRONMENT_TAR) 214 215 yield str(packaged_model_dir)