mod.rs
1 #[cfg(all(feature = "watch", feature = "async"))] 2 use std::path::{Path, PathBuf}; 3 4 #[cfg(all(feature = "watch", feature = "async"))] 5 use compact_str::CompactString; 6 7 #[cfg(all(feature = "watch", feature = "async"))] 8 use notify::{Event, EventKind, RecursiveMode, Watcher}; 9 10 #[cfg(all(feature = "watch", feature = "async"))] 11 use parking_lot::Mutex; 12 13 #[cfg(all(feature = "watch", feature = "async"))] 14 use std::collections::HashMap; 15 16 #[cfg(all(feature = "watch", feature = "async"))] 17 use std::sync::Arc; 18 19 #[derive(Debug, Clone, PartialEq, Eq)] 20 #[cfg(all(feature = "watch", feature = "async"))] 21 pub struct FileChanged { 22 pub path: PathBuf, 23 pub kind: ChangeKind, 24 } 25 26 #[derive(Debug, Clone, Copy, PartialEq, Eq)] 27 #[cfg(all(feature = "watch", feature = "async"))] 28 pub enum ChangeKind { 29 Created, 30 Modified, 31 Deleted, 32 } 33 34 #[cfg(all(feature = "watch", feature = "async"))] 35 pub type WatchCallback = Arc<dyn Fn(FileChanged) + Send + Sync>; 36 37 #[cfg(all(feature = "watch", feature = "async"))] 38 pub struct FileWatcher { 39 #[allow(dead_code)] 40 watcher: Arc<Mutex<notify::RecommendedWatcher>>, 41 paths: Arc<Mutex<HashMap<PathBuf, CompactString>>>, 42 callbacks: Arc<Mutex<Vec<WatchCallback>>>, 43 } 44 45 #[cfg(all(feature = "watch", feature = "async"))] 46 impl FileWatcher { 47 pub fn new() -> Result<Self, notify::Error> { 48 let paths = Arc::new(Mutex::new(HashMap::new())); 49 let callbacks = Arc::new(Mutex::new(Vec::<WatchCallback>::new())); 50 let paths_clone = Arc::clone(&paths); 51 let callbacks_clone = Arc::clone(&callbacks); 52 53 let watcher: notify::RecommendedWatcher = notify::recommended_watcher(move |res: Result<Event, _>| { 54 if let Ok(event) = res { 55 for path in event.paths { 56 let canonical = path.canonicalize().unwrap_or_else(|_| path.clone()); 57 58 let source_id = { 59 let paths = paths_clone.lock(); 60 paths.get(&canonical).cloned() 61 }; 62 63 if source_id.is_none() { 64 continue; 65 } 66 67 let kind = match event.kind { 68 EventKind::Create(_) => ChangeKind::Created, 69 EventKind::Modify(_) => ChangeKind::Modified, 70 EventKind::Remove(_) => ChangeKind::Deleted, 71 _ => continue, 72 }; 73 74 let change = FileChanged { path, kind }; 75 76 let callbacks = callbacks_clone.lock(); 77 for callback in callbacks.iter() { 78 callback(change.clone()); 79 } 80 } 81 } 82 })?; 83 84 Ok(Self { 85 watcher: Arc::new(Mutex::new(watcher)), 86 paths, 87 callbacks, 88 }) 89 } 90 91 pub fn watch(&self, path: impl AsRef<Path>, source_id: impl Into<CompactString>) { 92 let path = path.as_ref().canonicalize().unwrap_or_else(|_| path.as_ref().to_path_buf()); 93 if let Err(e) = self.watcher.lock().watch(&path, RecursiveMode::NonRecursive) { 94 tracing::warn!(path = %path.display(), error = %e, "Failed to watch path"); 95 } 96 self.paths.lock().insert(path, source_id.into()); 97 } 98 99 pub fn unwatch(&self, path: impl AsRef<Path>) { 100 let path = path.as_ref().canonicalize().unwrap_or_else(|_| path.as_ref().to_path_buf()); 101 if let Err(e) = self.watcher.lock().unwatch(&path) { 102 tracing::warn!(path = %path.display(), error = %e, "Failed to unwatch path"); 103 } 104 self.paths.lock().remove(&path); 105 } 106 107 pub fn register_callback(&self, callback: WatchCallback) { 108 self.callbacks.lock().push(callback); 109 } 110 111 pub fn paths(&self) -> Vec<PathBuf> { 112 self.paths.lock().keys().cloned().collect() 113 } 114 115 pub fn is_watching(&self, path: impl AsRef<Path>) -> bool { 116 let path = path.as_ref().canonicalize().unwrap_or_else(|_| path.as_ref().to_path_buf()); 117 self.paths.lock().contains_key(&path) 118 } 119 } 120 121 #[cfg(all(feature = "watch", feature = "async"))] 122 impl Default for FileWatcher { 123 fn default() -> Self { 124 Self::new().expect("Failed to create file watcher") 125 } 126 } 127 128 #[cfg(all(test, feature = "watch", feature = "async"))] 129 mod tests { 130 use super::*; 131 use tempfile::TempDir; 132 use std::time::Duration; 133 134 #[tokio::test] 135 async fn test_file_watcher() { 136 let temp_dir = TempDir::new().unwrap(); 137 let watcher = FileWatcher::new().unwrap(); 138 139 let test_file = temp_dir.path().join("test.env"); 140 watcher.watch(&test_file, "test-source"); 141 142 assert!(watcher.is_watching(&test_file)); 143 assert_eq!(watcher.paths().len(), 1); 144 145 watcher.unwatch(&test_file); 146 assert!(!watcher.is_watching(&test_file)); 147 } 148 149 #[tokio::test] 150 async fn test_callback_registration() { 151 let temp_dir = TempDir::new().unwrap(); 152 let watcher = FileWatcher::new().unwrap(); 153 154 let test_file = temp_dir.path().join("test.env"); 155 watcher.watch(&test_file, "test-source"); 156 157 let callback_called = Arc::new(std::sync::atomic::AtomicBool::new(false)); 158 let callback_clone = Arc::clone(&callback_called); 159 160 watcher.register_callback(Arc::new(move |_change| { 161 callback_clone.store(true, std::sync::atomic::Ordering::SeqCst); 162 })); 163 164 std::fs::write(&test_file, "TEST=value").unwrap(); 165 tokio::time::sleep(Duration::from_millis(100)).await; 166 167 tokio::time::timeout(Duration::from_millis(500), async { 168 while !callback_called.load(std::sync::atomic::Ordering::SeqCst) { 169 tokio::time::sleep(Duration::from_millis(10)).await; 170 } 171 }) 172 .await 173 .ok(); 174 } 175 }