stream_peek.rs
1 //! [`StreamUnobtrusivePeeker`] 2 //! 3 //! The memory tracker needs a way to look at the next item of a stream 4 //! (if there is one, or there can immediately be one), 5 //! *without* getting involved with the async tasks. 6 7 use educe::Educe; 8 use futures::stream::FusedStream; 9 use futures::task::noop_waker_ref; 10 use futures::Stream; 11 use pin_project::pin_project; 12 13 use crate::peekable_stream::{PeekableStream, UnobtrusivePeekableStream}; 14 15 use std::fmt::Debug; 16 use std::future::Future; 17 use std::pin::Pin; 18 use std::task::{Context, Poll, Poll::*, Waker}; 19 20 /// Wraps [`Stream`] and provides `\[poll_]peek` and `unobtrusive_peek` 21 /// 22 /// [`unobtrusive_peek`](StreamUnobtrusivePeeker::unobtrusive_peek) 23 /// is callable in sync contexts, outside the reading task. 24 /// 25 /// Like [`futures::stream::Peekable`], 26 /// this has an async `peek` method, and `poll_peek`, 27 /// for use from the task that is also reading (via the [`Stream`] impl). 28 /// But, that type doesn't have `unobtrusive_peek`. 29 /// 30 /// One way to conceptualise this is that `StreamUnobtrusivePeeker` is dual-ported: 31 /// the two sets of APIs, while provided on the same type, 32 /// are typically called from different contexts. 33 // 34 // It wasn't particularly easy to think of a good name for this type. 35 // We intend, probably: 36 // struct StreamUnobtrusivePeeker 37 // trait StreamUnobtrusivePeekable 38 // trait StreamPeekable (impl for StreamUnobtrusivePeeker and futures::stream::Peekable) 39 // 40 // Searching a thesaurus produced these suggested words: 41 // unobtrusive subtle discreet inconspicuous cautious furtive 42 // Asking in MR review also suggested 43 // quick 44 // 45 // It's awkward because "peek" already has significant connotations of not disturbing things. 46 // That's why it was used in Iterator::peek. 47 // 48 // But when we translate this into async context, 49 // we have the poll_peek method on futures::stream::Peekable, 50 // which doesn't remove items from the stream, 51 // but *does* *wait* for items and therefore engages with the async context, 52 // and therefore involves *mutating* the Peekable (to store the new waker). 53 // 54 // Now we end up needing a word for an *even less disturbing* kind of interaction. 55 // 56 // `quick` (and synonyms) isn't quite right either because it's not necessarily faster, 57 // and certainly not more performant. 58 #[derive(Debug)] 59 #[pin_project(project = PeekerProj)] 60 pub struct StreamUnobtrusivePeeker<S: Stream> { 61 /// An item that we have peeked. 62 /// 63 /// (If we peeked EOF, that's represented by `None` in inner.) 64 buffered: Option<S::Item>, 65 66 /// The `Waker` from the last time we were polled and returned `Pending` 67 /// 68 /// "polled" includes any of our `poll_` methods 69 /// but *not* `unobtrusive_peek`. 70 /// 71 /// `None` if we haven't been polled, or the last poll returned `Ready`. 72 poll_waker: Option<Waker>, 73 74 /// The inner stream 75 /// 76 /// `None if it has yielded `None` meaning EOF. We don't require S: FusedStream. 77 #[pin] 78 inner: Option<S>, 79 } 80 81 impl<S: Stream> StreamUnobtrusivePeeker<S> { 82 /// Create a new `StreamUnobtrusivePeeker` from a `Stream` 83 pub fn new(inner: S) -> Self { 84 StreamUnobtrusivePeeker { 85 buffered: None, 86 poll_waker: None, 87 inner: Some(inner), 88 } 89 } 90 } 91 92 impl<S: Stream> UnobtrusivePeekableStream for StreamUnobtrusivePeeker<S> { 93 fn unobtrusive_peek_mut<'s>(mut self: Pin<&'s mut Self>) -> Option<&'s mut S::Item> { 94 #[allow(clippy::question_mark)] // We use explicit control flow here for clarity 95 if self.as_mut().project().buffered.is_none() { 96 // We don't have a buffered item, but the stream may have an item available. 97 // We must poll it to find out. 98 // 99 // We need to pass a Context to poll_next. 100 // inner may store this context, replacing one provided via poll_*. 101 // 102 // Despite that, we need to make sure that wakeups will happen as expected. 103 // To achieve this we have retained a copy of the caller's Waker. 104 // 105 // When a future or stream returns Pending, it proposes to wake `waker` 106 // when it wants to be polled again. 107 // 108 // We uphold that promise by 109 // - only returning Pending from our poll methods if inner also returned Pending 110 // - when one of our poll methods returns Pending, saving the caller-supplied 111 // waker, so that we can make the intermediate poll call here. 112 // 113 // If the inner poll returns Ready, inner no longer guarantees to wake anyone. 114 // In principle, if our user is waiting (we returned Pending), 115 // then inner ought to have called `wake` on the caller's `Waker`. 116 // But I don't think we can guarantee that an executor won't defer a wakeup, 117 // and respond to a dropped Waker by cancelling that wakeup; 118 // or to put it another way, the wakeup might be "in flight" on entry, 119 // but the call to inner's poll_next returning Ready 120 // might somehow "cancel" the wakeup. 121 // 122 // So just to be sure, if we get a Ready here, we wake the stored waker. 123 124 let mut self_ = self.as_mut().project(); 125 126 let Some(inner) = self_.inner.as_mut().as_pin_mut() else { 127 return None; 128 }; 129 130 let waker = if let Some(waker) = self_.poll_waker.as_ref() { 131 waker 132 } else { 133 noop_waker_ref() 134 }; 135 136 match inner.poll_next(&mut Context::from_waker(waker)) { 137 Pending => {} 138 Ready(item_or_eof) => { 139 if let Some(waker) = self_.poll_waker.take() { 140 waker.wake(); 141 } 142 match item_or_eof { 143 None => self_.inner.set(None), 144 Some(item) => *self_.buffered = Some(item), 145 } 146 } 147 }; 148 } 149 150 self.project().buffered.as_mut() 151 } 152 } 153 154 impl<S: Stream> PeekableStream for StreamUnobtrusivePeeker<S> { 155 fn poll_peek<'s>(self: Pin<&'s mut Self>, cx: &mut Context<'_>) -> Poll<Option<&'s S::Item>> { 156 self.impl_poll_next_or_peek(cx, |buffered| buffered.as_ref()) 157 } 158 159 fn poll_peek_mut<'s>( 160 self: Pin<&'s mut Self>, 161 cx: &mut Context<'_>, 162 ) -> Poll<Option<&'s mut S::Item>> { 163 self.impl_poll_next_or_peek(cx, |buffered| buffered.as_mut()) 164 } 165 } 166 167 impl<S: Stream> StreamUnobtrusivePeeker<S> { 168 /// Implementation of `poll_{peek,next}` 169 /// 170 /// This takes care of 171 /// * examining the state of our buffer, and polling inner if needed 172 /// * ensuring that we store a waker, if needed 173 /// * dealing with some borrowck awkwardness 174 /// 175 /// The `Ready` value is always calculated from `buffer`. 176 /// `return_value_obtainer` is called only if we are going to return `Ready`. 177 /// It's given `buffer` and should either: 178 /// * [`take`](Option::take) the contained value (for `poll_next`) 179 /// * return a reference using [`Option::as_ref`] (for `poll_peek`) 180 fn impl_poll_next_or_peek<'s, R: 's>( 181 self: Pin<&'s mut Self>, 182 cx: &mut Context<'_>, 183 return_value_obtainer: impl FnOnce(&'s mut Option<S::Item>) -> Option<R>, 184 ) -> Poll<Option<R>> { 185 let mut self_ = self.project(); 186 let r = Self::next_or_peek_inner(&mut self_, cx); 187 let r = r.map(|()| return_value_obtainer(self_.buffered)); 188 Self::return_from_poll(self_.poll_waker, cx, r) 189 } 190 191 /// Try to populate `buffer`, and calculate if we're `Ready` 192 /// 193 /// Returns `Ready` iff `poll_next` or `poll_peek` should return `Ready`. 194 /// The actual `Ready` value (an `Option`) will be calculated later. 195 fn next_or_peek_inner(self_: &mut PeekerProj<S>, cx: &mut Context<'_>) -> Poll<()> { 196 if let Some(_item) = self_.buffered.as_ref() { 197 // `return_value_obtainer` will find `Some` in `buffered`; 198 // overall, we'll return `Ready(Some(..))`. 199 return Ready(()); 200 } 201 let Some(inner) = self_.inner.as_mut().as_pin_mut() else { 202 // `return_value_obtainer` will find `None` in `buffered`; 203 // overall, we'll return `Ready(None)`, ie EOF. 204 return Ready(()); 205 }; 206 match inner.poll_next(cx) { 207 Ready(None) => { 208 self_.inner.set(None); 209 // `buffered` is `None`, still. 210 // overall, we'll return `Ready(None)`, ie EOF. 211 Ready(()) 212 } 213 Ready(Some(item)) => { 214 *self_.buffered = Some(item); 215 // return_value_obtainer` will find `Some` in `buffered` 216 Ready(()) 217 } 218 Pending => { 219 // `return_value_obtainer` won't be called. 220 // overall, we'll return Pending 221 Pending 222 } 223 } 224 } 225 226 /// Wait for an item to be ready, and then inspect it 227 /// 228 /// Equivalent to [`futures::stream::Peekable::peek`]. 229 /// 230 /// # Tasks, waking, and calling context 231 /// 232 /// This should be called by the task that is reading from the stream. 233 /// If it is called by another task, the reading task would miss notifications. 234 // 235 // This ^ docs section is triplicated for poll_peek, poll_peek_mut, and peek 236 // 237 // TODO this should be a method on the `PeekableStream` trait? Or a 238 // `PeekableStreamExt` trait? 239 // TODO should there be peek_mut ? 240 #[allow(dead_code)] // TODO remove this allow if and when we make this module public 241 pub fn peek(self: Pin<&mut Self>) -> PeekFuture<Self> { 242 PeekFuture { peeker: Some(self) } 243 } 244 245 /// Return from a `poll_*` function, setting the stored waker appropriately 246 /// 247 /// Our `poll` functions always use this. 248 /// The rule is that if a future returns `Pending`, it has stored the waker. 249 fn return_from_poll<R>( 250 poll_waker: &mut Option<Waker>, 251 cx: &mut Context<'_>, 252 r: Poll<R>, 253 ) -> Poll<R> { 254 *poll_waker = match &r { 255 Ready(_) => { 256 // No need to wake this task up any more. 257 None 258 } 259 Pending => { 260 // try_peek must use the same waker to poll later 261 Some(cx.waker().clone()) 262 } 263 }; 264 r 265 } 266 267 /// Obtain a raw reference to the inner stream 268 /// 269 /// ### Correctness! 270 /// 271 /// This method must be used with care! 272 /// Whatever you do mustn't interfere with polling and peeking. 273 /// Careless use can result in wrong behaviour including deadlocks. 274 pub fn as_raw_inner_pin_mut<'s>(self: Pin<&'s mut Self>) -> Option<Pin<&'s mut S>> { 275 self.project().inner.as_pin_mut() 276 } 277 } 278 279 impl<S: Stream> Stream for StreamUnobtrusivePeeker<S> { 280 type Item = S::Item; 281 282 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { 283 self.impl_poll_next_or_peek(cx, |buffered| buffered.take()) 284 } 285 286 fn size_hint(&self) -> (usize, Option<usize>) { 287 let buf = self.buffered.iter().count(); 288 let (imin, imax) = match &self.inner { 289 Some(inner) => inner.size_hint(), 290 None => (0, Some(0)), 291 }; 292 (imin + buf, imax.and_then(|imap| imap.checked_add(buf))) 293 } 294 } 295 296 impl<S: Stream> FusedStream for StreamUnobtrusivePeeker<S> { 297 fn is_terminated(&self) -> bool { 298 self.buffered.is_none() && self.inner.is_none() 299 } 300 } 301 302 /// Future from [`StreamUnobtrusivePeeker::peek`] 303 // TODO: Move to tor_async_utils::peekable_stream. 304 #[derive(Educe)] 305 #[educe(Debug(bound("S: Debug")))] 306 #[must_use = "peek() return a Future, which does nothing unless awaited"] 307 pub struct PeekFuture<'s, S> { 308 /// The underlying stream. 309 /// 310 /// `Some` until we have returned `Ready`, then `None`. 311 /// See comment in `poll`. 312 peeker: Option<Pin<&'s mut S>>, 313 } 314 315 impl<'s, S: PeekableStream> PeekFuture<'s, S> { 316 /// Create a new `PeekFuture`. 317 // TODO: replace with a trait method. 318 pub fn new(stream: Pin<&'s mut S>) -> Self { 319 Self { 320 peeker: Some(stream), 321 } 322 } 323 } 324 325 impl<'s, S: PeekableStream> Future for PeekFuture<'s, S> { 326 type Output = Option<&'s S::Item>; 327 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<&'s S::Item>> { 328 let self_ = self.get_mut(); 329 let peeker = self_ 330 .peeker 331 .as_mut() 332 .expect("PeekFuture polled after Ready"); 333 match peeker.as_mut().poll_peek(cx) { 334 Pending => return Pending, 335 Ready(_y) => { 336 // Ideally we would have returned `y` here, but it's borrowed from PeekFuture 337 // not from the original StreamUnobtrusivePeeker, and there's no way 338 // to get a value with the right lifetime. (In non-async code, 339 // this is usually handled by the special magic for reborrowing &mut.) 340 // 341 // So we must redo the poll, but this time consuming `peeker`, 342 // which gets us the right lifetime. That's why it has to be `Option`. 343 // Because we own &mut ... Self, we know that repeating the poll 344 // gives the same answer. 345 } 346 } 347 let peeker = self_.peeker.take().expect("it was Some before!"); 348 let r = peeker.poll_peek(cx); 349 assert!(r.is_ready(), "it was Ready before!"); 350 r 351 } 352 } 353 354 #[cfg(test)] 355 mod test { 356 // @@ begin test lint list maintained by maint/add_warning @@ 357 #![allow(clippy::bool_assert_comparison)] 358 #![allow(clippy::clone_on_copy)] 359 #![allow(clippy::dbg_macro)] 360 #![allow(clippy::mixed_attributes_style)] 361 #![allow(clippy::print_stderr)] 362 #![allow(clippy::print_stdout)] 363 #![allow(clippy::single_char_pattern)] 364 #![allow(clippy::unwrap_used)] 365 #![allow(clippy::unchecked_duration_subtraction)] 366 #![allow(clippy::useless_vec)] 367 #![allow(clippy::needless_pass_by_value)] 368 //! <!-- @@ end test lint list maintained by maint/add_warning @@ --> 369 370 use super::*; 371 use futures::channel::mpsc; 372 use futures::{SinkExt as _, StreamExt as _}; 373 use std::pin::pin; 374 use std::sync::{Arc, Mutex}; 375 use std::time::Duration; 376 use tor_rtcompat::SleepProvider as _; 377 use tor_rtmock::MockRuntime; 378 379 fn ms(ms: u64) -> Duration { 380 Duration::from_millis(ms) 381 } 382 383 #[test] 384 fn wakeups() { 385 MockRuntime::test_with_various(|rt| async move { 386 let (mut tx, rx) = mpsc::unbounded(); 387 let ended = Arc::new(Mutex::new(false)); 388 389 rt.spawn_identified("rxr", { 390 let rt = rt.clone(); 391 let ended = ended.clone(); 392 393 async move { 394 let rx = StreamUnobtrusivePeeker::new(rx); 395 let mut rx = pin!(rx); 396 397 let mut next = 0; 398 loop { 399 rt.sleep(ms(50)).await; 400 eprintln!("rx peek... "); 401 let peeked = rx.as_mut().unobtrusive_peek_mut(); 402 eprintln!("rx peeked {peeked:?}"); 403 404 if let Some(peeked) = peeked { 405 assert_eq!(*peeked, next); 406 } 407 408 rt.sleep(ms(50)).await; 409 eprintln!("rx next... "); 410 let eaten = rx.next().await; 411 eprintln!("rx eaten {eaten:?}"); 412 if let Some(eaten) = eaten { 413 assert_eq!(eaten, next); 414 next += 1; 415 } else { 416 break; 417 } 418 } 419 420 *ended.lock().unwrap() = true; 421 eprintln!("rx ended"); 422 } 423 }); 424 425 rt.spawn_identified("tx", { 426 let rt = rt.clone(); 427 428 async move { 429 let mut numbers = 0..; 430 for wait in [125, 1, 125, 45, 1, 1, 1, 1000, 20, 1, 125, 125, 1000] { 431 eprintln!("tx sleep {wait}"); 432 rt.sleep(ms(wait)).await; 433 let num = numbers.next().unwrap(); 434 eprintln!("tx sending {num}"); 435 tx.send(num).await.unwrap(); 436 } 437 438 // This schedule arranges that, when we send EOF, the rx task 439 // has *peeked* rather than *polled* most recently, 440 // demonstrating that we can wake up the subsequent poll on EOF too. 441 eprintln!("tx final #1"); 442 rt.sleep(ms(75)).await; 443 eprintln!("tx EOF"); 444 drop(tx); 445 eprintln!("tx final #2"); 446 rt.sleep(ms(10)).await; 447 assert!(!*ended.lock().unwrap()); 448 eprintln!("tx final #3"); 449 rt.sleep(ms(50)).await; 450 eprintln!("tx final #4"); 451 assert!(*ended.lock().unwrap()); 452 } 453 }); 454 455 rt.advance_until_stalled().await; 456 }); 457 } 458 459 #[test] 460 fn poll_peek_paths() { 461 MockRuntime::test_with_various(|rt| async move { 462 let (mut tx, rx) = mpsc::unbounded(); 463 let ended = Arc::new(Mutex::new(false)); 464 465 rt.spawn_identified("rxr", { 466 let rt = rt.clone(); 467 let ended = ended.clone(); 468 469 async move { 470 let rx = StreamUnobtrusivePeeker::new(rx); 471 let mut rx = pin!(rx); 472 473 while let Some(peeked) = rx.as_mut().peek().await.copied() { 474 eprintln!("rx peeked {peeked}"); 475 let eaten = rx.next().await.unwrap(); 476 eprintln!("rx eaten {eaten}"); 477 assert_eq!(peeked, eaten); 478 rt.sleep(ms(10)).await; 479 eprintln!("rx slept, peeking"); 480 } 481 *ended.lock().unwrap() = true; 482 eprintln!("rx ended"); 483 } 484 }); 485 486 rt.spawn_identified("tx", { 487 let rt = rt.clone(); 488 489 async move { 490 let mut numbers = 0..; 491 492 // macro because we don't have proper async closures 493 macro_rules! send { {} => { 494 let num = numbers.next().unwrap(); 495 eprintln!("tx send {num}"); 496 tx.send(num).await.unwrap(); 497 } } 498 499 eprintln!("tx starting"); 500 rt.sleep(ms(100)).await; 501 send!(); 502 rt.sleep(ms(100)).await; 503 send!(); 504 send!(); 505 rt.sleep(ms(100)).await; 506 eprintln!("tx dropping"); 507 drop(tx); 508 rt.sleep(ms(5)).await; 509 eprintln!("tx ending"); 510 assert!(*ended.lock().unwrap()); 511 } 512 }); 513 514 rt.advance_until_stalled().await; 515 }); 516 } 517 }