/ src / source / remote / manager.rs
manager.rs
  1  //! Provider manager for coordinating multiple external providers.
  2  //!
  3  //! The `ProviderManager` is the central coordinator for all external providers.
  4  //! It handles:
  5  //!
  6  //! - Provider discovery and registration
  7  //! - Lazy/eager spawning based on configuration
  8  //! - Health check scheduling
  9  //! - Crash recovery with exponential backoff
 10  //! - Graceful shutdown of all providers
 11  //!
 12  //! ## Usage
 13  //!
 14  //! ```ignore
 15  //! let manager = ProviderManager::new(config, discovery);
 16  //!
 17  //! // Register providers from configuration
 18  //! manager.register_from_config().await?;
 19  //!
 20  //! // Get a specific provider
 21  //! if let Some(doppler) = manager.get("doppler") {
 22  //!     doppler.authenticate(credentials).await?;
 23  //!     let secrets = doppler.fetch_secrets().await?;
 24  //! }
 25  //!
 26  //! // Shutdown all providers
 27  //! manager.shutdown_all().await;
 28  //! ```
 29  
 30  use super::discovery::{DiscoveryConfig, ProviderDiscovery};
 31  use super::external::{ExternalProviderAdapter, ProviderState};
 32  use super::traits::RemoteSourceInfo;
 33  use crate::config::{ExternalProviderConfig, ProvidersConfig, SpawnStrategy};
 34  use crate::error::SourceError;
 35  use crate::source::traits::{AsyncEnvSource, SourceId};
 36  use parking_lot::RwLock;
 37  use std::collections::HashMap;
 38  use std::sync::Arc;
 39  use std::time::Duration;
 40  
 41  /// Health check interval.
 42  const HEALTH_CHECK_INTERVAL: Duration = Duration::from_secs(30);
 43  
 44  /// Provider manager state.
 45  pub struct ProviderManager {
 46      /// Provider configuration.
 47      config: ProvidersConfig,
 48      /// Provider discovery.
 49      discovery: ProviderDiscovery,
 50      /// Registered providers by ID.
 51      providers: RwLock<HashMap<String, Arc<ExternalProviderAdapter>>>,
 52      /// Whether the manager is running.
 53      running: RwLock<bool>,
 54      /// Health check task handle.
 55      health_check_handle: RwLock<Option<tokio::task::JoinHandle<()>>>,
 56  }
 57  
 58  impl ProviderManager {
 59      /// Creates a new provider manager.
 60      pub fn new(config: ProvidersConfig) -> Self {
 61          let discovery_config = DiscoveryConfig {
 62              providers_path: config.path.clone(),
 63              binary_overrides: config
 64                  .providers
 65                  .iter()
 66                  .filter_map(|(id, c)| c.binary.as_ref().map(|b| (id.clone(), b.clone())))
 67                  .collect(),
 68          };
 69  
 70          Self {
 71              config,
 72              discovery: ProviderDiscovery::new(discovery_config),
 73              providers: RwLock::new(HashMap::new()),
 74              running: RwLock::new(false),
 75              health_check_handle: RwLock::new(None),
 76          }
 77      }
 78  
 79      /// Creates a provider manager with custom discovery.
 80      pub fn with_discovery(config: ProvidersConfig, discovery: ProviderDiscovery) -> Self {
 81          Self {
 82              config,
 83              discovery,
 84              providers: RwLock::new(HashMap::new()),
 85              running: RwLock::new(false),
 86              health_check_handle: RwLock::new(None),
 87          }
 88      }
 89  
 90      /// Starts the provider manager.
 91      ///
 92      /// This registers all enabled providers and spawns eager providers.
 93      pub async fn start(&self) -> Result<(), SourceError> {
 94          *self.running.write() = true;
 95  
 96          // Register all enabled providers
 97          self.register_from_config().await?;
 98  
 99          // Start health check task
100          self.start_health_checks();
101  
102          Ok(())
103      }
104  
105      /// Registers providers based on configuration.
106      pub async fn register_from_config(&self) -> Result<(), SourceError> {
107          for (provider_id, provider_config) in &self.config.providers {
108              if !provider_config.enabled {
109                  continue;
110              }
111  
112              match self.register(provider_id, provider_config.clone()).await {
113                  Ok(_) => {
114                      tracing::info!("Registered provider: {}", provider_id);
115                  }
116                  Err(e) => {
117                      tracing::warn!("Failed to register provider {}: {}", provider_id, e);
118                      // Continue with other providers
119                  }
120              }
121          }
122  
123          Ok(())
124      }
125  
126      /// Registers a single provider.
127      pub async fn register(
128          &self,
129          provider_id: &str,
130          config: ExternalProviderConfig,
131      ) -> Result<Arc<ExternalProviderAdapter>, SourceError> {
132          // Create adapter
133          let adapter =
134              ExternalProviderAdapter::discover(provider_id, config.clone(), &self.discovery)?;
135          let adapter = Arc::new(adapter);
136  
137          // Spawn if eager
138          if config.spawn == SpawnStrategy::Eager {
139              adapter.spawn().await?;
140          }
141  
142          // Store
143          self.providers
144              .write()
145              .insert(provider_id.to_string(), Arc::clone(&adapter));
146  
147          Ok(adapter)
148      }
149  
150      /// Gets a provider by ID.
151      pub fn get(&self, provider_id: &str) -> Option<Arc<ExternalProviderAdapter>> {
152          self.providers.read().get(provider_id).cloned()
153      }
154  
155      /// Gets a provider by ID, spawning if not yet started.
156      pub async fn get_or_spawn(
157          &self,
158          provider_id: &str,
159      ) -> Result<Arc<ExternalProviderAdapter>, SourceError> {
160          let adapter = self.providers.read().get(provider_id).cloned();
161  
162          match adapter {
163              Some(adapter) => {
164                  // Spawn if not started
165                  if adapter.state() == ProviderState::NotStarted {
166                      adapter.spawn().await?;
167                  }
168                  Ok(adapter)
169              }
170              None => {
171                  // Check if configured
172                  if let Some(config) = self.config.providers.get(provider_id) {
173                      if config.enabled {
174                          // Register and spawn
175                          let adapter = self.register(provider_id, config.clone()).await?;
176                          adapter.spawn().await?;
177                          return Ok(adapter);
178                      }
179                  }
180                  Err(SourceError::UnknownProvider {
181                      provider: provider_id.into(),
182                  })
183              }
184          }
185      }
186  
187      /// Lists all registered providers.
188      pub fn list(&self) -> Vec<Arc<ExternalProviderAdapter>> {
189          self.providers.read().values().cloned().collect()
190      }
191  
192      /// Lists all provider IDs.
193      pub fn provider_ids(&self) -> Vec<String> {
194          self.providers.read().keys().cloned().collect()
195      }
196  
197      /// Lists providers with their info for UI display.
198      pub fn list_with_info(&self) -> Vec<RemoteSourceInfo> {
199          self.providers
200              .read()
201              .values()
202              .map(|p| p.info())
203              .collect()
204      }
205  
206      /// Returns provider info by ID.
207      pub fn info(&self, provider_id: &str) -> Option<RemoteSourceInfo> {
208          self.providers.read().get(provider_id).map(|p| p.info())
209      }
210  
211      /// Checks if a provider is registered.
212      pub fn has_provider(&self, provider_id: &str) -> bool {
213          self.providers.read().contains_key(provider_id)
214      }
215  
216      /// Checks if a provider is running.
217      pub fn is_running(&self, provider_id: &str) -> bool {
218          self.providers
219              .read()
220              .get(provider_id)
221              .map(|p| p.state() == ProviderState::Running)
222              .unwrap_or(false)
223      }
224  
225      /// Checks if a provider is authenticated.
226      pub fn is_authenticated(&self, provider_id: &str) -> bool {
227          self.providers
228              .read()
229              .get(provider_id)
230              .map(|p| p.auth_status().is_authenticated())
231              .unwrap_or(false)
232      }
233  
234      /// Unregisters a provider.
235      pub async fn unregister(&self, provider_id: &str) -> Result<(), SourceError> {
236          // Extract the adapter while holding the lock, then drop the lock before awaiting
237          let adapter = self.providers.write().remove(provider_id);
238          if let Some(adapter) = adapter {
239              adapter.shutdown().await?;
240          }
241          Ok(())
242      }
243  
244      /// Refreshes all running providers.
245      pub async fn refresh_all(&self) -> Result<(), SourceError> {
246          let providers: Vec<_> = self.providers.read().values().cloned().collect();
247  
248          for provider in providers {
249              if provider.state() == ProviderState::Running {
250                  if let Err(e) = provider.refresh().await {
251                      tracing::warn!(
252                          "Failed to refresh provider {}: {}",
253                          provider.provider_id(),
254                          e
255                      );
256                  }
257              }
258          }
259  
260          Ok(())
261      }
262  
263      /// Shuts down all providers.
264      pub async fn shutdown_all(&self) {
265          *self.running.write() = false;
266  
267          // Stop health check task
268          if let Some(handle) = self.health_check_handle.write().take() {
269              handle.abort();
270          }
271  
272          // Shutdown all providers
273          let providers: Vec<_> = self.providers.write().drain().collect();
274  
275          for (id, provider) in providers {
276              tracing::info!("Shutting down provider: {}", id);
277              if let Err(e) = provider.shutdown().await {
278                  tracing::warn!("Failed to shutdown provider {}: {}", id, e);
279              }
280          }
281      }
282  
283      /// Returns the providers path.
284      pub fn providers_path(&self) -> &std::path::Path {
285          &self.config.path
286      }
287  
288      /// Lists installed provider binaries.
289      pub fn list_installed(&self) -> Vec<super::discovery::ProviderBinaryInfo> {
290          self.discovery.list_installed()
291      }
292  
293      /// Checks if a provider binary is installed.
294      pub fn is_installed(&self, provider_id: &str) -> bool {
295          self.discovery.is_installed(provider_id)
296      }
297  
298      /// Returns the enabled provider IDs from configuration.
299      pub fn enabled_provider_ids(&self) -> Vec<String> {
300          self.config
301              .providers
302              .iter()
303              .filter(|(_, c)| c.enabled)
304              .map(|(id, _)| id.clone())
305              .collect()
306      }
307  
308      /// Starts the health check background task.
309      fn start_health_checks(&self) {
310          let providers = Arc::new(self.providers.read().clone());
311          let running = Arc::new(RwLock::new(true));
312  
313          let running_clone = Arc::clone(&running);
314          let handle = tokio::spawn(async move {
315              let mut interval = tokio::time::interval(HEALTH_CHECK_INTERVAL);
316  
317              loop {
318                  interval.tick().await;
319  
320                  if !*running_clone.read() {
321                      break;
322                  }
323  
324                  for (_id, provider) in providers.iter() {
325                      if provider.state() == ProviderState::Running {
326                          match provider.health_check().await {
327                              Ok(healthy) => {
328                                  if !healthy {
329                                      tracing::warn!(
330                                          "Provider {} health check failed, attempting restart",
331                                          provider.provider_id()
332                                      );
333                                      if let Err(e) = provider.restart_if_needed().await {
334                                          tracing::error!(
335                                              "Failed to restart provider {}: {}",
336                                              provider.provider_id(),
337                                              e
338                                          );
339                                      }
340                                  }
341                              }
342                              Err(e) => {
343                                  tracing::warn!(
344                                      "Provider {} health check error: {}",
345                                      provider.provider_id(),
346                                      e
347                                  );
348                              }
349                          }
350                      }
351                  }
352              }
353          });
354  
355          *self.health_check_handle.write() = Some(handle);
356      }
357  
358      /// Gets all providers as async env sources for registration.
359      pub fn as_async_sources(&self) -> Vec<Arc<dyn AsyncEnvSource>> {
360          self.providers
361              .read()
362              .values()
363              .map(|p| Arc::clone(p) as Arc<dyn AsyncEnvSource>)
364              .collect()
365      }
366  
367      /// Gets provider source IDs for filtering.
368      pub fn source_ids(&self) -> Vec<SourceId> {
369          self.providers
370              .read()
371              .values()
372              .map(|p| p.id().clone())
373              .collect()
374      }
375  }
376  
377  impl Drop for ProviderManager {
378      fn drop(&mut self) {
379          // Mark as not running to stop health checks
380          *self.running.write() = false;
381      }
382  }
383  
384  #[cfg(test)]
385  mod tests {
386      use super::*;
387      use std::path::PathBuf;
388  
389      #[test]
390      fn test_provider_manager_creation() {
391          let config = ProvidersConfig {
392              path: PathBuf::from("/tmp/providers"),
393              providers: HashMap::new(),
394          };
395  
396          let manager = ProviderManager::new(config);
397          assert!(manager.list().is_empty());
398      }
399  
400      #[test]
401      fn test_enabled_provider_ids() {
402          let mut providers = HashMap::new();
403          providers.insert(
404              "doppler".to_string(),
405              ExternalProviderConfig {
406                  enabled: true,
407                  ..Default::default()
408              },
409          );
410          providers.insert(
411              "aws".to_string(),
412              ExternalProviderConfig {
413                  enabled: false,
414                  ..Default::default()
415              },
416          );
417          providers.insert(
418              "vault".to_string(),
419              ExternalProviderConfig {
420                  enabled: true,
421                  ..Default::default()
422              },
423          );
424  
425          let config = ProvidersConfig {
426              path: PathBuf::from("/tmp/providers"),
427              providers,
428          };
429  
430          let manager = ProviderManager::new(config);
431          let enabled = manager.enabled_provider_ids();
432  
433          assert_eq!(enabled.len(), 2);
434          assert!(enabled.contains(&"doppler".to_string()));
435          assert!(enabled.contains(&"vault".to_string()));
436          assert!(!enabled.contains(&"aws".to_string()));
437      }
438  }