/ crates / tor-rtmock / src / io.rs
io.rs
  1  //! Mocking helpers for testing with futures::io types.
  2  //!
  3  //! Note that some of this code might be of general use, but for now
  4  //! we're only trying it for testing.
  5  
  6  #![forbid(unsafe_code)] // if you remove this, enable (or write) miri tests (git grep miri)
  7  
  8  use crate::util::mpsc_channel;
  9  use futures::channel::mpsc;
 10  use futures::io::{AsyncRead, AsyncWrite};
 11  use futures::sink::{Sink, SinkExt};
 12  use futures::stream::Stream;
 13  use std::io::{Error as IoError, ErrorKind, Result as IoResult};
 14  use std::pin::Pin;
 15  use std::task::{Context, Poll};
 16  use tor_rtcompat::{StreamOps, UnsupportedStreamOp};
 17  
 18  /// Channel capacity for our internal MPSC channels.
 19  ///
 20  /// We keep this intentionally low to make sure that some blocking
 21  /// will occur occur.
 22  const CAPACITY: usize = 4;
 23  
 24  /// Maximum size for a queued buffer on a local chunk.
 25  ///
 26  /// This size is deliberately weird, to try to find errors.
 27  const CHUNKSZ: usize = 213;
 28  
 29  /// Construct a new pair of linked LocalStream objects.
 30  ///
 31  /// Any bytes written to one will be readable on the other, and vice
 32  /// versa.  These streams will behave more or less like a socketpair,
 33  /// except without actually going through the operating system.
 34  ///
 35  /// Note that this implementation is intended for testing only, and
 36  /// isn't optimized.
 37  pub fn stream_pair() -> (LocalStream, LocalStream) {
 38      let (w1, r2) = mpsc_channel(CAPACITY);
 39      let (w2, r1) = mpsc_channel(CAPACITY);
 40      let s1 = LocalStream {
 41          w: w1,
 42          r: r1,
 43          pending_bytes: Vec::new(),
 44          tls_cert: None,
 45      };
 46      let s2 = LocalStream {
 47          w: w2,
 48          r: r2,
 49          pending_bytes: Vec::new(),
 50          tls_cert: None,
 51      };
 52      (s1, s2)
 53  }
 54  
 55  /// One half of a pair of linked streams returned by [`stream_pair`].
 56  //
 57  // Implementation notes: linked streams are made out a pair of mpsc
 58  // channels.  There's one channel for sending bytes in each direction.
 59  // Bytes are sent as IoResult<Vec<u8>>: sending an error causes an error
 60  // to occur on the other side.
 61  pub struct LocalStream {
 62      /// The writing side of the channel that we use to implement this
 63      /// stream.
 64      ///
 65      /// The reading side is held by the other linked stream.
 66      w: mpsc::Sender<IoResult<Vec<u8>>>,
 67      /// The reading side of the channel that we use to implement this
 68      /// stream.
 69      ///
 70      /// The writing side is held by the other linked stream.
 71      r: mpsc::Receiver<IoResult<Vec<u8>>>,
 72      /// Bytes that we have read from `r` but not yet delivered.
 73      pending_bytes: Vec<u8>,
 74      /// Data about the other side of this stream's fake TLS certificate, if any.
 75      /// If this is present, I/O operations will fail with an error.
 76      ///
 77      /// How this is intended to work: things that return `LocalStream`s that could potentially
 78      /// be connected to a fake TLS listener should set this field. Then, a fake TLS wrapper
 79      /// type would clear this field (after checking its contents are as expected).
 80      ///
 81      /// FIXME(eta): this is a bit of a layering violation, but it's hard to do otherwise
 82      pub(crate) tls_cert: Option<Vec<u8>>,
 83  }
 84  
 85  /// Helper: pull bytes off the front of `pending_bytes` and put them
 86  /// onto `buf.  Return the number of bytes moved.
 87  fn drain_helper(buf: &mut [u8], pending_bytes: &mut Vec<u8>) -> usize {
 88      let n_to_drain = std::cmp::min(buf.len(), pending_bytes.len());
 89      buf[..n_to_drain].copy_from_slice(&pending_bytes[..n_to_drain]);
 90      pending_bytes.drain(..n_to_drain);
 91      n_to_drain
 92  }
 93  
 94  impl AsyncRead for LocalStream {
 95      fn poll_read(
 96          mut self: Pin<&mut Self>,
 97          cx: &mut Context<'_>,
 98          buf: &mut [u8],
 99      ) -> Poll<IoResult<usize>> {
100          if buf.is_empty() {
101              return Poll::Ready(Ok(0));
102          }
103          if self.tls_cert.is_some() {
104              return Poll::Ready(Err(std::io::Error::new(
105                  std::io::ErrorKind::Other,
106                  "attempted to treat a TLS stream as non-TLS!",
107              )));
108          }
109          if !self.pending_bytes.is_empty() {
110              return Poll::Ready(Ok(drain_helper(buf, &mut self.pending_bytes)));
111          }
112  
113          match futures::ready!(Pin::new(&mut self.r).poll_next(cx)) {
114              Some(Err(e)) => Poll::Ready(Err(e)),
115              Some(Ok(bytes)) => {
116                  self.pending_bytes = bytes;
117                  let n = drain_helper(buf, &mut self.pending_bytes);
118                  Poll::Ready(Ok(n))
119              }
120              None => Poll::Ready(Ok(0)), // This is an EOF
121          }
122      }
123  }
124  
125  impl AsyncWrite for LocalStream {
126      fn poll_write(
127          mut self: Pin<&mut Self>,
128          cx: &mut Context<'_>,
129          buf: &[u8],
130      ) -> Poll<IoResult<usize>> {
131          if self.tls_cert.is_some() {
132              return Poll::Ready(Err(std::io::Error::new(
133                  std::io::ErrorKind::Other,
134                  "attempted to treat a TLS stream as non-TLS!",
135              )));
136          }
137  
138          match futures::ready!(Pin::new(&mut self.w).poll_ready(cx)) {
139              Ok(()) => (),
140              Err(e) => return Poll::Ready(Err(IoError::new(ErrorKind::BrokenPipe, e))),
141          }
142  
143          let buf = if buf.len() > CHUNKSZ {
144              &buf[..CHUNKSZ]
145          } else {
146              buf
147          };
148          let len = buf.len();
149          match Pin::new(&mut self.w).start_send(Ok(buf.to_vec())) {
150              Ok(()) => Poll::Ready(Ok(len)),
151              Err(e) => Poll::Ready(Err(IoError::new(ErrorKind::BrokenPipe, e))),
152          }
153      }
154      fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
155          Pin::new(&mut self.w)
156              .poll_flush(cx)
157              .map_err(|e| IoError::new(ErrorKind::BrokenPipe, e))
158      }
159      fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
160          Pin::new(&mut self.w)
161              .poll_close(cx)
162              .map_err(|e| IoError::new(ErrorKind::Other, e))
163      }
164  }
165  
166  impl StreamOps for LocalStream {
167      fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
168          Err(
169              UnsupportedStreamOp::new("set_tcp_notsent_lowat", "unsupported on local streams")
170                  .into(),
171          )
172      }
173  }
174  
175  /// An error generated by [`LocalStream::send_err`].
176  #[derive(Debug, Clone, Eq, PartialEq)]
177  #[non_exhaustive]
178  pub struct SyntheticError;
179  impl std::error::Error for SyntheticError {}
180  impl std::fmt::Display for SyntheticError {
181      fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182          write!(f, "Synthetic error")
183      }
184  }
185  
186  impl LocalStream {
187      /// Send an error to the other linked local stream.
188      ///
189      /// When the other stream reads this message, it will generate a
190      /// [`std::io::Error`] with the provided `ErrorKind`.
191      pub async fn send_err(&mut self, kind: ErrorKind) {
192          let _ignore = self.w.send(Err(IoError::new(kind, SyntheticError))).await;
193      }
194  }
195  
196  #[cfg(all(test, not(miri)))] // These tests are very slow under miri
197  mod test {
198      // @@ begin test lint list maintained by maint/add_warning @@
199      #![allow(clippy::bool_assert_comparison)]
200      #![allow(clippy::clone_on_copy)]
201      #![allow(clippy::dbg_macro)]
202      #![allow(clippy::mixed_attributes_style)]
203      #![allow(clippy::print_stderr)]
204      #![allow(clippy::print_stdout)]
205      #![allow(clippy::single_char_pattern)]
206      #![allow(clippy::unwrap_used)]
207      #![allow(clippy::unchecked_duration_subtraction)]
208      #![allow(clippy::useless_vec)]
209      #![allow(clippy::needless_pass_by_value)]
210      //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
211      use super::*;
212  
213      use futures::io::{AsyncReadExt, AsyncWriteExt};
214      use futures_await_test::async_test;
215      use rand::Rng;
216      use tor_basic_utils::test_rng::testing_rng;
217  
218      #[async_test]
219      async fn basic_rw() {
220          let (mut s1, mut s2) = stream_pair();
221          let mut text1 = vec![0_u8; 9999];
222          testing_rng().fill(&mut text1[..]);
223  
224          let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!(
225              async {
226                  for _ in 0_u8..10 {
227                      s1.write_all(&text1[..]).await?;
228                  }
229                  s1.close().await?;
230                  Ok(())
231              },
232              async {
233                  let mut text2: Vec<u8> = Vec::new();
234                  let mut buf = [0_u8; 33];
235                  loop {
236                      let n = s2.read(&mut buf[..]).await?;
237                      if n == 0 {
238                          break;
239                      }
240                      text2.extend(&buf[..n]);
241                  }
242                  for ch in text2[..].chunks(text1.len()) {
243                      assert_eq!(ch, &text1[..]);
244                  }
245                  Ok(())
246              }
247          );
248  
249          v1.unwrap();
250          v2.unwrap();
251      }
252  
253      #[async_test]
254      async fn send_error() {
255          let (mut s1, mut s2) = stream_pair();
256  
257          let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!(
258              async {
259                  s1.write_all(b"hello world").await?;
260                  s1.send_err(ErrorKind::PermissionDenied).await;
261                  Ok(())
262              },
263              async {
264                  let mut buf = [0_u8; 33];
265                  loop {
266                      let n = s2.read(&mut buf[..]).await?;
267                      if n == 0 {
268                          break;
269                      }
270                  }
271                  Ok(())
272              }
273          );
274  
275          v1.unwrap();
276          let e = v2.err().unwrap();
277          assert_eq!(e.kind(), ErrorKind::PermissionDenied);
278          let synth = e.into_inner().unwrap();
279          assert_eq!(synth.to_string(), "Synthetic error");
280      }
281  
282      #[async_test]
283      async fn drop_reader() {
284          let (mut s1, s2) = stream_pair();
285  
286          let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!(
287              async {
288                  for _ in 0_u16..1000 {
289                      s1.write_all(&[9_u8; 9999]).await?;
290                  }
291                  Ok(())
292              },
293              async {
294                  drop(s2);
295                  Ok(())
296              }
297          );
298  
299          v2.unwrap();
300          let e = v1.err().unwrap();
301          assert_eq!(e.kind(), ErrorKind::BrokenPipe);
302      }
303  }