framed.rs
1 //! Adapter that implements a message based protocol on top of a stream based 2 //! one 3 use std::fmt::Debug; 4 use std::io::{Read, Write}; 5 use std::marker::PhantomData; 6 use std::pin::Pin; 7 use std::task::{Context, Poll}; 8 9 use bytes::{Buf, BufMut, BytesMut}; 10 use fedimint_logging::LOG_NET_PEER; 11 use futures::{Sink, Stream}; 12 use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; 13 use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; 14 use tokio_util::codec::{FramedRead, FramedWrite}; 15 use tracing::{error, trace}; 16 17 /// Owned [`FramedTransport`] trait object 18 pub type AnyFramedTransport<M> = Box<dyn FramedTransport<M> + Send + Unpin + 'static>; 19 20 /// A bidirectional framed transport adapter that can be split into its read and 21 /// write half 22 pub trait FramedTransport<T>: 23 Sink<T, Error = anyhow::Error> + Stream<Item = Result<T, anyhow::Error>> 24 { 25 /// Split the framed transport into read and write half 26 fn borrow_split( 27 &mut self, 28 ) -> ( 29 &'_ mut (dyn Sink<T, Error = anyhow::Error> + Send + Unpin), 30 &'_ mut (dyn Stream<Item = Result<T, anyhow::Error>> + Send + Unpin), 31 ); 32 33 /// Transforms concrete `FramedTransport` object into an owned trait object 34 fn into_dyn(self) -> AnyFramedTransport<T> 35 where 36 Self: Sized + Send + Unpin + 'static, 37 { 38 Box::new(self) 39 } 40 } 41 42 /// Special case for tokio [`TcpStream`](tokio::net::TcpStream) based 43 /// [`BidiFramed`] instances 44 pub type TcpBidiFramed<T> = BidiFramed<T, OwnedWriteHalf, OwnedReadHalf>; 45 46 /// Sink (sending) half of [`BidiFramed`] 47 pub type FramedSink<S, T> = FramedWrite<S, BincodeCodec<T>>; 48 /// Stream (receiving) half of [`BidiFramed`] 49 pub type FramedStream<S, T> = FramedRead<S, BincodeCodec<T>>; 50 51 /// Framed transport codec for streams 52 /// 53 /// Wraps a stream `S` and allows sending packetized data of type `T` over it. 54 /// Data items are encoded using [`bincode`] and the bytes are sent over the 55 /// stream prepended with a length field. `BidiFramed` implements `Sink<T>` and 56 /// `Stream<Item=Result<T, _>>`. 57 #[derive(Debug)] 58 pub struct BidiFramed<T, WH, RH> { 59 sink: FramedSink<WH, T>, 60 stream: FramedStream<RH, T>, 61 } 62 63 /// Framed codec that uses [`bincode`] to encode structs with [`serde`] support 64 #[derive(Debug)] 65 pub struct BincodeCodec<T> { 66 _pd: PhantomData<T>, 67 } 68 69 impl<T, WH, RH> BidiFramed<T, WH, RH> 70 where 71 WH: AsyncWrite, 72 RH: AsyncRead, 73 T: serde::Serialize + serde::de::DeserializeOwned, 74 { 75 /// Builds a new `BidiFramed` codec around a stream `stream`. 76 /// 77 /// See [`TcpBidiFramed::new_from_tcp`] for a more efficient version in case 78 /// the stream is a tokio TCP stream. 79 pub fn new<S>(stream: S) -> BidiFramed<T, WriteHalf<S>, ReadHalf<S>> 80 where 81 S: AsyncRead + AsyncWrite, 82 { 83 let (read, write) = tokio::io::split(stream); 84 BidiFramed { 85 sink: FramedSink::new(write, BincodeCodec::new()), 86 stream: FramedStream::new(read, BincodeCodec::new()), 87 } 88 } 89 90 /// Splits the codec in its sending and receiving parts 91 /// 92 /// This can be useful in cases where potentially simultaneous read and 93 /// write operations are required. Otherwise a we would need a mutex to 94 /// guard access. 95 pub fn borrow_parts(&mut self) -> (&mut FramedSink<WH, T>, &mut FramedStream<RH, T>) { 96 (&mut self.sink, &mut self.stream) 97 } 98 } 99 100 impl<T> TcpBidiFramed<T> 101 where 102 T: serde::Serialize + serde::de::DeserializeOwned, 103 { 104 /// Special constructor for tokio TCP connections. 105 /// 106 /// Tokio [`TcpStream`](tokio::net::TcpStream) implements an efficient 107 /// method of splitting the stream into a read and a write half this 108 /// constructor takes advantage of. 109 pub fn new_from_tcp(stream: tokio::net::TcpStream) -> TcpBidiFramed<T> { 110 let (read, write) = stream.into_split(); 111 BidiFramed { 112 sink: FramedSink::new(write, BincodeCodec::new()), 113 stream: FramedStream::new(read, BincodeCodec::new()), 114 } 115 } 116 } 117 118 impl<T, WH, RH> Sink<T> for BidiFramed<T, WH, RH> 119 where 120 WH: tokio::io::AsyncWrite + Unpin, 121 RH: Unpin, 122 T: Debug + serde::Serialize, 123 { 124 type Error = anyhow::Error; 125 126 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 127 Sink::poll_ready(Pin::new(&mut self.sink), cx) 128 } 129 130 fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { 131 Sink::start_send(Pin::new(&mut self.sink), item) 132 } 133 134 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 135 Sink::poll_flush(Pin::new(&mut self.sink), cx) 136 } 137 138 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 139 Sink::poll_close(Pin::new(&mut self.sink), cx) 140 } 141 } 142 143 impl<T, WH, RH> Stream for BidiFramed<T, WH, RH> 144 where 145 T: serde::de::DeserializeOwned, 146 WH: Unpin, 147 RH: tokio::io::AsyncRead + Unpin, 148 { 149 type Item = Result<T, anyhow::Error>; 150 151 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { 152 Stream::poll_next(Pin::new(&mut self.stream), cx) 153 } 154 } 155 156 impl<T, WH, RH> FramedTransport<T> for BidiFramed<T, WH, RH> 157 where 158 T: Debug + serde::Serialize + serde::de::DeserializeOwned + Send, 159 WH: tokio::io::AsyncWrite + Send + Unpin, 160 RH: tokio::io::AsyncRead + Send + Unpin, 161 { 162 fn borrow_split( 163 &mut self, 164 ) -> ( 165 &'_ mut (dyn Sink<T, Error = anyhow::Error> + Send + Unpin), 166 &'_ mut (dyn Stream<Item = Result<T, anyhow::Error>> + Send + Unpin), 167 ) { 168 let (sink, stream) = self.borrow_parts(); 169 (&mut *sink, &mut *stream) 170 } 171 } 172 173 impl<T> BincodeCodec<T> { 174 fn new() -> BincodeCodec<T> { 175 BincodeCodec { 176 _pd: Default::default(), 177 } 178 } 179 } 180 181 impl<T> tokio_util::codec::Encoder<T> for BincodeCodec<T> 182 where 183 T: serde::Serialize + Debug, 184 { 185 type Error = anyhow::Error; 186 187 fn encode(&mut self, item: T, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> { 188 // First, write a dummy length field and remember its position 189 let old_len = dst.len(); 190 dst.writer().write_all(&[0u8; 8]).unwrap(); 191 assert_eq!(dst.len(), old_len + 8); 192 193 // Then we serialize the message into the buffer 194 bincode::serialize_into(dst.writer(), &item).map_err(|e| { 195 error!( 196 target: LOG_NET_PEER, 197 "Serializing message failed: {:?}", item 198 ); 199 e 200 })?; 201 202 // Lastly we update the length field by counting how many bytes have been 203 // written 204 let new_len = dst.len(); 205 let encoded_len = new_len - old_len - 8; 206 dst[old_len..old_len + 8].copy_from_slice(&encoded_len.to_be_bytes()[..]); 207 208 Ok(()) 209 } 210 } 211 212 impl<T> tokio_util::codec::Decoder for BincodeCodec<T> 213 where 214 T: serde::de::DeserializeOwned, 215 { 216 type Item = T; 217 type Error = anyhow::Error; 218 219 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { 220 if src.len() < 8 { 221 return Ok(None); 222 } 223 224 let length = u64::from_be_bytes(src[0..8].try_into().expect("correct length")); 225 if src.len() < (length as usize) + 8 { 226 trace!(length, buffern_len = src.len(), "Received partial message"); 227 return Ok(None); 228 } 229 trace!(length, "Received full message"); 230 231 src.reader() 232 .read_exact(&mut [0u8; 8][..]) 233 .expect("minimum length checked"); 234 235 Ok(bincode::deserialize_from(src.reader()).map(Option::Some)?) 236 } 237 } 238 239 #[cfg(test)] 240 mod tests { 241 use std::time::Duration; 242 243 use futures::{SinkExt, StreamExt}; 244 use serde::{Deserialize, Serialize}; 245 use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream, ReadHalf, WriteHalf}; 246 247 use crate::net::framed::BidiFramed; 248 249 #[tokio::test] 250 async fn test_roundtrip() { 251 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] 252 enum TestEnum { 253 Foo, 254 Bar(u64), 255 } 256 257 let input = vec![TestEnum::Foo, TestEnum::Bar(42), TestEnum::Foo]; 258 let (sender, recipient) = tokio::io::duplex(1024); 259 260 let mut framed_sender = 261 BidiFramed::<TestEnum, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new(sender); 262 263 let mut framed_recipient = 264 BidiFramed::<TestEnum, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new(recipient); 265 266 for item in &input { 267 framed_sender.send(item.clone()).await.unwrap(); 268 } 269 270 for item in &input { 271 let received = framed_recipient.next().await.unwrap().unwrap(); 272 assert_eq!(&received, item); 273 } 274 drop(framed_sender); 275 276 assert!(framed_recipient.next().await.is_none()); 277 } 278 279 #[tokio::test] 280 async fn test_not_try_parse_partial() { 281 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] 282 enum TestEnum { 283 Foo, 284 Bar(u64), 285 } 286 287 let (sender_src, mut recipient_src) = tokio::io::duplex(1024); 288 let (mut sender_dst, recipient_dst) = tokio::io::duplex(1024); 289 290 let mut framed_sender = 291 BidiFramed::<TestEnum, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new( 292 sender_src, 293 ); 294 let mut framed_recipient = 295 BidiFramed::<TestEnum, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new( 296 recipient_dst, 297 ); 298 299 framed_sender 300 .send(TestEnum::Bar(0x4242_4242_4242_4242)) 301 .await 302 .unwrap(); 303 304 // Simulate a partial send 305 let mut buf = [0u8; 3]; 306 recipient_src.read_exact(&mut buf).await.unwrap(); 307 sender_dst.write_all(&buf).await.unwrap(); 308 309 // Try to read, should not return an error but block 310 let received = tokio::time::timeout(Duration::from_secs(1), framed_recipient.next()).await; 311 312 assert!(received.is_err()); 313 } 314 }