/ src / system / stoppable_task.rs
stoppable_task.rs
  1  /* This file is part of DarkFi (https://dark.fi)
  2   *
  3   * Copyright (C) 2020-2025 Dyne.org foundation
  4   *
  5   * This program is free software: you can redistribute it and/or modify
  6   * it under the terms of the GNU Affero General Public License as
  7   * published by the Free Software Foundation, either version 3 of the
  8   * License, or (at your option) any later version.
  9   *
 10   * This program is distributed in the hope that it will be useful,
 11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
 12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 13   * GNU Affero General Public License for more details.
 14   *
 15   * You should have received a copy of the GNU Affero General Public License
 16   * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 17   */
 18  
 19  use rand::{rngs::OsRng, Rng};
 20  use smol::{
 21      future::{self, Future},
 22      Executor,
 23  };
 24  use std::sync::Arc;
 25  use tracing::trace;
 26  
 27  use super::CondVar;
 28  
 29  pub type StoppableTaskPtr = Arc<StoppableTask>;
 30  
 31  pub struct StoppableTask {
 32      /// Used to signal to the main running process that it should stop.
 33      signal: CondVar,
 34      /// When we call `stop()`, we wait until the process is finished. This is used to prevent
 35      /// `stop()` from exiting until the task has closed.
 36      barrier: CondVar,
 37  
 38      /// Used so we can keep StoppableTask in HashMap/HashSet
 39      pub task_id: u32,
 40  }
 41  
 42  /// A task that can be prematurely stopped at any time.
 43  ///
 44  /// ```rust
 45  ///     let task = StoppableTask::new();
 46  ///     task.clone().start(
 47  ///         my_method(),
 48  ///         |result| self_.handle_stop(result),
 49  ///         Error::MyStopError,
 50  ///         executor,
 51  ///     );
 52  /// ```
 53  ///
 54  /// Then at any time we can call `task.stop()` to close the task.
 55  impl StoppableTask {
 56      pub fn new() -> Arc<Self> {
 57          Arc::new(Self { signal: CondVar::new(), barrier: CondVar::new(), task_id: OsRng.gen() })
 58      }
 59  
 60      /// Starts the task.
 61      ///
 62      /// * `main` is a function of the type `async fn foo() -> ()`
 63      /// * `stop_handler` is a function of the type `async fn handle_stop(result: Result<()>) -> ()`
 64      /// * `stop_value` is the Error code passed to `stop_handler` when `task.stop()` is called
 65      pub fn start<'a, MainFut, StopFut, StopFn, Error>(
 66          self: Arc<Self>,
 67          main: MainFut,
 68          stop_handler: StopFn,
 69          stop_value: Error,
 70          executor: Arc<Executor<'a>>,
 71      ) where
 72          MainFut: Future<Output = std::result::Result<(), Error>> + Send + 'a,
 73          StopFut: Future<Output = ()> + Send,
 74          StopFn: FnOnce(std::result::Result<(), Error>) -> StopFut + Send + 'a,
 75          Error: std::error::Error + Send + 'a,
 76      {
 77          // NOTE: we could send the error code from stop() instead of having it specified in start()
 78          trace!(target: "system::StoppableTask", "Starting task {}", self.task_id);
 79          // Allow stopping and starting task again.
 80          // NOTE: maybe we should disallow this with a panic?
 81          self.signal.reset();
 82          self.barrier.reset();
 83  
 84          executor
 85              .spawn(async move {
 86                  // Task which waits for a stop signal
 87                  let stop_fut = async {
 88                      self.signal.wait().await;
 89                      trace!(
 90                          target: "system::StoppableTask",
 91                          "Stop signal received for task {}",
 92                          self.task_id
 93                      );
 94                      Err(stop_value)
 95                  };
 96  
 97                  // Wait on our main task or stop task - whichever finishes first
 98                  let result = future::or(main, stop_fut).await;
 99  
100                  trace!(
101                      target: "system::StoppableTask",
102                      "Closing task {} with result: {:?}",
103                      self.task_id,
104                      result
105                  );
106  
107                  stop_handler(result).await;
108                  // Allow `stop()` to finish
109                  self.barrier.notify();
110              })
111              .detach();
112      }
113  
114      /// Stops the task. On completion, guarantees the process has stopped.
115      /// Can be called multiple times. After the first call, this does nothing.
116      pub async fn stop(&self) {
117          trace!(target: "system::StoppableTask", "Stopping task {}", self.task_id);
118          self.signal.notify();
119          self.barrier.wait().await;
120          trace!(target: "system::StoppableTask", "Stopped task {}", self.task_id);
121      }
122  
123      /// Sends a stop signal and returns immediately. Doesn't guarantee the task
124      /// stopped on completion.
125      pub fn stop_nowait(&self) {
126          trace!(target: "system::StoppableTask", "Stopping task (nowait) {}", self.task_id);
127          self.signal.notify();
128      }
129  }
130  
131  impl std::hash::Hash for StoppableTask {
132      fn hash<H>(&self, state: &mut H)
133      where
134          H: std::hash::Hasher,
135      {
136          self.task_id.hash(state);
137      }
138  }
139  
140  impl std::cmp::PartialEq for StoppableTask {
141      fn eq(&self, other: &Self) -> bool {
142          self.task_id == other.task_id
143      }
144  }
145  
146  impl std::cmp::Eq for StoppableTask {}
147  
148  impl Drop for StoppableTask {
149      fn drop(&mut self) {
150          self.stop_nowait()
151      }
152  }
153  
154  #[cfg(test)]
155  mod tests {
156      use super::*;
157      use crate::{
158          error::Error,
159          system::sleep_forever,
160          util::logger::{setup_test_logger, Level},
161      };
162      use tracing::warn;
163  
164      #[test]
165      fn stoppit_mom() {
166          // We check this error so we can execute same file tests in parallel,
167          // otherwise second one fails to init logger here.
168          if setup_test_logger(
169              &["async_io", "polling"],
170              false,
171              //Level::Info,
172              //Level::Verbose,
173              //Level::Debug
174              Level::Trace,
175          )
176          .is_err()
177          {
178              warn!(target: "test_harness", "Logger already initialized");
179          }
180  
181          let executor = Arc::new(Executor::new());
182          let executor_ = executor.clone();
183          smol::block_on(executor.run(async move {
184              let task = StoppableTask::new();
185              task.clone().start(
186                  // Main process is an infinite loop
187                  async {
188                      sleep_forever().await;
189                      unreachable!()
190                  },
191                  // Handle stop
192                  |result| async move {
193                      assert!(matches!(result, Err(Error::DetachedTaskStopped)));
194                  },
195                  Error::DetachedTaskStopped,
196                  executor_,
197              );
198              task.stop().await;
199          }))
200      }
201  }