/ crates / tor-async-utils / src / stream_peek.rs
stream_peek.rs
  1  //! [`StreamUnobtrusivePeeker`]
  2  //!
  3  //! The memory tracker needs a way to look at the next item of a stream
  4  //! (if there is one, or there can immediately be one),
  5  //! *without* getting involved with the async tasks.
  6  
  7  use educe::Educe;
  8  use futures::stream::FusedStream;
  9  use futures::task::noop_waker_ref;
 10  use futures::Stream;
 11  use pin_project::pin_project;
 12  
 13  use crate::peekable_stream::{PeekableStream, UnobtrusivePeekableStream};
 14  
 15  use std::fmt::Debug;
 16  use std::future::Future;
 17  use std::pin::Pin;
 18  use std::task::{Context, Poll, Poll::*, Waker};
 19  
 20  /// Wraps [`Stream`] and provides `\[poll_]peek` and `unobtrusive_peek`
 21  ///
 22  /// [`unobtrusive_peek`](StreamUnobtrusivePeeker::unobtrusive_peek)
 23  /// is callable in sync contexts, outside the reading task.
 24  ///
 25  /// Like [`futures::stream::Peekable`],
 26  /// this has an async `peek` method, and `poll_peek`,
 27  /// for use from the task that is also reading (via the [`Stream`] impl).
 28  /// But, that type doesn't have `unobtrusive_peek`.
 29  ///
 30  /// One way to conceptualise this is that `StreamUnobtrusivePeeker` is dual-ported:
 31  /// the two sets of APIs, while provided on the same type,
 32  /// are typically called from different contexts.
 33  //
 34  // It wasn't particularly easy to think of a good name for this type.
 35  // We intend, probably:
 36  //     struct StreamUnobtrusivePeeker
 37  //     trait StreamUnobtrusivePeekable
 38  //     trait StreamPeekable (impl for StreamUnobtrusivePeeker and futures::stream::Peekable)
 39  //
 40  // Searching a thesaurus produced these suggested words:
 41  //     unobtrusive subtle discreet inconspicuous cautious furtive
 42  // Asking in MR review also suggested
 43  //     quick
 44  //
 45  // It's awkward because "peek" already has significant connotations of not disturbing things.
 46  // That's why it was used in Iterator::peek.
 47  //
 48  // But when we translate this into async context,
 49  // we have the poll_peek method on futures::stream::Peekable,
 50  // which doesn't remove items from the stream,
 51  // but *does* *wait* for items and therefore engages with the async context,
 52  // and therefore involves *mutating* the Peekable (to store the new waker).
 53  //
 54  // Now we end up needing a word for an *even less disturbing* kind of interaction.
 55  //
 56  // `quick` (and synonyms) isn't quite right either because it's not necessarily faster,
 57  // and certainly not more performant.
 58  #[derive(Debug)]
 59  #[pin_project(project = PeekerProj)]
 60  pub struct StreamUnobtrusivePeeker<S: Stream> {
 61      /// An item that we have peeked.
 62      ///
 63      /// (If we peeked EOF, that's represented by `None` in inner.)
 64      buffered: Option<S::Item>,
 65  
 66      /// The `Waker` from the last time we were polled and returned `Pending`
 67      ///
 68      /// "polled" includes any of our `poll_` methods
 69      /// but *not* `unobtrusive_peek`.
 70      ///
 71      /// `None` if we haven't been polled, or the last poll returned `Ready`.
 72      poll_waker: Option<Waker>,
 73  
 74      /// The inner stream
 75      ///
 76      /// `None if it has yielded `None` meaning EOF.  We don't require S: FusedStream.
 77      #[pin]
 78      inner: Option<S>,
 79  }
 80  
 81  impl<S: Stream> StreamUnobtrusivePeeker<S> {
 82      /// Create a new `StreamUnobtrusivePeeker` from a `Stream`
 83      pub fn new(inner: S) -> Self {
 84          StreamUnobtrusivePeeker {
 85              buffered: None,
 86              poll_waker: None,
 87              inner: Some(inner),
 88          }
 89      }
 90  }
 91  
 92  impl<S: Stream> UnobtrusivePeekableStream for StreamUnobtrusivePeeker<S> {
 93      fn unobtrusive_peek_mut<'s>(mut self: Pin<&'s mut Self>) -> Option<&'s mut S::Item> {
 94          #[allow(clippy::question_mark)] // We use explicit control flow here for clarity
 95          if self.as_mut().project().buffered.is_none() {
 96              // We don't have a buffered item, but the stream may have an item available.
 97              // We must poll it to find out.
 98              //
 99              // We need to pass a Context to poll_next.
100              // inner may store this context, replacing one provided via poll_*.
101              //
102              // Despite that, we need to make sure that wakeups will happen as expected.
103              // To achieve this we have retained a copy of the caller's Waker.
104              //
105              // When a future or stream returns Pending, it proposes to wake `waker`
106              // when it wants to be polled again.
107              //
108              // We uphold that promise by
109              // - only returning Pending from our poll methods if inner also returned Pending
110              // - when one of our poll methods returns Pending, saving the caller-supplied
111              //   waker, so that we can make the intermediate poll call here.
112              //
113              // If the inner poll returns Ready, inner no longer guarantees to wake anyone.
114              // In principle, if our user is waiting (we returned Pending),
115              // then inner ought to have called `wake` on the caller's `Waker`.
116              // But I don't think we can guarantee that an executor won't defer a wakeup,
117              // and respond to a dropped Waker by cancelling that wakeup;
118              // or to put it another way, the wakeup might be "in flight" on entry,
119              // but the call to inner's poll_next returning Ready
120              // might somehow "cancel" the wakeup.
121              //
122              // So just to be sure, if we get a Ready here, we wake the stored waker.
123  
124              let mut self_ = self.as_mut().project();
125  
126              let Some(inner) = self_.inner.as_mut().as_pin_mut() else {
127                  return None;
128              };
129  
130              let waker = if let Some(waker) = self_.poll_waker.as_ref() {
131                  waker
132              } else {
133                  noop_waker_ref()
134              };
135  
136              match inner.poll_next(&mut Context::from_waker(waker)) {
137                  Pending => {}
138                  Ready(item_or_eof) => {
139                      if let Some(waker) = self_.poll_waker.take() {
140                          waker.wake();
141                      }
142                      match item_or_eof {
143                          None => self_.inner.set(None),
144                          Some(item) => *self_.buffered = Some(item),
145                      }
146                  }
147              };
148          }
149  
150          self.project().buffered.as_mut()
151      }
152  }
153  
154  impl<S: Stream> PeekableStream for StreamUnobtrusivePeeker<S> {
155      fn poll_peek<'s>(self: Pin<&'s mut Self>, cx: &mut Context<'_>) -> Poll<Option<&'s S::Item>> {
156          self.impl_poll_next_or_peek(cx, |buffered| buffered.as_ref())
157      }
158  
159      fn poll_peek_mut<'s>(
160          self: Pin<&'s mut Self>,
161          cx: &mut Context<'_>,
162      ) -> Poll<Option<&'s mut S::Item>> {
163          self.impl_poll_next_or_peek(cx, |buffered| buffered.as_mut())
164      }
165  }
166  
167  impl<S: Stream> StreamUnobtrusivePeeker<S> {
168      /// Implementation of `poll_{peek,next}`
169      ///
170      /// This takes care of
171      ///   * examining the state of our buffer, and polling inner if needed
172      ///   * ensuring that we store a waker, if needed
173      ///   * dealing with some borrowck awkwardness
174      ///
175      /// The `Ready` value is always calculated from `buffer`.
176      /// `return_value_obtainer` is called only if we are going to return `Ready`.
177      /// It's given `buffer` and should either:
178      ///   * [`take`](Option::take) the contained value (for `poll_next`)
179      ///   * return a reference using [`Option::as_ref`] (for `poll_peek`)
180      fn impl_poll_next_or_peek<'s, R: 's>(
181          self: Pin<&'s mut Self>,
182          cx: &mut Context<'_>,
183          return_value_obtainer: impl FnOnce(&'s mut Option<S::Item>) -> Option<R>,
184      ) -> Poll<Option<R>> {
185          let mut self_ = self.project();
186          let r = Self::next_or_peek_inner(&mut self_, cx);
187          let r = r.map(|()| return_value_obtainer(self_.buffered));
188          Self::return_from_poll(self_.poll_waker, cx, r)
189      }
190  
191      /// Try to populate `buffer`, and calculate if we're `Ready`
192      ///
193      /// Returns `Ready` iff `poll_next` or `poll_peek` should return `Ready`.
194      /// The actual `Ready` value (an `Option`) will be calculated later.
195      fn next_or_peek_inner(self_: &mut PeekerProj<S>, cx: &mut Context<'_>) -> Poll<()> {
196          if let Some(_item) = self_.buffered.as_ref() {
197              // `return_value_obtainer` will find `Some` in `buffered`;
198              // overall, we'll return `Ready(Some(..))`.
199              return Ready(());
200          }
201          let Some(inner) = self_.inner.as_mut().as_pin_mut() else {
202              // `return_value_obtainer` will find `None` in `buffered`;
203              // overall, we'll return `Ready(None)`, ie EOF.
204              return Ready(());
205          };
206          match inner.poll_next(cx) {
207              Ready(None) => {
208                  self_.inner.set(None);
209                  // `buffered` is `None`, still.
210                  // overall, we'll return `Ready(None)`, ie EOF.
211                  Ready(())
212              }
213              Ready(Some(item)) => {
214                  *self_.buffered = Some(item);
215                  // return_value_obtainer` will find `Some` in `buffered`
216                  Ready(())
217              }
218              Pending => {
219                  // `return_value_obtainer` won't be called.
220                  // overall, we'll return Pending
221                  Pending
222              }
223          }
224      }
225  
226      /// Wait for an item to be ready, and then inspect it
227      ///
228      /// Equivalent to [`futures::stream::Peekable::peek`].
229      ///
230      /// # Tasks, waking, and calling context
231      ///
232      /// This should be called by the task that is reading from the stream.
233      /// If it is called by another task, the reading task would miss notifications.
234      //
235      // This ^ docs section is triplicated for poll_peek, poll_peek_mut, and peek
236      //
237      // TODO this should be a method on the `PeekableStream` trait? Or a
238      // `PeekableStreamExt` trait?
239      // TODO should there be peek_mut ?
240      #[allow(dead_code)] // TODO remove this allow if and when we make this module public
241      pub fn peek(self: Pin<&mut Self>) -> PeekFuture<Self> {
242          PeekFuture { peeker: Some(self) }
243      }
244  
245      /// Return from a `poll_*` function, setting the stored waker appropriately
246      ///
247      /// Our `poll` functions always use this.
248      /// The rule is that if a future returns `Pending`, it has stored the waker.
249      fn return_from_poll<R>(
250          poll_waker: &mut Option<Waker>,
251          cx: &mut Context<'_>,
252          r: Poll<R>,
253      ) -> Poll<R> {
254          *poll_waker = match &r {
255              Ready(_) => {
256                  // No need to wake this task up any more.
257                  None
258              }
259              Pending => {
260                  // try_peek must use the same waker to poll later
261                  Some(cx.waker().clone())
262              }
263          };
264          r
265      }
266  
267      /// Obtain a raw reference to the inner stream
268      ///
269      /// ### Correctness!
270      ///
271      /// This method must be used with care!
272      /// Whatever you do mustn't interfere with polling and peeking.
273      /// Careless use can result in wrong behaviour including deadlocks.
274      pub fn as_raw_inner_pin_mut<'s>(self: Pin<&'s mut Self>) -> Option<Pin<&'s mut S>> {
275          self.project().inner.as_pin_mut()
276      }
277  }
278  
279  impl<S: Stream> Stream for StreamUnobtrusivePeeker<S> {
280      type Item = S::Item;
281  
282      fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
283          self.impl_poll_next_or_peek(cx, |buffered| buffered.take())
284      }
285  
286      fn size_hint(&self) -> (usize, Option<usize>) {
287          let buf = self.buffered.iter().count();
288          let (imin, imax) = match &self.inner {
289              Some(inner) => inner.size_hint(),
290              None => (0, Some(0)),
291          };
292          (imin + buf, imax.and_then(|imap| imap.checked_add(buf)))
293      }
294  }
295  
296  impl<S: Stream> FusedStream for StreamUnobtrusivePeeker<S> {
297      fn is_terminated(&self) -> bool {
298          self.buffered.is_none() && self.inner.is_none()
299      }
300  }
301  
302  /// Future from [`StreamUnobtrusivePeeker::peek`]
303  // TODO: Move to tor_async_utils::peekable_stream.
304  #[derive(Educe)]
305  #[educe(Debug(bound("S: Debug")))]
306  #[must_use = "peek() return a Future, which does nothing unless awaited"]
307  pub struct PeekFuture<'s, S> {
308      /// The underlying stream.
309      ///
310      /// `Some` until we have returned `Ready`, then `None`.
311      /// See comment in `poll`.
312      peeker: Option<Pin<&'s mut S>>,
313  }
314  
315  impl<'s, S: PeekableStream> PeekFuture<'s, S> {
316      /// Create a new `PeekFuture`.
317      // TODO: replace with a trait method.
318      pub fn new(stream: Pin<&'s mut S>) -> Self {
319          Self {
320              peeker: Some(stream),
321          }
322      }
323  }
324  
325  impl<'s, S: PeekableStream> Future for PeekFuture<'s, S> {
326      type Output = Option<&'s S::Item>;
327      fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<&'s S::Item>> {
328          let self_ = self.get_mut();
329          let peeker = self_
330              .peeker
331              .as_mut()
332              .expect("PeekFuture polled after Ready");
333          match peeker.as_mut().poll_peek(cx) {
334              Pending => return Pending,
335              Ready(_y) => {
336                  // Ideally we would have returned `y` here, but it's borrowed from PeekFuture
337                  // not from the original StreamUnobtrusivePeeker, and there's no way
338                  // to get a value with the right lifetime.  (In non-async code,
339                  // this is usually handled by the special magic for reborrowing &mut.)
340                  //
341                  // So we must redo the poll, but this time consuming `peeker`,
342                  // which gets us the right lifetime.  That's why it has to be `Option`.
343                  // Because we own &mut ... Self, we know that repeating the poll
344                  // gives the same answer.
345              }
346          }
347          let peeker = self_.peeker.take().expect("it was Some before!");
348          let r = peeker.poll_peek(cx);
349          assert!(r.is_ready(), "it was Ready before!");
350          r
351      }
352  }
353  
354  #[cfg(test)]
355  mod test {
356      // @@ begin test lint list maintained by maint/add_warning @@
357      #![allow(clippy::bool_assert_comparison)]
358      #![allow(clippy::clone_on_copy)]
359      #![allow(clippy::dbg_macro)]
360      #![allow(clippy::mixed_attributes_style)]
361      #![allow(clippy::print_stderr)]
362      #![allow(clippy::print_stdout)]
363      #![allow(clippy::single_char_pattern)]
364      #![allow(clippy::unwrap_used)]
365      #![allow(clippy::unchecked_duration_subtraction)]
366      #![allow(clippy::useless_vec)]
367      #![allow(clippy::needless_pass_by_value)]
368      //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
369  
370      use super::*;
371      use futures::channel::mpsc;
372      use futures::{SinkExt as _, StreamExt as _};
373      use std::pin::pin;
374      use std::sync::{Arc, Mutex};
375      use std::time::Duration;
376      use tor_rtcompat::SleepProvider as _;
377      use tor_rtmock::MockRuntime;
378  
379      fn ms(ms: u64) -> Duration {
380          Duration::from_millis(ms)
381      }
382  
383      #[test]
384      fn wakeups() {
385          MockRuntime::test_with_various(|rt| async move {
386              let (mut tx, rx) = mpsc::unbounded();
387              let ended = Arc::new(Mutex::new(false));
388  
389              rt.spawn_identified("rxr", {
390                  let rt = rt.clone();
391                  let ended = ended.clone();
392  
393                  async move {
394                      let rx = StreamUnobtrusivePeeker::new(rx);
395                      let mut rx = pin!(rx);
396  
397                      let mut next = 0;
398                      loop {
399                          rt.sleep(ms(50)).await;
400                          eprintln!("rx peek... ");
401                          let peeked = rx.as_mut().unobtrusive_peek_mut();
402                          eprintln!("rx peeked {peeked:?}");
403  
404                          if let Some(peeked) = peeked {
405                              assert_eq!(*peeked, next);
406                          }
407  
408                          rt.sleep(ms(50)).await;
409                          eprintln!("rx next... ");
410                          let eaten = rx.next().await;
411                          eprintln!("rx eaten {eaten:?}");
412                          if let Some(eaten) = eaten {
413                              assert_eq!(eaten, next);
414                              next += 1;
415                          } else {
416                              break;
417                          }
418                      }
419  
420                      *ended.lock().unwrap() = true;
421                      eprintln!("rx ended");
422                  }
423              });
424  
425              rt.spawn_identified("tx", {
426                  let rt = rt.clone();
427  
428                  async move {
429                      let mut numbers = 0..;
430                      for wait in [125, 1, 125, 45, 1, 1, 1, 1000, 20, 1, 125, 125, 1000] {
431                          eprintln!("tx sleep {wait}");
432                          rt.sleep(ms(wait)).await;
433                          let num = numbers.next().unwrap();
434                          eprintln!("tx sending {num}");
435                          tx.send(num).await.unwrap();
436                      }
437  
438                      // This schedule arranges that, when we send EOF, the rx task
439                      // has *peeked* rather than *polled* most recently,
440                      // demonstrating that we can wake up the subsequent poll on EOF too.
441                      eprintln!("tx final #1");
442                      rt.sleep(ms(75)).await;
443                      eprintln!("tx EOF");
444                      drop(tx);
445                      eprintln!("tx final #2");
446                      rt.sleep(ms(10)).await;
447                      assert!(!*ended.lock().unwrap());
448                      eprintln!("tx final #3");
449                      rt.sleep(ms(50)).await;
450                      eprintln!("tx final #4");
451                      assert!(*ended.lock().unwrap());
452                  }
453              });
454  
455              rt.advance_until_stalled().await;
456          });
457      }
458  
459      #[test]
460      fn poll_peek_paths() {
461          MockRuntime::test_with_various(|rt| async move {
462              let (mut tx, rx) = mpsc::unbounded();
463              let ended = Arc::new(Mutex::new(false));
464  
465              rt.spawn_identified("rxr", {
466                  let rt = rt.clone();
467                  let ended = ended.clone();
468  
469                  async move {
470                      let rx = StreamUnobtrusivePeeker::new(rx);
471                      let mut rx = pin!(rx);
472  
473                      while let Some(peeked) = rx.as_mut().peek().await.copied() {
474                          eprintln!("rx peeked {peeked}");
475                          let eaten = rx.next().await.unwrap();
476                          eprintln!("rx eaten  {eaten}");
477                          assert_eq!(peeked, eaten);
478                          rt.sleep(ms(10)).await;
479                          eprintln!("rx slept, peeking");
480                      }
481                      *ended.lock().unwrap() = true;
482                      eprintln!("rx ended");
483                  }
484              });
485  
486              rt.spawn_identified("tx", {
487                  let rt = rt.clone();
488  
489                  async move {
490                      let mut numbers = 0..;
491  
492                      // macro because we don't have proper async closures
493                      macro_rules! send { {} => {
494                          let num = numbers.next().unwrap();
495                          eprintln!("tx send   {num}");
496                          tx.send(num).await.unwrap();
497                      } }
498  
499                      eprintln!("tx starting");
500                      rt.sleep(ms(100)).await;
501                      send!();
502                      rt.sleep(ms(100)).await;
503                      send!();
504                      send!();
505                      rt.sleep(ms(100)).await;
506                      eprintln!("tx dropping");
507                      drop(tx);
508                      rt.sleep(ms(5)).await;
509                      eprintln!("tx ending");
510                      assert!(*ended.lock().unwrap());
511                  }
512              });
513  
514              rt.advance_until_stalled().await;
515          });
516      }
517  }