/ crates / distributed / src / registry.rs
registry.rs
  1  /// Worker registry with health tracking
  2  use crate::proto::WorkerInfo;
  3  use anyhow::{Context, Result};
  4  use parking_lot::RwLock;
  5  use std::collections::HashMap;
  6  use std::sync::Arc;
  7  
  8  /// Worker registration entry
  9  #[derive(Debug, Clone)]
 10  pub struct WorkerEntry {
 11      pub info: WorkerInfo,
 12      pub last_heartbeat_ms: i64,
 13      pub healthy: bool,
 14  }
 15  
 16  /// Worker registry
 17  pub struct WorkerRegistry {
 18      workers: Arc<RwLock<HashMap<String, WorkerEntry>>>,
 19  }
 20  
 21  impl WorkerRegistry {
 22      /// Create a new worker registry
 23      pub fn new() -> Self {
 24          Self {
 25              workers: Arc::new(RwLock::new(HashMap::new())),
 26          }
 27      }
 28  
 29      /// Register a worker
 30      pub fn register(&self, info: WorkerInfo) -> Result<()> {
 31          let worker_id = info.worker_id.clone();
 32  
 33          let entry = WorkerEntry {
 34              info,
 35              last_heartbeat_ms: current_time_ms(),
 36              healthy: true,
 37          };
 38  
 39          self.workers.write().insert(worker_id, entry);
 40  
 41          Ok(())
 42      }
 43  
 44      /// Update worker heartbeat
 45      pub fn update_heartbeat(&self, worker_id: &str, healthy: bool) -> Result<()> {
 46          let mut workers = self.workers.write();
 47  
 48          let entry = workers
 49              .get_mut(worker_id)
 50              .context("Worker not found")?;
 51  
 52          entry.last_heartbeat_ms = current_time_ms();
 53          entry.healthy = healthy;
 54  
 55          Ok(())
 56      }
 57  
 58      /// Check for stale workers and mark as unhealthy
 59      pub fn check_stale_workers(&self, timeout_ms: i64) {
 60          let now = current_time_ms();
 61          let mut workers = self.workers.write();
 62  
 63          for entry in workers.values_mut() {
 64              if now - entry.last_heartbeat_ms > timeout_ms {
 65                  entry.healthy = false;
 66              }
 67          }
 68      }
 69  
 70      /// Get worker count
 71      pub fn worker_count(&self) -> usize {
 72          self.workers.read().len()
 73      }
 74  
 75      /// Get healthy worker count
 76      pub fn healthy_worker_count(&self) -> usize {
 77          self.workers
 78              .read()
 79              .values()
 80              .filter(|w| w.healthy)
 81              .count()
 82      }
 83  
 84      /// Get total capacity across all workers
 85      pub fn total_capacity(&self) -> u32 {
 86          self.workers
 87              .read()
 88              .values()
 89              .filter(|w| w.healthy)
 90              .map(|w| w.info.max_bots)
 91              .sum()
 92      }
 93  
 94      /// Find an available worker
 95      pub fn find_available_worker(&self) -> Option<String> {
 96          self.workers
 97              .read()
 98              .iter()
 99              .find(|(_, entry)| entry.healthy)
100              .map(|(id, _)| id.clone())
101      }
102  
103      /// List all workers
104      pub fn list_workers(&self) -> Vec<WorkerInfo> {
105          self.workers
106              .read()
107              .values()
108              .filter(|w| w.healthy)
109              .map(|w| w.info.clone())
110              .collect()
111      }
112  
113      /// Get worker by ID
114      pub fn get_worker(&self, worker_id: &str) -> Option<WorkerEntry> {
115          self.workers.read().get(worker_id).cloned()
116      }
117  
118      /// Remove a worker
119      pub fn remove_worker(&self, worker_id: &str) -> Result<()> {
120          self.workers
121              .write()
122              .remove(worker_id)
123              .context("Worker not found")?;
124          Ok(())
125      }
126  }
127  
128  impl Default for WorkerRegistry {
129      fn default() -> Self {
130          Self::new()
131      }
132  }
133  
134  fn current_time_ms() -> i64 {
135      std::time::SystemTime::now()
136          .duration_since(std::time::UNIX_EPOCH)
137          .unwrap_or_default()
138          .as_millis() as i64
139  }
140  
141  #[cfg(test)]
142  mod tests {
143      use super::*;
144  
145      fn create_worker_info(id: &str, max_bots: u32) -> WorkerInfo {
146          WorkerInfo {
147              worker_id: id.to_string(),
148              cpu_cores: 8,
149              memory_bytes: 8 * 1024 * 1024 * 1024,
150              max_bots,
151              capabilities: vec!["trader".to_string()],
152              address: "localhost:50051".to_string(),
153          }
154      }
155  
156      #[test]
157      fn test_worker_registration() {
158          let registry = WorkerRegistry::new();
159  
160          let worker = create_worker_info("worker-1", 100);
161          registry.register(worker).unwrap();
162  
163          assert_eq!(registry.worker_count(), 1);
164          assert_eq!(registry.healthy_worker_count(), 1);
165          assert_eq!(registry.total_capacity(), 100);
166      }
167  
168      #[test]
169      fn test_heartbeat_update() {
170          let registry = WorkerRegistry::new();
171  
172          let worker = create_worker_info("worker-1", 100);
173          registry.register(worker).unwrap();
174  
175          registry.update_heartbeat("worker-1", true).unwrap();
176  
177          let entry = registry.get_worker("worker-1").unwrap();
178          assert!(entry.healthy);
179      }
180  
181      #[test]
182      fn test_find_available_worker() {
183          let registry = WorkerRegistry::new();
184  
185          let worker1 = create_worker_info("worker-1", 100);
186          let worker2 = create_worker_info("worker-2", 100);
187  
188          registry.register(worker1).unwrap();
189          registry.register(worker2).unwrap();
190  
191          let available = registry.find_available_worker();
192          assert!(available.is_some());
193      }
194  
195      #[test]
196      fn test_remove_worker() {
197          let registry = WorkerRegistry::new();
198  
199          let worker = create_worker_info("worker-1", 100);
200          registry.register(worker).unwrap();
201  
202          assert_eq!(registry.worker_count(), 1);
203  
204          registry.remove_worker("worker-1").unwrap();
205  
206          assert_eq!(registry.worker_count(), 0);
207      }
208  }