codec.rs
 1  //! Protocol Codec implementation
 2  //!
 3  //! The protocol messages look like this:
 4  //!
 5  //! - 4 bytes for message protocol version (for this implementation 1), LE
 6  //! - 8 bytes of message size, LE
 7  //! - N bytes (as defined by previous field) for JSON object
 8  //!
 9  
10  use tokio_util::bytes::Buf;
11  use tokio_util::bytes::BufMut;
12  
13  use super::message::ProtocolMessage;
14  
15  pub struct Decoder;
16  
17  impl tokio_util::codec::Decoder for Decoder {
18      type Item = ProtocolMessage;
19      type Error = DecoderError;
20  
21      fn decode(
22          &mut self,
23          src: &mut tokio_util::bytes::BytesMut,
24      ) -> Result<Option<Self::Item>, Self::Error> {
25          if src.len() < 12 {
26              src.reserve(12 - src.len());
27              return Ok(None);
28          }
29  
30          let version = u32::from_le_bytes((&src[0..4]).try_into().expect("Invariant failed"));
31  
32          if version != 1 {
33              // TODO: Don't hardcode protocol version as number
34              return Err(DecoderError::ProtocolVersion(version));
35          }
36  
37          let message_size = u64::from_le_bytes((&src[4..12]).try_into().expect("Invariant failed"));
38  
39          let buffer_len_without_header = (src.len() - 12) as u64;
40  
41          if buffer_len_without_header < message_size {
42              let additional = message_size - buffer_len_without_header;
43              src.reserve(additional as usize);
44              return Ok(None);
45          }
46  
47          src.advance(12);
48  
49          serde_json::from_reader(src.take(message_size as usize).reader()).map_err(Self::Error::from)
50      }
51  }
52  
53  #[derive(Debug, thiserror::Error)]
54  pub enum DecoderError {
55      #[error(transparent)]
56      Io(#[from] std::io::Error),
57  
58      #[error("Wrong protocol version: Expected 1, got {}", .0)]
59      ProtocolVersion(u32),
60  
61      #[error(transparent)]
62      SerdeJson(#[from] serde_json::Error),
63  }
64  
65  pub struct Encoder;
66  
67  impl tokio_util::codec::Encoder<ProtocolMessage> for Encoder {
68      type Error = EncoderError;
69  
70      fn encode(
71          &mut self,
72          item: ProtocolMessage,
73          dst: &mut tokio_util::bytes::BytesMut,
74      ) -> Result<(), Self::Error> {
75          let buf = serde_json::to_vec(&item)
76              .inspect_err(|error| tracing::error!(?error, "failed to encode as JSON"))?;
77          dst.put_u32_le(1); // TODO: Don't hardcode protocol version as number
78          dst.put_u64_le(buf.len() as u64);
79          dst.put_slice(&buf);
80          Ok(())
81      }
82  }
83  
84  #[derive(Debug, thiserror::Error)]
85  pub enum EncoderError {
86      #[error(transparent)]
87      Io(#[from] std::io::Error),
88  
89      #[error(transparent)]
90      SerdeJson(#[from] serde_json::Error),
91  }