/ src / model / onnx / text_encoders.rs
text_encoders.rs
 1  use std::path::Path;
 2  
 3  use ndarray::{Array, IxDyn};
 4  use ort::{session::{Session, IoBinding}, value::Value};
 5  #[cfg(feature = "cuda")]
 6  use half::f16;
 7  use super::{OnnxError, OnnxModel, SessionOptions};
 8  
 9  
10  pub struct TextEncoders(pub(in crate::model) Session);
11  
12  impl OnnxModel for TextEncoders {}
13  
14  impl TextEncoders {
15      pub fn new(path: &Path, opts: &SessionOptions) -> Result<Self, OnnxError> {
16          Ok(Self(
17              Self::load(&path.join("text_encoders.onnx"), true, opts)?
18          ))
19      }
20  
21      /// Run a batched [neg, pos] forward pass and bind the outputs
22      /// into the UNet's IoBinding.
23      /// Text encoders are always fp32. On CUDA with fp16 UNet, outputs are
24      /// cast to f16 before binding. With fp32 UNet (cpu_unet), kept as fp32.
25      #[cfg_attr(not(feature = "cuda"), allow(unused_variables))]
26      pub fn encode(&mut self, t1: Array<i64, IxDyn>, t2: Array<i64, IxDyn>, unet_binding: &mut IoBinding, unet_fp16: bool) -> Result<(), OnnxError> {
27          #[cfg(feature = "cuda")]
28          {
29              let mut binding = self.0.create_binding()?;
30              binding.bind_input("tokens1", &Value::from_array(t1)?)?;
31              binding.bind_input("tokens2", &Value::from_array(t2)?)?;
32              binding.bind_output_to_device("prompt_embeds", &self.0.allocator().memory_info())?;
33              binding.bind_output_to_device("pooled_embeds", &self.0.allocator().memory_info())?;
34              let mut out = self.0.run_binding(&binding)?;
35              let embeds_f32 = out.remove("prompt_embeds").unwrap().try_extract_array::<f32>()?.view().into_dyn().into_owned();
36              let pooled_f32 = out.remove("pooled_embeds").unwrap().try_extract_array::<f32>()?.view().into_dyn().into_owned();
37              if unet_fp16 {
38                  unet_binding.bind_input("encoder_hidden_states", &Value::from_array(embeds_f32.mapv(f16::from_f32))?)?;
39                  unet_binding.bind_input("text_embeds", &Value::from_array(pooled_f32.mapv(f16::from_f32))?)?;
40              } else {
41                  unet_binding.bind_input("encoder_hidden_states", &Value::from_array(embeds_f32)?)?;
42                  unet_binding.bind_input("text_embeds", &Value::from_array(pooled_f32)?)?;
43              }
44          }
45          #[cfg(not(feature = "cuda"))]
46          {
47              let v1 = Value::from_array(t1)?;
48              let v2 = Value::from_array(t2)?;
49              let outputs = self.0.run(ort::inputs![
50                  "tokens1" => v1,
51                  "tokens2" => v2,
52              ])?;
53              unet_binding.bind_input("encoder_hidden_states", &outputs["prompt_embeds"])?;
54              unet_binding.bind_input("text_embeds", &outputs["pooled_embeds"])?;
55          }
56          Ok(())
57      }
58  }