/ export_merged_model.py
export_merged_model.py
1 import torch 2 import os 3 from diffusers import StableDiffusionXLImg2ImgPipeline 4 5 os.makedirs("./sdxl-funko-pop-merged/onnx", exist_ok=True) 6 7 # ── 1. Download & merge LoRA into UNet ──────────────────────────────────────── 8 print("Downloading SDXL...") 9 pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( 10 "stabilityai/stable-diffusion-xl-base-1.0", 11 torch_dtype=torch.float32, 12 use_safetensors=True, 13 ) 14 15 print("Downloading and fusing LoRA...") 16 pipe.load_lora_weights("ProomptEngineer/pe-funko-pop-diffusion-style") 17 pipe.fuse_lora(lora_scale=0.9) 18 pipe.unload_lora_weights() 19 print("✅ LoRA fused into UNet") 20 21 # ── 2. Export merged text encoders ──────────────────────────────────────────── 22 class TextEncoders(torch.nn.Module): 23 def __init__(self, enc1, enc2): 24 super().__init__() 25 self.enc1 = enc1 26 self.enc2 = enc2 27 28 def forward(self, tokens1, tokens2): 29 out1 = self.enc1(tokens1, output_hidden_states=True) 30 out2 = self.enc2(tokens2, output_hidden_states=True) 31 embeds = torch.cat([out1.hidden_states[-2], out2.hidden_states[-2]], dim=-1) 32 pooled = out2[0] 33 return embeds, pooled 34 35 print("Exporting merged text encoders...") 36 torch.onnx.export( 37 TextEncoders(pipe.text_encoder, pipe.text_encoder_2), 38 (torch.zeros(1, 77, dtype=torch.int64), torch.zeros(1, 77, dtype=torch.int64)), 39 "./sdxl-funko-pop-merged/onnx/text_encoders.onnx", 40 input_names=["tokens1", "tokens2"], 41 output_names=["prompt_embeds", "pooled_embeds"], 42 dynamic_axes={}, 43 opset_version=17, 44 ) 45 print("✅ text_encoders.onnx") 46 47 # ── 3. Export UNet (LoRA baked in) ──────────────────────────────────────────── 48 print("Exporting UNet...") 49 torch.onnx.export( 50 pipe.unet, 51 ( 52 torch.randn(1, 4, 128, 128), 53 torch.tensor([999], dtype=torch.float32), 54 torch.randn(1, 77, 2048), 55 {"text_embeds": torch.randn(1, 1280), 56 "time_ids": torch.tensor([[1024, 1024, 0, 0, 1024, 1024]], dtype=torch.float32)}, 57 ), 58 "./sdxl-funko-pop-merged/onnx/unet.onnx", 59 input_names=["latents", "timestep", "encoder_hidden_states", "text_embeds", "time_ids"], 60 output_names=["noise_pred"], 61 dynamic_axes={}, 62 opset_version=17, 63 ) 64 print("✅ unet.onnx") 65 66 # ── 4. Export VAE Encoder ───────────────────────────────────────────────────── 67 print("Exporting VAE encoder...") 68 torch.onnx.export( 69 pipe.vae.encoder, 70 torch.randn(1, 3, 1024, 1024), 71 "./sdxl-funko-pop-merged/onnx/vae_encoder.onnx", 72 input_names=["image"], 73 output_names=["latents"], 74 dynamic_axes={}, 75 opset_version=17, 76 ) 77 print("✅ vae_encoder.onnx") 78 79 # ── 5. Export VAE Decoder ───────────────────────────────────────────────────── 80 print("Exporting VAE decoder...") 81 torch.onnx.export( 82 pipe.vae.decoder, 83 torch.randn(1, 4, 128, 128), 84 "./sdxl-funko-pop-merged/onnx/vae_decoder.onnx", 85 input_names=["latents"], 86 output_names=["image"], 87 dynamic_axes={}, 88 opset_version=17, 89 ) 90 print("✅ vae_decoder.onnx") 91 92 # ── 6. Validate ─────────────────────────────────────────────────────────────── 93 print("\nValidating exports...") 94 import onnxruntime as ort 95 import numpy as np 96 97 def validate(path, inputs, torch_out): 98 sess = ort.InferenceSession(path) 99 names = [i.name for i in sess.get_inputs()] 100 onnx_out = sess.run(None, dict(zip(names, [i.numpy() for i in inputs])))[0] 101 np.testing.assert_allclose(onnx_out, torch_out.detach().numpy(), rtol=1e-3) 102 print(f"✅ {path} validated") 103 104 dummy_t1 = torch.zeros(1, 77, dtype=torch.int64) 105 dummy_t2 = torch.zeros(1, 77, dtype=torch.int64) 106 te = TextEncoders(pipe.text_encoder, pipe.text_encoder_2) 107 validate("./sdxl-funko-pop-merged/onnx/text_encoders.onnx", [dummy_t1, dummy_t2], te(dummy_t1, dummy_t2)[0]) 108 109 print("\n✅ All done") 110 print(" ./sdxl-funko-pop-merged/onnx/text_encoders.onnx") 111 print(" ./sdxl-funko-pop-merged/onnx/unet.onnx") 112 print(" ./sdxl-funko-pop-merged/onnx/vae_encoder.onnx") 113 print(" ./sdxl-funko-pop-merged/onnx/vae_decoder.onnx")