/ src / watch / mod.rs
mod.rs
  1  #[cfg(all(feature = "watch", feature = "async"))]
  2  use std::path::{Path, PathBuf};
  3  
  4  #[cfg(all(feature = "watch", feature = "async"))]
  5  use compact_str::CompactString;
  6  
  7  #[cfg(all(feature = "watch", feature = "async"))]
  8  use notify::{Event, EventKind, RecursiveMode, Watcher};
  9  
 10  #[cfg(all(feature = "watch", feature = "async"))]
 11  use parking_lot::Mutex;
 12  
 13  #[cfg(all(feature = "watch", feature = "async"))]
 14  use std::collections::HashMap;
 15  
 16  #[cfg(all(feature = "watch", feature = "async"))]
 17  use std::sync::Arc;
 18  
 19  #[derive(Debug, Clone, PartialEq, Eq)]
 20  #[cfg(all(feature = "watch", feature = "async"))]
 21  pub struct FileChanged {
 22      pub path: PathBuf,
 23      pub kind: ChangeKind,
 24  }
 25  
 26  #[derive(Debug, Clone, Copy, PartialEq, Eq)]
 27  #[cfg(all(feature = "watch", feature = "async"))]
 28  pub enum ChangeKind {
 29      Created,
 30      Modified,
 31      Deleted,
 32  }
 33  
 34  #[cfg(all(feature = "watch", feature = "async"))]
 35  pub type WatchCallback = Arc<dyn Fn(FileChanged) + Send + Sync>;
 36  
 37  #[cfg(all(feature = "watch", feature = "async"))]
 38  pub struct FileWatcher {
 39      #[allow(dead_code)]
 40      watcher: Arc<Mutex<notify::RecommendedWatcher>>,
 41      paths: Arc<Mutex<HashMap<PathBuf, CompactString>>>,
 42      callbacks: Arc<Mutex<Vec<WatchCallback>>>,
 43  }
 44  
 45  #[cfg(all(feature = "watch", feature = "async"))]
 46  impl FileWatcher {
 47      pub fn new() -> Result<Self, notify::Error> {
 48          let paths = Arc::new(Mutex::new(HashMap::new()));
 49          let callbacks = Arc::new(Mutex::new(Vec::<WatchCallback>::new()));
 50          let paths_clone = Arc::clone(&paths);
 51          let callbacks_clone = Arc::clone(&callbacks);
 52  
 53          let watcher: notify::RecommendedWatcher = notify::recommended_watcher(move |res: Result<Event, _>| {
 54              if let Ok(event) = res {
 55                  for path in event.paths {
 56                      let canonical = path.canonicalize().unwrap_or_else(|_| path.clone());
 57  
 58                      let source_id = {
 59                          let paths = paths_clone.lock();
 60                          paths.get(&canonical).cloned()
 61                      };
 62  
 63                      if source_id.is_none() {
 64                          continue;
 65                      }
 66  
 67                      let kind = match event.kind {
 68                          EventKind::Create(_) => ChangeKind::Created,
 69                          EventKind::Modify(_) => ChangeKind::Modified,
 70                          EventKind::Remove(_) => ChangeKind::Deleted,
 71                          _ => continue,
 72                      };
 73  
 74                      let change = FileChanged { path, kind };
 75  
 76                      let callbacks = callbacks_clone.lock();
 77                      for callback in callbacks.iter() {
 78                          callback(change.clone());
 79                      }
 80                  }
 81              }
 82          })?;
 83  
 84          Ok(Self {
 85              watcher: Arc::new(Mutex::new(watcher)),
 86              paths,
 87              callbacks,
 88          })
 89      }
 90  
 91      pub fn watch(&self, path: impl AsRef<Path>, source_id: impl Into<CompactString>) {
 92          let path = path.as_ref().canonicalize().unwrap_or_else(|_| path.as_ref().to_path_buf());
 93          if let Err(e) = self.watcher.lock().watch(&path, RecursiveMode::NonRecursive) {
 94              tracing::warn!(path = %path.display(), error = %e, "Failed to watch path");
 95          }
 96          self.paths.lock().insert(path, source_id.into());
 97      }
 98  
 99      pub fn unwatch(&self, path: impl AsRef<Path>) {
100          let path = path.as_ref().canonicalize().unwrap_or_else(|_| path.as_ref().to_path_buf());
101          if let Err(e) = self.watcher.lock().unwatch(&path) {
102              tracing::warn!(path = %path.display(), error = %e, "Failed to unwatch path");
103          }
104          self.paths.lock().remove(&path);
105      }
106  
107      pub fn register_callback(&self, callback: WatchCallback) {
108          self.callbacks.lock().push(callback);
109      }
110  
111      pub fn paths(&self) -> Vec<PathBuf> {
112          self.paths.lock().keys().cloned().collect()
113      }
114  
115      pub fn is_watching(&self, path: impl AsRef<Path>) -> bool {
116          let path = path.as_ref().canonicalize().unwrap_or_else(|_| path.as_ref().to_path_buf());
117          self.paths.lock().contains_key(&path)
118      }
119  }
120  
121  #[cfg(all(feature = "watch", feature = "async"))]
122  impl Default for FileWatcher {
123      fn default() -> Self {
124          Self::new().expect("Failed to create file watcher")
125      }
126  }
127  
128  #[cfg(all(test, feature = "watch", feature = "async"))]
129  mod tests {
130      use super::*;
131      use tempfile::TempDir;
132      use std::time::Duration;
133  
134      #[tokio::test]
135      async fn test_file_watcher() {
136          let temp_dir = TempDir::new().unwrap();
137          let watcher = FileWatcher::new().unwrap();
138  
139          let test_file = temp_dir.path().join("test.env");
140          watcher.watch(&test_file, "test-source");
141  
142          assert!(watcher.is_watching(&test_file));
143          assert_eq!(watcher.paths().len(), 1);
144  
145          watcher.unwatch(&test_file);
146          assert!(!watcher.is_watching(&test_file));
147      }
148  
149      #[tokio::test]
150      async fn test_callback_registration() {
151          let temp_dir = TempDir::new().unwrap();
152          let watcher = FileWatcher::new().unwrap();
153  
154          let test_file = temp_dir.path().join("test.env");
155          watcher.watch(&test_file, "test-source");
156  
157          let callback_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
158          let callback_clone = Arc::clone(&callback_called);
159  
160          watcher.register_callback(Arc::new(move |_change| {
161              callback_clone.store(true, std::sync::atomic::Ordering::SeqCst);
162          }));
163  
164          std::fs::write(&test_file, "TEST=value").unwrap();
165          tokio::time::sleep(Duration::from_millis(100)).await;
166  
167          tokio::time::timeout(Duration::from_millis(500), async {
168              while !callback_called.load(std::sync::atomic::Ordering::SeqCst) {
169                  tokio::time::sleep(Duration::from_millis(10)).await;
170              }
171          })
172          .await
173          .ok();
174      }
175  }