/ src / gateway / compression.rs
compression.rs
  1  //! Efficiently decompress gateway events.
  2  
  3  use std::{error::Error, fmt};
  4  
  5  /// Decompressed event buffer.
  6  const BUFFER_SIZE: usize = 32 * 1024;
  7  
  8  /// An operation relating to compression failed.
  9  #[derive(Debug)]
 10  pub struct CompressionError {
 11    /// Type of error.
 12    pub kind: CompressionErrorType,
 13    /// Source error if available.
 14    pub source: Option<Box<dyn Error + Send + Sync>>,
 15  }
 16  
 17  impl CompressionError {
 18    /// Shortcut to create a new error for an erroneous status code.
 19    pub fn from_code(code: usize) -> Self {
 20      Self {
 21        kind: CompressionErrorType::Decompressing,
 22        source: Some(zstd_safe::get_error_name(code).into()),
 23      }
 24    }
 25  }
 26  
 27  impl fmt::Display for CompressionError {
 28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 29      match self.kind {
 30        CompressionErrorType::Decompressing => {
 31          f.write_str("message could not be decompressed")
 32        }
 33      }
 34    }
 35  }
 36  
 37  impl Error for CompressionError {
 38    fn source(&self) -> Option<&(dyn Error + 'static)> {
 39      self
 40        .source
 41        .as_ref()
 42        .map(|source| &**source as &(dyn Error + 'static))
 43    }
 44  }
 45  
 46  /// Type of [`CompressionError`] that occurred.
 47  #[derive(Debug)]
 48  pub enum CompressionErrorType {
 49    /// Decompressing a frame failed.
 50    Decompressing,
 51  }
 52  
 53  pub struct Decompressor {
 54    /// Common decompressed message buffer.
 55    buffer: Box<[u8]>,
 56    /// Reusable zstd decompression context.
 57    ctx: zstd_safe::DCtx<'static>,
 58  }
 59  
 60  impl Decompressor {
 61    /// Create a new decompressor for a shard.
 62    pub fn new() -> Self {
 63      Self {
 64        buffer: vec![0; BUFFER_SIZE].into_boxed_slice(),
 65        ctx: zstd_safe::DCtx::create(),
 66      }
 67    }
 68  
 69    /// Decompress a message.
 70    ///
 71    /// # Errors
 72    ///
 73    /// Returns a [`CompressionErrorType::Decompressing`] error type if the
 74    /// message could not be decompressed.
 75    pub fn decompress(
 76      &mut self,
 77      message: &[u8],
 78    ) -> Result<Vec<u8>, CompressionError> {
 79      let mut input = zstd_safe::InBuffer::around(message);
 80  
 81      // Decompressed message. `Vec::extend_from_slice` efficiently allocates
 82      // only what's necessary.
 83      let mut decompressed = Vec::new();
 84  
 85      loop {
 86        let mut output = zstd_safe::OutBuffer::around(self.buffer.as_mut());
 87  
 88        self
 89          .ctx
 90          .decompress_stream(&mut output, &mut input)
 91          .map_err(CompressionError::from_code)?;
 92  
 93        decompressed.extend_from_slice(output.as_slice());
 94  
 95        // Break when message has been fully decompressed.
 96        if input.pos == input.src.len() && output.pos() != output.capacity() {
 97          break;
 98        }
 99      }
100  
101      Ok(decompressed)
102    }
103  
104    /// Reset the decompressor's internal state.
105    pub fn reset(&mut self) {
106      self
107        .ctx
108        .reset(zstd_safe::ResetDirective::SessionOnly)
109        .expect("resetting session is infallible");
110    }
111  }
112  
113  impl fmt::Debug for Decompressor {
114    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115      f.debug_struct("Decompressor")
116        .field("buffer", &self.buffer)
117        .field("ctx", &"<decompression context>")
118        .finish()
119    }
120  }
121  
122  #[cfg(test)]
123  mod tests {
124    use super::Decompressor;
125  
126    const MESSAGE: [u8; 117] = [
127      40, 181, 47, 253, 0, 64, 100, 3, 0, 66, 7, 25, 28, 112, 137, 115, 116, 40,
128      208, 203, 85, 255, 167, 74, 75, 126, 203, 222, 231, 255, 151, 18, 211, 212,
129      171, 144, 151, 210, 255, 51, 4, 49, 34, 71, 98, 2, 36, 253, 122, 141, 99,
130      203, 225, 11, 162, 47, 133, 241, 6, 201, 82, 245, 91, 206, 247, 164, 226,
131      156, 92, 108, 130, 123, 11, 95, 199, 15, 61, 179, 117, 157, 28, 37, 65, 64,
132      25, 250, 182, 8, 199, 205, 44, 73, 47, 19, 218, 45, 27, 14, 245, 202, 81,
133      82, 122, 167, 121, 71, 173, 61, 140, 190, 15, 3, 1, 0, 36, 74, 18,
134    ];
135  
136    const OUTPUT: &str = r#"{"t":null,"s":null,"op":10,"d":{"heartbeat_interval":41250,"_trace":["[\"gateway-prd-us-east1-c-7s4x\",{\"micros\":0.0}]"]}}"#;
137  
138    #[test]
139    fn message() {
140      let mut decompressor = Decompressor::new();
141      assert_eq!(
142        decompressor.decompress(&MESSAGE).unwrap(),
143        OUTPUT.as_bytes()
144      );
145    }
146  
147    #[test]
148    fn reset() {
149      let mut decompressor = Decompressor::new();
150      decompressor
151        .decompress(&MESSAGE[..MESSAGE.len() - 2])
152        .unwrap();
153  
154      assert!(decompressor.decompress(&MESSAGE).is_err());
155      decompressor.reset();
156      assert_eq!(
157        decompressor.decompress(&MESSAGE).unwrap(),
158        OUTPUT.as_bytes()
159      );
160    }
161  }