/ src / install / mod.rs
mod.rs
 1  mod errors;
 2  pub use errors::InstallError;
 3  
 4  use log::info;
 5  use std::{fs, path::PathBuf};
 6  
 7  const BASE_URL: &str = "https://huggingface.co/r0r-5chach/funko-diffusion/resolve/main";
 8  
 9  fn download(url: &str, dest: &PathBuf) -> Result<(), InstallError> {
10      info!("Downloading {}...", url);
11      let bytes = reqwest::blocking::get(url)?.bytes()?;
12      fs::write(dest, bytes)?;
13      Ok(())
14  }
15  
16  pub struct Installer {
17      app_name: String,
18  }
19  
20  impl Installer {
21      pub fn new(app_name: impl Into<String>) -> Self {
22          Self { app_name: app_name.into(), }
23      }
24  
25      pub fn data_dir(&self) -> Result<PathBuf, InstallError> {
26          Ok(dirs::data_dir()
27              .ok_or(InstallError::NoDataDir)?
28              .join(&self.app_name)
29              .join("funko-diffusion"))
30      }
31  
32      pub fn is_installed(&self) -> bool {
33          match self.data_dir() {
34              Ok(dir) => dir.join("text_encoders.onnx").exists()
35                  && dir.join("text_encoders.onnx.data").exists()
36                  && {
37                      #[cfg(not(feature = "cuda"))] {
38                          dir.join("unet.onnx").exists() && dir.join("unet.onnx.data").exists()
39                      }
40                      #[cfg(feature = "cuda")] {
41                          dir.join("unet_fp16.onnx").exists() && dir.join("unet_fp16.onnx.data").exists()
42                      }
43                  }
44                  && dir.join("vae_encoder.onnx").exists()
45                  && dir.join("vae_decoder.onnx").exists()
46                  && dir.join("tokenizer.json").exists()
47                  && dir.join("tokenizer_2.json").exists(),
48              Err(_) => false,
49          }
50      }
51  
52      pub fn install(&self) -> Result<PathBuf, InstallError> {
53          let dir = self.data_dir()?;
54          fs::create_dir_all(&dir)?;
55  
56          info!("Installing funko-diffusion to {}...", dir.display());
57  
58          download(&format!("{BASE_URL}/tokenizer.json"), &dir.join("tokenizer.json"))?;
59          download(&format!("{BASE_URL}/tokenizer_2.json"), &dir.join("tokenizer_2.json"))?;
60          download(&format!("{BASE_URL}/text_encoders.onnx"), &dir.join("text_encoders.onnx"))?;
61          download(&format!("{BASE_URL}/text_encoders.onnx.data"), &dir.join("text_encoders.onnx.data"))?;
62  
63          // CPU: fp32 UNet, data file is unet.onnx.data.
64          // CUDA: fp16 UNet; the external data file name is embedded in the model header as
65          //       unet_fp16.onnx.data — ORT looks for it by that name alongside unet.onnx.
66          #[cfg(not(feature = "cuda"))] {
67              download(&format!("{BASE_URL}/unet.onnx"), &dir.join("unet.onnx"))?;
68              download(&format!("{BASE_URL}/unet.onnx.data"), &dir.join("unet.onnx.data"))?;
69          }
70          #[cfg(feature = "cuda")] {
71              download(&format!("{BASE_URL}/unet_fp16.onnx"), &dir.join("unet_fp16.onnx"))?;
72              download(&format!("{BASE_URL}/unet_fp16.onnx.data"), &dir.join("unet_fp16.onnx.data"))?;
73          }
74  
75          download(&format!("{BASE_URL}/vae_encoder.onnx"), &dir.join("vae_encoder.onnx"))?;
76          download(&format!("{BASE_URL}/vae_decoder.onnx"), &dir.join("vae_decoder.onnx"))?;
77  
78          info!("Installation complete");
79          Ok(dir)
80      }
81  }