mod.rs
1 mod errors; 2 pub use errors::InstallError; 3 4 use log::info; 5 use std::{fs, path::PathBuf}; 6 7 const BASE_URL: &str = "https://huggingface.co/r0r-5chach/funko-diffusion/resolve/main"; 8 9 fn download(url: &str, dest: &PathBuf) -> Result<(), InstallError> { 10 info!("Downloading {}...", url); 11 let bytes = reqwest::blocking::get(url)?.bytes()?; 12 fs::write(dest, bytes)?; 13 Ok(()) 14 } 15 16 pub struct Installer { 17 app_name: String, 18 } 19 20 impl Installer { 21 pub fn new(app_name: impl Into<String>) -> Self { 22 Self { app_name: app_name.into(), } 23 } 24 25 pub fn data_dir(&self) -> Result<PathBuf, InstallError> { 26 Ok(dirs::data_dir() 27 .ok_or(InstallError::NoDataDir)? 28 .join(&self.app_name) 29 .join("funko-diffusion")) 30 } 31 32 pub fn is_installed(&self) -> bool { 33 match self.data_dir() { 34 Ok(dir) => dir.join("text_encoders.onnx").exists() 35 && dir.join("text_encoders.onnx.data").exists() 36 && { 37 #[cfg(not(feature = "cuda"))] { 38 dir.join("unet.onnx").exists() && dir.join("unet.onnx.data").exists() 39 } 40 #[cfg(feature = "cuda")] { 41 dir.join("unet_fp16.onnx").exists() && dir.join("unet_fp16.onnx.data").exists() 42 } 43 } 44 && dir.join("vae_encoder.onnx").exists() 45 && dir.join("vae_decoder.onnx").exists() 46 && dir.join("tokenizer.json").exists() 47 && dir.join("tokenizer_2.json").exists(), 48 Err(_) => false, 49 } 50 } 51 52 pub fn install(&self) -> Result<PathBuf, InstallError> { 53 let dir = self.data_dir()?; 54 fs::create_dir_all(&dir)?; 55 56 info!("Installing funko-diffusion to {}...", dir.display()); 57 58 download(&format!("{BASE_URL}/tokenizer.json"), &dir.join("tokenizer.json"))?; 59 download(&format!("{BASE_URL}/tokenizer_2.json"), &dir.join("tokenizer_2.json"))?; 60 download(&format!("{BASE_URL}/text_encoders.onnx"), &dir.join("text_encoders.onnx"))?; 61 download(&format!("{BASE_URL}/text_encoders.onnx.data"), &dir.join("text_encoders.onnx.data"))?; 62 63 // CPU: fp32 UNet, data file is unet.onnx.data. 64 // CUDA: fp16 UNet; the external data file name is embedded in the model header as 65 // unet_fp16.onnx.data — ORT looks for it by that name alongside unet.onnx. 66 #[cfg(not(feature = "cuda"))] { 67 download(&format!("{BASE_URL}/unet.onnx"), &dir.join("unet.onnx"))?; 68 download(&format!("{BASE_URL}/unet.onnx.data"), &dir.join("unet.onnx.data"))?; 69 } 70 #[cfg(feature = "cuda")] { 71 download(&format!("{BASE_URL}/unet_fp16.onnx"), &dir.join("unet_fp16.onnx"))?; 72 download(&format!("{BASE_URL}/unet_fp16.onnx.data"), &dir.join("unet_fp16.onnx.data"))?; 73 } 74 75 download(&format!("{BASE_URL}/vae_encoder.onnx"), &dir.join("vae_encoder.onnx"))?; 76 download(&format!("{BASE_URL}/vae_decoder.onnx"), &dir.join("vae_decoder.onnx"))?; 77 78 info!("Installation complete"); 79 Ok(dir) 80 } 81 }