/ fedimint-server / src / net / framed.rs
framed.rs
  1  //! Adapter that implements a message based protocol on top of a stream based
  2  //! one
  3  use std::fmt::Debug;
  4  use std::io::{Read, Write};
  5  use std::marker::PhantomData;
  6  use std::pin::Pin;
  7  use std::task::{Context, Poll};
  8  
  9  use bytes::{Buf, BufMut, BytesMut};
 10  use fedimint_logging::LOG_NET_PEER;
 11  use futures::{Sink, Stream};
 12  use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
 13  use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
 14  use tokio_util::codec::{FramedRead, FramedWrite};
 15  use tracing::{error, trace};
 16  
 17  /// Owned [`FramedTransport`] trait object
 18  pub type AnyFramedTransport<M> = Box<dyn FramedTransport<M> + Send + Unpin + 'static>;
 19  
 20  /// A bidirectional framed transport adapter that can be split into its read and
 21  /// write half
 22  pub trait FramedTransport<T>:
 23      Sink<T, Error = anyhow::Error> + Stream<Item = Result<T, anyhow::Error>>
 24  {
 25      /// Split the framed transport into read and write half
 26      fn borrow_split(
 27          &mut self,
 28      ) -> (
 29          &'_ mut (dyn Sink<T, Error = anyhow::Error> + Send + Unpin),
 30          &'_ mut (dyn Stream<Item = Result<T, anyhow::Error>> + Send + Unpin),
 31      );
 32  
 33      /// Transforms concrete `FramedTransport` object into an owned trait object
 34      fn into_dyn(self) -> AnyFramedTransport<T>
 35      where
 36          Self: Sized + Send + Unpin + 'static,
 37      {
 38          Box::new(self)
 39      }
 40  }
 41  
 42  /// Special case for tokio [`TcpStream`](tokio::net::TcpStream) based
 43  /// [`BidiFramed`] instances
 44  pub type TcpBidiFramed<T> = BidiFramed<T, OwnedWriteHalf, OwnedReadHalf>;
 45  
 46  /// Sink (sending) half of [`BidiFramed`]
 47  pub type FramedSink<S, T> = FramedWrite<S, BincodeCodec<T>>;
 48  /// Stream (receiving) half of [`BidiFramed`]
 49  pub type FramedStream<S, T> = FramedRead<S, BincodeCodec<T>>;
 50  
 51  /// Framed transport codec for streams
 52  ///
 53  /// Wraps a stream `S` and allows sending packetized data of type `T` over it.
 54  /// Data items are encoded using [`bincode`] and the bytes are sent over the
 55  /// stream prepended with a length field. `BidiFramed` implements `Sink<T>` and
 56  /// `Stream<Item=Result<T, _>>`.
 57  #[derive(Debug)]
 58  pub struct BidiFramed<T, WH, RH> {
 59      sink: FramedSink<WH, T>,
 60      stream: FramedStream<RH, T>,
 61  }
 62  
 63  /// Framed codec that uses [`bincode`] to encode structs with [`serde`] support
 64  #[derive(Debug)]
 65  pub struct BincodeCodec<T> {
 66      _pd: PhantomData<T>,
 67  }
 68  
 69  impl<T, WH, RH> BidiFramed<T, WH, RH>
 70  where
 71      WH: AsyncWrite,
 72      RH: AsyncRead,
 73      T: serde::Serialize + serde::de::DeserializeOwned,
 74  {
 75      /// Builds a new `BidiFramed` codec around a stream `stream`.
 76      ///
 77      /// See [`TcpBidiFramed::new_from_tcp`] for a more efficient version in case
 78      /// the stream is a tokio TCP stream.
 79      pub fn new<S>(stream: S) -> BidiFramed<T, WriteHalf<S>, ReadHalf<S>>
 80      where
 81          S: AsyncRead + AsyncWrite,
 82      {
 83          let (read, write) = tokio::io::split(stream);
 84          BidiFramed {
 85              sink: FramedSink::new(write, BincodeCodec::new()),
 86              stream: FramedStream::new(read, BincodeCodec::new()),
 87          }
 88      }
 89  
 90      /// Splits the codec in its sending and receiving parts
 91      ///
 92      /// This can be useful in cases where potentially simultaneous read and
 93      /// write operations are required. Otherwise a we would need a mutex to
 94      /// guard access.
 95      pub fn borrow_parts(&mut self) -> (&mut FramedSink<WH, T>, &mut FramedStream<RH, T>) {
 96          (&mut self.sink, &mut self.stream)
 97      }
 98  }
 99  
100  impl<T> TcpBidiFramed<T>
101  where
102      T: serde::Serialize + serde::de::DeserializeOwned,
103  {
104      /// Special constructor for tokio TCP connections.
105      ///
106      /// Tokio [`TcpStream`](tokio::net::TcpStream) implements an efficient
107      /// method of splitting the stream into a read and a write half this
108      /// constructor takes advantage of.
109      pub fn new_from_tcp(stream: tokio::net::TcpStream) -> TcpBidiFramed<T> {
110          let (read, write) = stream.into_split();
111          BidiFramed {
112              sink: FramedSink::new(write, BincodeCodec::new()),
113              stream: FramedStream::new(read, BincodeCodec::new()),
114          }
115      }
116  }
117  
118  impl<T, WH, RH> Sink<T> for BidiFramed<T, WH, RH>
119  where
120      WH: tokio::io::AsyncWrite + Unpin,
121      RH: Unpin,
122      T: Debug + serde::Serialize,
123  {
124      type Error = anyhow::Error;
125  
126      fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
127          Sink::poll_ready(Pin::new(&mut self.sink), cx)
128      }
129  
130      fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
131          Sink::start_send(Pin::new(&mut self.sink), item)
132      }
133  
134      fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
135          Sink::poll_flush(Pin::new(&mut self.sink), cx)
136      }
137  
138      fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
139          Sink::poll_close(Pin::new(&mut self.sink), cx)
140      }
141  }
142  
143  impl<T, WH, RH> Stream for BidiFramed<T, WH, RH>
144  where
145      T: serde::de::DeserializeOwned,
146      WH: Unpin,
147      RH: tokio::io::AsyncRead + Unpin,
148  {
149      type Item = Result<T, anyhow::Error>;
150  
151      fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
152          Stream::poll_next(Pin::new(&mut self.stream), cx)
153      }
154  }
155  
156  impl<T, WH, RH> FramedTransport<T> for BidiFramed<T, WH, RH>
157  where
158      T: Debug + serde::Serialize + serde::de::DeserializeOwned + Send,
159      WH: tokio::io::AsyncWrite + Send + Unpin,
160      RH: tokio::io::AsyncRead + Send + Unpin,
161  {
162      fn borrow_split(
163          &mut self,
164      ) -> (
165          &'_ mut (dyn Sink<T, Error = anyhow::Error> + Send + Unpin),
166          &'_ mut (dyn Stream<Item = Result<T, anyhow::Error>> + Send + Unpin),
167      ) {
168          let (sink, stream) = self.borrow_parts();
169          (&mut *sink, &mut *stream)
170      }
171  }
172  
173  impl<T> BincodeCodec<T> {
174      fn new() -> BincodeCodec<T> {
175          BincodeCodec {
176              _pd: Default::default(),
177          }
178      }
179  }
180  
181  impl<T> tokio_util::codec::Encoder<T> for BincodeCodec<T>
182  where
183      T: serde::Serialize + Debug,
184  {
185      type Error = anyhow::Error;
186  
187      fn encode(&mut self, item: T, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
188          // First, write a dummy length field and remember its position
189          let old_len = dst.len();
190          dst.writer().write_all(&[0u8; 8]).unwrap();
191          assert_eq!(dst.len(), old_len + 8);
192  
193          // Then we serialize the message into the buffer
194          bincode::serialize_into(dst.writer(), &item).map_err(|e| {
195              error!(
196                  target: LOG_NET_PEER,
197                  "Serializing message failed: {:?}", item
198              );
199              e
200          })?;
201  
202          // Lastly we update the length field by counting how many bytes have been
203          // written
204          let new_len = dst.len();
205          let encoded_len = new_len - old_len - 8;
206          dst[old_len..old_len + 8].copy_from_slice(&encoded_len.to_be_bytes()[..]);
207  
208          Ok(())
209      }
210  }
211  
212  impl<T> tokio_util::codec::Decoder for BincodeCodec<T>
213  where
214      T: serde::de::DeserializeOwned,
215  {
216      type Item = T;
217      type Error = anyhow::Error;
218  
219      fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
220          if src.len() < 8 {
221              return Ok(None);
222          }
223  
224          let length = u64::from_be_bytes(src[0..8].try_into().expect("correct length"));
225          if src.len() < (length as usize) + 8 {
226              trace!(length, buffern_len = src.len(), "Received partial message");
227              return Ok(None);
228          }
229          trace!(length, "Received full message");
230  
231          src.reader()
232              .read_exact(&mut [0u8; 8][..])
233              .expect("minimum length checked");
234  
235          Ok(bincode::deserialize_from(src.reader()).map(Option::Some)?)
236      }
237  }
238  
239  #[cfg(test)]
240  mod tests {
241      use std::time::Duration;
242  
243      use futures::{SinkExt, StreamExt};
244      use serde::{Deserialize, Serialize};
245      use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream, ReadHalf, WriteHalf};
246  
247      use crate::net::framed::BidiFramed;
248  
249      #[tokio::test]
250      async fn test_roundtrip() {
251          #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
252          enum TestEnum {
253              Foo,
254              Bar(u64),
255          }
256  
257          let input = vec![TestEnum::Foo, TestEnum::Bar(42), TestEnum::Foo];
258          let (sender, recipient) = tokio::io::duplex(1024);
259  
260          let mut framed_sender =
261              BidiFramed::<TestEnum, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new(sender);
262  
263          let mut framed_recipient =
264              BidiFramed::<TestEnum, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new(recipient);
265  
266          for item in &input {
267              framed_sender.send(item.clone()).await.unwrap();
268          }
269  
270          for item in &input {
271              let received = framed_recipient.next().await.unwrap().unwrap();
272              assert_eq!(&received, item);
273          }
274          drop(framed_sender);
275  
276          assert!(framed_recipient.next().await.is_none());
277      }
278  
279      #[tokio::test]
280      async fn test_not_try_parse_partial() {
281          #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
282          enum TestEnum {
283              Foo,
284              Bar(u64),
285          }
286  
287          let (sender_src, mut recipient_src) = tokio::io::duplex(1024);
288          let (mut sender_dst, recipient_dst) = tokio::io::duplex(1024);
289  
290          let mut framed_sender =
291              BidiFramed::<TestEnum, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new(
292                  sender_src,
293              );
294          let mut framed_recipient =
295              BidiFramed::<TestEnum, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new(
296                  recipient_dst,
297              );
298  
299          framed_sender
300              .send(TestEnum::Bar(0x4242_4242_4242_4242))
301              .await
302              .unwrap();
303  
304          // Simulate a partial send
305          let mut buf = [0u8; 3];
306          recipient_src.read_exact(&mut buf).await.unwrap();
307          sender_dst.write_all(&buf).await.unwrap();
308  
309          // Try to read, should not return an error but block
310          let received = tokio::time::timeout(Duration::from_secs(1), framed_recipient.next()).await;
311  
312          assert!(received.is_err());
313      }
314  }