demo_diffusers_adapter.py
1 """ 2 Demo: MLflow Diffusers Adapter Flavor (LoRA) 3 4 This script demonstrates the full workflow of logging and loading a diffusion 5 model LoRA adapter using the native mlflow.diffusers flavor. 6 7 No GPU or real model weights required — uses a fake adapter for validation. 8 """ 9 10 import tempfile 11 from pathlib import Path 12 13 import numpy as np 14 import yaml 15 from safetensors.numpy import save_file 16 17 import mlflow 18 import mlflow.diffusers 19 20 21 def create_fake_lora_adapter(output_dir: Path) -> Path: 22 output_dir.mkdir(parents=True, exist_ok=True) 23 24 # Simulate LoRA weight matrices (small random tensors) 25 tensors = { 26 "unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.lora_down.weight": ( 27 np.random.randn(4, 320).astype(np.float32) 28 ), 29 "unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.lora_up.weight": ( 30 np.random.randn(320, 4).astype(np.float32) 31 ), 32 } 33 adapter_file = output_dir / "pytorch_lora_weights.safetensors" 34 save_file(tensors, str(adapter_file)) 35 36 print(f"Created fake LoRA adapter at: {adapter_file}") 37 print(f" Adapter size: {adapter_file.stat().st_size} bytes") 38 return output_dir 39 40 41 def demo_log_and_load(): 42 """Demonstrate the full log -> load -> inspect cycle.""" 43 with tempfile.TemporaryDirectory() as tmpdir: 44 # 1. Create fake adapter 45 adapter_dir = create_fake_lora_adapter(Path(tmpdir) / "my_lora") 46 47 # 2. Log the adapter with MLflow 48 print("\n--- Logging adapter with mlflow.diffusers.log_model() ---") 49 mlflow.set_experiment("diffusers-adapter-poc") 50 51 with mlflow.start_run(run_name="lora-adapter-demo") as run: 52 model_info = mlflow.diffusers.log_model( 53 adapter_path=str(adapter_dir), 54 base_model="black-forest-labs/FLUX.1-dev", 55 adapter_type="lora", 56 name="lora_model", 57 metadata={ 58 "lora_rank": 4, 59 "training_steps": 1000, 60 "trigger_word": "sks style", 61 }, 62 ) 63 64 print(f" Run ID: {run.info.run_id}") 65 print(f" Model URI: {model_info.model_uri}") 66 67 # 3. Inspect the MLmodel file 68 print("\n--- MLmodel file contents ---") 69 model_uri = f"runs:/{run.info.run_id}/lora_model" 70 local_path = mlflow.artifacts.download_artifacts(model_uri) 71 mlmodel_path = Path(local_path) / "MLmodel" 72 73 with open(mlmodel_path) as f: 74 mlmodel = yaml.safe_load(f) 75 76 print(yaml.dump(mlmodel, default_flow_style=False, indent=2)) 77 78 # 4. Load the model back 79 print("--- Loading model back with mlflow.diffusers.load_model() ---") 80 loaded = mlflow.diffusers.load_model(model_uri) 81 82 print(f" Type: {type(loaded).__name__}") 83 print(f" Base model: {loaded.base_model}") 84 print(f" Adapter type: {loaded.adapter_type}") 85 print(f" Adapter path: {loaded.adapter_path}") 86 print(f" Adapter files: {list(Path(loaded.adapter_path).iterdir())}") 87 88 # 5. Verify flavor config from MLmodel 89 print("\n--- Flavor config ---") 90 flavor_conf = mlmodel["flavors"]["diffusers"] 91 print(f" base_model: {flavor_conf['base_model']}") 92 print(f" adapter_type: {flavor_conf['adapter_type']}") 93 print(f" adapter_weights: {flavor_conf['adapter_weights']}") 94 95 # 6. Show that pyfunc interface is available 96 print("\n--- Pyfunc model interface ---") 97 print(" mlflow.pyfunc.load_model() would return a wrapper with predict()") 98 print(" predict() accepts: DataFrame/dict with 'prompt' column") 99 print(" predict() returns: list of PNG-encoded image bytes") 100 print(" (Skipping actual pyfunc load — requires base model download)") 101 102 print("\n--- Demo complete! ---") 103 print( 104 "The adapter is logged as a first-class MLflow model with full model registry support." 105 ) 106 print( 107 "To generate images, call loaded.load_pipeline() on a machine " 108 "with the base model available." 109 ) 110 111 112 if __name__ == "__main__": 113 demo_log_and_load()