unet.rs
1 use std::path::Path; 2 3 use ndarray::{Array, ArrayView, ArrayViewMut, Axis, IxDyn, Slice, Zip}; 4 #[cfg(feature = "cuda")] 5 use half::f16; 6 use ort::{ 7 session::{Session, IoBinding}, 8 value::Value, 9 }; 10 11 use super::{OnnxError, OnnxModel, SessionOptions}; 12 13 pub struct UNet { 14 session: Session, 15 pub(in crate::model) binding: IoBinding, 16 latents: Option<Array<f32, IxDyn>>, 17 #[cfg(feature = "cuda")] 18 fp16: bool, 19 } 20 21 impl OnnxModel for UNet {} 22 23 impl UNet { 24 pub fn new(path: &Path, opts: &SessionOptions) -> Result<Self, OnnxError> { 25 #[cfg(feature = "cuda")] 26 let (file, cpu_only) = if opts.use_cpu_unet { 27 ("unet.onnx", true) 28 } else { 29 ("unet_fp16.onnx", false) 30 }; 31 #[cfg(not(feature = "cuda"))] 32 let (file, cpu_only) = ("unet.onnx", true); 33 let session = Self::load(&path.join(file), cpu_only, opts)?; 34 let binding = session.create_binding()?; 35 Ok(Self { 36 session, 37 binding, 38 latents: None, 39 #[cfg(feature = "cuda")] 40 fp16: !opts.use_cpu_unet, 41 }) 42 } 43 44 pub fn is_fp16(&self) -> bool { 45 #[cfg(feature = "cuda")] 46 return self.fp16; 47 #[cfg(not(feature = "cuda"))] 48 return false; 49 } 50 51 /// Called once per inference. Allocates the latents buffer, binds the 52 /// constant time_ids input, and pre-binds the noise_pred output. 53 pub fn prepare(&mut self, lh: usize, lw: usize, time_ids: Array<f32, IxDyn>) -> Result<(), OnnxError> { 54 self.latents = Some(Array::from_elem(IxDyn(&[2, 4, lh, lw]), 0.0f32)); 55 56 #[cfg(feature = "cuda")] 57 if self.fp16 { 58 let time_ids_f16 = time_ids.mapv(f16::from_f32); 59 self.binding.bind_input("time_ids", &Value::from_array(time_ids_f16)?)?; 60 } else { 61 self.binding.bind_input("time_ids", &Value::from_array(time_ids)?)?; 62 } 63 #[cfg(not(feature = "cuda"))] 64 { 65 self.binding.bind_input("time_ids", &Value::from_array(time_ids)?)?; 66 } 67 68 // Bind output once — reused across all denoising steps. 69 self.binding.bind_output_to_device("noise_pred", &self.session.allocator().memory_info())?; 70 71 Ok(()) 72 } 73 74 /// Mutable view into the first half of the latents [1, 4, lh, lw]. 75 pub fn latents_mut(&mut self) -> ArrayViewMut<'_, f32, IxDyn> { 76 self.latents.as_mut() 77 .expect("call prepare() before latents_mut()") 78 .view_mut() 79 .slice_axis_move(Axis(0), Slice::from(0..1)) 80 } 81 82 /// Immutable view into the current latents [1, 4, lh, lw]. 83 pub fn latents_view(&self) -> ArrayView<'_, f32, IxDyn> { 84 self.latents.as_ref() 85 .expect("call prepare() before latents_view()") 86 .view() 87 .slice_axis_move(Axis(0), Slice::from(0..1)) 88 } 89 90 /// Run a single denoising step. 91 /// Duplicates latents for CFG, runs the UNet, then updates latents in-place. 92 pub fn step( 93 &mut self, 94 timestep: f32, 95 a: f32, 96 b: f32, 97 guidance_scale: f32, 98 ) -> Result<(), OnnxError> { 99 let arr = self.latents.as_mut().unwrap(); 100 101 // CFG: copy first half into second half 102 { 103 let (first, mut second) = arr.view_mut().split_at(Axis(0), 1); 104 second.assign(&first); 105 } 106 107 #[cfg(feature = "cuda")] 108 if self.fp16 { 109 let latents_f16 = arr.mapv(f16::from_f32); 110 let timestep_f16 = Array::from_elem(IxDyn(&[1]), f16::from_f32(timestep)); 111 let latents_val = Value::from_array(latents_f16)?; 112 let ts_val = Value::from_array(timestep_f16)?; 113 self.binding.bind_input("latents", &latents_val)?; 114 self.binding.bind_input("timestep", &ts_val)?; 115 116 let mut outputs = self.session.run_binding(&self.binding)?; 117 let noise_pred_val = outputs.remove("noise_pred").unwrap(); 118 let noise_pred_f16 = noise_pred_val.try_extract_array::<f16>()?; 119 let noise_pred = noise_pred_f16.view().mapv(f32::from); 120 let (noise_uncond, noise_cond) = noise_pred.view().split_at(Axis(0), 1); 121 122 let latents = arr.view_mut().slice_axis_move(Axis(0), Slice::from(0..1)); 123 Zip::from(latents) 124 .and(&noise_uncond) 125 .and(&noise_cond) 126 .for_each(|l, &u, &c| { 127 *l = *l * a + (u + guidance_scale * (c - u)) * b; 128 }); 129 } else { 130 let latents_val = Value::from_array(arr.clone())?; 131 let ts_val = Value::from_array(Array::from_elem(IxDyn(&[1]), timestep))?; 132 self.binding.bind_input("latents", &latents_val)?; 133 self.binding.bind_input("timestep", &ts_val)?; 134 135 let mut outputs = self.session.run_binding(&self.binding)?; 136 let noise_pred_val = outputs.remove("noise_pred").unwrap(); 137 let noise_pred = noise_pred_val.try_extract_array::<f32>()?; 138 let (noise_uncond, noise_cond) = noise_pred.view().split_at(Axis(0), 1); 139 140 let latents = arr.view_mut().slice_axis_move(Axis(0), Slice::from(0..1)); 141 Zip::from(latents) 142 .and(&noise_uncond) 143 .and(&noise_cond) 144 .for_each(|l, &u, &c| { 145 *l = *l * a + (u + guidance_scale * (c - u)) * b; 146 }); 147 } 148 #[cfg(not(feature = "cuda"))] 149 { 150 let latents_val = Value::from_array(arr.clone())?; 151 let ts_val = Value::from_array(Array::from_elem(IxDyn(&[1]), timestep))?; 152 self.binding.bind_input("latents", &latents_val)?; 153 self.binding.bind_input("timestep", &ts_val)?; 154 155 let mut outputs = self.session.run_binding(&self.binding)?; 156 let noise_pred_val = outputs.remove("noise_pred").unwrap(); 157 let noise_pred = noise_pred_val.try_extract_array::<f32>()?; 158 let (noise_uncond, noise_cond) = noise_pred.view().split_at(Axis(0), 1); 159 160 let latents = arr.view_mut().slice_axis_move(Axis(0), Slice::from(0..1)); 161 Zip::from(latents) 162 .and(&noise_uncond) 163 .and(&noise_cond) 164 .for_each(|l, &u, &c| { 165 *l = *l * a + (u + guidance_scale * (c - u)) * b; 166 }); 167 } 168 169 Ok(()) 170 } 171 }