/ src / model / scheduler.rs
scheduler.rs
 1  use ndarray::{Array, IxDyn};
 2  
 3  pub struct DDIMScheduler {
 4      alphas_cumprod:      Vec<f32>,
 5      timesteps:           Vec<usize>,
 6  }
 7  
 8  impl DDIMScheduler {
 9      pub fn new(num_inference_steps: usize) -> Self {
10          let num_train_timesteps = 1000usize;
11          let beta_start = 0.00085f32;
12          let beta_end   = 0.012f32;
13  
14          // Linear beta schedule
15          let betas: Vec<f32> = (0..num_train_timesteps)
16              .map(|i| {
17                  beta_start + (beta_end - beta_start) * i as f32 / (num_train_timesteps - 1) as f32
18              })
19              .collect();
20  
21          // Compute alphas cumprod
22          let mut alphas_cumprod = Vec::with_capacity(num_train_timesteps);
23          let mut prod = 1.0f32;
24          for &b in &betas {
25              prod *= 1.0 - b;
26              alphas_cumprod.push(prod);
27          }
28  
29          // Evenly spaced timesteps from high to low
30          let step_size = num_train_timesteps / num_inference_steps;
31          let timesteps = (0..num_inference_steps)
32              .map(|i| num_train_timesteps - 1 - i * step_size)
33              .collect();
34  
35          Self {
36              alphas_cumprod,
37              timesteps,
38          }
39      }
40  
41      pub fn timesteps(&self) -> &[usize] {
42          &self.timesteps
43      }
44  
45      /// Add noise to latents up to a given timestep (for img2img/inpaint)
46      pub fn add_noise(
47          &self,
48          latents: &Array<f32, IxDyn>,
49          noise:   &Array<f32, IxDyn>,
50          t:       usize,
51      ) -> Array<f32, IxDyn> {
52          let alpha    = self.alphas_cumprod[t];
53          let sqrt_a   = alpha.sqrt();
54          let sqrt_1ma = (1.0 - alpha).sqrt();
55          latents * sqrt_a + noise * sqrt_1ma
56      }
57  
58      /// Returns `(latent_scale, noise_scale)` for the DDIM step at timestep `t`:
59      ///   `latents_next = latents * latent_scale + guided_noise * noise_scale`
60      /// Fuse with guidance at the call site to avoid intermediate allocations.
61      pub fn step_coefficients(&self, t: usize) -> (f32, f32) {
62          let prev_t = if t > 0 {
63              self.timesteps
64                  .iter()
65                  .position(|&x| x == t)
66                  .and_then(|i| self.timesteps.get(i + 1))
67                  .copied()
68                  .unwrap_or(0)
69          } else {
70              0
71          };
72  
73          let alpha_prod      = self.alphas_cumprod[t];
74          let alpha_prod_prev = self.alphas_cumprod[prev_t];
75          let beta_prod       = 1.0 - alpha_prod;
76  
77          let latent_scale = (alpha_prod_prev / alpha_prod).sqrt();
78          let noise_scale  = (1.0 - alpha_prod_prev).sqrt()
79              - (beta_prod * alpha_prod_prev / alpha_prod).sqrt();
80  
81          (latent_scale, noise_scale)
82      }
83  
84      /// Get the start step index based on strength (for img2img/inpaint)
85      pub fn start_step(&self, strength: f32) -> usize {
86          let start = (self.timesteps.len() as f32 * (1.0 - strength)) as usize;
87          start.min(self.timesteps.len() - 1)
88      }
89  }