/ crates / tor-async-utils / src / sink_try_send.rs
sink_try_send.rs
  1  //! [`SinkTrySend`]
  2  
  3  use std::error::Error;
  4  use std::pin::Pin;
  5  use std::sync::Arc;
  6  
  7  use futures::channel::mpsc;
  8  use futures::Sink;
  9  
 10  use derive_deftly::{define_derive_deftly, Deftly};
 11  use thiserror::Error;
 12  
 13  //---------- principal API ----------
 14  
 15  /// A [`Sink`] with a `try_send` method like [`futures::channel::mpsc::Sender`'s]
 16  pub trait SinkTrySend<T>: Sink<T> {
 17      /// Errors that is not disconnected, or full
 18      type Error: SinkTrySendError;
 19  
 20      /// Try to send a message `msg`
 21      ///
 22      /// If this returns with an error indicating that the stream is full,
 23      /// *No* arrangements will have been made for a wakeup when space becomes available.
 24      ///
 25      /// If the send fails, `item` is dropped.
 26      /// If you need it back, use [`try_send_or_return`](SinkTrySend::try_send_or_return),
 27      ///
 28      /// (When implementing the trait, implement `try_send_or_return`, *not* this method.)
 29      fn try_send(self: Pin<&mut Self>, item: T) -> Result<(), <Self as SinkTrySend<T>>::Error> {
 30          self.try_send_or_return(item)
 31              .map_err(|(error, _item)| error)
 32      }
 33  
 34      /// Try to send a message `msg`
 35      ///
 36      /// Like [`try_send`](SinkTrySend::try_send),
 37      /// but if the send fails, the item is returned.
 38      ///
 39      /// (When implementing the trait, implement this method.)
 40      fn try_send_or_return(
 41          self: Pin<&mut Self>,
 42          item: T,
 43      ) -> Result<(), (<Self as SinkTrySend<T>>::Error, T)>;
 44  }
 45  
 46  /// Error from [`SinkTrySend::try_send`]
 47  ///
 48  /// See also [`ErasedSinkTrySendError`] which can often
 49  /// be usefully used when an implementation of `SinkTrySendError` is needed.
 50  pub trait SinkTrySendError: Error + 'static {
 51      /// The stream was full.
 52      ///
 53      /// *No* arrangements will have been made for a wakeup when space becomes available.
 54      ///
 55      /// Corresponds to [`futures::channel::mpsc::TrySendError::is_full`]
 56      fn is_full(&self) -> bool;
 57  
 58      /// The stream has disconnected
 59      ///
 60      /// Corresponds to [`futures::channel::mpsc::TrySendError::is_disconnected`]
 61      fn is_disconnected(&self) -> bool;
 62  }
 63  
 64  //---------- macrology - this has to come here, ideally all in one go ----------
 65  
 66  #[rustfmt::skip] // rustfmt makes a complete hash of this
 67  define_derive_deftly! {
 68      /// Implements various things which handle `full` and `disconnected`
 69      ///
 70      /// # Generates
 71      ///
 72      ///  * `SinkTrySendError for`ErasedSinkTrySendError`
 73      ///  * `From<E: SinkTrySendError> for`ErasedSinkTrySendError`
 74      ///  * [`handle_mpsc_error`]
 75      ///
 76      /// Use of macros avoids copypaste errors like
 77      /// `fn is_full(..) { self.is_disconnected() }`.
 78      ErasedSinkTrySendError expect items:
 79  
 80      ${defcond PREDICATE vmeta(predicate)}
 81      ${define PREDICATE { $<is_ ${snake_case $vname}> }}
 82  
 83      impl SinkTrySendError for ErasedSinkTrySendError {
 84          $(
 85              ${when PREDICATE}
 86  
 87              fn $PREDICATE(&self) -> bool {
 88                  matches!(self, $vtype)
 89              }
 90          )
 91      }
 92  
 93      impl ErasedSinkTrySendError {
 94          /// Obtain an `ErasedSinkTrySendError` from a concrete `SinkTrySendError`
 95          //
 96          // (Can't be a `From` impl because it conflicts with the identity `From<T> for T`.)
 97          pub fn from<E>(e: E) -> ErasedSinkTrySendError
 98          where E: SinkTrySendError + Send + Sync
 99          {
100              $(
101                  ${when PREDICATE}
102                  if e.$PREDICATE() {
103                      $vtype
104                  } else
105              )
106                  /* else */ {
107                      let e = Arc::new(e);
108                      // Avoid generating a nested ErasedSinkTrySendError.
109                      // Is it *already* an ESTSE (necessarily, then, an `Other`?)
110                      //
111                      // TODO replace this with a call to `downcast_value` from arti!2460
112                      let e2 = e.clone();
113                      match Arc::downcast(e2) {
114                          Ok::<Arc<ErasedSinkTrySendError>, _>(y2) => {
115                              drop(e); // Drop the original
116                              let inner: ErasedSinkTrySendError =
117                                  Arc::into_inner(y2).expect(
118                "somehow we weren't the only owner, despite us just having made an Arc!"
119                                  );
120                              return inner;
121                          }
122                          Err(other_e2) => {
123                              drop(other_e2);
124                              // We need to use e, not other_e2, because Arc::downcast
125                              // returns dyn Any but we need dyn SinkTrySendError.
126                              ErasedSinkTrySendError::Other(e)
127                          },
128                      }
129                  }
130          }
131      }
132  
133      fn handle_mpsc_error<T>(me: mpsc::TrySendError<T>) -> (ErasedSinkTrySendError, T) {
134          let error = $(
135              ${when PREDICATE}
136  
137              if me.$PREDICATE() {
138                  $vtype
139              } else
140          )
141              /* else */ {
142                  $ttype::Other(Arc::new(MpscOtherSinkTrySendError {}))
143              };
144          (error, me.into_inner())
145      }
146  }
147  
148  //---------- helper - erased error ----------
149  
150  /// Type-erased error for [`SinkTrySend::try_send`]
151  ///
152  /// Provided for situations where providing a concrete error type is awkward.
153  ///
154  /// `futures::channel::mpsc::Sender` wants this because when its `try_send` method fails,
155  /// it is not possible to extract both the sent item, and the error!
156  ///
157  /// `tor_memquota::mq_queue::Sender` wants this because the types of the error return
158  /// from `its `try_send` would otherwise be tainted by complex generics,
159  /// including its private `Entry` type.
160  #[derive(Debug, Error, Clone, Deftly)]
161  #[derive_deftly(ErasedSinkTrySendError)]
162  #[allow(clippy::exhaustive_enums)] // Adding other variants would be a breaking change anyway
163  pub enum ErasedSinkTrySendError {
164      /// The stream was full.
165      ///
166      /// *No* arrangements will have been made for a wakeup when space becomes available.
167      ///
168      /// Corresponds to [`SinkTrySendError::is_full`]
169      #[error("stream full (backpressure)")]
170      #[deftly(predicate)]
171      Full,
172  
173      /// The stream has disconnected
174      ///
175      /// Corresponds to [`SinkTrySendError::is_disconnected`]
176      #[error("stream disconnected")]
177      #[deftly(predicate)]
178      Disconnected,
179  
180      /// Something else went wrong
181      #[error("failed to convey data")]
182      Other(#[source] Arc<dyn Error + Send + Sync + 'static>),
183  }
184  
185  //---------- impl for futures::channel::mpsc ----------
186  
187  /// [`mpsc::Sender::try_send`] returned an uncategorisable error
188  ///
189  /// Both `.full()` and `.disconnected()` returned `false`.
190  /// We could call [`mpsc::TrySendError::into_send_error`] but then we don't get the payload.
191  /// In the future, we might replace this type with a type alias for [`mpsc::SendError`].
192  ///
193  /// When returned from `<mpsc::Sender::SinkTrySend::try_send`,
194  /// this is wrapped in [`ErasedSinkTrySendError::Other`].
195  #[derive(Debug, Error)]
196  #[error("mpsc::Sender::try_send returned an error which is neither .full() nor .disconnected()")]
197  #[non_exhaustive]
198  pub struct MpscOtherSinkTrySendError {}
199  
200  impl<T> SinkTrySend<T> for mpsc::Sender<T> {
201      // Ideally we would just use [`mpsc::SendError`].
202      // But `mpsc::TrySendError` lacks an `into_parts` method that gives both `SendError` and `T`.
203      type Error = ErasedSinkTrySendError;
204  
205      fn try_send_or_return(
206          self: Pin<&mut Self>,
207          item: T,
208      ) -> Result<(), (ErasedSinkTrySendError, T)> {
209          let self_: &mut Self = Pin::into_inner(self);
210          mpsc::Sender::try_send(self_, item).map_err(handle_mpsc_error)
211      }
212  }
213  
214  #[cfg(test)]
215  mod test {
216      // @@ begin test lint list maintained by maint/add_warning @@
217      #![allow(clippy::bool_assert_comparison)]
218      #![allow(clippy::clone_on_copy)]
219      #![allow(clippy::dbg_macro)]
220      #![allow(clippy::mixed_attributes_style)]
221      #![allow(clippy::print_stderr)]
222      #![allow(clippy::print_stdout)]
223      #![allow(clippy::single_char_pattern)]
224      #![allow(clippy::unwrap_used)]
225      #![allow(clippy::unchecked_duration_subtraction)]
226      #![allow(clippy::useless_vec)]
227      #![allow(clippy::needless_pass_by_value)]
228      //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
229      #![allow(clippy::arithmetic_side_effects)] // don't mind potential panicking ops in tests
230      #![allow(clippy::useless_format)] // srsly
231  
232      use super::*;
233      use derive_deftly::derive_deftly_adhoc;
234      use tor_error::ErrorReport as _;
235  
236      #[test]
237      fn chk_erased_sink() {
238          #[derive(Error, Clone, Debug, Deftly)]
239          #[error("concrete {is_full} {is_disconnected}")]
240          #[derive_deftly_adhoc]
241          struct Concrete {
242              is_full: bool,
243              is_disconnected: bool,
244          }
245  
246          derive_deftly_adhoc! {
247              Concrete:
248  
249              impl SinkTrySendError for Concrete { $(
250                  fn $fname(&self) -> bool { self.$fname }
251              ) }
252          }
253  
254          for is_full in [false, true] {
255              for is_disconnected in [false, true] {
256                  let c = Concrete {
257                      is_full,
258                      is_disconnected,
259                  };
260                  let e = ErasedSinkTrySendError::from(c.clone());
261                  let e2 = ErasedSinkTrySendError::from(e.clone());
262  
263                  let cs = format!("concrete {is_full} {is_disconnected}");
264  
265                  let es = if is_full {
266                      format!("stream full (backpressure)")
267                  } else if is_disconnected {
268                      format!("stream disconnected")
269                  } else {
270                      format!("failed to convey data: {cs}")
271                  };
272  
273                  assert_eq!(c.report().to_string(), format!("error: {cs}"));
274                  assert_eq!(e.report().to_string(), format!("error: {es}"));
275                  assert_eq!(e2.report().to_string(), format!("error: {es}"));
276              }
277          }
278      }
279  }