/ src / model / onnx / unet.rs
unet.rs
  1  use std::path::Path;
  2  
  3  use ndarray::{Array, ArrayView, ArrayViewMut, Axis, IxDyn, Slice, Zip};
  4  #[cfg(feature = "cuda")]
  5  use half::f16;
  6  use ort::{
  7      session::{Session, IoBinding},
  8      value::Value,
  9  };
 10  
 11  use super::{OnnxError, OnnxModel, SessionOptions};
 12  
 13  pub struct UNet {
 14      session: Session,
 15      pub(in crate::model) binding: IoBinding,
 16      latents: Option<Array<f32, IxDyn>>,
 17      #[cfg(feature = "cuda")]
 18      fp16: bool,
 19  }
 20  
 21  impl OnnxModel for UNet {}
 22  
 23  impl UNet {
 24      pub fn new(path: &Path, opts: &SessionOptions) -> Result<Self, OnnxError> {
 25          #[cfg(feature = "cuda")]
 26          let (file, cpu_only) = if opts.use_cpu_unet {
 27              ("unet.onnx", true)
 28          } else {
 29              ("unet_fp16.onnx", false)
 30          };
 31          #[cfg(not(feature = "cuda"))]
 32          let (file, cpu_only) = ("unet.onnx", true);
 33          let session = Self::load(&path.join(file), cpu_only, opts)?;
 34          let binding = session.create_binding()?;
 35          Ok(Self {
 36              session,
 37              binding,
 38              latents: None,
 39              #[cfg(feature = "cuda")]
 40              fp16: !opts.use_cpu_unet,
 41          })
 42      }
 43  
 44      pub fn is_fp16(&self) -> bool {
 45          #[cfg(feature = "cuda")]
 46          return self.fp16;
 47          #[cfg(not(feature = "cuda"))]
 48          return false;
 49      }
 50  
 51      /// Called once per inference. Allocates the latents buffer, binds the
 52      /// constant time_ids input, and pre-binds the noise_pred output.
 53      pub fn prepare(&mut self, lh: usize, lw: usize, time_ids: Array<f32, IxDyn>) -> Result<(), OnnxError> {
 54          self.latents = Some(Array::from_elem(IxDyn(&[2, 4, lh, lw]), 0.0f32));
 55  
 56          #[cfg(feature = "cuda")]
 57          if self.fp16 {
 58              let time_ids_f16 = time_ids.mapv(f16::from_f32);
 59              self.binding.bind_input("time_ids", &Value::from_array(time_ids_f16)?)?;
 60          } else {
 61              self.binding.bind_input("time_ids", &Value::from_array(time_ids)?)?;
 62          }
 63          #[cfg(not(feature = "cuda"))]
 64          {
 65              self.binding.bind_input("time_ids", &Value::from_array(time_ids)?)?;
 66          }
 67  
 68          // Bind output once — reused across all denoising steps.
 69          self.binding.bind_output_to_device("noise_pred", &self.session.allocator().memory_info())?;
 70  
 71          Ok(())
 72      }
 73  
 74      /// Mutable view into the first half of the latents [1, 4, lh, lw].
 75      pub fn latents_mut(&mut self) -> ArrayViewMut<'_, f32, IxDyn> {
 76          self.latents.as_mut()
 77              .expect("call prepare() before latents_mut()")
 78              .view_mut()
 79              .slice_axis_move(Axis(0), Slice::from(0..1))
 80      }
 81  
 82      /// Immutable view into the current latents [1, 4, lh, lw].
 83      pub fn latents_view(&self) -> ArrayView<'_, f32, IxDyn> {
 84          self.latents.as_ref()
 85              .expect("call prepare() before latents_view()")
 86              .view()
 87              .slice_axis_move(Axis(0), Slice::from(0..1))
 88      }
 89  
 90      /// Run a single denoising step.
 91      /// Duplicates latents for CFG, runs the UNet, then updates latents in-place.
 92      pub fn step(
 93          &mut self,
 94          timestep: f32,
 95          a: f32,
 96          b: f32,
 97          guidance_scale: f32,
 98      ) -> Result<(), OnnxError> {
 99          let arr = self.latents.as_mut().unwrap();
100  
101          // CFG: copy first half into second half
102          {
103              let (first, mut second) = arr.view_mut().split_at(Axis(0), 1);
104              second.assign(&first);
105          }
106  
107          #[cfg(feature = "cuda")]
108          if self.fp16 {
109              let latents_f16 = arr.mapv(f16::from_f32);
110              let timestep_f16 = Array::from_elem(IxDyn(&[1]), f16::from_f32(timestep));
111              let latents_val = Value::from_array(latents_f16)?;
112              let ts_val = Value::from_array(timestep_f16)?;
113              self.binding.bind_input("latents", &latents_val)?;
114              self.binding.bind_input("timestep", &ts_val)?;
115  
116              let mut outputs = self.session.run_binding(&self.binding)?;
117              let noise_pred_val = outputs.remove("noise_pred").unwrap();
118              let noise_pred_f16 = noise_pred_val.try_extract_array::<f16>()?;
119              let noise_pred = noise_pred_f16.view().mapv(f32::from);
120              let (noise_uncond, noise_cond) = noise_pred.view().split_at(Axis(0), 1);
121  
122              let latents = arr.view_mut().slice_axis_move(Axis(0), Slice::from(0..1));
123              Zip::from(latents)
124                  .and(&noise_uncond)
125                  .and(&noise_cond)
126                  .for_each(|l, &u, &c| {
127                      *l = *l * a + (u + guidance_scale * (c - u)) * b;
128                  });
129          } else {
130              let latents_val = Value::from_array(arr.clone())?;
131              let ts_val = Value::from_array(Array::from_elem(IxDyn(&[1]), timestep))?;
132              self.binding.bind_input("latents", &latents_val)?;
133              self.binding.bind_input("timestep", &ts_val)?;
134  
135              let mut outputs = self.session.run_binding(&self.binding)?;
136              let noise_pred_val = outputs.remove("noise_pred").unwrap();
137              let noise_pred = noise_pred_val.try_extract_array::<f32>()?;
138              let (noise_uncond, noise_cond) = noise_pred.view().split_at(Axis(0), 1);
139  
140              let latents = arr.view_mut().slice_axis_move(Axis(0), Slice::from(0..1));
141              Zip::from(latents)
142                  .and(&noise_uncond)
143                  .and(&noise_cond)
144                  .for_each(|l, &u, &c| {
145                      *l = *l * a + (u + guidance_scale * (c - u)) * b;
146                  });
147          }
148          #[cfg(not(feature = "cuda"))]
149          {
150              let latents_val = Value::from_array(arr.clone())?;
151              let ts_val = Value::from_array(Array::from_elem(IxDyn(&[1]), timestep))?;
152              self.binding.bind_input("latents", &latents_val)?;
153              self.binding.bind_input("timestep", &ts_val)?;
154  
155              let mut outputs = self.session.run_binding(&self.binding)?;
156              let noise_pred_val = outputs.remove("noise_pred").unwrap();
157              let noise_pred = noise_pred_val.try_extract_array::<f32>()?;
158              let (noise_uncond, noise_cond) = noise_pred.view().split_at(Axis(0), 1);
159  
160              let latents = arr.view_mut().slice_axis_move(Axis(0), Slice::from(0..1));
161              Zip::from(latents)
162                  .and(&noise_uncond)
163                  .and(&noise_cond)
164                  .for_each(|l, &u, &c| {
165                      *l = *l * a + (u + guidance_scale * (c - u)) * b;
166                  });
167          }
168  
169          Ok(())
170      }
171  }