/ src / gui / mod.rs
mod.rs
  1  use eframe::{CreationContext, Frame, NativeOptions};
  2  use egui::{CentralPanel, Color32, ColorImage, Context, DragValue, Grid, SidePanel, TextureHandle, TopBottomPanel, ViewportBuilder};
  3  use funko_diffusion::{FunkoDiffusionModel, Installer, Mode, SessionOptions};
  4  use image::DynamicImage;
  5  use rfd::FileDialog;
  6  use std::path::PathBuf;
  7  
  8  #[derive(PartialEq)]
  9  enum Tab {
 10      #[cfg(feature = "txt2img")]
 11      Txt2Img,
 12      #[cfg(feature = "img2img")]
 13      Img2Img,
 14      #[cfg(feature = "inpaint")]
 15      Inpaint,
 16      Settings,
 17  }
 18  
 19  struct App {
 20      tab: Tab,
 21      prompt: String,
 22      negative_prompt: String,
 23      input_path: Option<PathBuf>,
 24      mask_path: Option<PathBuf>,
 25      steps: usize,
 26      guidance_scale: f32,
 27      strength: f32,
 28      width: u32,
 29      height: u32,
 30      use_seed: bool,
 31      seed: u64,
 32      #[cfg(feature = "cuda")]
 33      use_cpu_vae: bool,
 34      #[cfg(feature = "cuda")]
 35      use_cpu_unet: bool,
 36      output_texture: Option<TextureHandle>,
 37      output_image: Option<DynamicImage>,
 38      model: FunkoDiffusionModel,
 39      generating: bool,
 40      error: Option<String>,
 41  }
 42  
 43  impl App {
 44      fn new(_cc: &CreationContext) -> Self {
 45          let installer = Installer::new("funko-diffusion");
 46          let opts = SessionOptions::default();
 47          let model = FunkoDiffusionModel::new(&installer, Mode::All, opts)
 48              .expect("Failed to load model");
 49  
 50          Self {
 51              #[cfg(feature = "txt2img")]
 52              tab: Tab::Txt2Img,
 53              #[cfg(all(not(feature = "txt2img"), feature = "img2img"))]
 54              tab: Tab::Img2Img,
 55              #[cfg(all(not(feature = "txt2img"), not(feature = "img2img"), feature = "inpaint"))]
 56              tab: Tab::Inpaint,
 57              #[cfg(all(not(feature = "txt2img"), not(feature = "img2img"), not(feature = "inpaint")))]
 58              tab: Tab::Settings,
 59              prompt: String::new(),
 60              negative_prompt: String::new(),
 61              input_path: None,
 62              mask_path: None,
 63              steps: 40,
 64              guidance_scale: 7.5,
 65              strength: 0.75,
 66              width: 1024,
 67              height: 1024,
 68              use_seed: false,
 69              seed: 0,
 70              #[cfg(feature = "cuda")]
 71              use_cpu_vae: false,
 72              #[cfg(feature = "cuda")]
 73              use_cpu_unet: false,
 74              output_texture: None,
 75              output_image: None,
 76              model,
 77              generating: false,
 78              error: None,
 79          }
 80      }
 81  
 82      fn reload_model(&mut self) {
 83          let installer = Installer::new("funko-diffusion");
 84          let opts = SessionOptions {
 85              intra_op_num_threads: None,
 86              #[cfg(feature = "cuda")]
 87              memory_pattern: true,
 88              #[cfg(feature = "cuda")]
 89              use_cpu_vae: self.use_cpu_vae,
 90              #[cfg(feature = "cuda")]
 91              use_cpu_unet: self.use_cpu_unet,
 92          };
 93          match FunkoDiffusionModel::new(&installer, Mode::All, opts) {
 94              Ok(model) => {
 95                  self.model = model;
 96                  self.error = None;
 97              }
 98              Err(e) => {
 99                  self.error = Some(format!("Failed to reload model: {e}"));
100              }
101          }
102      }
103  
104      fn generate(&mut self, ctx: &Context) {
105          self.generating = true;
106          self.error = None;
107  
108          let result: Result<DynamicImage, _> = match self.tab {
109              #[cfg(feature = "txt2img")]
110              Tab::Txt2Img => {
111                  use funko_diffusion::params::TextToImageParams;
112                  self.model.text_to_image(TextToImageParams {
113                      prompt: self.prompt.clone(),
114                      negative_prompt: if self.negative_prompt.is_empty() { None } else { Some(self.negative_prompt.clone()) },
115                      width: self.width,
116                      height: self.height,
117                      steps: self.steps,
118                      guidance_scale: self.guidance_scale,
119                      seed: if self.use_seed { Some(self.seed) } else { None },
120                  })
121              }
122  
123              #[cfg(feature = "img2img")]
124              Tab::Img2Img => {
125                  use funko_diffusion::params::ImageToImageParams;
126                  let input = image::open(self.input_path.as_ref().expect("No input image")).expect("Failed to open image");
127                  self.model.image_to_image(ImageToImageParams {
128                      prompt: self.prompt.clone(),
129                      negative_prompt: if self.negative_prompt.is_empty() { None } else { Some(self.negative_prompt.clone()) },
130                      image: input,
131                      width: Some(self.width),
132                      height: Some(self.height),
133                      strength: self.strength,
134                      steps: self.steps,
135                      guidance_scale: self.guidance_scale,
136                      seed: if self.use_seed { Some(self.seed) } else { None },
137                  })
138              }
139  
140              #[cfg(feature = "inpaint")]
141              Tab::Inpaint => {
142                  use funko_diffusion::params::InpaintParams;
143                  let input = image::open(self.input_path.as_ref().expect("No input image")).expect("Failed to open image");
144                  let mask  = image::open(self.mask_path.as_ref().expect("No mask image")).expect("Failed to open mask");
145                  self.model.inpaint(InpaintParams {
146                      prompt: self.prompt.clone(),
147                      negative_prompt: if self.negative_prompt.is_empty() { None } else { Some(self.negative_prompt.clone()) },
148                      image: input,
149                      mask,
150                      width: Some(self.width),
151                      height: Some(self.height),
152                      strength: self.strength,
153                      steps: self.steps,
154                      guidance_scale: self.guidance_scale,
155                      seed: if self.use_seed { Some(self.seed) } else { None },
156                  })
157              }
158  
159              Tab::Settings => return,
160          };
161  
162          match result {
163              Ok(image) => {
164                  let rgba = image.to_rgba8();
165                  let size = [rgba.width() as usize, rgba.height() as usize];
166                  let pixels = rgba.into_raw();
167                  let color_image = ColorImage::from_rgba_unmultiplied(size, &pixels);
168                  self.output_texture = Some(ctx.load_texture("output", color_image, Default::default()));
169                  self.output_image = Some(image);
170              }
171              Err(e) => {
172                  self.error = Some(e.to_string());
173              }
174          }
175  
176          self.generating = false;
177      }
178  }
179  
180  impl eframe::App for App {
181      fn update(&mut self, ctx: &Context, _frame: &mut Frame) {
182          TopBottomPanel::top("mode_bar").show(ctx, |ui| {
183              ui.horizontal(|ui| {
184                  ui.heading("Funko Diffusion");
185                  ui.separator();
186                  #[cfg(feature = "txt2img")]
187                  ui.selectable_value(&mut self.tab, Tab::Txt2Img, "Txt2Img");
188                  #[cfg(feature = "img2img")]
189                  ui.selectable_value(&mut self.tab, Tab::Img2Img, "Img2Img");
190                  #[cfg(feature = "inpaint")]
191                  ui.selectable_value(&mut self.tab, Tab::Inpaint, "Inpaint");
192                  ui.selectable_value(&mut self.tab, Tab::Settings, "Settings");
193              });
194          });
195  
196          SidePanel::right("output_panel").min_width(512.0).show(ctx, |ui| {
197              ui.heading("Output");
198              if let Some(image) = &self.output_image {
199                  if ui.button("Save").clicked() {
200                      if let Some(path) = FileDialog::new()
201                          .add_filter("PNG", &["png"])
202                          .add_filter("JPEG", &["jpg", "jpeg"])
203                          .add_filter("WebP", &["webp"])
204                          .set_file_name("output.png")
205                          .save_file()
206                      {
207                          if let Err(e) = image.save(&path) {
208                              self.error = Some(format!("Failed to save: {e}"));
209                          }
210                      }
211                  }
212              }
213              ui.separator();
214              if let Some(texture) = &self.output_texture {
215                  let size = texture.size_vec2();
216                  let available = ui.available_size();
217                  let scale = (available.x / size.x).min(available.y / size.y).min(1.0);
218                  ui.image((texture.id(), size * scale));
219              } else {
220                  ui.centered_and_justified(|ui| {
221                      ui.label("Output will appear here");
222                  });
223              }
224          });
225  
226          CentralPanel::default().show(ctx, |ui| {
227              if self.tab == Tab::Settings {
228                  ui.heading("Session Settings");
229                  ui.label("Changes take effect after reloading the model.");
230                  ui.separator();
231  
232                  #[cfg(feature = "cuda")]
233                  Grid::new("settings")
234                      .num_columns(2)
235                      .spacing([8.0, 8.0])
236                      .show(ui, |ui| {
237                          ui.label("CPU VAE");
238                          ui.checkbox(&mut self.use_cpu_vae, "Run VAE on CPU (reduces VRAM usage)");
239                          ui.end_row();
240  
241                          ui.label("CPU UNet");
242                          ui.checkbox(&mut self.use_cpu_unet, "Run UNet on CPU/fp32 (no VRAM, much slower)");
243                          ui.end_row();
244                      });
245  
246                  ui.separator();
247                  if ui.button("Reload Model").clicked() {
248                      self.reload_model();
249                  }
250                  if let Some(error) = &self.error {
251                      ui.colored_label(Color32::RED, error);
252                  }
253                  return;
254              }
255  
256              Grid::new("params")
257                  .num_columns(2)
258                  .spacing([8.0, 8.0])
259                  .show(ui, |ui| {
260                      ui.label("Prompt");
261                      ui.text_edit_singleline(&mut self.prompt);
262                      ui.end_row();
263  
264                      ui.label("Negative Prompt");
265                      ui.text_edit_singleline(&mut self.negative_prompt);
266                      ui.end_row();
267  
268                      #[cfg(any(feature = "img2img", feature = "inpaint"))]
269                      if self.tab == Tab::Img2Img || self.tab == Tab::Inpaint {
270                          ui.label("Input Image");
271                          ui.horizontal(|ui| {
272                              if ui.button("Browse...").clicked() {
273                                  if let Some(path) = FileDialog::new()
274                                      .add_filter("Images", &["png", "jpg", "jpeg", "webp"])
275                                      .pick_file()
276                                  {
277                                      self.input_path = Some(path);
278                                  }
279                              }
280                              if let Some(path) = &self.input_path {
281                                  ui.label(path.file_name().unwrap_or_default().to_string_lossy().as_ref());
282                              }
283                          });
284                          ui.end_row();
285                      }
286  
287                      #[cfg(feature = "inpaint")]
288                      if self.tab == Tab::Inpaint {
289                          ui.label("Mask");
290                          ui.horizontal(|ui| {
291                              if ui.button("Browse...").clicked() {
292                                  if let Some(path) = FileDialog::new()
293                                      .add_filter("Images", &["png", "jpg", "jpeg", "webp"])
294                                      .pick_file()
295                                  {
296                                      self.mask_path = Some(path);
297                                  }
298                              }
299                              if let Some(path) = &self.mask_path {
300                                  ui.label(path.file_name().unwrap_or_default().to_string_lossy().as_ref());
301                              }
302                          });
303                          ui.end_row();
304                      }
305  
306                      ui.label("Steps");
307                      ui.add(DragValue::new(&mut self.steps).range(1..=150));
308                      ui.end_row();
309  
310                      ui.label("Guidance Scale");
311                      ui.add(DragValue::new(&mut self.guidance_scale).range(1.0..=20.0).speed(0.1));
312                      ui.end_row();
313  
314                      #[cfg(any(feature = "img2img", feature = "inpaint"))]
315                      if self.tab == Tab::Img2Img || self.tab == Tab::Inpaint {
316                          ui.label("Strength");
317                          ui.add(DragValue::new(&mut self.strength).range(0.0..=1.0).speed(0.01));
318                          ui.end_row();
319                      }
320  
321                      ui.label("Width");
322                      ui.add(DragValue::new(&mut self.width).range(512..=2048));
323                      ui.end_row();
324  
325                      ui.label("Height");
326                      ui.add(DragValue::new(&mut self.height).range(512..=2048));
327                      ui.end_row();
328  
329                      ui.label("Seed");
330                      ui.horizontal(|ui| {
331                          ui.checkbox(&mut self.use_seed, "");
332                          ui.add_enabled(
333                              self.use_seed,
334                              DragValue::new(&mut self.seed).range(0u64..=u64::MAX),
335                          );
336                      });
337                      ui.end_row();
338                  });
339  
340              ui.separator();
341  
342              if let Some(error) = &self.error {
343                  ui.colored_label(Color32::RED, error);
344              }
345  
346              ui.add_enabled_ui(!self.generating, |ui| {
347                  if ui.button(if self.generating { "Generating..." } else { "Generate" }).clicked() {
348                      self.generate(ctx);
349                  }
350              });
351          });
352      }
353  }
354  
355  pub fn run() {
356      let options = NativeOptions {
357          viewport: ViewportBuilder::default()
358              .with_title("Funko Diffusion")
359              .with_inner_size([1200.0, 700.0]),
360          ..Default::default()
361      };
362  
363      eframe::run_native(
364          "Funko Diffusion",
365          options,
366          Box::new(|cc| Ok(Box::new(App::new(cc)))),
367      ).expect("Failed to start GUI");
368  }