text_encoders.rs
1 use std::path::Path; 2 3 use ndarray::{Array, IxDyn}; 4 use ort::{session::{Session, IoBinding}, value::Value}; 5 #[cfg(feature = "cuda")] 6 use half::f16; 7 use super::{OnnxError, OnnxModel, SessionOptions}; 8 9 10 pub struct TextEncoders(pub(in crate::model) Session); 11 12 impl OnnxModel for TextEncoders {} 13 14 impl TextEncoders { 15 pub fn new(path: &Path, opts: &SessionOptions) -> Result<Self, OnnxError> { 16 Ok(Self( 17 Self::load(&path.join("text_encoders.onnx"), true, opts)? 18 )) 19 } 20 21 /// Run a batched [neg, pos] forward pass and bind the outputs 22 /// into the UNet's IoBinding. 23 /// Text encoders are always fp32. On CUDA with fp16 UNet, outputs are 24 /// cast to f16 before binding. With fp32 UNet (cpu_unet), kept as fp32. 25 #[cfg_attr(not(feature = "cuda"), allow(unused_variables))] 26 pub fn encode(&mut self, t1: Array<i64, IxDyn>, t2: Array<i64, IxDyn>, unet_binding: &mut IoBinding, unet_fp16: bool) -> Result<(), OnnxError> { 27 #[cfg(feature = "cuda")] 28 { 29 let mut binding = self.0.create_binding()?; 30 binding.bind_input("tokens1", &Value::from_array(t1)?)?; 31 binding.bind_input("tokens2", &Value::from_array(t2)?)?; 32 binding.bind_output_to_device("prompt_embeds", &self.0.allocator().memory_info())?; 33 binding.bind_output_to_device("pooled_embeds", &self.0.allocator().memory_info())?; 34 let mut out = self.0.run_binding(&binding)?; 35 let embeds_f32 = out.remove("prompt_embeds").unwrap().try_extract_array::<f32>()?.view().into_dyn().into_owned(); 36 let pooled_f32 = out.remove("pooled_embeds").unwrap().try_extract_array::<f32>()?.view().into_dyn().into_owned(); 37 if unet_fp16 { 38 unet_binding.bind_input("encoder_hidden_states", &Value::from_array(embeds_f32.mapv(f16::from_f32))?)?; 39 unet_binding.bind_input("text_embeds", &Value::from_array(pooled_f32.mapv(f16::from_f32))?)?; 40 } else { 41 unet_binding.bind_input("encoder_hidden_states", &Value::from_array(embeds_f32)?)?; 42 unet_binding.bind_input("text_embeds", &Value::from_array(pooled_f32)?)?; 43 } 44 } 45 #[cfg(not(feature = "cuda"))] 46 { 47 let v1 = Value::from_array(t1)?; 48 let v2 = Value::from_array(t2)?; 49 let outputs = self.0.run(ort::inputs![ 50 "tokens1" => v1, 51 "tokens2" => v2, 52 ])?; 53 unet_binding.bind_input("encoder_hidden_states", &outputs["prompt_embeds"])?; 54 unet_binding.bind_input("text_embeds", &outputs["pooled_embeds"])?; 55 } 56 Ok(()) 57 } 58 }