wrapper.py
1 import io 2 import logging 3 import threading 4 from types import MappingProxyType 5 from typing import Any 6 7 import pandas as pd 8 9 from mlflow.diffusers import _detect_device 10 from mlflow.exceptions import MlflowException 11 12 _logger = logging.getLogger(__name__) 13 14 15 class _DiffusersAdapterWrapper: 16 def __init__( 17 self, 18 adapter_path: str, 19 flavor_conf: dict[str, Any], 20 model_config: dict[str, Any] | None = None, 21 ): 22 self._adapter_path = adapter_path 23 self._flavor_conf = flavor_conf 24 self._model_config = MappingProxyType(model_config or {}) 25 self._pipeline = None 26 self._load_lock = threading.Lock() 27 28 def _load_pipeline(self): 29 from diffusers import DiffusionPipeline 30 31 base_model = self._model_config.get("base_model") or self._flavor_conf["base_model"] 32 base_model_revision = self._flavor_conf.get("base_model_revision") 33 device = _detect_device(self._model_config.get("device")) 34 torch_dtype = self._model_config.get("torch_dtype", "auto") 35 36 load_kwargs = {"torch_dtype": torch_dtype} 37 if base_model_revision: 38 load_kwargs["revision"] = base_model_revision 39 40 weight_name = self._flavor_conf.get("weight_name") 41 lora_kwargs = {} 42 if weight_name: 43 lora_kwargs["weight_name"] = weight_name 44 45 _logger.info("Loading base pipeline: %s", base_model) 46 try: 47 pipe = DiffusionPipeline.from_pretrained(base_model, **load_kwargs) 48 except OSError as e: 49 raise MlflowException( 50 f"Failed to load base model '{base_model}'. If the model has moved, " 51 "pass the correct location via " 52 "model_config={{'base_model': '<new_path_or_hub_id>'}} " 53 "when loading with mlflow.pyfunc.load_model()." 54 ) from e 55 56 _logger.info("Loading LoRA adapter from: %s", self._adapter_path) 57 pipe.load_lora_weights(self._adapter_path, **lora_kwargs) 58 59 self._pipeline = pipe.to(device) 60 61 def get_raw_model(self): 62 if self._pipeline is None: 63 with self._load_lock: 64 if self._pipeline is None: 65 self._load_pipeline() 66 return self._pipeline 67 68 def _flatten_prompts(self, prompts): 69 """Flatten nested lists produced by schema enforcement.""" 70 flat = [] 71 for item in prompts: 72 if isinstance(item, list): 73 flat.extend(item) 74 else: 75 flat.append(item) 76 return flat 77 78 def predict(self, data, params: dict[str, Any] | None = None): 79 pipeline = self.get_raw_model() 80 81 if isinstance(data, pd.DataFrame): 82 if "prompt" in data.columns: 83 prompts = data["prompt"].tolist() 84 elif len(data.columns) == 1: 85 # Schema enforcement wraps scalar strings into a single-column DataFrame 86 prompts = data.iloc[:, 0].tolist() 87 else: 88 raise MlflowException( 89 f"Input DataFrame must contain a 'prompt' column. " 90 f"Got columns: {list(data.columns)}" 91 ) 92 # Schema enforcement may wrap {"prompt": ["a","b"]} into a 93 # single-row DataFrame where the cell contains a list, producing 94 # [["a","b"]] after tolist(). Flatten to ["a","b"]. 95 prompts = self._flatten_prompts(prompts) 96 elif isinstance(data, str): 97 prompts = [data] 98 elif isinstance(data, dict): 99 if "prompt" not in data: 100 raise MlflowException( 101 f"Input dict must contain a 'prompt' key. Got keys: {list(data.keys())}" 102 ) 103 prompts = data["prompt"] 104 if isinstance(prompts, str): 105 prompts = [prompts] 106 elif isinstance(prompts, list): 107 prompts = self._flatten_prompts(prompts) 108 else: 109 raise MlflowException( 110 "'prompt' value must be a string or list of strings, " 111 f"got {type(prompts).__name__}." 112 ) 113 elif isinstance(data, list): 114 prompts = self._flatten_prompts(data) 115 else: 116 raise MlflowException(f"Unsupported input type: {type(data)}") 117 118 if not prompts: 119 raise MlflowException( 120 "No prompts provided. Input must contain at least one prompt string." 121 ) 122 123 if any(p is None for p in prompts): 124 raise MlflowException( 125 "Prompt values must be strings, not None. " 126 "Check your input for missing or null values." 127 ) 128 129 params = params or {} 130 param_keys = ("num_inference_steps", "guidance_scale", "height", "width", "negative_prompt") 131 gen_kwargs = {k: params[k] for k in param_keys if k in params} 132 # Drop empty-string negative_prompt so the pipeline uses its own default 133 if gen_kwargs.get("negative_prompt") == "": 134 del gen_kwargs["negative_prompt"] 135 136 output = pipeline(prompt=prompts, **gen_kwargs) 137 138 if not hasattr(output, "images") or not output.images: 139 raise MlflowException( 140 "Pipeline returned no images. The output may have been filtered " 141 "by the safety checker, or the pipeline does not support image generation." 142 ) 143 144 results = [] 145 for image in output.images: 146 buf = io.BytesIO() 147 image.save(buf, format="PNG") 148 results.append(buf.getvalue()) 149 buf.close() 150 151 return results