/ fedimint-core / src / util / broadcaststream.rs
broadcaststream.rs
 1  use std::fmt;
 2  use std::future::Future;
 3  use std::pin::Pin;
 4  use std::task::{Context, Poll};
 5  
 6  use futures::{ready, Stream};
 7  use tokio::sync::broadcast::error::RecvError;
 8  use tokio::sync::broadcast::Receiver;
 9  
10  use crate::task::MaybeSend;
11  use crate::util::BoxFuture;
12  
13  /// A wrapper around [`tokio::sync::broadcast::Receiver`] that implements
14  /// [`Stream`].
15  ///
16  /// [`tokio::sync::broadcast::Receiver`]: struct@tokio::sync::broadcast::Receiver
17  /// [`Stream`]: trait@futures::Stream
18  pub struct BroadcastStream<T> {
19      inner: BoxFuture<'static, (Result<T, RecvError>, Receiver<T>)>,
20  }
21  
22  /// An error returned from the inner stream of a [`BroadcastStream`].
23  #[derive(Debug, PartialEq, Eq, Clone)]
24  pub enum BroadcastStreamRecvError {
25      /// The receiver lagged too far behind. Attempting to receive again will
26      /// return the oldest message still retained by the channel.
27      ///
28      /// Includes the number of skipped messages.
29      Lagged(u64),
30  }
31  
32  impl fmt::Display for BroadcastStreamRecvError {
33      fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34          match self {
35              BroadcastStreamRecvError::Lagged(amt) => write!(f, "channel lagged by {amt}"),
36          }
37      }
38  }
39  
40  impl std::error::Error for BroadcastStreamRecvError {}
41  
42  async fn make_future<T: Clone>(mut rx: Receiver<T>) -> (Result<T, RecvError>, Receiver<T>) {
43      let result = rx.recv().await;
44      (result, rx)
45  }
46  
47  impl<T: 'static + Clone + MaybeSend> BroadcastStream<T> {
48      /// Create a new `BroadcastStream`.
49      pub fn new(rx: Receiver<T>) -> Self {
50          Self {
51              inner: Box::pin(make_future(rx)),
52          }
53      }
54  }
55  
56  impl<T: 'static + Clone + MaybeSend> Stream for BroadcastStream<T> {
57      type Item = Result<T, BroadcastStreamRecvError>;
58      fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
59          let (result, rx) = ready!(Pin::new(&mut self.inner).poll(cx));
60          self.inner = Box::pin(make_future(rx));
61          match result {
62              Ok(item) => Poll::Ready(Some(Ok(item))),
63              Err(RecvError::Closed) => Poll::Ready(None),
64              Err(RecvError::Lagged(n)) => {
65                  Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(n))))
66              }
67          }
68      }
69  }
70  
71  impl<T> fmt::Debug for BroadcastStream<T> {
72      fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73          f.debug_struct("BroadcastStream").finish()
74      }
75  }
76  
77  impl<T: 'static + Clone + MaybeSend> From<Receiver<T>> for BroadcastStream<T> {
78      fn from(recv: Receiver<T>) -> Self {
79          Self::new(recv)
80      }
81  }