/ src / model / mod.rs
mod.rs
  1  mod errors;
  2  mod onnx;
  3  pub mod params;
  4  mod scheduler;
  5  mod tokenizer;
  6  use image::{DynamicImage, Rgb, RgbImage, imageops::FilterType};
  7  use onnx::{OnnxError, TextEncoders, UNet, VaeEncoder, VaeDecoder};
  8  pub use onnx::SessionOptions;
  9  use ndarray::{Array, IxDyn, Zip};
 10  use rand::{SeedableRng, rngs::StdRng};
 11  use rand_distr::{Distribution, StandardNormal};
 12  
 13  pub use errors::ModelError;
 14  use scheduler::DDIMScheduler;
 15  use tokenizer::ClipTokenizers;
 16  
 17  #[cfg(feature = "txt2img")]
 18  use params::TextToImageParams;
 19  
 20  #[cfg(feature = "img2img")]
 21  use params::ImageToImageParams;
 22  
 23  #[cfg(feature = "inpaint")]
 24  use params::InpaintParams;
 25  
 26  use crate::{install::Installer};
 27  
 28  pub enum Mode {
 29      #[cfg(feature = "txt2img")]
 30      Txt2img,
 31      #[cfg(feature = "img2img")]
 32      Img2img,
 33      #[cfg(feature = "inpaint")]
 34      Inpaint,
 35      #[cfg(feature = "all_pipelines")]
 36      All,
 37  }
 38  
 39  pub struct FunkoDiffusionModel {
 40      tokenizers: ClipTokenizers,
 41      text_encoders: TextEncoders,
 42      unet: UNet,
 43      vae_decoder: VaeDecoder,
 44      vae_encoder: Option<VaeEncoder>,
 45  }
 46  impl FunkoDiffusionModel {
 47      pub fn new(installer: &Installer, mode: Mode, opts: SessionOptions) -> Result<Self, ModelError> {
 48          if !installer.is_installed() {
 49              installer.install()?;
 50          }
 51  
 52          let data_dir = installer.data_dir()?;
 53  
 54          Ok(Self {
 55              tokenizers: ClipTokenizers::new(&data_dir)?,
 56              text_encoders: TextEncoders::new(&data_dir, &opts)?,
 57              unet: UNet::new(&data_dir, &opts)?,
 58              vae_decoder: VaeDecoder::new(&data_dir, &opts)?,
 59              vae_encoder: match mode {
 60                  #[cfg(feature = "txt2img")]
 61                  Mode::Txt2img => None,
 62                  _ => Some(VaeEncoder::new(&data_dir, &opts)?),
 63              },
 64          })
 65      }
 66  
 67      #[cfg(feature = "txt2img")]
 68      pub fn text_to_image(&mut self, params: TextToImageParams) -> Result<DynamicImage, ModelError> {
 69          self.encode_prompt(&params.prompt, params.negative_prompt.as_deref())?;
 70  
 71          let time_ids = Array::from_shape_vec(
 72              IxDyn(&[2, 6]),
 73              vec![
 74                  params.height as f32, params.width as f32, 0.0, 0.0, params.height as f32, params.width as f32,
 75                  params.height as f32, params.width as f32, 0.0, 0.0, params.height as f32, params.width as f32,
 76              ]
 77          ).map_err(OnnxError::Shape)?;
 78  
 79          let lh = (params.height / 8) as usize;
 80          let lw = (params.width / 8) as usize;
 81  
 82          self.unet.prepare(lh, lw, time_ids)?;
 83  
 84          // Fill latents buffer with random noise
 85          let mut rng = StdRng::seed_from_u64(params.seed.unwrap_or_else(rand::random));
 86          self.unet.latents_mut()
 87              .map_inplace(|v| *v = Distribution::sample(&StandardNormal, &mut rng));
 88  
 89          // Denoising loop
 90          let scheduler = DDIMScheduler::new(params.steps);
 91          for &t in scheduler.timesteps() {
 92              let (a, b) = scheduler.step_coefficients(t);
 93              self.unet.step(t as f32, a, b, params.guidance_scale)?;
 94          }
 95  
 96          let latents = self.unet.latents_view().mapv(|v| v / 0.13025);
 97          let image_tensor = self.vae_decoder.decode(&latents)?;
 98          tensor_to_image(&image_tensor)
 99      }
100  
101      #[cfg(feature = "img2img")]
102      pub fn image_to_image(&mut self, params: ImageToImageParams) -> Result<DynamicImage, ModelError> {
103          self.encode_prompt(&params.prompt, params.negative_prompt.as_deref())?;
104  
105          let width = params.width.unwrap_or_else(|| params.image.width());
106          let height = params.height.unwrap_or_else(|| params.image.height());
107  
108          let time_ids = Array::from_shape_vec(
109              IxDyn(&[2, 6]),
110              vec![
111                  height as f32, width as f32, 0.0, 0.0, height as f32, width as f32,
112                  height as f32, width as f32, 0.0, 0.0, height as f32, width as f32,
113              ],
114          ).map_err(OnnxError::Shape)?;
115  
116          let image_tensor = image_to_tensor(&params.image, width, height);
117          let image_latents = self.vae_encoder.as_mut()
118              .ok_or(ModelError::VaeEncoderNotLoaded)?
119              .encode(&image_tensor)? * 0.13025;
120  
121          let lh = (height / 8) as usize;
122          let lw = (width / 8) as usize;
123          self.unet.prepare(lh, lw, time_ids)?;
124  
125          let shape = image_latents.shape().to_vec();
126          let mut rng = StdRng::seed_from_u64(params.seed.unwrap_or_else(rand::random));
127          let noise_vec: Vec<f32> = (0..shape.iter().product::<usize>())
128              .map(|_| Distribution::sample(&StandardNormal, &mut rng))
129              .collect();
130          let noise = Array::from_shape_vec(IxDyn(&shape), noise_vec)
131              .map_err(OnnxError::Shape)?;
132  
133          let scheduler = DDIMScheduler::new(params.steps);
134          let start = scheduler.start_step(params.strength);
135          let t_start = scheduler.timesteps()[start];
136          let noised = scheduler.add_noise(&image_latents, &noise, t_start);
137          self.unet.latents_mut().assign(&noised);
138  
139          // Denoising loop
140          for &t in &scheduler.timesteps()[start..] {
141              let (a, b) = scheduler.step_coefficients(t);
142              self.unet.step(t as f32, a, b, params.guidance_scale)?;
143          }
144  
145          let latents = self.unet.latents_view().mapv(|v| v / 0.13025);
146          let image_tensor = self.vae_decoder.decode(&latents)?;
147          tensor_to_image(&image_tensor)
148      }
149  
150      #[cfg(feature = "inpaint")]
151      pub fn inpaint(&mut self, params: InpaintParams) -> Result<DynamicImage, ModelError> {
152          self.encode_prompt(&params.prompt, params.negative_prompt.as_deref())?;
153  
154          let width = params.width.unwrap_or_else(|| params.image.width());
155          let height = params.height.unwrap_or_else(|| params.image.height());
156  
157          let time_ids = Array::from_shape_vec(
158              IxDyn(&[2, 6]),
159              vec![
160                  height as f32, width as f32, 0.0, 0.0, height as f32, width as f32,
161                  height as f32, width as f32, 0.0, 0.0, height as f32, width as f32,
162              ],
163          ).map_err(OnnxError::Shape)?;
164  
165          let image_tensor = image_to_tensor(&params.image, width, height);
166          let image_latents = self.vae_encoder.as_mut()
167              .ok_or(ModelError::VaeEncoderNotLoaded)?
168              .encode(&image_tensor)? * 0.13025;
169  
170          let lh = (height / 8) as usize;
171          let lw = (width / 8) as usize;
172          self.unet.prepare(lh, lw, time_ids)?;
173  
174          // Prepare mask
175          let mask_img = params.mask
176              .resize_exact(lw as u32, lh as u32, FilterType::Nearest)
177              .to_luma8();
178          let mask = Array::from_shape_fn(IxDyn(&[1, 1, lh, lw]), |idx| {
179              let y = idx[2];
180              let x = idx[3];
181              if mask_img.get_pixel(x as u32, y as u32)[0] > 127 { 1.0f32 } else { 0.0 }
182          });
183  
184          let shape = image_latents.shape().to_vec();
185          let mut rng = StdRng::seed_from_u64(params.seed.unwrap_or_else(rand::random));
186          let noise_vec: Vec<f32> = (0..shape.iter().product::<usize>())
187              .map(|_| Distribution::sample(&StandardNormal, &mut rng))
188              .collect();
189          let noise = Array::from_shape_vec(IxDyn(&shape), noise_vec)
190              .map_err(OnnxError::Shape)?;
191  
192          let scheduler = DDIMScheduler::new(params.steps);
193          let start = scheduler.start_step(params.strength);
194          let t_start = scheduler.timesteps()[start];
195          let noised = scheduler.add_noise(&image_latents, &noise, t_start);
196          self.unet.latents_mut().assign(&noised);
197  
198          // Denoising loop
199          for &t in &scheduler.timesteps()[start..] {
200              // Apply mask in-place: blend current latents with original image latents
201              Zip::from(self.unet.latents_mut())
202                  .and(&mask)
203                  .and(&image_latents)
204                  .for_each(|l, &m, &il| {
205                      *l = *l * m + il * (1.0 - m);
206                  });
207  
208              let (a, b) = scheduler.step_coefficients(t);
209              self.unet.step(t as f32, a, b, params.guidance_scale)?;
210          }
211  
212          let latents = self.unet.latents_view().mapv(|v| v / 0.13025);
213          let image_tensor = self.vae_decoder.decode(&latents)?;
214  
215          let output = tensor_to_image(&image_tensor)?;
216  
217          let output_rgb = output.to_rgb8();
218          let input_rgb = params.image
219              .resize_exact(width, height, FilterType::Lanczos3).to_rgb8();
220          let mask_full = params.mask
221              .resize_exact(width, height, FilterType::Nearest).to_luma8();
222  
223          let mut blended = RgbImage::new(width, height);
224          for y in 0..height {
225              for x in 0..width {
226                  let m = mask_full.get_pixel(x, y)[0] as f32 / 255.0;
227                  let o = output_rgb.get_pixel(x, y);
228                  let i = input_rgb.get_pixel(x, y);
229                  blended.put_pixel(x, y, Rgb([
230                      (o[0] as f32 * m + i[0] as f32 * (1.0 - m)) as u8,
231                      (o[1] as f32 * m + i[1] as f32 * (1.0 - m)) as u8,
232                      (o[2] as f32 * m + i[2] as f32 * (1.0 - m)) as u8,
233                  ]));
234              }
235          }
236          Ok(DynamicImage::ImageRgb8(blended))
237      }
238  
239      fn encode_prompt(&mut self, prompt: &str, negative_prompt: Option<&str>) -> Result<(), ModelError> {
240          let (pos_tokens1, pos_tokens2) = self.tokenizers.encode(prompt)?;
241          let (neg_tokens1, neg_tokens2) = self.tokenizers.encode(negative_prompt.unwrap_or(""))?;
242  
243          // Stack neg and pos into batch=2 — [neg, pos]
244          let t1 = Array::from_shape_vec((2, 77),
245              neg_tokens1.into_iter().chain(pos_tokens1).collect()
246          ).map_err(OnnxError::Shape)?.into_dyn();
247          let t2 = Array::from_shape_vec((2, 77),
248              neg_tokens2.into_iter().chain(pos_tokens2).collect()
249          ).map_err(OnnxError::Shape)?.into_dyn();
250  
251          let unet_fp16 = self.unet.is_fp16();
252          self.text_encoders.encode(t1, t2, &mut self.unet.binding, unet_fp16)?;
253  
254          Ok(())
255      }
256  }
257  
258  fn image_to_tensor(image: &DynamicImage, width: u32, height: u32) -> Array<f32, IxDyn> {
259      let rgb = image.resize_exact(width, height, FilterType::Lanczos3).to_rgb8();
260      Array::from_shape_fn(IxDyn(&[1, 3, height as usize, width as usize]), |idx| {
261          let c = idx[1];
262          let y = idx[2];
263          let x = idx[3];
264          rgb.get_pixel(x as u32, y as u32)[c] as f32 / 255.0 * 2.0 - 1.0
265      })
266  }
267  
268  fn tensor_to_image(tensor: &Array<f32, IxDyn>) -> Result<DynamicImage, ModelError> {
269      let shape = tensor.shape();
270      let h = shape[2];
271      let w = shape[3];
272  
273      let mut img = RgbImage::new(w as u32, h as u32);
274      for y in 0..h {
275          for x in 0..w {
276              let r = ((tensor[[0, 0, y, x]] + 1.0) * 127.5).clamp(0.0, 255.0) as u8;
277              let g = ((tensor[[0, 1, y, x]] + 1.0) * 127.5).clamp(0.0, 255.0) as u8;
278              let b = ((tensor[[0, 2, y, x]] + 1.0) * 127.5).clamp(0.0, 255.0) as u8;
279              img.put_pixel(x as u32, y as u32, Rgb([r, g, b]));
280          }
281      }
282      Ok(DynamicImage::ImageRgb8(img))
283  }