/ crates / bot / src / scheduler.rs
scheduler.rs
  1  /// Task scheduler for bot operations
  2  ///
  3  /// Provides tokio-based task scheduling with support for:
  4  /// - One-time tasks
  5  /// - Recurring tasks
  6  /// - Delayed execution
  7  
  8  use crate::{BotError, Result};
  9  use tokio::time::{Duration, Instant, interval_at, sleep};
 10  use std::sync::Arc;
 11  use parking_lot::Mutex;
 12  
 13  /// A scheduled task
 14  pub struct Task {
 15      /// Task identifier
 16      pub id: String,
 17  
 18      /// Task type
 19      pub kind: TaskKind,
 20  
 21      /// Task function
 22      pub func: Arc<dyn Fn() -> tokio::task::JoinHandle<()> + Send + Sync>,
 23  }
 24  
 25  /// Type of scheduled task
 26  #[derive(Debug, Clone)]
 27  pub enum TaskKind {
 28      /// Execute once immediately
 29      Immediate,
 30  
 31      /// Execute once after a delay
 32      Delayed { delay: Duration },
 33  
 34      /// Execute repeatedly at intervals
 35      Recurring { interval: Duration },
 36  }
 37  
 38  /// Task scheduler for managing bot operations
 39  pub struct Scheduler {
 40      /// Active tasks
 41      tasks: Arc<Mutex<Vec<Task>>>,
 42  
 43      /// Shutdown signal
 44      shutdown: Arc<Mutex<bool>>,
 45  }
 46  
 47  impl Scheduler {
 48      /// Create a new scheduler
 49      pub fn new() -> Self {
 50          Self {
 51              tasks: Arc::new(Mutex::new(Vec::new())),
 52              shutdown: Arc::new(Mutex::new(false)),
 53          }
 54      }
 55  
 56      /// Schedule a one-time immediate task
 57      pub fn schedule_immediate<F>(&self, id: String, func: F) -> Result<()>
 58      where
 59          F: Fn() -> tokio::task::JoinHandle<()> + Send + Sync + 'static,
 60      {
 61          let task = Task {
 62              id,
 63              kind: TaskKind::Immediate,
 64              func: Arc::new(func),
 65          };
 66  
 67          self.tasks.lock().push(task);
 68          Ok(())
 69      }
 70  
 71      /// Schedule a one-time delayed task
 72      pub fn schedule_delayed<F>(&self, id: String, delay: Duration, func: F) -> Result<()>
 73      where
 74          F: Fn() -> tokio::task::JoinHandle<()> + Send + Sync + 'static,
 75      {
 76          let task = Task {
 77              id,
 78              kind: TaskKind::Delayed { delay },
 79              func: Arc::new(func),
 80          };
 81  
 82          self.tasks.lock().push(task);
 83          Ok(())
 84      }
 85  
 86      /// Schedule a recurring task
 87      pub fn schedule_recurring<F>(&self, id: String, interval_duration: Duration, func: F) -> Result<()>
 88      where
 89          F: Fn() -> tokio::task::JoinHandle<()> + Send + Sync + 'static,
 90      {
 91          let task = Task {
 92              id,
 93              kind: TaskKind::Recurring { interval: interval_duration },
 94              func: Arc::new(func),
 95          };
 96  
 97          self.tasks.lock().push(task);
 98          Ok(())
 99      }
100  
101      /// Run all scheduled tasks
102      pub async fn run(&self) -> Result<()> {
103          let tasks = self.tasks.lock().clone();
104  
105          for task in tasks {
106              if *self.shutdown.lock() {
107                  break;
108              }
109  
110              match task.kind {
111                  TaskKind::Immediate => {
112                      (task.func)();
113                  }
114                  TaskKind::Delayed { delay } => {
115                      let func = task.func.clone();
116                      let shutdown = self.shutdown.clone();
117  
118                      tokio::spawn(async move {
119                          sleep(delay).await;
120                          if !*shutdown.lock() {
121                              func();
122                          }
123                      });
124                  }
125                  TaskKind::Recurring { interval: interval_duration } => {
126                      let func = task.func.clone();
127                      let shutdown = self.shutdown.clone();
128  
129                      tokio::spawn(async move {
130                          let start = Instant::now() + interval_duration;
131                          let mut ticker = interval_at(start, interval_duration);
132  
133                          loop {
134                              ticker.tick().await;
135  
136                              if *shutdown.lock() {
137                                  break;
138                              }
139  
140                              func();
141                          }
142                      });
143                  }
144              }
145          }
146  
147          Ok(())
148      }
149  
150      /// Shutdown the scheduler
151      pub fn shutdown(&self) {
152          *self.shutdown.lock() = true;
153      }
154  
155      /// Check if scheduler is shutdown
156      pub fn is_shutdown(&self) -> bool {
157          *self.shutdown.lock()
158      }
159  
160      /// Get number of scheduled tasks
161      pub fn task_count(&self) -> usize {
162          self.tasks.lock().len()
163      }
164  }
165  
166  impl Default for Scheduler {
167      fn default() -> Self {
168          Self::new()
169      }
170  }
171  
172  #[cfg(test)]
173  mod tests {
174      use super::*;
175      use std::sync::atomic::{AtomicUsize, Ordering};
176  
177      #[tokio::test]
178      async fn test_immediate_task() {
179          let scheduler = Scheduler::new();
180          let counter = Arc::new(AtomicUsize::new(0));
181  
182          let counter_clone = counter.clone();
183          scheduler
184              .schedule_immediate("test-task".to_string(), move || {
185                  let c = counter_clone.clone();
186                  tokio::spawn(async move {
187                      c.fetch_add(1, Ordering::SeqCst);
188                  })
189              })
190              .unwrap();
191  
192          scheduler.run().await.unwrap();
193  
194          tokio::time::sleep(Duration::from_millis(100)).await;
195          assert_eq!(counter.load(Ordering::SeqCst), 1);
196      }
197  
198      #[tokio::test]
199      async fn test_delayed_task() {
200          let scheduler = Scheduler::new();
201          let counter = Arc::new(AtomicUsize::new(0));
202  
203          let counter_clone = counter.clone();
204          scheduler
205              .schedule_delayed(
206                  "delayed-task".to_string(),
207                  Duration::from_millis(100),
208                  move || {
209                      let c = counter_clone.clone();
210                      tokio::spawn(async move {
211                          c.fetch_add(1, Ordering::SeqCst);
212                      })
213                  },
214              )
215              .unwrap();
216  
217          scheduler.run().await.unwrap();
218  
219          // Check before delay
220          assert_eq!(counter.load(Ordering::SeqCst), 0);
221  
222          // Check after delay
223          tokio::time::sleep(Duration::from_millis(200)).await;
224          assert_eq!(counter.load(Ordering::SeqCst), 1);
225      }
226  
227      #[tokio::test]
228      async fn test_recurring_task() {
229          let scheduler = Scheduler::new();
230          let counter = Arc::new(AtomicUsize::new(0));
231  
232          let counter_clone = counter.clone();
233          scheduler
234              .schedule_recurring(
235                  "recurring-task".to_string(),
236                  Duration::from_millis(50),
237                  move || {
238                      let c = counter_clone.clone();
239                      tokio::spawn(async move {
240                          c.fetch_add(1, Ordering::SeqCst);
241                      })
242                  },
243              )
244              .unwrap();
245  
246          scheduler.run().await.unwrap();
247  
248          // Wait for multiple intervals
249          tokio::time::sleep(Duration::from_millis(250)).await;
250  
251          // Should have executed at least 3 times
252          assert!(counter.load(Ordering::SeqCst) >= 3);
253  
254          scheduler.shutdown();
255      }
256  
257      #[test]
258      fn test_task_count() {
259          let scheduler = Scheduler::new();
260  
261          scheduler
262              .schedule_immediate("task1".to_string(), || {
263                  tokio::spawn(async {})
264              })
265              .unwrap();
266  
267          scheduler
268              .schedule_delayed("task2".to_string(), Duration::from_secs(1), || {
269                  tokio::spawn(async {})
270              })
271              .unwrap();
272  
273          assert_eq!(scheduler.task_count(), 2);
274      }
275  }