/ mlflow / diffusers / __init__.py
__init__.py
  1  """
  2  The ``mlflow.diffusers`` module provides an API for logging and loading diffusion model
  3  LoRA adapters as MLflow Models. This module exports adapter models with
  4  the following flavors:
  5  
  6  :py:mod:`mlflow.diffusers`
  7      Adapter weights in safetensors format, with a reference to the base model.
  8  
  9  :py:mod:`mlflow.pyfunc`
 10      Produced for use by generic pyfunc-based deployment tools and batch inference.
 11      The pyfunc wrapper loads the base diffusion pipeline and applies the adapter
 12      at inference time.
 13  """
 14  
 15  import importlib.util
 16  import logging
 17  import shutil
 18  from dataclasses import dataclass
 19  from pathlib import Path
 20  from typing import Any, Literal
 21  
 22  import yaml
 23  
 24  import mlflow
 25  from mlflow import pyfunc
 26  from mlflow.environment_variables import MLFLOW_DEFAULT_PREDICTION_DEVICE
 27  from mlflow.exceptions import MlflowException
 28  from mlflow.models import Model, ModelInputExample, ModelSignature
 29  from mlflow.models.model import MLMODEL_FILE_NAME
 30  from mlflow.models.utils import _save_example
 31  from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
 32  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
 33  from mlflow.types import DataType, ParamSchema, ParamSpec, Schema
 34  from mlflow.types.schema import ColSpec
 35  from mlflow.utils.docstring_utils import (
 36      LOG_MODEL_PARAM_DOCS,
 37      docstring_version_compatibility_warning,
 38      format_docstring,
 39  )
 40  from mlflow.utils.environment import (
 41      _CONDA_ENV_FILE_NAME,
 42      _CONSTRAINTS_FILE_NAME,
 43      _PYTHON_ENV_FILE_NAME,
 44      _REQUIREMENTS_FILE_NAME,
 45      _mlflow_conda_env,
 46      _process_conda_env,
 47      _process_pip_requirements,
 48      _PythonEnv,
 49      _validate_env_arguments,
 50  )
 51  from mlflow.utils.file_utils import get_total_file_size, write_to
 52  from mlflow.utils.model_utils import (
 53      _add_code_from_conf_to_system_path,
 54      _get_flavor_configuration,
 55      _validate_and_copy_code_paths,
 56      _validate_and_prepare_target_save_path,
 57  )
 58  from mlflow.utils.requirements_utils import _get_pinned_requirement
 59  
 60  _logger = logging.getLogger(__name__)
 61  
 62  FLAVOR_NAME = "diffusers"
 63  
 64  _ADAPTER_WEIGHTS_DIR = "adapter_weights"
 65  _STANDARD_WEIGHT_NAME = "pytorch_lora_weights.safetensors"
 66  
 67  SUPPORTED_ADAPTER_TYPES = ("lora",)
 68  
 69  _BASE_MODEL_REVISION_KEY = "base_model_revision"
 70  
 71  
 72  def _resolve_base_model_revision(base_model):
 73      """Resolve the HuggingFace Hub commit hash for a base model ID.
 74  
 75      Returns None if the ID looks like a local path or if resolution fails.
 76      """
 77      # Only treat as a local path if it's absolute or explicitly relative (./  ../).
 78      # Bare "org/model" strings should always be resolved as HF Hub IDs, even if
 79      # a matching directory happens to exist in the current working directory.
 80      p = Path(base_model)
 81      if p.is_absolute() or base_model.startswith(("./", "../")):
 82          return None
 83  
 84      try:
 85          from mlflow.utils.huggingface_utils import get_latest_commit_for_repo
 86  
 87          return get_latest_commit_for_repo(base_model)
 88      except Exception as e:
 89          # Broad catch is intentional: huggingface_hub types (HfHubHTTPError,
 90          # RepositoryNotFoundError) can't be imported unconditionally.
 91          # Revision pinning is optional — graceful degradation is preferred.
 92          _logger.warning(
 93              "Could not resolve HuggingFace commit hash for '%s' (%s). "
 94              "The base model revision will not be pinned.",
 95              base_model,
 96              type(e).__name__,
 97          )
 98          return None
 99  
100  
101  def _validate_safetensors_format(file_path):
102      try:
103          from safetensors import safe_open
104      except ImportError as e:
105          raise MlflowException.invalid_parameter_value(
106              "The 'safetensors' package is required to validate adapter weights. "
107              "Install it with: pip install safetensors"
108          ) from e
109  
110      try:
111          with safe_open(str(file_path), framework="numpy"):
112              pass
113      except Exception as e:
114          raise MlflowException.invalid_parameter_value(
115              f"File is not a valid safetensors file: {file_path}. Error: {e}"
116          ) from e
117  
118  
119  def _detect_device(device=None):
120      import torch
121  
122      if device is not None:
123          return device
124      if env_device := MLFLOW_DEFAULT_PREDICTION_DEVICE.get():
125          return env_device
126      if torch.cuda.is_available():
127          return "cuda"
128      try:
129          if torch.backends.mps.is_available():
130              return "mps"
131      except AttributeError:
132          pass
133      return "cpu"
134  
135  
136  def _get_default_signature():
137      return ModelSignature(
138          inputs=Schema([ColSpec(type=DataType.string, name="prompt")]),
139          outputs=Schema([ColSpec(type=DataType.binary, name="image")]),
140          params=ParamSchema([
141              ParamSpec(name="num_inference_steps", dtype=DataType.integer, default=30),
142              ParamSpec(name="guidance_scale", dtype=DataType.double, default=7.5),
143              ParamSpec(name="height", dtype=DataType.integer, default=512),
144              ParamSpec(name="width", dtype=DataType.integer, default=512),
145              ParamSpec(name="negative_prompt", dtype=DataType.string, default=""),
146          ]),
147      )
148  
149  
150  def get_default_pip_requirements():
151      # peft: load_lora_weights() depends on it; safetensors: adapter format + validation
152      packages = ["diffusers", "transformers", "torch", "peft", "safetensors"]
153      packages.extend(pkg for pkg in ["accelerate"] if importlib.util.find_spec(pkg))
154      return [_get_pinned_requirement(pkg) for pkg in packages]
155  
156  
157  def get_default_conda_env():
158      return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements())
159  
160  
161  @dataclass(frozen=True)
162  class DiffusersAdapterModel:
163      """A loaded LoRA adapter referencing a HuggingFace base model.
164  
165      Returned by :py:func:`load_model`. Call :py:meth:`load_pipeline` to get
166      a ready-to-use diffusers pipeline with the adapter applied.
167      """
168  
169      adapter_path: str
170      base_model: str
171      adapter_type: Literal["lora"]
172      base_model_revision: str | None = None
173      weight_name: str | None = None
174  
175      def load_pipeline(self, *, base_model: str | None = None, **kwargs):
176          """Download the base model and apply the LoRA adapter.
177  
178          Args:
179              base_model: Override the base model reference stored at save time.
180                  Useful when the original local path is no longer available.
181                  Accepts a HuggingFace model ID or a local directory path.
182              kwargs: Forwarded to ``DiffusionPipeline.from_pretrained()``.
183                  Common options include ``device``, ``torch_dtype``, and ``revision``.
184  
185          Returns:
186              A ``DiffusionPipeline`` with LoRA weights applied.
187          """
188          from diffusers import DiffusionPipeline
189  
190          effective_base_model = base_model or self.base_model
191          device = _detect_device(kwargs.pop("device", None))
192          kwargs.setdefault("torch_dtype", "auto")
193          if self.base_model_revision and "revision" not in kwargs:
194              kwargs["revision"] = self.base_model_revision
195  
196          try:
197              pipe = DiffusionPipeline.from_pretrained(effective_base_model, **kwargs)
198          except OSError as e:
199              raise MlflowException(
200                  f"Failed to load base model '{effective_base_model}'. If the model "
201                  "has moved, pass the correct location via "
202                  "load_pipeline(base_model=...)."
203              ) from e
204  
205          lora_kwargs = {}
206          if self.weight_name:
207              lora_kwargs["weight_name"] = self.weight_name
208          pipe.load_lora_weights(self.adapter_path, **lora_kwargs)
209          return pipe.to(device)
210  
211  
212  @docstring_version_compatibility_warning(integration_name=FLAVOR_NAME)
213  @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="diffusers"))
214  def save_model(
215      adapter_path: str,
216      path: str,
217      base_model: str,
218      adapter_type: Literal["lora"] = "lora",
219      conda_env=None,
220      code_paths: list[str] | None = None,
221      mlflow_model: Model | None = None,
222      signature: ModelSignature | None = None,
223      input_example: ModelInputExample | None = None,
224      pip_requirements: list[str] | str | None = None,
225      extra_pip_requirements: list[str] | str | None = None,
226      metadata: dict[str, Any] | None = None,
227  ) -> None:
228      """Save a diffusers adapter model to a path on the local file system.
229  
230      Args:
231          adapter_path: Path to the adapter weights. Can be a single .safetensors file
232              or a directory containing adapter files. Single files and directories
233              containing a single safetensors file are normalized to
234              ``pytorch_lora_weights.safetensors`` to match the convention expected
235              by ``load_lora_weights()``. Directories with multiple weight files
236              are copied as-is.
237          path: Local path where the model is to be saved.
238          base_model: HuggingFace model ID or local path of the base diffusion model
239              that this adapter was trained on (e.g., "black-forest-labs/FLUX.1-dev").
240          adapter_type: Type of adapter. Currently only "lora" is supported.
241          conda_env: {{ conda_env }}
242          code_paths: {{ code_paths }}
243          mlflow_model: :py:mod:`mlflow.models.Model` this flavor is being added to.
244          signature: {{ signature }}
245          input_example: {{ input_example }}
246          pip_requirements: {{ pip_requirements }}
247          extra_pip_requirements: {{ extra_pip_requirements }}
248          metadata: {{ metadata }}
249      """
250      try:
251          import diffusers
252      except ImportError as e:
253          raise MlflowException.invalid_parameter_value(
254              "The 'diffusers' package is required to save a diffusers adapter model. "
255              "Install it with: pip install diffusers"
256          ) from e
257  
258      try:
259          import peft  # noqa: F401
260      except ImportError as e:
261          raise MlflowException.invalid_parameter_value(
262              "The 'peft' package is required to save a diffusers LoRA adapter model. "
263              "Install it with: pip install peft"
264          ) from e
265  
266      diffusers_version = diffusers.__version__
267  
268      _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements)
269  
270      if not isinstance(base_model, str) or not base_model.strip():
271          raise MlflowException.invalid_parameter_value(
272              "base_model must be a non-empty string (HuggingFace model ID or local path)."
273          )
274  
275      if not isinstance(adapter_type, str):
276          raise MlflowException.invalid_parameter_value(
277              f"adapter_type must be a string, got {type(adapter_type).__name__}"
278          )
279      adapter_type = adapter_type.lower()
280      if adapter_type not in SUPPORTED_ADAPTER_TYPES:
281          raise MlflowException.invalid_parameter_value(
282              f"Unsupported adapter type: {adapter_type}. Supported types: {SUPPORTED_ADAPTER_TYPES}"
283          )
284  
285      adapter_path = Path(adapter_path)
286      if not adapter_path.exists():
287          raise MlflowException.invalid_parameter_value(
288              f"Adapter path does not exist: {adapter_path}"
289          )
290  
291      path = Path(path)
292  
293      _validate_and_prepare_target_save_path(path)
294      code_path_subdir = _validate_and_copy_code_paths(code_paths, path)
295  
296      if mlflow_model is None:
297          mlflow_model = Model()
298  
299      _save_example(mlflow_model, input_example, path)
300  
301      if signature is None:
302          signature = _get_default_signature()
303      mlflow_model.signature = signature
304      if metadata is not None:
305          mlflow_model.metadata = metadata
306  
307      # Copy adapter weights — normalize to the standard filename that
308      # load_lora_weights() expects, so inference works regardless of
309      # what the training framework named the file.
310      weights_dst = path / _ADAPTER_WEIGHTS_DIR
311      weight_name = None
312      if adapter_path.is_file():
313          if adapter_path.suffix != ".safetensors":
314              raise MlflowException.invalid_parameter_value(
315                  f"Single-file adapter must be a .safetensors file, got: {adapter_path.suffix}"
316              )
317          _validate_safetensors_format(adapter_path)
318          weights_dst.mkdir(parents=True, exist_ok=True)
319          shutil.copy2(adapter_path, weights_dst / _STANDARD_WEIGHT_NAME)
320      elif adapter_path.is_dir():
321          # Filter hidden files (.DS_Store, etc.) that break single-file detection
322          all_files = [p for p in adapter_path.iterdir() if not p.name.startswith(".")]
323          safetensor_files = sorted(
324              (p for p in all_files if p.suffix == ".safetensors"),
325              key=lambda p: p.name,
326          )
327          if not safetensor_files:
328              raise MlflowException.invalid_parameter_value(
329                  f"Adapter directory contains no .safetensors files: {adapter_path}"
330              )
331          for sf in safetensor_files:
332              _validate_safetensors_format(sf)
333          if len(safetensor_files) == 1 and len(all_files) == 1:
334              # Directory with a single safetensors file — normalize its name
335              weights_dst.mkdir(parents=True, exist_ok=True)
336              shutil.copy2(safetensor_files[0], weights_dst / _STANDARD_WEIGHT_NAME)
337          else:
338              # Multiple files or companion files — copy entire directory as-is
339              shutil.copytree(adapter_path, weights_dst)
340              # If no standard weight file exists, record which file
341              # load_lora_weights should target so inference doesn't silently
342              # pick an arbitrary file or fail in offline mode.
343              has_standard = any(sf.name == _STANDARD_WEIGHT_NAME for sf in safetensor_files)
344              if not has_standard:
345                  weight_name = safetensor_files[0].name
346                  if len(safetensor_files) >= 2:
347                      _logger.warning(
348                          "Adapter directory contains %d .safetensors files but none named "
349                          "'%s'. Will use '%s' as the primary weight file at inference time. "
350                          "Consider renaming it to '%s' to avoid ambiguity.",
351                          len(safetensor_files),
352                          _STANDARD_WEIGHT_NAME,
353                          weight_name,
354                          _STANDARD_WEIGHT_NAME,
355                      )
356      else:
357          raise MlflowException.invalid_parameter_value(
358              f"Adapter path is neither a file nor a directory: {adapter_path}"
359          )
360  
361      flavor_kwargs = {
362          "base_model": base_model,
363          "adapter_type": adapter_type,
364          "adapter_weights": _ADAPTER_WEIGHTS_DIR,
365          "diffusers_version": diffusers_version,
366          "code": code_path_subdir,
367      }
368      if revision := _resolve_base_model_revision(base_model):
369          flavor_kwargs[_BASE_MODEL_REVISION_KEY] = revision
370      if weight_name:
371          flavor_kwargs["weight_name"] = weight_name
372      mlflow_model.add_flavor(FLAVOR_NAME, **flavor_kwargs)
373      pyfunc.add_to_model(
374          mlflow_model,
375          loader_module="mlflow.diffusers",
376          conda_env=_CONDA_ENV_FILE_NAME,
377          python_env=_PYTHON_ENV_FILE_NAME,
378          code=code_path_subdir,
379      )
380  
381      if size := get_total_file_size(path):
382          mlflow_model.model_size_bytes = size
383      mlflow_model.save(str(path / MLMODEL_FILE_NAME))
384  
385      # Save environment files
386      if conda_env is None:
387          default_reqs = get_default_pip_requirements() if pip_requirements is None else None
388          conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
389              default_reqs,
390              pip_requirements,
391              extra_pip_requirements,
392          )
393      else:
394          conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
395  
396      with open(path / _CONDA_ENV_FILE_NAME, "w") as f:
397          yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
398  
399      if pip_constraints:
400          write_to(str(path / _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints))
401  
402      write_to(str(path / _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))
403      _PythonEnv.current().to_yaml(str(path / _PYTHON_ENV_FILE_NAME))
404  
405  
406  @docstring_version_compatibility_warning(integration_name=FLAVOR_NAME)
407  @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="diffusers"))
408  def log_model(
409      adapter_path,
410      base_model,
411      adapter_type: Literal["lora"] = "lora",
412      artifact_path: str | None = None,
413      conda_env=None,
414      code_paths=None,
415      registered_model_name=None,
416      signature: ModelSignature | None = None,
417      input_example: ModelInputExample | None = None,
418      await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
419      pip_requirements=None,
420      extra_pip_requirements=None,
421      metadata=None,
422      params: dict[str, Any] | None = None,
423      tags: dict[str, Any] | None = None,
424      model_type: str | None = None,
425      step: int = 0,
426      model_id: str | None = None,
427      name: str | None = None,
428      **kwargs,
429  ):
430      """Log a diffusers adapter model as an MLflow artifact for the current run.
431  
432      Args:
433          adapter_path: Path to the adapter weights. Can be a single .safetensors file
434              or a directory containing adapter files.
435          base_model: HuggingFace model ID or local path of the base diffusion model.
436          adapter_type: Type of adapter. Currently only "lora" is supported.
437          artifact_path: Deprecated. Use ``name`` instead.
438          conda_env: {{ conda_env }}
439          code_paths: {{ code_paths }}
440          registered_model_name: If given, create a model version under this name.
441          signature: {{ signature }}
442          input_example: {{ input_example }}
443          await_registration_for: Number of seconds to wait for model version creation.
444          pip_requirements: {{ pip_requirements }}
445          extra_pip_requirements: {{ extra_pip_requirements }}
446          metadata: {{ metadata }}
447          params: {{ params }}
448          tags: {{ tags }}
449          model_type: {{ model_type }}
450          step: {{ step }}
451          model_id: {{ model_id }}
452          name: {{ name }}
453          kwargs: Extra arguments to pass to :py:func:`mlflow.models.Model.log`.
454  
455      Returns:
456          A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance.
457      """
458      return Model.log(
459          artifact_path=artifact_path,
460          name=name,
461          flavor=mlflow.diffusers,
462          adapter_path=adapter_path,
463          base_model=base_model,
464          adapter_type=adapter_type,
465          conda_env=conda_env,
466          code_paths=code_paths,
467          registered_model_name=registered_model_name,
468          signature=signature,
469          input_example=input_example,
470          await_registration_for=await_registration_for,
471          pip_requirements=pip_requirements,
472          extra_pip_requirements=extra_pip_requirements,
473          metadata=metadata,
474          params=params,
475          tags=tags,
476          model_type=model_type,
477          step=step,
478          model_id=model_id,
479          **kwargs,
480      )
481  
482  
483  @docstring_version_compatibility_warning(integration_name=FLAVOR_NAME)
484  def load_model(model_uri, dst_path=None):
485      """Load a diffusers adapter model from a local file or a run.
486  
487      Args:
488          model_uri: The location, in URI format, of the MLflow model. Examples:
489  
490              - ``/Users/me/path/to/local/model``
491              - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
492              - ``models:/<model_name>/<model_version>``
493  
494          dst_path: The local filesystem path to download the model artifact to.
495  
496      Returns:
497          A :py:class:`DiffusersAdapterModel` with adapter_path, base_model,
498          and adapter_type. Call ``.load_pipeline()`` to get a ready-to-use
499          diffusers pipeline with the adapter applied.
500      """
501      local_model_path = Path(
502          _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path)
503      )
504      flavor_conf = _get_flavor_configuration(
505          model_path=str(local_model_path), flavor_name=FLAVOR_NAME
506      )
507      _add_code_from_conf_to_system_path(str(local_model_path), flavor_conf)
508  
509      adapter_weights_path = local_model_path / flavor_conf["adapter_weights"]
510  
511      return DiffusersAdapterModel(
512          adapter_path=str(adapter_weights_path),
513          base_model=flavor_conf["base_model"],
514          adapter_type=flavor_conf["adapter_type"],
515          base_model_revision=flavor_conf.get(_BASE_MODEL_REVISION_KEY),
516          weight_name=flavor_conf.get("weight_name"),
517      )
518  
519  
520  def _load_pyfunc(path, model_config=None):
521      from mlflow.diffusers.wrapper import _DiffusersAdapterWrapper
522  
523      path = Path(path)
524      flavor_conf = _get_flavor_configuration(model_path=str(path), flavor_name=FLAVOR_NAME)
525  
526      return _DiffusersAdapterWrapper(
527          adapter_path=str(path / flavor_conf["adapter_weights"]),
528          flavor_conf=flavor_conf,
529          model_config=model_config,
530      )
531  
532  
533  __all__ = [
534      "DiffusersAdapterModel",
535      "load_model",
536      "save_model",
537      "log_model",
538      "get_default_pip_requirements",
539      "get_default_conda_env",
540  ]