broadcaststream.rs
1 use std::fmt; 2 use std::future::Future; 3 use std::pin::Pin; 4 use std::task::{Context, Poll}; 5 6 use futures::{ready, Stream}; 7 use tokio::sync::broadcast::error::RecvError; 8 use tokio::sync::broadcast::Receiver; 9 10 use crate::task::MaybeSend; 11 use crate::util::BoxFuture; 12 13 /// A wrapper around [`tokio::sync::broadcast::Receiver`] that implements 14 /// [`Stream`]. 15 /// 16 /// [`tokio::sync::broadcast::Receiver`]: struct@tokio::sync::broadcast::Receiver 17 /// [`Stream`]: trait@futures::Stream 18 pub struct BroadcastStream<T> { 19 inner: BoxFuture<'static, (Result<T, RecvError>, Receiver<T>)>, 20 } 21 22 /// An error returned from the inner stream of a [`BroadcastStream`]. 23 #[derive(Debug, PartialEq, Eq, Clone)] 24 pub enum BroadcastStreamRecvError { 25 /// The receiver lagged too far behind. Attempting to receive again will 26 /// return the oldest message still retained by the channel. 27 /// 28 /// Includes the number of skipped messages. 29 Lagged(u64), 30 } 31 32 impl fmt::Display for BroadcastStreamRecvError { 33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 34 match self { 35 BroadcastStreamRecvError::Lagged(amt) => write!(f, "channel lagged by {amt}"), 36 } 37 } 38 } 39 40 impl std::error::Error for BroadcastStreamRecvError {} 41 42 async fn make_future<T: Clone>(mut rx: Receiver<T>) -> (Result<T, RecvError>, Receiver<T>) { 43 let result = rx.recv().await; 44 (result, rx) 45 } 46 47 impl<T: 'static + Clone + MaybeSend> BroadcastStream<T> { 48 /// Create a new `BroadcastStream`. 49 pub fn new(rx: Receiver<T>) -> Self { 50 Self { 51 inner: Box::pin(make_future(rx)), 52 } 53 } 54 } 55 56 impl<T: 'static + Clone + MaybeSend> Stream for BroadcastStream<T> { 57 type Item = Result<T, BroadcastStreamRecvError>; 58 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { 59 let (result, rx) = ready!(Pin::new(&mut self.inner).poll(cx)); 60 self.inner = Box::pin(make_future(rx)); 61 match result { 62 Ok(item) => Poll::Ready(Some(Ok(item))), 63 Err(RecvError::Closed) => Poll::Ready(None), 64 Err(RecvError::Lagged(n)) => { 65 Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(n)))) 66 } 67 } 68 } 69 } 70 71 impl<T> fmt::Debug for BroadcastStream<T> { 72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 73 f.debug_struct("BroadcastStream").finish() 74 } 75 } 76 77 impl<T: 'static + Clone + MaybeSend> From<Receiver<T>> for BroadcastStream<T> { 78 fn from(recv: Receiver<T>) -> Self { 79 Self::new(recv) 80 } 81 }