/ src / model / onnx / vae_decoder.rs
vae_decoder.rs
 1  use std::path::Path;
 2  
 3  use ndarray::{Array, IxDyn};
 4  use ort::{inputs, session::Session, value::Value};
 5  use super::{OnnxError, OnnxModel, SessionOptions};
 6  
 7  pub struct VaeDecoder(Session);
 8  
 9  impl OnnxModel for VaeDecoder { }
10  
11  impl VaeDecoder {
12      pub fn new(path: &Path, opts: &SessionOptions) -> Result<Self, OnnxError> {
13          #[cfg(feature = "cuda")]
14          let cpu_only = opts.use_cpu_vae;
15          #[cfg(not(feature = "cuda"))]
16          let cpu_only = true;
17          Ok(Self(Self::load(&path.join("vae_decoder.onnx"), cpu_only, opts)?))
18      }
19  
20      pub fn decode(&mut self, latents: &Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, OnnxError> {
21          let outputs = self.0.run(inputs![
22              Value::from_array(latents.to_owned())?
23          ])?;
24  
25          let (shape, data) = outputs[0].try_extract_tensor::<f32>()?;
26          let dims: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
27          Ok(Array::from_shape_vec(IxDyn(&dims), data.to_vec())?)
28      }
29  }