vae_decoder.rs
1 use std::path::Path; 2 3 use ndarray::{Array, IxDyn}; 4 use ort::{inputs, session::Session, value::Value}; 5 use super::{OnnxError, OnnxModel, SessionOptions}; 6 7 pub struct VaeDecoder(Session); 8 9 impl OnnxModel for VaeDecoder { } 10 11 impl VaeDecoder { 12 pub fn new(path: &Path, opts: &SessionOptions) -> Result<Self, OnnxError> { 13 #[cfg(feature = "cuda")] 14 let cpu_only = opts.use_cpu_vae; 15 #[cfg(not(feature = "cuda"))] 16 let cpu_only = true; 17 Ok(Self(Self::load(&path.join("vae_decoder.onnx"), cpu_only, opts)?)) 18 } 19 20 pub fn decode(&mut self, latents: &Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, OnnxError> { 21 let outputs = self.0.run(inputs![ 22 Value::from_array(latents.to_owned())? 23 ])?; 24 25 let (shape, data) = outputs[0].try_extract_tensor::<f32>()?; 26 let dims: Vec<usize> = shape.iter().map(|&d| d as usize).collect(); 27 Ok(Array::from_shape_vec(IxDyn(&dims), data.to_vec())?) 28 } 29 }