scheduler.rs
1 /// Task scheduler for bot operations 2 /// 3 /// Provides tokio-based task scheduling with support for: 4 /// - One-time tasks 5 /// - Recurring tasks 6 /// - Delayed execution 7 8 use crate::{BotError, Result}; 9 use tokio::time::{Duration, Instant, interval_at, sleep}; 10 use std::sync::Arc; 11 use parking_lot::Mutex; 12 13 /// A scheduled task 14 pub struct Task { 15 /// Task identifier 16 pub id: String, 17 18 /// Task type 19 pub kind: TaskKind, 20 21 /// Task function 22 pub func: Arc<dyn Fn() -> tokio::task::JoinHandle<()> + Send + Sync>, 23 } 24 25 /// Type of scheduled task 26 #[derive(Debug, Clone)] 27 pub enum TaskKind { 28 /// Execute once immediately 29 Immediate, 30 31 /// Execute once after a delay 32 Delayed { delay: Duration }, 33 34 /// Execute repeatedly at intervals 35 Recurring { interval: Duration }, 36 } 37 38 /// Task scheduler for managing bot operations 39 pub struct Scheduler { 40 /// Active tasks 41 tasks: Arc<Mutex<Vec<Task>>>, 42 43 /// Shutdown signal 44 shutdown: Arc<Mutex<bool>>, 45 } 46 47 impl Scheduler { 48 /// Create a new scheduler 49 pub fn new() -> Self { 50 Self { 51 tasks: Arc::new(Mutex::new(Vec::new())), 52 shutdown: Arc::new(Mutex::new(false)), 53 } 54 } 55 56 /// Schedule a one-time immediate task 57 pub fn schedule_immediate<F>(&self, id: String, func: F) -> Result<()> 58 where 59 F: Fn() -> tokio::task::JoinHandle<()> + Send + Sync + 'static, 60 { 61 let task = Task { 62 id, 63 kind: TaskKind::Immediate, 64 func: Arc::new(func), 65 }; 66 67 self.tasks.lock().push(task); 68 Ok(()) 69 } 70 71 /// Schedule a one-time delayed task 72 pub fn schedule_delayed<F>(&self, id: String, delay: Duration, func: F) -> Result<()> 73 where 74 F: Fn() -> tokio::task::JoinHandle<()> + Send + Sync + 'static, 75 { 76 let task = Task { 77 id, 78 kind: TaskKind::Delayed { delay }, 79 func: Arc::new(func), 80 }; 81 82 self.tasks.lock().push(task); 83 Ok(()) 84 } 85 86 /// Schedule a recurring task 87 pub fn schedule_recurring<F>(&self, id: String, interval_duration: Duration, func: F) -> Result<()> 88 where 89 F: Fn() -> tokio::task::JoinHandle<()> + Send + Sync + 'static, 90 { 91 let task = Task { 92 id, 93 kind: TaskKind::Recurring { interval: interval_duration }, 94 func: Arc::new(func), 95 }; 96 97 self.tasks.lock().push(task); 98 Ok(()) 99 } 100 101 /// Run all scheduled tasks 102 pub async fn run(&self) -> Result<()> { 103 let tasks = self.tasks.lock().clone(); 104 105 for task in tasks { 106 if *self.shutdown.lock() { 107 break; 108 } 109 110 match task.kind { 111 TaskKind::Immediate => { 112 (task.func)(); 113 } 114 TaskKind::Delayed { delay } => { 115 let func = task.func.clone(); 116 let shutdown = self.shutdown.clone(); 117 118 tokio::spawn(async move { 119 sleep(delay).await; 120 if !*shutdown.lock() { 121 func(); 122 } 123 }); 124 } 125 TaskKind::Recurring { interval: interval_duration } => { 126 let func = task.func.clone(); 127 let shutdown = self.shutdown.clone(); 128 129 tokio::spawn(async move { 130 let start = Instant::now() + interval_duration; 131 let mut ticker = interval_at(start, interval_duration); 132 133 loop { 134 ticker.tick().await; 135 136 if *shutdown.lock() { 137 break; 138 } 139 140 func(); 141 } 142 }); 143 } 144 } 145 } 146 147 Ok(()) 148 } 149 150 /// Shutdown the scheduler 151 pub fn shutdown(&self) { 152 *self.shutdown.lock() = true; 153 } 154 155 /// Check if scheduler is shutdown 156 pub fn is_shutdown(&self) -> bool { 157 *self.shutdown.lock() 158 } 159 160 /// Get number of scheduled tasks 161 pub fn task_count(&self) -> usize { 162 self.tasks.lock().len() 163 } 164 } 165 166 impl Default for Scheduler { 167 fn default() -> Self { 168 Self::new() 169 } 170 } 171 172 #[cfg(test)] 173 mod tests { 174 use super::*; 175 use std::sync::atomic::{AtomicUsize, Ordering}; 176 177 #[tokio::test] 178 async fn test_immediate_task() { 179 let scheduler = Scheduler::new(); 180 let counter = Arc::new(AtomicUsize::new(0)); 181 182 let counter_clone = counter.clone(); 183 scheduler 184 .schedule_immediate("test-task".to_string(), move || { 185 let c = counter_clone.clone(); 186 tokio::spawn(async move { 187 c.fetch_add(1, Ordering::SeqCst); 188 }) 189 }) 190 .unwrap(); 191 192 scheduler.run().await.unwrap(); 193 194 tokio::time::sleep(Duration::from_millis(100)).await; 195 assert_eq!(counter.load(Ordering::SeqCst), 1); 196 } 197 198 #[tokio::test] 199 async fn test_delayed_task() { 200 let scheduler = Scheduler::new(); 201 let counter = Arc::new(AtomicUsize::new(0)); 202 203 let counter_clone = counter.clone(); 204 scheduler 205 .schedule_delayed( 206 "delayed-task".to_string(), 207 Duration::from_millis(100), 208 move || { 209 let c = counter_clone.clone(); 210 tokio::spawn(async move { 211 c.fetch_add(1, Ordering::SeqCst); 212 }) 213 }, 214 ) 215 .unwrap(); 216 217 scheduler.run().await.unwrap(); 218 219 // Check before delay 220 assert_eq!(counter.load(Ordering::SeqCst), 0); 221 222 // Check after delay 223 tokio::time::sleep(Duration::from_millis(200)).await; 224 assert_eq!(counter.load(Ordering::SeqCst), 1); 225 } 226 227 #[tokio::test] 228 async fn test_recurring_task() { 229 let scheduler = Scheduler::new(); 230 let counter = Arc::new(AtomicUsize::new(0)); 231 232 let counter_clone = counter.clone(); 233 scheduler 234 .schedule_recurring( 235 "recurring-task".to_string(), 236 Duration::from_millis(50), 237 move || { 238 let c = counter_clone.clone(); 239 tokio::spawn(async move { 240 c.fetch_add(1, Ordering::SeqCst); 241 }) 242 }, 243 ) 244 .unwrap(); 245 246 scheduler.run().await.unwrap(); 247 248 // Wait for multiple intervals 249 tokio::time::sleep(Duration::from_millis(250)).await; 250 251 // Should have executed at least 3 times 252 assert!(counter.load(Ordering::SeqCst) >= 3); 253 254 scheduler.shutdown(); 255 } 256 257 #[test] 258 fn test_task_count() { 259 let scheduler = Scheduler::new(); 260 261 scheduler 262 .schedule_immediate("task1".to_string(), || { 263 tokio::spawn(async {}) 264 }) 265 .unwrap(); 266 267 scheduler 268 .schedule_delayed("task2".to_string(), Duration::from_secs(1), || { 269 tokio::spawn(async {}) 270 }) 271 .unwrap(); 272 273 assert_eq!(scheduler.task_count(), 2); 274 } 275 }