/ mlflow / diffusers / wrapper.py
wrapper.py
  1  import io
  2  import logging
  3  import threading
  4  from types import MappingProxyType
  5  from typing import Any
  6  
  7  import pandas as pd
  8  
  9  from mlflow.diffusers import _detect_device
 10  from mlflow.exceptions import MlflowException
 11  
 12  _logger = logging.getLogger(__name__)
 13  
 14  
 15  class _DiffusersAdapterWrapper:
 16      def __init__(
 17          self,
 18          adapter_path: str,
 19          flavor_conf: dict[str, Any],
 20          model_config: dict[str, Any] | None = None,
 21      ):
 22          self._adapter_path = adapter_path
 23          self._flavor_conf = flavor_conf
 24          self._model_config = MappingProxyType(model_config or {})
 25          self._pipeline = None
 26          self._load_lock = threading.Lock()
 27  
 28      def _load_pipeline(self):
 29          from diffusers import DiffusionPipeline
 30  
 31          base_model = self._model_config.get("base_model") or self._flavor_conf["base_model"]
 32          base_model_revision = self._flavor_conf.get("base_model_revision")
 33          device = _detect_device(self._model_config.get("device"))
 34          torch_dtype = self._model_config.get("torch_dtype", "auto")
 35  
 36          load_kwargs = {"torch_dtype": torch_dtype}
 37          if base_model_revision:
 38              load_kwargs["revision"] = base_model_revision
 39  
 40          weight_name = self._flavor_conf.get("weight_name")
 41          lora_kwargs = {}
 42          if weight_name:
 43              lora_kwargs["weight_name"] = weight_name
 44  
 45          _logger.info("Loading base pipeline: %s", base_model)
 46          try:
 47              pipe = DiffusionPipeline.from_pretrained(base_model, **load_kwargs)
 48          except OSError as e:
 49              raise MlflowException(
 50                  f"Failed to load base model '{base_model}'. If the model has moved, "
 51                  "pass the correct location via "
 52                  "model_config={{'base_model': '<new_path_or_hub_id>'}} "
 53                  "when loading with mlflow.pyfunc.load_model()."
 54              ) from e
 55  
 56          _logger.info("Loading LoRA adapter from: %s", self._adapter_path)
 57          pipe.load_lora_weights(self._adapter_path, **lora_kwargs)
 58  
 59          self._pipeline = pipe.to(device)
 60  
 61      def get_raw_model(self):
 62          if self._pipeline is None:
 63              with self._load_lock:
 64                  if self._pipeline is None:
 65                      self._load_pipeline()
 66          return self._pipeline
 67  
 68      def _flatten_prompts(self, prompts):
 69          """Flatten nested lists produced by schema enforcement."""
 70          flat = []
 71          for item in prompts:
 72              if isinstance(item, list):
 73                  flat.extend(item)
 74              else:
 75                  flat.append(item)
 76          return flat
 77  
 78      def predict(self, data, params: dict[str, Any] | None = None):
 79          pipeline = self.get_raw_model()
 80  
 81          if isinstance(data, pd.DataFrame):
 82              if "prompt" in data.columns:
 83                  prompts = data["prompt"].tolist()
 84              elif len(data.columns) == 1:
 85                  # Schema enforcement wraps scalar strings into a single-column DataFrame
 86                  prompts = data.iloc[:, 0].tolist()
 87              else:
 88                  raise MlflowException(
 89                      f"Input DataFrame must contain a 'prompt' column. "
 90                      f"Got columns: {list(data.columns)}"
 91                  )
 92              # Schema enforcement may wrap {"prompt": ["a","b"]} into a
 93              # single-row DataFrame where the cell contains a list, producing
 94              # [["a","b"]] after tolist(). Flatten to ["a","b"].
 95              prompts = self._flatten_prompts(prompts)
 96          elif isinstance(data, str):
 97              prompts = [data]
 98          elif isinstance(data, dict):
 99              if "prompt" not in data:
100                  raise MlflowException(
101                      f"Input dict must contain a 'prompt' key. Got keys: {list(data.keys())}"
102                  )
103              prompts = data["prompt"]
104              if isinstance(prompts, str):
105                  prompts = [prompts]
106              elif isinstance(prompts, list):
107                  prompts = self._flatten_prompts(prompts)
108              else:
109                  raise MlflowException(
110                      "'prompt' value must be a string or list of strings, "
111                      f"got {type(prompts).__name__}."
112                  )
113          elif isinstance(data, list):
114              prompts = self._flatten_prompts(data)
115          else:
116              raise MlflowException(f"Unsupported input type: {type(data)}")
117  
118          if not prompts:
119              raise MlflowException(
120                  "No prompts provided. Input must contain at least one prompt string."
121              )
122  
123          if any(p is None for p in prompts):
124              raise MlflowException(
125                  "Prompt values must be strings, not None. "
126                  "Check your input for missing or null values."
127              )
128  
129          params = params or {}
130          param_keys = ("num_inference_steps", "guidance_scale", "height", "width", "negative_prompt")
131          gen_kwargs = {k: params[k] for k in param_keys if k in params}
132          # Drop empty-string negative_prompt so the pipeline uses its own default
133          if gen_kwargs.get("negative_prompt") == "":
134              del gen_kwargs["negative_prompt"]
135  
136          output = pipeline(prompt=prompts, **gen_kwargs)
137  
138          if not hasattr(output, "images") or not output.images:
139              raise MlflowException(
140                  "Pipeline returned no images. The output may have been filtered "
141                  "by the safety checker, or the pipeline does not support image generation."
142              )
143  
144          results = []
145          for image in output.images:
146              buf = io.BytesIO()
147              image.save(buf, format="PNG")
148              results.append(buf.getvalue())
149              buf.close()
150  
151          return results