/ src / source / registry.rs
registry.rs
  1  use crate::error::SourceError;
  2  use crate::source::traits::*;
  3  use compact_str::CompactString;
  4  use hashbrown::HashMap;
  5  use parking_lot::RwLock;
  6  use std::sync::Arc;
  7  
  8  #[cfg(feature = "remote")]
  9  use crate::source::remote::ExternalProviderAdapter;
 10  
 11  pub struct SourceRegistry {
 12      sync_sources: RwLock<HashMap<SourceId, Arc<dyn EnvSource>>>,
 13      #[cfg(feature = "async")]
 14      async_sources: RwLock<HashMap<SourceId, Arc<dyn AsyncEnvSource>>>,
 15      path_index: RwLock<HashMap<std::path::PathBuf, SourceId>>,
 16      factories: RwLock<HashMap<CompactString, Arc<dyn SourceFactory>>>,
 17      /// External provider adapters (out-of-process providers).
 18      #[cfg(feature = "remote")]
 19      external_providers: RwLock<HashMap<String, Arc<ExternalProviderAdapter>>>,
 20  }
 21  
 22  impl SourceRegistry {
 23      pub fn new() -> Self {
 24          let mut factories: HashMap<CompactString, Arc<dyn SourceFactory>> = HashMap::new();
 25  
 26          #[cfg(feature = "file")]
 27          factories.insert(
 28              CompactString::new("file"),
 29              Arc::new(FileSourceFactory) as Arc<dyn SourceFactory>,
 30          );
 31          #[cfg(feature = "shell")]
 32          factories.insert(
 33              CompactString::new("shell"),
 34              Arc::new(ShellSourceFactory) as Arc<dyn SourceFactory>,
 35          );
 36          factories.insert(
 37              CompactString::new("memory"),
 38              Arc::new(MemorySourceFactory) as Arc<dyn SourceFactory>,
 39          );
 40  
 41          Self {
 42              sync_sources: RwLock::new(HashMap::new()),
 43              #[cfg(feature = "async")]
 44              async_sources: RwLock::new(HashMap::new()),
 45              path_index: RwLock::new(HashMap::new()),
 46              factories: RwLock::new(factories),
 47              #[cfg(feature = "remote")]
 48              external_providers: RwLock::new(HashMap::new()),
 49          }
 50      }
 51  
 52      pub fn register_factory<F: SourceFactory + 'static>(&self, source_type: &str, factory: F) {
 53          self.factories
 54              .write()
 55              .insert(CompactString::new(source_type), Arc::new(factory));
 56      }
 57  
 58      pub fn register_sync(&self, source: Arc<dyn EnvSource>) -> SourceId {
 59          let id = source.id().clone();
 60          self.sync_sources.write().insert(id.clone(), source.clone());
 61  
 62          if source.source_type() == SourceType::File {
 63              if let Some(path) = id.as_str().strip_prefix("file:") {
 64                  let path_buf = std::path::PathBuf::from(path);
 65                  self.path_index.write().insert(path_buf, id.clone());
 66              }
 67          }
 68  
 69          id
 70      }
 71  
 72      #[cfg(feature = "async")]
 73      pub fn register_async(&self, source: Arc<dyn AsyncEnvSource>) -> SourceId {
 74          let id = source.id().clone();
 75          self.async_sources.write().insert(id.clone(), source);
 76          id
 77      }
 78  
 79      pub fn sync_sources_by_priority(&self) -> Vec<Arc<dyn EnvSource>> {
 80          let sources = self.sync_sources.read();
 81          let mut sorted: Vec<_> = sources.values().cloned().collect();
 82          sorted.sort_by_key(|a| std::cmp::Reverse(a.priority()));
 83          sorted
 84      }
 85  
 86      #[cfg(feature = "async")]
 87      pub fn async_sources(&self) -> Vec<Arc<dyn AsyncEnvSource>> {
 88          self.async_sources.read().values().cloned().collect()
 89      }
 90  
 91      #[cfg(feature = "async")]
 92      pub fn has_async_sources(&self) -> bool {
 93          !self.async_sources.read().is_empty()
 94      }
 95  
 96      #[cfg(feature = "async")]
 97      pub async fn load_all(&self) -> Result<Vec<SourceSnapshot>, SourceError> {
 98          let snapshots = {
 99              let mut snapshots = Vec::new();
100              let sources_guard = self.sync_sources.read();
101              for (_id, source) in sources_guard.iter() {
102                  let snapshot = source.load()?;
103                  snapshots.push(snapshot);
104              }
105              snapshots
106          };
107          let mut snapshots = snapshots;
108  
109          if self.has_async_sources() {
110              let async_sources = self.async_sources.read().clone();
111              let futures: Vec<_> = async_sources.values().map(|s| s.load()).collect();
112  
113              let results = futures::future::join_all(futures).await;
114              for result in results {
115                  snapshots.push(result?);
116              }
117          }
118  
119          Ok(snapshots)
120      }
121  
122      #[cfg(feature = "async")]
123      pub async fn refresh_async(&self) -> Result<(), SourceError> {
124          let async_sources = self.async_sources.read().clone();
125          let futures: Vec<_> = async_sources.values().map(|s| s.refresh()).collect();
126  
127          let results = futures::future::try_join_all(futures).await?;
128          for result in results {
129              if result {
130                  tracing::info!("Async source refreshed");
131              }
132          }
133  
134          Ok(())
135      }
136  
137      pub fn sources_of_type(&self, source_type: SourceType) -> Vec<Arc<dyn EnvSource>> {
138          let sources = self.sync_sources.read();
139          sources
140              .values()
141              .filter(|s| s.source_type() == source_type)
142              .cloned()
143              .collect()
144      }
145  
146      pub fn sources_for_paths(&self, paths: &[std::path::PathBuf]) -> Vec<Arc<dyn EnvSource>> {
147          let sources = self.sync_sources.read();
148          let path_index = self.path_index.read();
149          let mut result = Vec::new();
150          let mut seen_ids = std::collections::HashSet::new();
151  
152          for path in paths {
153              if let Some(source_id) = path_index.get(path) {
154                  if !seen_ids.contains(source_id) {
155                      if let Some(source) = sources.get(source_id) {
156                          result.push(Arc::clone(source));
157                          seen_ids.insert(source_id.clone());
158                      }
159                  }
160              }
161          }
162  
163          result
164      }
165  
166      pub fn invalidate_file(&self, _path: &std::path::Path) {
167          for source in self.sync_sources.read().values() {
168              if source.source_type() == SourceType::File {
169                  source.invalidate();
170              }
171          }
172      }
173  
174      pub fn is_registered(&self, id: &SourceId) -> bool {
175          self.sync_sources.read().contains_key(id)
176      }
177  
178      pub fn unregister_sync(&self, id: &SourceId) {
179          self.sync_sources.write().remove(id);
180  
181          if let Some(path) = id.as_str().strip_prefix("file:") {
182              let path_buf = std::path::PathBuf::from(path);
183              self.path_index.write().remove(&path_buf);
184          }
185      }
186  
187      pub fn registered_file_paths(&self) -> Vec<std::path::PathBuf> {
188          self.path_index.read().keys().cloned().collect()
189      }
190  
191      pub fn source_count(&self) -> usize {
192          let count = self.sync_sources.read().len();
193          #[cfg(feature = "async")]
194          let count = count + self.async_sources.read().len();
195          count
196      }
197  }
198  
199  // External provider methods (out-of-process providers)
200  #[cfg(feature = "remote")]
201  impl SourceRegistry {
202      /// Registers an external provider adapter.
203      ///
204      /// Also registers it as an async source for load_all().
205      pub fn register_external_provider(&self, adapter: Arc<ExternalProviderAdapter>) {
206          let provider_id = adapter.provider_id().to_string();
207          let source_id = adapter.id().clone();
208  
209          self.external_providers.write().insert(provider_id, Arc::clone(&adapter));
210          self.async_sources.write().insert(source_id, adapter);
211      }
212  
213      /// Gets an external provider by ID.
214      pub fn get_external_provider(&self, provider_id: &str) -> Option<Arc<ExternalProviderAdapter>> {
215          self.external_providers.read().get(provider_id).cloned()
216      }
217  
218      /// Lists all registered external providers.
219      pub fn external_providers(&self) -> Vec<Arc<ExternalProviderAdapter>> {
220          self.external_providers.read().values().cloned().collect()
221      }
222  
223      /// Returns the number of registered external providers.
224      pub fn external_provider_count(&self) -> usize {
225          self.external_providers.read().len()
226      }
227  
228      /// Unregisters an external provider by ID.
229      pub fn unregister_external_provider(&self, provider_id: &str) {
230          if let Some(adapter) = self.external_providers.write().remove(provider_id) {
231              let source_id = adapter.id().clone();
232              self.async_sources.write().remove(&source_id);
233          }
234      }
235  
236      /// Lists registered external provider IDs.
237      pub fn external_provider_ids(&self) -> Vec<String> {
238          self.external_providers.read().keys().cloned().collect()
239      }
240  }
241  
242  impl Default for SourceRegistry {
243      fn default() -> Self {
244          Self::new()
245      }
246  }
247  
248  #[cfg(not(feature = "async"))]
249  impl SourceRegistry {
250      pub fn has_async_sources(&self) -> bool {
251          false
252      }
253  
254      pub fn load_all(&self) -> Result<Vec<SourceSnapshot>, SourceError> {
255          let mut snapshots = Vec::new();
256          for source in self.sync_sources.read().values() {
257              snapshots.push(source.load()?);
258          }
259          Ok(snapshots)
260      }
261  }
262  
263  pub trait SourceFactory: Send + Sync {
264      fn create(&self, config: &SourceConfig) -> Result<Arc<dyn EnvSource>, SourceError>;
265      fn source_type(&self) -> &'static str;
266  }
267  
268  #[derive(Debug, Clone)]
269  pub struct SourceConfig {
270      pub source_type: String,
271      pub path: Option<std::path::PathBuf>,
272      pub enabled: bool,
273  }
274  
275  struct FileSourceFactory;
276  impl SourceFactory for FileSourceFactory {
277      fn create(&self, config: &SourceConfig) -> Result<Arc<dyn EnvSource>, SourceError> {
278          if let Some(path) = &config.path {
279              crate::source::file::FileSource::new(path)
280                  .map(|s| Arc::new(s) as Arc<dyn EnvSource>)
281                  .map_err(|e| SourceError::SourceRead {
282                      source_name: path.display().to_string(),
283                      reason: e.to_string(),
284                  })
285          } else {
286              Err(SourceError::SourceRead {
287                  source_name: "file".into(),
288                  reason: "No path specified".into(),
289              })
290          }
291      }
292  
293      fn source_type(&self) -> &'static str {
294          "file"
295      }
296  }
297  
298  struct ShellSourceFactory;
299  impl SourceFactory for ShellSourceFactory {
300      fn create(&self, _config: &SourceConfig) -> Result<Arc<dyn EnvSource>, SourceError> {
301          Ok(Arc::new(crate::source::shell::ShellSource::new()) as Arc<dyn EnvSource>)
302      }
303  
304      fn source_type(&self) -> &'static str {
305          "shell"
306      }
307  }
308  
309  struct MemorySourceFactory;
310  impl SourceFactory for MemorySourceFactory {
311      fn create(&self, _config: &SourceConfig) -> Result<Arc<dyn EnvSource>, SourceError> {
312          Ok(Arc::new(crate::source::memory::MemorySource::new()) as Arc<dyn EnvSource>)
313      }
314  
315      fn source_type(&self) -> &'static str {
316          "memory"
317      }
318  }
319  
320  #[cfg(test)]
321  mod tests {
322      use super::*;
323  
324      #[test]
325      fn test_registry_basics() {
326          let registry = SourceRegistry::new();
327          assert_eq!(registry.source_count(), 0);
328      }
329  }