/ examples / diffusers / demo_diffusers_adapter.py
demo_diffusers_adapter.py
  1  """
  2  Demo: MLflow Diffusers Adapter Flavor (LoRA)
  3  
  4  This script demonstrates the full workflow of logging and loading a diffusion
  5  model LoRA adapter using the native mlflow.diffusers flavor.
  6  
  7  No GPU or real model weights required — uses a fake adapter for validation.
  8  """
  9  
 10  import tempfile
 11  from pathlib import Path
 12  
 13  import numpy as np
 14  import yaml
 15  from safetensors.numpy import save_file
 16  
 17  import mlflow
 18  import mlflow.diffusers
 19  
 20  
 21  def create_fake_lora_adapter(output_dir: Path) -> Path:
 22      output_dir.mkdir(parents=True, exist_ok=True)
 23  
 24      # Simulate LoRA weight matrices (small random tensors)
 25      tensors = {
 26          "unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.lora_down.weight": (
 27              np.random.randn(4, 320).astype(np.float32)
 28          ),
 29          "unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.lora_up.weight": (
 30              np.random.randn(320, 4).astype(np.float32)
 31          ),
 32      }
 33      adapter_file = output_dir / "pytorch_lora_weights.safetensors"
 34      save_file(tensors, str(adapter_file))
 35  
 36      print(f"Created fake LoRA adapter at: {adapter_file}")
 37      print(f"  Adapter size: {adapter_file.stat().st_size} bytes")
 38      return output_dir
 39  
 40  
 41  def demo_log_and_load():
 42      """Demonstrate the full log -> load -> inspect cycle."""
 43      with tempfile.TemporaryDirectory() as tmpdir:
 44          # 1. Create fake adapter
 45          adapter_dir = create_fake_lora_adapter(Path(tmpdir) / "my_lora")
 46  
 47          # 2. Log the adapter with MLflow
 48          print("\n--- Logging adapter with mlflow.diffusers.log_model() ---")
 49          mlflow.set_experiment("diffusers-adapter-poc")
 50  
 51          with mlflow.start_run(run_name="lora-adapter-demo") as run:
 52              model_info = mlflow.diffusers.log_model(
 53                  adapter_path=str(adapter_dir),
 54                  base_model="black-forest-labs/FLUX.1-dev",
 55                  adapter_type="lora",
 56                  name="lora_model",
 57                  metadata={
 58                      "lora_rank": 4,
 59                      "training_steps": 1000,
 60                      "trigger_word": "sks style",
 61                  },
 62              )
 63  
 64              print(f"  Run ID: {run.info.run_id}")
 65              print(f"  Model URI: {model_info.model_uri}")
 66  
 67          # 3. Inspect the MLmodel file
 68          print("\n--- MLmodel file contents ---")
 69          model_uri = f"runs:/{run.info.run_id}/lora_model"
 70          local_path = mlflow.artifacts.download_artifacts(model_uri)
 71          mlmodel_path = Path(local_path) / "MLmodel"
 72  
 73          with open(mlmodel_path) as f:
 74              mlmodel = yaml.safe_load(f)
 75  
 76          print(yaml.dump(mlmodel, default_flow_style=False, indent=2))
 77  
 78          # 4. Load the model back
 79          print("--- Loading model back with mlflow.diffusers.load_model() ---")
 80          loaded = mlflow.diffusers.load_model(model_uri)
 81  
 82          print(f"  Type: {type(loaded).__name__}")
 83          print(f"  Base model: {loaded.base_model}")
 84          print(f"  Adapter type: {loaded.adapter_type}")
 85          print(f"  Adapter path: {loaded.adapter_path}")
 86          print(f"  Adapter files: {list(Path(loaded.adapter_path).iterdir())}")
 87  
 88          # 5. Verify flavor config from MLmodel
 89          print("\n--- Flavor config ---")
 90          flavor_conf = mlmodel["flavors"]["diffusers"]
 91          print(f"  base_model: {flavor_conf['base_model']}")
 92          print(f"  adapter_type: {flavor_conf['adapter_type']}")
 93          print(f"  adapter_weights: {flavor_conf['adapter_weights']}")
 94  
 95          # 6. Show that pyfunc interface is available
 96          print("\n--- Pyfunc model interface ---")
 97          print("  mlflow.pyfunc.load_model() would return a wrapper with predict()")
 98          print("  predict() accepts: DataFrame/dict with 'prompt' column")
 99          print("  predict() returns: list of PNG-encoded image bytes")
100          print("  (Skipping actual pyfunc load — requires base model download)")
101  
102          print("\n--- Demo complete! ---")
103          print(
104              "The adapter is logged as a first-class MLflow model with full model registry support."
105          )
106          print(
107              "To generate images, call loaded.load_pipeline() on a machine "
108              "with the base model available."
109          )
110  
111  
112  if __name__ == "__main__":
113      demo_log_and_load()