/ export_merged_model.py
export_merged_model.py
  1  import torch
  2  import os
  3  from diffusers import StableDiffusionXLImg2ImgPipeline
  4  
  5  os.makedirs("./sdxl-funko-pop-merged/onnx", exist_ok=True)
  6  
  7  # ── 1. Download & merge LoRA into UNet ────────────────────────────────────────
  8  print("Downloading SDXL...")
  9  pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
 10      "stabilityai/stable-diffusion-xl-base-1.0",
 11      torch_dtype=torch.float32,
 12      use_safetensors=True,
 13  )
 14  
 15  print("Downloading and fusing LoRA...")
 16  pipe.load_lora_weights("ProomptEngineer/pe-funko-pop-diffusion-style")
 17  pipe.fuse_lora(lora_scale=0.9)
 18  pipe.unload_lora_weights()
 19  print("✅ LoRA fused into UNet")
 20  
 21  # ── 2. Export merged text encoders ────────────────────────────────────────────
 22  class TextEncoders(torch.nn.Module):
 23      def __init__(self, enc1, enc2):
 24          super().__init__()
 25          self.enc1 = enc1
 26          self.enc2 = enc2
 27  
 28      def forward(self, tokens1, tokens2):
 29          out1 = self.enc1(tokens1, output_hidden_states=True)
 30          out2 = self.enc2(tokens2, output_hidden_states=True)
 31          embeds = torch.cat([out1.hidden_states[-2], out2.hidden_states[-2]], dim=-1)
 32          pooled = out2[0]
 33          return embeds, pooled
 34  
 35  print("Exporting merged text encoders...")
 36  torch.onnx.export(
 37      TextEncoders(pipe.text_encoder, pipe.text_encoder_2),
 38      (torch.zeros(1, 77, dtype=torch.int64), torch.zeros(1, 77, dtype=torch.int64)),
 39      "./sdxl-funko-pop-merged/onnx/text_encoders.onnx",
 40      input_names=["tokens1", "tokens2"],
 41      output_names=["prompt_embeds", "pooled_embeds"],
 42      dynamic_axes={},
 43      opset_version=17,
 44  )
 45  print("✅ text_encoders.onnx")
 46  
 47  # ── 3. Export UNet (LoRA baked in) ────────────────────────────────────────────
 48  print("Exporting UNet...")
 49  torch.onnx.export(
 50      pipe.unet,
 51      (
 52          torch.randn(1, 4, 128, 128),
 53          torch.tensor([999], dtype=torch.float32),
 54          torch.randn(1, 77, 2048),
 55          {"text_embeds": torch.randn(1, 1280),
 56           "time_ids":    torch.tensor([[1024, 1024, 0, 0, 1024, 1024]], dtype=torch.float32)},
 57      ),
 58      "./sdxl-funko-pop-merged/onnx/unet.onnx",
 59      input_names=["latents", "timestep", "encoder_hidden_states", "text_embeds", "time_ids"],
 60      output_names=["noise_pred"],
 61      dynamic_axes={},
 62      opset_version=17,
 63  )
 64  print("✅ unet.onnx")
 65  
 66  # ── 4. Export VAE Encoder ─────────────────────────────────────────────────────
 67  print("Exporting VAE encoder...")
 68  torch.onnx.export(
 69      pipe.vae.encoder,
 70      torch.randn(1, 3, 1024, 1024),
 71      "./sdxl-funko-pop-merged/onnx/vae_encoder.onnx",
 72      input_names=["image"],
 73      output_names=["latents"],
 74      dynamic_axes={},
 75      opset_version=17,
 76  )
 77  print("✅ vae_encoder.onnx")
 78  
 79  # ── 5. Export VAE Decoder ─────────────────────────────────────────────────────
 80  print("Exporting VAE decoder...")
 81  torch.onnx.export(
 82      pipe.vae.decoder,
 83      torch.randn(1, 4, 128, 128),
 84      "./sdxl-funko-pop-merged/onnx/vae_decoder.onnx",
 85      input_names=["latents"],
 86      output_names=["image"],
 87      dynamic_axes={},
 88      opset_version=17,
 89  )
 90  print("✅ vae_decoder.onnx")
 91  
 92  # ── 6. Validate ───────────────────────────────────────────────────────────────
 93  print("\nValidating exports...")
 94  import onnxruntime as ort
 95  import numpy as np
 96  
 97  def validate(path, inputs, torch_out):
 98      sess = ort.InferenceSession(path)
 99      names = [i.name for i in sess.get_inputs()]
100      onnx_out = sess.run(None, dict(zip(names, [i.numpy() for i in inputs])))[0]
101      np.testing.assert_allclose(onnx_out, torch_out.detach().numpy(), rtol=1e-3)
102      print(f"✅ {path} validated")
103  
104  dummy_t1 = torch.zeros(1, 77, dtype=torch.int64)
105  dummy_t2 = torch.zeros(1, 77, dtype=torch.int64)
106  te = TextEncoders(pipe.text_encoder, pipe.text_encoder_2)
107  validate("./sdxl-funko-pop-merged/onnx/text_encoders.onnx", [dummy_t1, dummy_t2], te(dummy_t1, dummy_t2)[0])
108  
109  print("\n✅ All done")
110  print("   ./sdxl-funko-pop-merged/onnx/text_encoders.onnx")
111  print("   ./sdxl-funko-pop-merged/onnx/unet.onnx")
112  print("   ./sdxl-funko-pop-merged/onnx/vae_encoder.onnx")
113  print("   ./sdxl-funko-pop-merged/onnx/vae_decoder.onnx")