scheduler.rs
1 use ndarray::{Array, IxDyn}; 2 3 pub struct DDIMScheduler { 4 alphas_cumprod: Vec<f32>, 5 timesteps: Vec<usize>, 6 } 7 8 impl DDIMScheduler { 9 pub fn new(num_inference_steps: usize) -> Self { 10 let num_train_timesteps = 1000usize; 11 let beta_start = 0.00085f32; 12 let beta_end = 0.012f32; 13 14 // Linear beta schedule 15 let betas: Vec<f32> = (0..num_train_timesteps) 16 .map(|i| { 17 beta_start + (beta_end - beta_start) * i as f32 / (num_train_timesteps - 1) as f32 18 }) 19 .collect(); 20 21 // Compute alphas cumprod 22 let mut alphas_cumprod = Vec::with_capacity(num_train_timesteps); 23 let mut prod = 1.0f32; 24 for &b in &betas { 25 prod *= 1.0 - b; 26 alphas_cumprod.push(prod); 27 } 28 29 // Evenly spaced timesteps from high to low 30 let step_size = num_train_timesteps / num_inference_steps; 31 let timesteps = (0..num_inference_steps) 32 .map(|i| num_train_timesteps - 1 - i * step_size) 33 .collect(); 34 35 Self { 36 alphas_cumprod, 37 timesteps, 38 } 39 } 40 41 pub fn timesteps(&self) -> &[usize] { 42 &self.timesteps 43 } 44 45 /// Add noise to latents up to a given timestep (for img2img/inpaint) 46 pub fn add_noise( 47 &self, 48 latents: &Array<f32, IxDyn>, 49 noise: &Array<f32, IxDyn>, 50 t: usize, 51 ) -> Array<f32, IxDyn> { 52 let alpha = self.alphas_cumprod[t]; 53 let sqrt_a = alpha.sqrt(); 54 let sqrt_1ma = (1.0 - alpha).sqrt(); 55 latents * sqrt_a + noise * sqrt_1ma 56 } 57 58 /// Returns `(latent_scale, noise_scale)` for the DDIM step at timestep `t`: 59 /// `latents_next = latents * latent_scale + guided_noise * noise_scale` 60 /// Fuse with guidance at the call site to avoid intermediate allocations. 61 pub fn step_coefficients(&self, t: usize) -> (f32, f32) { 62 let prev_t = if t > 0 { 63 self.timesteps 64 .iter() 65 .position(|&x| x == t) 66 .and_then(|i| self.timesteps.get(i + 1)) 67 .copied() 68 .unwrap_or(0) 69 } else { 70 0 71 }; 72 73 let alpha_prod = self.alphas_cumprod[t]; 74 let alpha_prod_prev = self.alphas_cumprod[prev_t]; 75 let beta_prod = 1.0 - alpha_prod; 76 77 let latent_scale = (alpha_prod_prev / alpha_prod).sqrt(); 78 let noise_scale = (1.0 - alpha_prod_prev).sqrt() 79 - (beta_prod * alpha_prod_prev / alpha_prod).sqrt(); 80 81 (latent_scale, noise_scale) 82 } 83 84 /// Get the start step index based on strength (for img2img/inpaint) 85 pub fn start_step(&self, strength: f32) -> usize { 86 let start = (self.timesteps.len() as f32 * (1.0 - strength)) as usize; 87 start.min(self.timesteps.len() - 1) 88 } 89 }