/ radicle-node / src / deserializer.rs
deserializer.rs
  1  use std::io;
  2  use std::marker::PhantomData;
  3  
  4  use crate::service::message::Message;
  5  use crate::wire;
  6  
  7  /// Message stream deserializer.
  8  ///
  9  /// Used to for example turn a byte stream into network messages.
 10  #[derive(Debug)]
 11  pub struct Deserializer<D = Message> {
 12      unparsed: Vec<u8>,
 13      item: PhantomData<D>,
 14  }
 15  
 16  impl<D: wire::Decode> Default for Deserializer<D> {
 17      fn default() -> Self {
 18          Self::new(wire::Size::MAX as usize + 1)
 19      }
 20  }
 21  
 22  impl<D> From<Vec<u8>> for Deserializer<D> {
 23      fn from(unparsed: Vec<u8>) -> Self {
 24          Self {
 25              unparsed,
 26              item: PhantomData,
 27          }
 28      }
 29  }
 30  
 31  impl<D: wire::Decode> Deserializer<D> {
 32      /// Create a new stream decoder.
 33      pub fn new(capacity: usize) -> Self {
 34          Self {
 35              unparsed: Vec::with_capacity(capacity),
 36              item: PhantomData,
 37          }
 38      }
 39  
 40      /// Input bytes into the decoder.
 41      pub fn input(&mut self, bytes: &[u8]) {
 42          self.unparsed.extend_from_slice(bytes);
 43      }
 44  
 45      /// Decode and return the next message. Returns [`None`] if nothing was decoded.
 46      pub fn deserialize_next(&mut self) -> Result<Option<D>, wire::Error> {
 47          let mut reader = io::Cursor::new(self.unparsed.as_mut_slice());
 48  
 49          match D::decode(&mut reader) {
 50              Ok(msg) => {
 51                  let pos = reader.position() as usize;
 52                  self.unparsed.drain(..pos);
 53  
 54                  Ok(Some(msg))
 55              }
 56              Err(err) if err.is_eof() => Ok(None),
 57              Err(err) => Err(err),
 58          }
 59      }
 60  
 61      /// Drain the unparsed buffer.
 62      pub fn unparsed(&mut self) -> impl ExactSizeIterator<Item = u8> + '_ {
 63          self.unparsed.drain(..)
 64      }
 65  
 66      /// Return whether there are unparsed bytes.
 67      pub fn is_empty(&self) -> bool {
 68          self.unparsed.is_empty()
 69      }
 70  }
 71  
 72  impl<D: wire::Decode> io::Write for Deserializer<D> {
 73      fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
 74          self.input(buf);
 75  
 76          Ok(buf.len())
 77      }
 78  
 79      fn flush(&mut self) -> io::Result<()> {
 80          Ok(())
 81      }
 82  }
 83  
 84  impl<D: wire::Decode> Iterator for Deserializer<D> {
 85      type Item = Result<D, wire::Error>;
 86  
 87      fn next(&mut self) -> Option<Self::Item> {
 88          self.deserialize_next().transpose()
 89      }
 90  }
 91  
 92  #[cfg(test)]
 93  mod test {
 94      use super::*;
 95      use qcheck_macros::quickcheck;
 96  
 97      use crate::test::assert_matches;
 98  
 99      const MSG_HELLO: &[u8] = &[5, b'h', b'e', b'l', b'l', b'o'];
100      const MSG_BYE: &[u8] = &[3, b'b', b'y', b'e'];
101  
102      #[test]
103      fn test_decode_next() {
104          let mut decoder = Deserializer::<String>::new(8);
105  
106          decoder.input(&[3, b'b']);
107          assert_matches!(decoder.deserialize_next(), Ok(None));
108          assert_eq!(decoder.unparsed.len(), 2);
109  
110          decoder.input(&[b'y']);
111          assert_matches!(decoder.deserialize_next(), Ok(None));
112          assert_eq!(decoder.unparsed.len(), 3);
113  
114          decoder.input(&[b'e']);
115          assert_matches!(decoder.deserialize_next(), Ok(Some(s)) if s.as_str() == "bye");
116          assert_eq!(decoder.unparsed.len(), 0);
117          assert!(decoder.is_empty());
118      }
119  
120      #[test]
121      fn test_unparsed() {
122          let mut decoder = Deserializer::<String>::new(8);
123  
124          decoder.input(&[3, b'b', b'y']);
125          assert_eq!(decoder.unparsed().collect::<Vec<_>>(), vec![3, b'b', b'y']);
126          assert!(decoder.is_empty());
127      }
128  
129      #[quickcheck]
130      fn prop_decode_next(chunk_size: usize) {
131          let mut bytes = vec![];
132          let mut msgs = vec![];
133          let mut decoder = Deserializer::<String>::new(8);
134  
135          let chunk_size = 1 + chunk_size % MSG_HELLO.len() + MSG_BYE.len();
136  
137          bytes.extend_from_slice(MSG_HELLO);
138          bytes.extend_from_slice(MSG_BYE);
139  
140          for chunk in bytes.as_slice().chunks(chunk_size) {
141              decoder.input(chunk);
142  
143              while let Some(msg) = decoder.deserialize_next().unwrap() {
144                  msgs.push(msg);
145              }
146          }
147  
148          assert_eq!(decoder.unparsed.len(), 0);
149          assert_eq!(msgs.len(), 2);
150          assert_eq!(msgs[0], String::from("hello"));
151          assert_eq!(msgs[1], String::from("bye"));
152      }
153  }