mod.rs
1 mod errors; 2 mod onnx; 3 pub mod params; 4 mod scheduler; 5 mod tokenizer; 6 use image::{DynamicImage, Rgb, RgbImage, imageops::FilterType}; 7 use onnx::{OnnxError, TextEncoders, UNet, VaeEncoder, VaeDecoder}; 8 pub use onnx::SessionOptions; 9 use ndarray::{Array, IxDyn, Zip}; 10 use rand::{SeedableRng, rngs::StdRng}; 11 use rand_distr::{Distribution, StandardNormal}; 12 13 pub use errors::ModelError; 14 use scheduler::DDIMScheduler; 15 use tokenizer::ClipTokenizers; 16 17 #[cfg(feature = "txt2img")] 18 use params::TextToImageParams; 19 20 #[cfg(feature = "img2img")] 21 use params::ImageToImageParams; 22 23 #[cfg(feature = "inpaint")] 24 use params::InpaintParams; 25 26 use crate::{install::Installer}; 27 28 pub enum Mode { 29 #[cfg(feature = "txt2img")] 30 Txt2img, 31 #[cfg(feature = "img2img")] 32 Img2img, 33 #[cfg(feature = "inpaint")] 34 Inpaint, 35 #[cfg(feature = "all_pipelines")] 36 All, 37 } 38 39 pub struct FunkoDiffusionModel { 40 tokenizers: ClipTokenizers, 41 text_encoders: TextEncoders, 42 unet: UNet, 43 vae_decoder: VaeDecoder, 44 vae_encoder: Option<VaeEncoder>, 45 } 46 impl FunkoDiffusionModel { 47 pub fn new(installer: &Installer, mode: Mode, opts: SessionOptions) -> Result<Self, ModelError> { 48 if !installer.is_installed() { 49 installer.install()?; 50 } 51 52 let data_dir = installer.data_dir()?; 53 54 Ok(Self { 55 tokenizers: ClipTokenizers::new(&data_dir)?, 56 text_encoders: TextEncoders::new(&data_dir, &opts)?, 57 unet: UNet::new(&data_dir, &opts)?, 58 vae_decoder: VaeDecoder::new(&data_dir, &opts)?, 59 vae_encoder: match mode { 60 #[cfg(feature = "txt2img")] 61 Mode::Txt2img => None, 62 _ => Some(VaeEncoder::new(&data_dir, &opts)?), 63 }, 64 }) 65 } 66 67 #[cfg(feature = "txt2img")] 68 pub fn text_to_image(&mut self, params: TextToImageParams) -> Result<DynamicImage, ModelError> { 69 self.encode_prompt(¶ms.prompt, params.negative_prompt.as_deref())?; 70 71 let time_ids = Array::from_shape_vec( 72 IxDyn(&[2, 6]), 73 vec![ 74 params.height as f32, params.width as f32, 0.0, 0.0, params.height as f32, params.width as f32, 75 params.height as f32, params.width as f32, 0.0, 0.0, params.height as f32, params.width as f32, 76 ] 77 ).map_err(OnnxError::Shape)?; 78 79 let lh = (params.height / 8) as usize; 80 let lw = (params.width / 8) as usize; 81 82 self.unet.prepare(lh, lw, time_ids)?; 83 84 // Fill latents buffer with random noise 85 let mut rng = StdRng::seed_from_u64(params.seed.unwrap_or_else(rand::random)); 86 self.unet.latents_mut() 87 .map_inplace(|v| *v = Distribution::sample(&StandardNormal, &mut rng)); 88 89 // Denoising loop 90 let scheduler = DDIMScheduler::new(params.steps); 91 for &t in scheduler.timesteps() { 92 let (a, b) = scheduler.step_coefficients(t); 93 self.unet.step(t as f32, a, b, params.guidance_scale)?; 94 } 95 96 let latents = self.unet.latents_view().mapv(|v| v / 0.13025); 97 let image_tensor = self.vae_decoder.decode(&latents)?; 98 tensor_to_image(&image_tensor) 99 } 100 101 #[cfg(feature = "img2img")] 102 pub fn image_to_image(&mut self, params: ImageToImageParams) -> Result<DynamicImage, ModelError> { 103 self.encode_prompt(¶ms.prompt, params.negative_prompt.as_deref())?; 104 105 let width = params.width.unwrap_or_else(|| params.image.width()); 106 let height = params.height.unwrap_or_else(|| params.image.height()); 107 108 let time_ids = Array::from_shape_vec( 109 IxDyn(&[2, 6]), 110 vec![ 111 height as f32, width as f32, 0.0, 0.0, height as f32, width as f32, 112 height as f32, width as f32, 0.0, 0.0, height as f32, width as f32, 113 ], 114 ).map_err(OnnxError::Shape)?; 115 116 let image_tensor = image_to_tensor(¶ms.image, width, height); 117 let image_latents = self.vae_encoder.as_mut() 118 .ok_or(ModelError::VaeEncoderNotLoaded)? 119 .encode(&image_tensor)? * 0.13025; 120 121 let lh = (height / 8) as usize; 122 let lw = (width / 8) as usize; 123 self.unet.prepare(lh, lw, time_ids)?; 124 125 let shape = image_latents.shape().to_vec(); 126 let mut rng = StdRng::seed_from_u64(params.seed.unwrap_or_else(rand::random)); 127 let noise_vec: Vec<f32> = (0..shape.iter().product::<usize>()) 128 .map(|_| Distribution::sample(&StandardNormal, &mut rng)) 129 .collect(); 130 let noise = Array::from_shape_vec(IxDyn(&shape), noise_vec) 131 .map_err(OnnxError::Shape)?; 132 133 let scheduler = DDIMScheduler::new(params.steps); 134 let start = scheduler.start_step(params.strength); 135 let t_start = scheduler.timesteps()[start]; 136 let noised = scheduler.add_noise(&image_latents, &noise, t_start); 137 self.unet.latents_mut().assign(&noised); 138 139 // Denoising loop 140 for &t in &scheduler.timesteps()[start..] { 141 let (a, b) = scheduler.step_coefficients(t); 142 self.unet.step(t as f32, a, b, params.guidance_scale)?; 143 } 144 145 let latents = self.unet.latents_view().mapv(|v| v / 0.13025); 146 let image_tensor = self.vae_decoder.decode(&latents)?; 147 tensor_to_image(&image_tensor) 148 } 149 150 #[cfg(feature = "inpaint")] 151 pub fn inpaint(&mut self, params: InpaintParams) -> Result<DynamicImage, ModelError> { 152 self.encode_prompt(¶ms.prompt, params.negative_prompt.as_deref())?; 153 154 let width = params.width.unwrap_or_else(|| params.image.width()); 155 let height = params.height.unwrap_or_else(|| params.image.height()); 156 157 let time_ids = Array::from_shape_vec( 158 IxDyn(&[2, 6]), 159 vec![ 160 height as f32, width as f32, 0.0, 0.0, height as f32, width as f32, 161 height as f32, width as f32, 0.0, 0.0, height as f32, width as f32, 162 ], 163 ).map_err(OnnxError::Shape)?; 164 165 let image_tensor = image_to_tensor(¶ms.image, width, height); 166 let image_latents = self.vae_encoder.as_mut() 167 .ok_or(ModelError::VaeEncoderNotLoaded)? 168 .encode(&image_tensor)? * 0.13025; 169 170 let lh = (height / 8) as usize; 171 let lw = (width / 8) as usize; 172 self.unet.prepare(lh, lw, time_ids)?; 173 174 // Prepare mask 175 let mask_img = params.mask 176 .resize_exact(lw as u32, lh as u32, FilterType::Nearest) 177 .to_luma8(); 178 let mask = Array::from_shape_fn(IxDyn(&[1, 1, lh, lw]), |idx| { 179 let y = idx[2]; 180 let x = idx[3]; 181 if mask_img.get_pixel(x as u32, y as u32)[0] > 127 { 1.0f32 } else { 0.0 } 182 }); 183 184 let shape = image_latents.shape().to_vec(); 185 let mut rng = StdRng::seed_from_u64(params.seed.unwrap_or_else(rand::random)); 186 let noise_vec: Vec<f32> = (0..shape.iter().product::<usize>()) 187 .map(|_| Distribution::sample(&StandardNormal, &mut rng)) 188 .collect(); 189 let noise = Array::from_shape_vec(IxDyn(&shape), noise_vec) 190 .map_err(OnnxError::Shape)?; 191 192 let scheduler = DDIMScheduler::new(params.steps); 193 let start = scheduler.start_step(params.strength); 194 let t_start = scheduler.timesteps()[start]; 195 let noised = scheduler.add_noise(&image_latents, &noise, t_start); 196 self.unet.latents_mut().assign(&noised); 197 198 // Denoising loop 199 for &t in &scheduler.timesteps()[start..] { 200 // Apply mask in-place: blend current latents with original image latents 201 Zip::from(self.unet.latents_mut()) 202 .and(&mask) 203 .and(&image_latents) 204 .for_each(|l, &m, &il| { 205 *l = *l * m + il * (1.0 - m); 206 }); 207 208 let (a, b) = scheduler.step_coefficients(t); 209 self.unet.step(t as f32, a, b, params.guidance_scale)?; 210 } 211 212 let latents = self.unet.latents_view().mapv(|v| v / 0.13025); 213 let image_tensor = self.vae_decoder.decode(&latents)?; 214 215 let output = tensor_to_image(&image_tensor)?; 216 217 let output_rgb = output.to_rgb8(); 218 let input_rgb = params.image 219 .resize_exact(width, height, FilterType::Lanczos3).to_rgb8(); 220 let mask_full = params.mask 221 .resize_exact(width, height, FilterType::Nearest).to_luma8(); 222 223 let mut blended = RgbImage::new(width, height); 224 for y in 0..height { 225 for x in 0..width { 226 let m = mask_full.get_pixel(x, y)[0] as f32 / 255.0; 227 let o = output_rgb.get_pixel(x, y); 228 let i = input_rgb.get_pixel(x, y); 229 blended.put_pixel(x, y, Rgb([ 230 (o[0] as f32 * m + i[0] as f32 * (1.0 - m)) as u8, 231 (o[1] as f32 * m + i[1] as f32 * (1.0 - m)) as u8, 232 (o[2] as f32 * m + i[2] as f32 * (1.0 - m)) as u8, 233 ])); 234 } 235 } 236 Ok(DynamicImage::ImageRgb8(blended)) 237 } 238 239 fn encode_prompt(&mut self, prompt: &str, negative_prompt: Option<&str>) -> Result<(), ModelError> { 240 let (pos_tokens1, pos_tokens2) = self.tokenizers.encode(prompt)?; 241 let (neg_tokens1, neg_tokens2) = self.tokenizers.encode(negative_prompt.unwrap_or(""))?; 242 243 // Stack neg and pos into batch=2 — [neg, pos] 244 let t1 = Array::from_shape_vec((2, 77), 245 neg_tokens1.into_iter().chain(pos_tokens1).collect() 246 ).map_err(OnnxError::Shape)?.into_dyn(); 247 let t2 = Array::from_shape_vec((2, 77), 248 neg_tokens2.into_iter().chain(pos_tokens2).collect() 249 ).map_err(OnnxError::Shape)?.into_dyn(); 250 251 let unet_fp16 = self.unet.is_fp16(); 252 self.text_encoders.encode(t1, t2, &mut self.unet.binding, unet_fp16)?; 253 254 Ok(()) 255 } 256 } 257 258 fn image_to_tensor(image: &DynamicImage, width: u32, height: u32) -> Array<f32, IxDyn> { 259 let rgb = image.resize_exact(width, height, FilterType::Lanczos3).to_rgb8(); 260 Array::from_shape_fn(IxDyn(&[1, 3, height as usize, width as usize]), |idx| { 261 let c = idx[1]; 262 let y = idx[2]; 263 let x = idx[3]; 264 rgb.get_pixel(x as u32, y as u32)[c] as f32 / 255.0 * 2.0 - 1.0 265 }) 266 } 267 268 fn tensor_to_image(tensor: &Array<f32, IxDyn>) -> Result<DynamicImage, ModelError> { 269 let shape = tensor.shape(); 270 let h = shape[2]; 271 let w = shape[3]; 272 273 let mut img = RgbImage::new(w as u32, h as u32); 274 for y in 0..h { 275 for x in 0..w { 276 let r = ((tensor[[0, 0, y, x]] + 1.0) * 127.5).clamp(0.0, 255.0) as u8; 277 let g = ((tensor[[0, 1, y, x]] + 1.0) * 127.5).clamp(0.0, 255.0) as u8; 278 let b = ((tensor[[0, 2, y, x]] + 1.0) * 127.5).clamp(0.0, 255.0) as u8; 279 img.put_pixel(x as u32, y as u32, Rgb([r, g, b])); 280 } 281 } 282 Ok(DynamicImage::ImageRgb8(img)) 283 }