stoppable_task.rs
1 /* This file is part of DarkFi (https://dark.fi) 2 * 3 * Copyright (C) 2020-2025 Dyne.org foundation 4 * 5 * This program is free software: you can redistribute it and/or modify 6 * it under the terms of the GNU Affero General Public License as 7 * published by the Free Software Foundation, either version 3 of the 8 * License, or (at your option) any later version. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU Affero General Public License for more details. 14 * 15 * You should have received a copy of the GNU Affero General Public License 16 * along with this program. If not, see <https://www.gnu.org/licenses/>. 17 */ 18 19 use rand::{rngs::OsRng, Rng}; 20 use smol::{ 21 future::{self, Future}, 22 Executor, 23 }; 24 use std::sync::Arc; 25 use tracing::trace; 26 27 use super::CondVar; 28 29 pub type StoppableTaskPtr = Arc<StoppableTask>; 30 31 pub struct StoppableTask { 32 /// Used to signal to the main running process that it should stop. 33 signal: CondVar, 34 /// When we call `stop()`, we wait until the process is finished. This is used to prevent 35 /// `stop()` from exiting until the task has closed. 36 barrier: CondVar, 37 38 /// Used so we can keep StoppableTask in HashMap/HashSet 39 pub task_id: u32, 40 } 41 42 /// A task that can be prematurely stopped at any time. 43 /// 44 /// ```rust 45 /// let task = StoppableTask::new(); 46 /// task.clone().start( 47 /// my_method(), 48 /// |result| self_.handle_stop(result), 49 /// Error::MyStopError, 50 /// executor, 51 /// ); 52 /// ``` 53 /// 54 /// Then at any time we can call `task.stop()` to close the task. 55 impl StoppableTask { 56 pub fn new() -> Arc<Self> { 57 Arc::new(Self { signal: CondVar::new(), barrier: CondVar::new(), task_id: OsRng.gen() }) 58 } 59 60 /// Starts the task. 61 /// 62 /// * `main` is a function of the type `async fn foo() -> ()` 63 /// * `stop_handler` is a function of the type `async fn handle_stop(result: Result<()>) -> ()` 64 /// * `stop_value` is the Error code passed to `stop_handler` when `task.stop()` is called 65 pub fn start<'a, MainFut, StopFut, StopFn, Error>( 66 self: Arc<Self>, 67 main: MainFut, 68 stop_handler: StopFn, 69 stop_value: Error, 70 executor: Arc<Executor<'a>>, 71 ) where 72 MainFut: Future<Output = std::result::Result<(), Error>> + Send + 'a, 73 StopFut: Future<Output = ()> + Send, 74 StopFn: FnOnce(std::result::Result<(), Error>) -> StopFut + Send + 'a, 75 Error: std::error::Error + Send + 'a, 76 { 77 // NOTE: we could send the error code from stop() instead of having it specified in start() 78 trace!(target: "system::StoppableTask", "Starting task {}", self.task_id); 79 // Allow stopping and starting task again. 80 // NOTE: maybe we should disallow this with a panic? 81 self.signal.reset(); 82 self.barrier.reset(); 83 84 executor 85 .spawn(async move { 86 // Task which waits for a stop signal 87 let stop_fut = async { 88 self.signal.wait().await; 89 trace!( 90 target: "system::StoppableTask", 91 "Stop signal received for task {}", 92 self.task_id 93 ); 94 Err(stop_value) 95 }; 96 97 // Wait on our main task or stop task - whichever finishes first 98 let result = future::or(main, stop_fut).await; 99 100 trace!( 101 target: "system::StoppableTask", 102 "Closing task {} with result: {:?}", 103 self.task_id, 104 result 105 ); 106 107 stop_handler(result).await; 108 // Allow `stop()` to finish 109 self.barrier.notify(); 110 }) 111 .detach(); 112 } 113 114 /// Stops the task. On completion, guarantees the process has stopped. 115 /// Can be called multiple times. After the first call, this does nothing. 116 pub async fn stop(&self) { 117 trace!(target: "system::StoppableTask", "Stopping task {}", self.task_id); 118 self.signal.notify(); 119 self.barrier.wait().await; 120 trace!(target: "system::StoppableTask", "Stopped task {}", self.task_id); 121 } 122 123 /// Sends a stop signal and returns immediately. Doesn't guarantee the task 124 /// stopped on completion. 125 pub fn stop_nowait(&self) { 126 trace!(target: "system::StoppableTask", "Stopping task (nowait) {}", self.task_id); 127 self.signal.notify(); 128 } 129 } 130 131 impl std::hash::Hash for StoppableTask { 132 fn hash<H>(&self, state: &mut H) 133 where 134 H: std::hash::Hasher, 135 { 136 self.task_id.hash(state); 137 } 138 } 139 140 impl std::cmp::PartialEq for StoppableTask { 141 fn eq(&self, other: &Self) -> bool { 142 self.task_id == other.task_id 143 } 144 } 145 146 impl std::cmp::Eq for StoppableTask {} 147 148 impl Drop for StoppableTask { 149 fn drop(&mut self) { 150 self.stop_nowait() 151 } 152 } 153 154 #[cfg(test)] 155 mod tests { 156 use super::*; 157 use crate::{ 158 error::Error, 159 system::sleep_forever, 160 util::logger::{setup_test_logger, Level}, 161 }; 162 use tracing::warn; 163 164 #[test] 165 fn stoppit_mom() { 166 // We check this error so we can execute same file tests in parallel, 167 // otherwise second one fails to init logger here. 168 if setup_test_logger( 169 &["async_io", "polling"], 170 false, 171 //Level::Info, 172 //Level::Verbose, 173 //Level::Debug 174 Level::Trace, 175 ) 176 .is_err() 177 { 178 warn!(target: "test_harness", "Logger already initialized"); 179 } 180 181 let executor = Arc::new(Executor::new()); 182 let executor_ = executor.clone(); 183 smol::block_on(executor.run(async move { 184 let task = StoppableTask::new(); 185 task.clone().start( 186 // Main process is an infinite loop 187 async { 188 sleep_forever().await; 189 unreachable!() 190 }, 191 // Handle stop 192 |result| async move { 193 assert!(matches!(result, Err(Error::DetachedTaskStopped))); 194 }, 195 Error::DetachedTaskStopped, 196 executor_, 197 ); 198 task.stop().await; 199 })) 200 } 201 }