manager.rs
1 //! Provider manager for coordinating multiple external providers. 2 //! 3 //! The `ProviderManager` is the central coordinator for all external providers. 4 //! It handles: 5 //! 6 //! - Provider discovery and registration 7 //! - Lazy/eager spawning based on configuration 8 //! - Health check scheduling 9 //! - Crash recovery with exponential backoff 10 //! - Graceful shutdown of all providers 11 //! 12 //! ## Usage 13 //! 14 //! ```ignore 15 //! let manager = ProviderManager::new(config, discovery); 16 //! 17 //! // Register providers from configuration 18 //! manager.register_from_config().await?; 19 //! 20 //! // Get a specific provider 21 //! if let Some(doppler) = manager.get("doppler") { 22 //! doppler.authenticate(credentials).await?; 23 //! let secrets = doppler.fetch_secrets().await?; 24 //! } 25 //! 26 //! // Shutdown all providers 27 //! manager.shutdown_all().await; 28 //! ``` 29 30 use super::discovery::{DiscoveryConfig, ProviderDiscovery}; 31 use super::external::{ExternalProviderAdapter, ProviderState}; 32 use super::traits::RemoteSourceInfo; 33 use crate::config::{ExternalProviderConfig, ProvidersConfig, SpawnStrategy}; 34 use crate::error::SourceError; 35 use crate::source::traits::{AsyncEnvSource, SourceId}; 36 use parking_lot::RwLock; 37 use std::collections::HashMap; 38 use std::sync::Arc; 39 use std::time::Duration; 40 41 /// Health check interval. 42 const HEALTH_CHECK_INTERVAL: Duration = Duration::from_secs(30); 43 44 /// Provider manager state. 45 pub struct ProviderManager { 46 /// Provider configuration. 47 config: ProvidersConfig, 48 /// Provider discovery. 49 discovery: ProviderDiscovery, 50 /// Registered providers by ID. 51 providers: RwLock<HashMap<String, Arc<ExternalProviderAdapter>>>, 52 /// Whether the manager is running. 53 running: RwLock<bool>, 54 /// Health check task handle. 55 health_check_handle: RwLock<Option<tokio::task::JoinHandle<()>>>, 56 } 57 58 impl ProviderManager { 59 /// Creates a new provider manager. 60 pub fn new(config: ProvidersConfig) -> Self { 61 let discovery_config = DiscoveryConfig { 62 providers_path: config.path.clone(), 63 binary_overrides: config 64 .providers 65 .iter() 66 .filter_map(|(id, c)| c.binary.as_ref().map(|b| (id.clone(), b.clone()))) 67 .collect(), 68 }; 69 70 Self { 71 config, 72 discovery: ProviderDiscovery::new(discovery_config), 73 providers: RwLock::new(HashMap::new()), 74 running: RwLock::new(false), 75 health_check_handle: RwLock::new(None), 76 } 77 } 78 79 /// Creates a provider manager with custom discovery. 80 pub fn with_discovery(config: ProvidersConfig, discovery: ProviderDiscovery) -> Self { 81 Self { 82 config, 83 discovery, 84 providers: RwLock::new(HashMap::new()), 85 running: RwLock::new(false), 86 health_check_handle: RwLock::new(None), 87 } 88 } 89 90 /// Starts the provider manager. 91 /// 92 /// This registers all enabled providers and spawns eager providers. 93 pub async fn start(&self) -> Result<(), SourceError> { 94 *self.running.write() = true; 95 96 // Register all enabled providers 97 self.register_from_config().await?; 98 99 // Start health check task 100 self.start_health_checks(); 101 102 Ok(()) 103 } 104 105 /// Registers providers based on configuration. 106 pub async fn register_from_config(&self) -> Result<(), SourceError> { 107 for (provider_id, provider_config) in &self.config.providers { 108 if !provider_config.enabled { 109 continue; 110 } 111 112 match self.register(provider_id, provider_config.clone()).await { 113 Ok(_) => { 114 tracing::info!("Registered provider: {}", provider_id); 115 } 116 Err(e) => { 117 tracing::warn!("Failed to register provider {}: {}", provider_id, e); 118 // Continue with other providers 119 } 120 } 121 } 122 123 Ok(()) 124 } 125 126 /// Registers a single provider. 127 pub async fn register( 128 &self, 129 provider_id: &str, 130 config: ExternalProviderConfig, 131 ) -> Result<Arc<ExternalProviderAdapter>, SourceError> { 132 // Create adapter 133 let adapter = 134 ExternalProviderAdapter::discover(provider_id, config.clone(), &self.discovery)?; 135 let adapter = Arc::new(adapter); 136 137 // Spawn if eager 138 if config.spawn == SpawnStrategy::Eager { 139 adapter.spawn().await?; 140 } 141 142 // Store 143 self.providers 144 .write() 145 .insert(provider_id.to_string(), Arc::clone(&adapter)); 146 147 Ok(adapter) 148 } 149 150 /// Gets a provider by ID. 151 pub fn get(&self, provider_id: &str) -> Option<Arc<ExternalProviderAdapter>> { 152 self.providers.read().get(provider_id).cloned() 153 } 154 155 /// Gets a provider by ID, spawning if not yet started. 156 pub async fn get_or_spawn( 157 &self, 158 provider_id: &str, 159 ) -> Result<Arc<ExternalProviderAdapter>, SourceError> { 160 let adapter = self.providers.read().get(provider_id).cloned(); 161 162 match adapter { 163 Some(adapter) => { 164 // Spawn if not started 165 if adapter.state() == ProviderState::NotStarted { 166 adapter.spawn().await?; 167 } 168 Ok(adapter) 169 } 170 None => { 171 // Check if configured 172 if let Some(config) = self.config.providers.get(provider_id) { 173 if config.enabled { 174 // Register and spawn 175 let adapter = self.register(provider_id, config.clone()).await?; 176 adapter.spawn().await?; 177 return Ok(adapter); 178 } 179 } 180 Err(SourceError::UnknownProvider { 181 provider: provider_id.into(), 182 }) 183 } 184 } 185 } 186 187 /// Lists all registered providers. 188 pub fn list(&self) -> Vec<Arc<ExternalProviderAdapter>> { 189 self.providers.read().values().cloned().collect() 190 } 191 192 /// Lists all provider IDs. 193 pub fn provider_ids(&self) -> Vec<String> { 194 self.providers.read().keys().cloned().collect() 195 } 196 197 /// Lists providers with their info for UI display. 198 pub fn list_with_info(&self) -> Vec<RemoteSourceInfo> { 199 self.providers 200 .read() 201 .values() 202 .map(|p| p.info()) 203 .collect() 204 } 205 206 /// Returns provider info by ID. 207 pub fn info(&self, provider_id: &str) -> Option<RemoteSourceInfo> { 208 self.providers.read().get(provider_id).map(|p| p.info()) 209 } 210 211 /// Checks if a provider is registered. 212 pub fn has_provider(&self, provider_id: &str) -> bool { 213 self.providers.read().contains_key(provider_id) 214 } 215 216 /// Checks if a provider is running. 217 pub fn is_running(&self, provider_id: &str) -> bool { 218 self.providers 219 .read() 220 .get(provider_id) 221 .map(|p| p.state() == ProviderState::Running) 222 .unwrap_or(false) 223 } 224 225 /// Checks if a provider is authenticated. 226 pub fn is_authenticated(&self, provider_id: &str) -> bool { 227 self.providers 228 .read() 229 .get(provider_id) 230 .map(|p| p.auth_status().is_authenticated()) 231 .unwrap_or(false) 232 } 233 234 /// Unregisters a provider. 235 pub async fn unregister(&self, provider_id: &str) -> Result<(), SourceError> { 236 // Extract the adapter while holding the lock, then drop the lock before awaiting 237 let adapter = self.providers.write().remove(provider_id); 238 if let Some(adapter) = adapter { 239 adapter.shutdown().await?; 240 } 241 Ok(()) 242 } 243 244 /// Refreshes all running providers. 245 pub async fn refresh_all(&self) -> Result<(), SourceError> { 246 let providers: Vec<_> = self.providers.read().values().cloned().collect(); 247 248 for provider in providers { 249 if provider.state() == ProviderState::Running { 250 if let Err(e) = provider.refresh().await { 251 tracing::warn!( 252 "Failed to refresh provider {}: {}", 253 provider.provider_id(), 254 e 255 ); 256 } 257 } 258 } 259 260 Ok(()) 261 } 262 263 /// Shuts down all providers. 264 pub async fn shutdown_all(&self) { 265 *self.running.write() = false; 266 267 // Stop health check task 268 if let Some(handle) = self.health_check_handle.write().take() { 269 handle.abort(); 270 } 271 272 // Shutdown all providers 273 let providers: Vec<_> = self.providers.write().drain().collect(); 274 275 for (id, provider) in providers { 276 tracing::info!("Shutting down provider: {}", id); 277 if let Err(e) = provider.shutdown().await { 278 tracing::warn!("Failed to shutdown provider {}: {}", id, e); 279 } 280 } 281 } 282 283 /// Returns the providers path. 284 pub fn providers_path(&self) -> &std::path::Path { 285 &self.config.path 286 } 287 288 /// Lists installed provider binaries. 289 pub fn list_installed(&self) -> Vec<super::discovery::ProviderBinaryInfo> { 290 self.discovery.list_installed() 291 } 292 293 /// Checks if a provider binary is installed. 294 pub fn is_installed(&self, provider_id: &str) -> bool { 295 self.discovery.is_installed(provider_id) 296 } 297 298 /// Returns the enabled provider IDs from configuration. 299 pub fn enabled_provider_ids(&self) -> Vec<String> { 300 self.config 301 .providers 302 .iter() 303 .filter(|(_, c)| c.enabled) 304 .map(|(id, _)| id.clone()) 305 .collect() 306 } 307 308 /// Starts the health check background task. 309 fn start_health_checks(&self) { 310 let providers = Arc::new(self.providers.read().clone()); 311 let running = Arc::new(RwLock::new(true)); 312 313 let running_clone = Arc::clone(&running); 314 let handle = tokio::spawn(async move { 315 let mut interval = tokio::time::interval(HEALTH_CHECK_INTERVAL); 316 317 loop { 318 interval.tick().await; 319 320 if !*running_clone.read() { 321 break; 322 } 323 324 for (_id, provider) in providers.iter() { 325 if provider.state() == ProviderState::Running { 326 match provider.health_check().await { 327 Ok(healthy) => { 328 if !healthy { 329 tracing::warn!( 330 "Provider {} health check failed, attempting restart", 331 provider.provider_id() 332 ); 333 if let Err(e) = provider.restart_if_needed().await { 334 tracing::error!( 335 "Failed to restart provider {}: {}", 336 provider.provider_id(), 337 e 338 ); 339 } 340 } 341 } 342 Err(e) => { 343 tracing::warn!( 344 "Provider {} health check error: {}", 345 provider.provider_id(), 346 e 347 ); 348 } 349 } 350 } 351 } 352 } 353 }); 354 355 *self.health_check_handle.write() = Some(handle); 356 } 357 358 /// Gets all providers as async env sources for registration. 359 pub fn as_async_sources(&self) -> Vec<Arc<dyn AsyncEnvSource>> { 360 self.providers 361 .read() 362 .values() 363 .map(|p| Arc::clone(p) as Arc<dyn AsyncEnvSource>) 364 .collect() 365 } 366 367 /// Gets provider source IDs for filtering. 368 pub fn source_ids(&self) -> Vec<SourceId> { 369 self.providers 370 .read() 371 .values() 372 .map(|p| p.id().clone()) 373 .collect() 374 } 375 } 376 377 impl Drop for ProviderManager { 378 fn drop(&mut self) { 379 // Mark as not running to stop health checks 380 *self.running.write() = false; 381 } 382 } 383 384 #[cfg(test)] 385 mod tests { 386 use super::*; 387 use std::path::PathBuf; 388 389 #[test] 390 fn test_provider_manager_creation() { 391 let config = ProvidersConfig { 392 path: PathBuf::from("/tmp/providers"), 393 providers: HashMap::new(), 394 }; 395 396 let manager = ProviderManager::new(config); 397 assert!(manager.list().is_empty()); 398 } 399 400 #[test] 401 fn test_enabled_provider_ids() { 402 let mut providers = HashMap::new(); 403 providers.insert( 404 "doppler".to_string(), 405 ExternalProviderConfig { 406 enabled: true, 407 ..Default::default() 408 }, 409 ); 410 providers.insert( 411 "aws".to_string(), 412 ExternalProviderConfig { 413 enabled: false, 414 ..Default::default() 415 }, 416 ); 417 providers.insert( 418 "vault".to_string(), 419 ExternalProviderConfig { 420 enabled: true, 421 ..Default::default() 422 }, 423 ); 424 425 let config = ProvidersConfig { 426 path: PathBuf::from("/tmp/providers"), 427 providers, 428 }; 429 430 let manager = ProviderManager::new(config); 431 let enabled = manager.enabled_provider_ids(); 432 433 assert_eq!(enabled.len(), 2); 434 assert!(enabled.contains(&"doppler".to_string())); 435 assert!(enabled.contains(&"vault".to_string())); 436 assert!(!enabled.contains(&"aws".to_string())); 437 } 438 }