registry.rs
1 use crate::error::SourceError; 2 use crate::source::traits::*; 3 use compact_str::CompactString; 4 use hashbrown::HashMap; 5 use parking_lot::RwLock; 6 use std::sync::Arc; 7 8 #[cfg(feature = "remote")] 9 use crate::source::remote::ExternalProviderAdapter; 10 11 pub struct SourceRegistry { 12 sync_sources: RwLock<HashMap<SourceId, Arc<dyn EnvSource>>>, 13 #[cfg(feature = "async")] 14 async_sources: RwLock<HashMap<SourceId, Arc<dyn AsyncEnvSource>>>, 15 path_index: RwLock<HashMap<std::path::PathBuf, SourceId>>, 16 factories: RwLock<HashMap<CompactString, Arc<dyn SourceFactory>>>, 17 /// External provider adapters (out-of-process providers). 18 #[cfg(feature = "remote")] 19 external_providers: RwLock<HashMap<String, Arc<ExternalProviderAdapter>>>, 20 } 21 22 impl SourceRegistry { 23 pub fn new() -> Self { 24 let mut factories: HashMap<CompactString, Arc<dyn SourceFactory>> = HashMap::new(); 25 26 #[cfg(feature = "file")] 27 factories.insert( 28 CompactString::new("file"), 29 Arc::new(FileSourceFactory) as Arc<dyn SourceFactory>, 30 ); 31 #[cfg(feature = "shell")] 32 factories.insert( 33 CompactString::new("shell"), 34 Arc::new(ShellSourceFactory) as Arc<dyn SourceFactory>, 35 ); 36 factories.insert( 37 CompactString::new("memory"), 38 Arc::new(MemorySourceFactory) as Arc<dyn SourceFactory>, 39 ); 40 41 Self { 42 sync_sources: RwLock::new(HashMap::new()), 43 #[cfg(feature = "async")] 44 async_sources: RwLock::new(HashMap::new()), 45 path_index: RwLock::new(HashMap::new()), 46 factories: RwLock::new(factories), 47 #[cfg(feature = "remote")] 48 external_providers: RwLock::new(HashMap::new()), 49 } 50 } 51 52 pub fn register_factory<F: SourceFactory + 'static>(&self, source_type: &str, factory: F) { 53 self.factories 54 .write() 55 .insert(CompactString::new(source_type), Arc::new(factory)); 56 } 57 58 pub fn register_sync(&self, source: Arc<dyn EnvSource>) -> SourceId { 59 let id = source.id().clone(); 60 self.sync_sources.write().insert(id.clone(), source.clone()); 61 62 if source.source_type() == SourceType::File { 63 if let Some(path) = id.as_str().strip_prefix("file:") { 64 let path_buf = std::path::PathBuf::from(path); 65 self.path_index.write().insert(path_buf, id.clone()); 66 } 67 } 68 69 id 70 } 71 72 #[cfg(feature = "async")] 73 pub fn register_async(&self, source: Arc<dyn AsyncEnvSource>) -> SourceId { 74 let id = source.id().clone(); 75 self.async_sources.write().insert(id.clone(), source); 76 id 77 } 78 79 pub fn sync_sources_by_priority(&self) -> Vec<Arc<dyn EnvSource>> { 80 let sources = self.sync_sources.read(); 81 let mut sorted: Vec<_> = sources.values().cloned().collect(); 82 sorted.sort_by_key(|a| std::cmp::Reverse(a.priority())); 83 sorted 84 } 85 86 #[cfg(feature = "async")] 87 pub fn async_sources(&self) -> Vec<Arc<dyn AsyncEnvSource>> { 88 self.async_sources.read().values().cloned().collect() 89 } 90 91 #[cfg(feature = "async")] 92 pub fn has_async_sources(&self) -> bool { 93 !self.async_sources.read().is_empty() 94 } 95 96 #[cfg(feature = "async")] 97 pub async fn load_all(&self) -> Result<Vec<SourceSnapshot>, SourceError> { 98 let snapshots = { 99 let mut snapshots = Vec::new(); 100 let sources_guard = self.sync_sources.read(); 101 for (_id, source) in sources_guard.iter() { 102 let snapshot = source.load()?; 103 snapshots.push(snapshot); 104 } 105 snapshots 106 }; 107 let mut snapshots = snapshots; 108 109 if self.has_async_sources() { 110 let async_sources = self.async_sources.read().clone(); 111 let futures: Vec<_> = async_sources.values().map(|s| s.load()).collect(); 112 113 let results = futures::future::join_all(futures).await; 114 for result in results { 115 snapshots.push(result?); 116 } 117 } 118 119 Ok(snapshots) 120 } 121 122 #[cfg(feature = "async")] 123 pub async fn refresh_async(&self) -> Result<(), SourceError> { 124 let async_sources = self.async_sources.read().clone(); 125 let futures: Vec<_> = async_sources.values().map(|s| s.refresh()).collect(); 126 127 let results = futures::future::try_join_all(futures).await?; 128 for result in results { 129 if result { 130 tracing::info!("Async source refreshed"); 131 } 132 } 133 134 Ok(()) 135 } 136 137 pub fn sources_of_type(&self, source_type: SourceType) -> Vec<Arc<dyn EnvSource>> { 138 let sources = self.sync_sources.read(); 139 sources 140 .values() 141 .filter(|s| s.source_type() == source_type) 142 .cloned() 143 .collect() 144 } 145 146 pub fn sources_for_paths(&self, paths: &[std::path::PathBuf]) -> Vec<Arc<dyn EnvSource>> { 147 let sources = self.sync_sources.read(); 148 let path_index = self.path_index.read(); 149 let mut result = Vec::new(); 150 let mut seen_ids = std::collections::HashSet::new(); 151 152 for path in paths { 153 if let Some(source_id) = path_index.get(path) { 154 if !seen_ids.contains(source_id) { 155 if let Some(source) = sources.get(source_id) { 156 result.push(Arc::clone(source)); 157 seen_ids.insert(source_id.clone()); 158 } 159 } 160 } 161 } 162 163 result 164 } 165 166 pub fn invalidate_file(&self, _path: &std::path::Path) { 167 for source in self.sync_sources.read().values() { 168 if source.source_type() == SourceType::File { 169 source.invalidate(); 170 } 171 } 172 } 173 174 pub fn is_registered(&self, id: &SourceId) -> bool { 175 self.sync_sources.read().contains_key(id) 176 } 177 178 pub fn unregister_sync(&self, id: &SourceId) { 179 self.sync_sources.write().remove(id); 180 181 if let Some(path) = id.as_str().strip_prefix("file:") { 182 let path_buf = std::path::PathBuf::from(path); 183 self.path_index.write().remove(&path_buf); 184 } 185 } 186 187 pub fn registered_file_paths(&self) -> Vec<std::path::PathBuf> { 188 self.path_index.read().keys().cloned().collect() 189 } 190 191 pub fn source_count(&self) -> usize { 192 let count = self.sync_sources.read().len(); 193 #[cfg(feature = "async")] 194 let count = count + self.async_sources.read().len(); 195 count 196 } 197 } 198 199 // External provider methods (out-of-process providers) 200 #[cfg(feature = "remote")] 201 impl SourceRegistry { 202 /// Registers an external provider adapter. 203 /// 204 /// Also registers it as an async source for load_all(). 205 pub fn register_external_provider(&self, adapter: Arc<ExternalProviderAdapter>) { 206 let provider_id = adapter.provider_id().to_string(); 207 let source_id = adapter.id().clone(); 208 209 self.external_providers.write().insert(provider_id, Arc::clone(&adapter)); 210 self.async_sources.write().insert(source_id, adapter); 211 } 212 213 /// Gets an external provider by ID. 214 pub fn get_external_provider(&self, provider_id: &str) -> Option<Arc<ExternalProviderAdapter>> { 215 self.external_providers.read().get(provider_id).cloned() 216 } 217 218 /// Lists all registered external providers. 219 pub fn external_providers(&self) -> Vec<Arc<ExternalProviderAdapter>> { 220 self.external_providers.read().values().cloned().collect() 221 } 222 223 /// Returns the number of registered external providers. 224 pub fn external_provider_count(&self) -> usize { 225 self.external_providers.read().len() 226 } 227 228 /// Unregisters an external provider by ID. 229 pub fn unregister_external_provider(&self, provider_id: &str) { 230 if let Some(adapter) = self.external_providers.write().remove(provider_id) { 231 let source_id = adapter.id().clone(); 232 self.async_sources.write().remove(&source_id); 233 } 234 } 235 236 /// Lists registered external provider IDs. 237 pub fn external_provider_ids(&self) -> Vec<String> { 238 self.external_providers.read().keys().cloned().collect() 239 } 240 } 241 242 impl Default for SourceRegistry { 243 fn default() -> Self { 244 Self::new() 245 } 246 } 247 248 #[cfg(not(feature = "async"))] 249 impl SourceRegistry { 250 pub fn has_async_sources(&self) -> bool { 251 false 252 } 253 254 pub fn load_all(&self) -> Result<Vec<SourceSnapshot>, SourceError> { 255 let mut snapshots = Vec::new(); 256 for source in self.sync_sources.read().values() { 257 snapshots.push(source.load()?); 258 } 259 Ok(snapshots) 260 } 261 } 262 263 pub trait SourceFactory: Send + Sync { 264 fn create(&self, config: &SourceConfig) -> Result<Arc<dyn EnvSource>, SourceError>; 265 fn source_type(&self) -> &'static str; 266 } 267 268 #[derive(Debug, Clone)] 269 pub struct SourceConfig { 270 pub source_type: String, 271 pub path: Option<std::path::PathBuf>, 272 pub enabled: bool, 273 } 274 275 struct FileSourceFactory; 276 impl SourceFactory for FileSourceFactory { 277 fn create(&self, config: &SourceConfig) -> Result<Arc<dyn EnvSource>, SourceError> { 278 if let Some(path) = &config.path { 279 crate::source::file::FileSource::new(path) 280 .map(|s| Arc::new(s) as Arc<dyn EnvSource>) 281 .map_err(|e| SourceError::SourceRead { 282 source_name: path.display().to_string(), 283 reason: e.to_string(), 284 }) 285 } else { 286 Err(SourceError::SourceRead { 287 source_name: "file".into(), 288 reason: "No path specified".into(), 289 }) 290 } 291 } 292 293 fn source_type(&self) -> &'static str { 294 "file" 295 } 296 } 297 298 struct ShellSourceFactory; 299 impl SourceFactory for ShellSourceFactory { 300 fn create(&self, _config: &SourceConfig) -> Result<Arc<dyn EnvSource>, SourceError> { 301 Ok(Arc::new(crate::source::shell::ShellSource::new()) as Arc<dyn EnvSource>) 302 } 303 304 fn source_type(&self) -> &'static str { 305 "shell" 306 } 307 } 308 309 struct MemorySourceFactory; 310 impl SourceFactory for MemorySourceFactory { 311 fn create(&self, _config: &SourceConfig) -> Result<Arc<dyn EnvSource>, SourceError> { 312 Ok(Arc::new(crate::source::memory::MemorySource::new()) as Arc<dyn EnvSource>) 313 } 314 315 fn source_type(&self) -> &'static str { 316 "memory" 317 } 318 } 319 320 #[cfg(test)] 321 mod tests { 322 use super::*; 323 324 #[test] 325 fn test_registry_basics() { 326 let registry = SourceRegistry::new(); 327 assert_eq!(registry.source_count(), 0); 328 } 329 }