__init__.py
1 """ 2 The ``mlflow.diffusers`` module provides an API for logging and loading diffusion model 3 LoRA adapters as MLflow Models. This module exports adapter models with 4 the following flavors: 5 6 :py:mod:`mlflow.diffusers` 7 Adapter weights in safetensors format, with a reference to the base model. 8 9 :py:mod:`mlflow.pyfunc` 10 Produced for use by generic pyfunc-based deployment tools and batch inference. 11 The pyfunc wrapper loads the base diffusion pipeline and applies the adapter 12 at inference time. 13 """ 14 15 import importlib.util 16 import logging 17 import shutil 18 from dataclasses import dataclass 19 from pathlib import Path 20 from typing import Any, Literal 21 22 import yaml 23 24 import mlflow 25 from mlflow import pyfunc 26 from mlflow.environment_variables import MLFLOW_DEFAULT_PREDICTION_DEVICE 27 from mlflow.exceptions import MlflowException 28 from mlflow.models import Model, ModelInputExample, ModelSignature 29 from mlflow.models.model import MLMODEL_FILE_NAME 30 from mlflow.models.utils import _save_example 31 from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS 32 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 33 from mlflow.types import DataType, ParamSchema, ParamSpec, Schema 34 from mlflow.types.schema import ColSpec 35 from mlflow.utils.docstring_utils import ( 36 LOG_MODEL_PARAM_DOCS, 37 docstring_version_compatibility_warning, 38 format_docstring, 39 ) 40 from mlflow.utils.environment import ( 41 _CONDA_ENV_FILE_NAME, 42 _CONSTRAINTS_FILE_NAME, 43 _PYTHON_ENV_FILE_NAME, 44 _REQUIREMENTS_FILE_NAME, 45 _mlflow_conda_env, 46 _process_conda_env, 47 _process_pip_requirements, 48 _PythonEnv, 49 _validate_env_arguments, 50 ) 51 from mlflow.utils.file_utils import get_total_file_size, write_to 52 from mlflow.utils.model_utils import ( 53 _add_code_from_conf_to_system_path, 54 _get_flavor_configuration, 55 _validate_and_copy_code_paths, 56 _validate_and_prepare_target_save_path, 57 ) 58 from mlflow.utils.requirements_utils import _get_pinned_requirement 59 60 _logger = logging.getLogger(__name__) 61 62 FLAVOR_NAME = "diffusers" 63 64 _ADAPTER_WEIGHTS_DIR = "adapter_weights" 65 _STANDARD_WEIGHT_NAME = "pytorch_lora_weights.safetensors" 66 67 SUPPORTED_ADAPTER_TYPES = ("lora",) 68 69 _BASE_MODEL_REVISION_KEY = "base_model_revision" 70 71 72 def _resolve_base_model_revision(base_model): 73 """Resolve the HuggingFace Hub commit hash for a base model ID. 74 75 Returns None if the ID looks like a local path or if resolution fails. 76 """ 77 # Only treat as a local path if it's absolute or explicitly relative (./ ../). 78 # Bare "org/model" strings should always be resolved as HF Hub IDs, even if 79 # a matching directory happens to exist in the current working directory. 80 p = Path(base_model) 81 if p.is_absolute() or base_model.startswith(("./", "../")): 82 return None 83 84 try: 85 from mlflow.utils.huggingface_utils import get_latest_commit_for_repo 86 87 return get_latest_commit_for_repo(base_model) 88 except Exception as e: 89 # Broad catch is intentional: huggingface_hub types (HfHubHTTPError, 90 # RepositoryNotFoundError) can't be imported unconditionally. 91 # Revision pinning is optional — graceful degradation is preferred. 92 _logger.warning( 93 "Could not resolve HuggingFace commit hash for '%s' (%s). " 94 "The base model revision will not be pinned.", 95 base_model, 96 type(e).__name__, 97 ) 98 return None 99 100 101 def _validate_safetensors_format(file_path): 102 try: 103 from safetensors import safe_open 104 except ImportError as e: 105 raise MlflowException.invalid_parameter_value( 106 "The 'safetensors' package is required to validate adapter weights. " 107 "Install it with: pip install safetensors" 108 ) from e 109 110 try: 111 with safe_open(str(file_path), framework="numpy"): 112 pass 113 except Exception as e: 114 raise MlflowException.invalid_parameter_value( 115 f"File is not a valid safetensors file: {file_path}. Error: {e}" 116 ) from e 117 118 119 def _detect_device(device=None): 120 import torch 121 122 if device is not None: 123 return device 124 if env_device := MLFLOW_DEFAULT_PREDICTION_DEVICE.get(): 125 return env_device 126 if torch.cuda.is_available(): 127 return "cuda" 128 try: 129 if torch.backends.mps.is_available(): 130 return "mps" 131 except AttributeError: 132 pass 133 return "cpu" 134 135 136 def _get_default_signature(): 137 return ModelSignature( 138 inputs=Schema([ColSpec(type=DataType.string, name="prompt")]), 139 outputs=Schema([ColSpec(type=DataType.binary, name="image")]), 140 params=ParamSchema([ 141 ParamSpec(name="num_inference_steps", dtype=DataType.integer, default=30), 142 ParamSpec(name="guidance_scale", dtype=DataType.double, default=7.5), 143 ParamSpec(name="height", dtype=DataType.integer, default=512), 144 ParamSpec(name="width", dtype=DataType.integer, default=512), 145 ParamSpec(name="negative_prompt", dtype=DataType.string, default=""), 146 ]), 147 ) 148 149 150 def get_default_pip_requirements(): 151 # peft: load_lora_weights() depends on it; safetensors: adapter format + validation 152 packages = ["diffusers", "transformers", "torch", "peft", "safetensors"] 153 packages.extend(pkg for pkg in ["accelerate"] if importlib.util.find_spec(pkg)) 154 return [_get_pinned_requirement(pkg) for pkg in packages] 155 156 157 def get_default_conda_env(): 158 return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements()) 159 160 161 @dataclass(frozen=True) 162 class DiffusersAdapterModel: 163 """A loaded LoRA adapter referencing a HuggingFace base model. 164 165 Returned by :py:func:`load_model`. Call :py:meth:`load_pipeline` to get 166 a ready-to-use diffusers pipeline with the adapter applied. 167 """ 168 169 adapter_path: str 170 base_model: str 171 adapter_type: Literal["lora"] 172 base_model_revision: str | None = None 173 weight_name: str | None = None 174 175 def load_pipeline(self, *, base_model: str | None = None, **kwargs): 176 """Download the base model and apply the LoRA adapter. 177 178 Args: 179 base_model: Override the base model reference stored at save time. 180 Useful when the original local path is no longer available. 181 Accepts a HuggingFace model ID or a local directory path. 182 kwargs: Forwarded to ``DiffusionPipeline.from_pretrained()``. 183 Common options include ``device``, ``torch_dtype``, and ``revision``. 184 185 Returns: 186 A ``DiffusionPipeline`` with LoRA weights applied. 187 """ 188 from diffusers import DiffusionPipeline 189 190 effective_base_model = base_model or self.base_model 191 device = _detect_device(kwargs.pop("device", None)) 192 kwargs.setdefault("torch_dtype", "auto") 193 if self.base_model_revision and "revision" not in kwargs: 194 kwargs["revision"] = self.base_model_revision 195 196 try: 197 pipe = DiffusionPipeline.from_pretrained(effective_base_model, **kwargs) 198 except OSError as e: 199 raise MlflowException( 200 f"Failed to load base model '{effective_base_model}'. If the model " 201 "has moved, pass the correct location via " 202 "load_pipeline(base_model=...)." 203 ) from e 204 205 lora_kwargs = {} 206 if self.weight_name: 207 lora_kwargs["weight_name"] = self.weight_name 208 pipe.load_lora_weights(self.adapter_path, **lora_kwargs) 209 return pipe.to(device) 210 211 212 @docstring_version_compatibility_warning(integration_name=FLAVOR_NAME) 213 @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="diffusers")) 214 def save_model( 215 adapter_path: str, 216 path: str, 217 base_model: str, 218 adapter_type: Literal["lora"] = "lora", 219 conda_env=None, 220 code_paths: list[str] | None = None, 221 mlflow_model: Model | None = None, 222 signature: ModelSignature | None = None, 223 input_example: ModelInputExample | None = None, 224 pip_requirements: list[str] | str | None = None, 225 extra_pip_requirements: list[str] | str | None = None, 226 metadata: dict[str, Any] | None = None, 227 ) -> None: 228 """Save a diffusers adapter model to a path on the local file system. 229 230 Args: 231 adapter_path: Path to the adapter weights. Can be a single .safetensors file 232 or a directory containing adapter files. Single files and directories 233 containing a single safetensors file are normalized to 234 ``pytorch_lora_weights.safetensors`` to match the convention expected 235 by ``load_lora_weights()``. Directories with multiple weight files 236 are copied as-is. 237 path: Local path where the model is to be saved. 238 base_model: HuggingFace model ID or local path of the base diffusion model 239 that this adapter was trained on (e.g., "black-forest-labs/FLUX.1-dev"). 240 adapter_type: Type of adapter. Currently only "lora" is supported. 241 conda_env: {{ conda_env }} 242 code_paths: {{ code_paths }} 243 mlflow_model: :py:mod:`mlflow.models.Model` this flavor is being added to. 244 signature: {{ signature }} 245 input_example: {{ input_example }} 246 pip_requirements: {{ pip_requirements }} 247 extra_pip_requirements: {{ extra_pip_requirements }} 248 metadata: {{ metadata }} 249 """ 250 try: 251 import diffusers 252 except ImportError as e: 253 raise MlflowException.invalid_parameter_value( 254 "The 'diffusers' package is required to save a diffusers adapter model. " 255 "Install it with: pip install diffusers" 256 ) from e 257 258 try: 259 import peft # noqa: F401 260 except ImportError as e: 261 raise MlflowException.invalid_parameter_value( 262 "The 'peft' package is required to save a diffusers LoRA adapter model. " 263 "Install it with: pip install peft" 264 ) from e 265 266 diffusers_version = diffusers.__version__ 267 268 _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements) 269 270 if not isinstance(base_model, str) or not base_model.strip(): 271 raise MlflowException.invalid_parameter_value( 272 "base_model must be a non-empty string (HuggingFace model ID or local path)." 273 ) 274 275 if not isinstance(adapter_type, str): 276 raise MlflowException.invalid_parameter_value( 277 f"adapter_type must be a string, got {type(adapter_type).__name__}" 278 ) 279 adapter_type = adapter_type.lower() 280 if adapter_type not in SUPPORTED_ADAPTER_TYPES: 281 raise MlflowException.invalid_parameter_value( 282 f"Unsupported adapter type: {adapter_type}. Supported types: {SUPPORTED_ADAPTER_TYPES}" 283 ) 284 285 adapter_path = Path(adapter_path) 286 if not adapter_path.exists(): 287 raise MlflowException.invalid_parameter_value( 288 f"Adapter path does not exist: {adapter_path}" 289 ) 290 291 path = Path(path) 292 293 _validate_and_prepare_target_save_path(path) 294 code_path_subdir = _validate_and_copy_code_paths(code_paths, path) 295 296 if mlflow_model is None: 297 mlflow_model = Model() 298 299 _save_example(mlflow_model, input_example, path) 300 301 if signature is None: 302 signature = _get_default_signature() 303 mlflow_model.signature = signature 304 if metadata is not None: 305 mlflow_model.metadata = metadata 306 307 # Copy adapter weights — normalize to the standard filename that 308 # load_lora_weights() expects, so inference works regardless of 309 # what the training framework named the file. 310 weights_dst = path / _ADAPTER_WEIGHTS_DIR 311 weight_name = None 312 if adapter_path.is_file(): 313 if adapter_path.suffix != ".safetensors": 314 raise MlflowException.invalid_parameter_value( 315 f"Single-file adapter must be a .safetensors file, got: {adapter_path.suffix}" 316 ) 317 _validate_safetensors_format(adapter_path) 318 weights_dst.mkdir(parents=True, exist_ok=True) 319 shutil.copy2(adapter_path, weights_dst / _STANDARD_WEIGHT_NAME) 320 elif adapter_path.is_dir(): 321 # Filter hidden files (.DS_Store, etc.) that break single-file detection 322 all_files = [p for p in adapter_path.iterdir() if not p.name.startswith(".")] 323 safetensor_files = sorted( 324 (p for p in all_files if p.suffix == ".safetensors"), 325 key=lambda p: p.name, 326 ) 327 if not safetensor_files: 328 raise MlflowException.invalid_parameter_value( 329 f"Adapter directory contains no .safetensors files: {adapter_path}" 330 ) 331 for sf in safetensor_files: 332 _validate_safetensors_format(sf) 333 if len(safetensor_files) == 1 and len(all_files) == 1: 334 # Directory with a single safetensors file — normalize its name 335 weights_dst.mkdir(parents=True, exist_ok=True) 336 shutil.copy2(safetensor_files[0], weights_dst / _STANDARD_WEIGHT_NAME) 337 else: 338 # Multiple files or companion files — copy entire directory as-is 339 shutil.copytree(adapter_path, weights_dst) 340 # If no standard weight file exists, record which file 341 # load_lora_weights should target so inference doesn't silently 342 # pick an arbitrary file or fail in offline mode. 343 has_standard = any(sf.name == _STANDARD_WEIGHT_NAME for sf in safetensor_files) 344 if not has_standard: 345 weight_name = safetensor_files[0].name 346 if len(safetensor_files) >= 2: 347 _logger.warning( 348 "Adapter directory contains %d .safetensors files but none named " 349 "'%s'. Will use '%s' as the primary weight file at inference time. " 350 "Consider renaming it to '%s' to avoid ambiguity.", 351 len(safetensor_files), 352 _STANDARD_WEIGHT_NAME, 353 weight_name, 354 _STANDARD_WEIGHT_NAME, 355 ) 356 else: 357 raise MlflowException.invalid_parameter_value( 358 f"Adapter path is neither a file nor a directory: {adapter_path}" 359 ) 360 361 flavor_kwargs = { 362 "base_model": base_model, 363 "adapter_type": adapter_type, 364 "adapter_weights": _ADAPTER_WEIGHTS_DIR, 365 "diffusers_version": diffusers_version, 366 "code": code_path_subdir, 367 } 368 if revision := _resolve_base_model_revision(base_model): 369 flavor_kwargs[_BASE_MODEL_REVISION_KEY] = revision 370 if weight_name: 371 flavor_kwargs["weight_name"] = weight_name 372 mlflow_model.add_flavor(FLAVOR_NAME, **flavor_kwargs) 373 pyfunc.add_to_model( 374 mlflow_model, 375 loader_module="mlflow.diffusers", 376 conda_env=_CONDA_ENV_FILE_NAME, 377 python_env=_PYTHON_ENV_FILE_NAME, 378 code=code_path_subdir, 379 ) 380 381 if size := get_total_file_size(path): 382 mlflow_model.model_size_bytes = size 383 mlflow_model.save(str(path / MLMODEL_FILE_NAME)) 384 385 # Save environment files 386 if conda_env is None: 387 default_reqs = get_default_pip_requirements() if pip_requirements is None else None 388 conda_env, pip_requirements, pip_constraints = _process_pip_requirements( 389 default_reqs, 390 pip_requirements, 391 extra_pip_requirements, 392 ) 393 else: 394 conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env) 395 396 with open(path / _CONDA_ENV_FILE_NAME, "w") as f: 397 yaml.safe_dump(conda_env, stream=f, default_flow_style=False) 398 399 if pip_constraints: 400 write_to(str(path / _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints)) 401 402 write_to(str(path / _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements)) 403 _PythonEnv.current().to_yaml(str(path / _PYTHON_ENV_FILE_NAME)) 404 405 406 @docstring_version_compatibility_warning(integration_name=FLAVOR_NAME) 407 @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="diffusers")) 408 def log_model( 409 adapter_path, 410 base_model, 411 adapter_type: Literal["lora"] = "lora", 412 artifact_path: str | None = None, 413 conda_env=None, 414 code_paths=None, 415 registered_model_name=None, 416 signature: ModelSignature | None = None, 417 input_example: ModelInputExample | None = None, 418 await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, 419 pip_requirements=None, 420 extra_pip_requirements=None, 421 metadata=None, 422 params: dict[str, Any] | None = None, 423 tags: dict[str, Any] | None = None, 424 model_type: str | None = None, 425 step: int = 0, 426 model_id: str | None = None, 427 name: str | None = None, 428 **kwargs, 429 ): 430 """Log a diffusers adapter model as an MLflow artifact for the current run. 431 432 Args: 433 adapter_path: Path to the adapter weights. Can be a single .safetensors file 434 or a directory containing adapter files. 435 base_model: HuggingFace model ID or local path of the base diffusion model. 436 adapter_type: Type of adapter. Currently only "lora" is supported. 437 artifact_path: Deprecated. Use ``name`` instead. 438 conda_env: {{ conda_env }} 439 code_paths: {{ code_paths }} 440 registered_model_name: If given, create a model version under this name. 441 signature: {{ signature }} 442 input_example: {{ input_example }} 443 await_registration_for: Number of seconds to wait for model version creation. 444 pip_requirements: {{ pip_requirements }} 445 extra_pip_requirements: {{ extra_pip_requirements }} 446 metadata: {{ metadata }} 447 params: {{ params }} 448 tags: {{ tags }} 449 model_type: {{ model_type }} 450 step: {{ step }} 451 model_id: {{ model_id }} 452 name: {{ name }} 453 kwargs: Extra arguments to pass to :py:func:`mlflow.models.Model.log`. 454 455 Returns: 456 A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance. 457 """ 458 return Model.log( 459 artifact_path=artifact_path, 460 name=name, 461 flavor=mlflow.diffusers, 462 adapter_path=adapter_path, 463 base_model=base_model, 464 adapter_type=adapter_type, 465 conda_env=conda_env, 466 code_paths=code_paths, 467 registered_model_name=registered_model_name, 468 signature=signature, 469 input_example=input_example, 470 await_registration_for=await_registration_for, 471 pip_requirements=pip_requirements, 472 extra_pip_requirements=extra_pip_requirements, 473 metadata=metadata, 474 params=params, 475 tags=tags, 476 model_type=model_type, 477 step=step, 478 model_id=model_id, 479 **kwargs, 480 ) 481 482 483 @docstring_version_compatibility_warning(integration_name=FLAVOR_NAME) 484 def load_model(model_uri, dst_path=None): 485 """Load a diffusers adapter model from a local file or a run. 486 487 Args: 488 model_uri: The location, in URI format, of the MLflow model. Examples: 489 490 - ``/Users/me/path/to/local/model`` 491 - ``runs:/<mlflow_run_id>/run-relative/path/to/model`` 492 - ``models:/<model_name>/<model_version>`` 493 494 dst_path: The local filesystem path to download the model artifact to. 495 496 Returns: 497 A :py:class:`DiffusersAdapterModel` with adapter_path, base_model, 498 and adapter_type. Call ``.load_pipeline()`` to get a ready-to-use 499 diffusers pipeline with the adapter applied. 500 """ 501 local_model_path = Path( 502 _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path) 503 ) 504 flavor_conf = _get_flavor_configuration( 505 model_path=str(local_model_path), flavor_name=FLAVOR_NAME 506 ) 507 _add_code_from_conf_to_system_path(str(local_model_path), flavor_conf) 508 509 adapter_weights_path = local_model_path / flavor_conf["adapter_weights"] 510 511 return DiffusersAdapterModel( 512 adapter_path=str(adapter_weights_path), 513 base_model=flavor_conf["base_model"], 514 adapter_type=flavor_conf["adapter_type"], 515 base_model_revision=flavor_conf.get(_BASE_MODEL_REVISION_KEY), 516 weight_name=flavor_conf.get("weight_name"), 517 ) 518 519 520 def _load_pyfunc(path, model_config=None): 521 from mlflow.diffusers.wrapper import _DiffusersAdapterWrapper 522 523 path = Path(path) 524 flavor_conf = _get_flavor_configuration(model_path=str(path), flavor_name=FLAVOR_NAME) 525 526 return _DiffusersAdapterWrapper( 527 adapter_path=str(path / flavor_conf["adapter_weights"]), 528 flavor_conf=flavor_conf, 529 model_config=model_config, 530 ) 531 532 533 __all__ = [ 534 "DiffusersAdapterModel", 535 "load_model", 536 "save_model", 537 "log_model", 538 "get_default_pip_requirements", 539 "get_default_conda_env", 540 ]