io.rs
1 //! Mocking helpers for testing with futures::io types. 2 //! 3 //! Note that some of this code might be of general use, but for now 4 //! we're only trying it for testing. 5 6 #![forbid(unsafe_code)] // if you remove this, enable (or write) miri tests (git grep miri) 7 8 use crate::util::mpsc_channel; 9 use futures::channel::mpsc; 10 use futures::io::{AsyncRead, AsyncWrite}; 11 use futures::sink::{Sink, SinkExt}; 12 use futures::stream::Stream; 13 use std::io::{Error as IoError, ErrorKind, Result as IoResult}; 14 use std::pin::Pin; 15 use std::task::{Context, Poll}; 16 use tor_rtcompat::{StreamOps, UnsupportedStreamOp}; 17 18 /// Channel capacity for our internal MPSC channels. 19 /// 20 /// We keep this intentionally low to make sure that some blocking 21 /// will occur occur. 22 const CAPACITY: usize = 4; 23 24 /// Maximum size for a queued buffer on a local chunk. 25 /// 26 /// This size is deliberately weird, to try to find errors. 27 const CHUNKSZ: usize = 213; 28 29 /// Construct a new pair of linked LocalStream objects. 30 /// 31 /// Any bytes written to one will be readable on the other, and vice 32 /// versa. These streams will behave more or less like a socketpair, 33 /// except without actually going through the operating system. 34 /// 35 /// Note that this implementation is intended for testing only, and 36 /// isn't optimized. 37 pub fn stream_pair() -> (LocalStream, LocalStream) { 38 let (w1, r2) = mpsc_channel(CAPACITY); 39 let (w2, r1) = mpsc_channel(CAPACITY); 40 let s1 = LocalStream { 41 w: w1, 42 r: r1, 43 pending_bytes: Vec::new(), 44 tls_cert: None, 45 }; 46 let s2 = LocalStream { 47 w: w2, 48 r: r2, 49 pending_bytes: Vec::new(), 50 tls_cert: None, 51 }; 52 (s1, s2) 53 } 54 55 /// One half of a pair of linked streams returned by [`stream_pair`]. 56 // 57 // Implementation notes: linked streams are made out a pair of mpsc 58 // channels. There's one channel for sending bytes in each direction. 59 // Bytes are sent as IoResult<Vec<u8>>: sending an error causes an error 60 // to occur on the other side. 61 pub struct LocalStream { 62 /// The writing side of the channel that we use to implement this 63 /// stream. 64 /// 65 /// The reading side is held by the other linked stream. 66 w: mpsc::Sender<IoResult<Vec<u8>>>, 67 /// The reading side of the channel that we use to implement this 68 /// stream. 69 /// 70 /// The writing side is held by the other linked stream. 71 r: mpsc::Receiver<IoResult<Vec<u8>>>, 72 /// Bytes that we have read from `r` but not yet delivered. 73 pending_bytes: Vec<u8>, 74 /// Data about the other side of this stream's fake TLS certificate, if any. 75 /// If this is present, I/O operations will fail with an error. 76 /// 77 /// How this is intended to work: things that return `LocalStream`s that could potentially 78 /// be connected to a fake TLS listener should set this field. Then, a fake TLS wrapper 79 /// type would clear this field (after checking its contents are as expected). 80 /// 81 /// FIXME(eta): this is a bit of a layering violation, but it's hard to do otherwise 82 pub(crate) tls_cert: Option<Vec<u8>>, 83 } 84 85 /// Helper: pull bytes off the front of `pending_bytes` and put them 86 /// onto `buf. Return the number of bytes moved. 87 fn drain_helper(buf: &mut [u8], pending_bytes: &mut Vec<u8>) -> usize { 88 let n_to_drain = std::cmp::min(buf.len(), pending_bytes.len()); 89 buf[..n_to_drain].copy_from_slice(&pending_bytes[..n_to_drain]); 90 pending_bytes.drain(..n_to_drain); 91 n_to_drain 92 } 93 94 impl AsyncRead for LocalStream { 95 fn poll_read( 96 mut self: Pin<&mut Self>, 97 cx: &mut Context<'_>, 98 buf: &mut [u8], 99 ) -> Poll<IoResult<usize>> { 100 if buf.is_empty() { 101 return Poll::Ready(Ok(0)); 102 } 103 if self.tls_cert.is_some() { 104 return Poll::Ready(Err(std::io::Error::new( 105 std::io::ErrorKind::Other, 106 "attempted to treat a TLS stream as non-TLS!", 107 ))); 108 } 109 if !self.pending_bytes.is_empty() { 110 return Poll::Ready(Ok(drain_helper(buf, &mut self.pending_bytes))); 111 } 112 113 match futures::ready!(Pin::new(&mut self.r).poll_next(cx)) { 114 Some(Err(e)) => Poll::Ready(Err(e)), 115 Some(Ok(bytes)) => { 116 self.pending_bytes = bytes; 117 let n = drain_helper(buf, &mut self.pending_bytes); 118 Poll::Ready(Ok(n)) 119 } 120 None => Poll::Ready(Ok(0)), // This is an EOF 121 } 122 } 123 } 124 125 impl AsyncWrite for LocalStream { 126 fn poll_write( 127 mut self: Pin<&mut Self>, 128 cx: &mut Context<'_>, 129 buf: &[u8], 130 ) -> Poll<IoResult<usize>> { 131 if self.tls_cert.is_some() { 132 return Poll::Ready(Err(std::io::Error::new( 133 std::io::ErrorKind::Other, 134 "attempted to treat a TLS stream as non-TLS!", 135 ))); 136 } 137 138 match futures::ready!(Pin::new(&mut self.w).poll_ready(cx)) { 139 Ok(()) => (), 140 Err(e) => return Poll::Ready(Err(IoError::new(ErrorKind::BrokenPipe, e))), 141 } 142 143 let buf = if buf.len() > CHUNKSZ { 144 &buf[..CHUNKSZ] 145 } else { 146 buf 147 }; 148 let len = buf.len(); 149 match Pin::new(&mut self.w).start_send(Ok(buf.to_vec())) { 150 Ok(()) => Poll::Ready(Ok(len)), 151 Err(e) => Poll::Ready(Err(IoError::new(ErrorKind::BrokenPipe, e))), 152 } 153 } 154 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> { 155 Pin::new(&mut self.w) 156 .poll_flush(cx) 157 .map_err(|e| IoError::new(ErrorKind::BrokenPipe, e)) 158 } 159 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> { 160 Pin::new(&mut self.w) 161 .poll_close(cx) 162 .map_err(|e| IoError::new(ErrorKind::Other, e)) 163 } 164 } 165 166 impl StreamOps for LocalStream { 167 fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> { 168 Err( 169 UnsupportedStreamOp::new("set_tcp_notsent_lowat", "unsupported on local streams") 170 .into(), 171 ) 172 } 173 } 174 175 /// An error generated by [`LocalStream::send_err`]. 176 #[derive(Debug, Clone, Eq, PartialEq)] 177 #[non_exhaustive] 178 pub struct SyntheticError; 179 impl std::error::Error for SyntheticError {} 180 impl std::fmt::Display for SyntheticError { 181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 182 write!(f, "Synthetic error") 183 } 184 } 185 186 impl LocalStream { 187 /// Send an error to the other linked local stream. 188 /// 189 /// When the other stream reads this message, it will generate a 190 /// [`std::io::Error`] with the provided `ErrorKind`. 191 pub async fn send_err(&mut self, kind: ErrorKind) { 192 let _ignore = self.w.send(Err(IoError::new(kind, SyntheticError))).await; 193 } 194 } 195 196 #[cfg(all(test, not(miri)))] // These tests are very slow under miri 197 mod test { 198 // @@ begin test lint list maintained by maint/add_warning @@ 199 #![allow(clippy::bool_assert_comparison)] 200 #![allow(clippy::clone_on_copy)] 201 #![allow(clippy::dbg_macro)] 202 #![allow(clippy::mixed_attributes_style)] 203 #![allow(clippy::print_stderr)] 204 #![allow(clippy::print_stdout)] 205 #![allow(clippy::single_char_pattern)] 206 #![allow(clippy::unwrap_used)] 207 #![allow(clippy::unchecked_duration_subtraction)] 208 #![allow(clippy::useless_vec)] 209 #![allow(clippy::needless_pass_by_value)] 210 //! <!-- @@ end test lint list maintained by maint/add_warning @@ --> 211 use super::*; 212 213 use futures::io::{AsyncReadExt, AsyncWriteExt}; 214 use futures_await_test::async_test; 215 use rand::Rng; 216 use tor_basic_utils::test_rng::testing_rng; 217 218 #[async_test] 219 async fn basic_rw() { 220 let (mut s1, mut s2) = stream_pair(); 221 let mut text1 = vec![0_u8; 9999]; 222 testing_rng().fill(&mut text1[..]); 223 224 let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!( 225 async { 226 for _ in 0_u8..10 { 227 s1.write_all(&text1[..]).await?; 228 } 229 s1.close().await?; 230 Ok(()) 231 }, 232 async { 233 let mut text2: Vec<u8> = Vec::new(); 234 let mut buf = [0_u8; 33]; 235 loop { 236 let n = s2.read(&mut buf[..]).await?; 237 if n == 0 { 238 break; 239 } 240 text2.extend(&buf[..n]); 241 } 242 for ch in text2[..].chunks(text1.len()) { 243 assert_eq!(ch, &text1[..]); 244 } 245 Ok(()) 246 } 247 ); 248 249 v1.unwrap(); 250 v2.unwrap(); 251 } 252 253 #[async_test] 254 async fn send_error() { 255 let (mut s1, mut s2) = stream_pair(); 256 257 let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!( 258 async { 259 s1.write_all(b"hello world").await?; 260 s1.send_err(ErrorKind::PermissionDenied).await; 261 Ok(()) 262 }, 263 async { 264 let mut buf = [0_u8; 33]; 265 loop { 266 let n = s2.read(&mut buf[..]).await?; 267 if n == 0 { 268 break; 269 } 270 } 271 Ok(()) 272 } 273 ); 274 275 v1.unwrap(); 276 let e = v2.err().unwrap(); 277 assert_eq!(e.kind(), ErrorKind::PermissionDenied); 278 let synth = e.into_inner().unwrap(); 279 assert_eq!(synth.to_string(), "Synthetic error"); 280 } 281 282 #[async_test] 283 async fn drop_reader() { 284 let (mut s1, s2) = stream_pair(); 285 286 let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!( 287 async { 288 for _ in 0_u16..1000 { 289 s1.write_all(&[9_u8; 9999]).await?; 290 } 291 Ok(()) 292 }, 293 async { 294 drop(s2); 295 Ok(()) 296 } 297 ); 298 299 v2.unwrap(); 300 let e = v1.err().unwrap(); 301 assert_eq!(e.kind(), ErrorKind::BrokenPipe); 302 } 303 }