compression.rs
1 //! Efficiently decompress gateway events. 2 3 use std::{error::Error, fmt}; 4 5 /// Decompressed event buffer. 6 const BUFFER_SIZE: usize = 32 * 1024; 7 8 /// An operation relating to compression failed. 9 #[derive(Debug)] 10 pub struct CompressionError { 11 /// Type of error. 12 pub kind: CompressionErrorType, 13 /// Source error if available. 14 pub source: Option<Box<dyn Error + Send + Sync>>, 15 } 16 17 impl CompressionError { 18 /// Shortcut to create a new error for an erroneous status code. 19 pub fn from_code(code: usize) -> Self { 20 Self { 21 kind: CompressionErrorType::Decompressing, 22 source: Some(zstd_safe::get_error_name(code).into()), 23 } 24 } 25 } 26 27 impl fmt::Display for CompressionError { 28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 29 match self.kind { 30 CompressionErrorType::Decompressing => { 31 f.write_str("message could not be decompressed") 32 } 33 } 34 } 35 } 36 37 impl Error for CompressionError { 38 fn source(&self) -> Option<&(dyn Error + 'static)> { 39 self 40 .source 41 .as_ref() 42 .map(|source| &**source as &(dyn Error + 'static)) 43 } 44 } 45 46 /// Type of [`CompressionError`] that occurred. 47 #[derive(Debug)] 48 pub enum CompressionErrorType { 49 /// Decompressing a frame failed. 50 Decompressing, 51 } 52 53 pub struct Decompressor { 54 /// Common decompressed message buffer. 55 buffer: Box<[u8]>, 56 /// Reusable zstd decompression context. 57 ctx: zstd_safe::DCtx<'static>, 58 } 59 60 impl Decompressor { 61 /// Create a new decompressor for a shard. 62 pub fn new() -> Self { 63 Self { 64 buffer: vec![0; BUFFER_SIZE].into_boxed_slice(), 65 ctx: zstd_safe::DCtx::create(), 66 } 67 } 68 69 /// Decompress a message. 70 /// 71 /// # Errors 72 /// 73 /// Returns a [`CompressionErrorType::Decompressing`] error type if the 74 /// message could not be decompressed. 75 pub fn decompress( 76 &mut self, 77 message: &[u8], 78 ) -> Result<Vec<u8>, CompressionError> { 79 let mut input = zstd_safe::InBuffer::around(message); 80 81 // Decompressed message. `Vec::extend_from_slice` efficiently allocates 82 // only what's necessary. 83 let mut decompressed = Vec::new(); 84 85 loop { 86 let mut output = zstd_safe::OutBuffer::around(self.buffer.as_mut()); 87 88 self 89 .ctx 90 .decompress_stream(&mut output, &mut input) 91 .map_err(CompressionError::from_code)?; 92 93 decompressed.extend_from_slice(output.as_slice()); 94 95 // Break when message has been fully decompressed. 96 if input.pos == input.src.len() && output.pos() != output.capacity() { 97 break; 98 } 99 } 100 101 Ok(decompressed) 102 } 103 104 /// Reset the decompressor's internal state. 105 pub fn reset(&mut self) { 106 self 107 .ctx 108 .reset(zstd_safe::ResetDirective::SessionOnly) 109 .expect("resetting session is infallible"); 110 } 111 } 112 113 impl fmt::Debug for Decompressor { 114 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 115 f.debug_struct("Decompressor") 116 .field("buffer", &self.buffer) 117 .field("ctx", &"<decompression context>") 118 .finish() 119 } 120 } 121 122 #[cfg(test)] 123 mod tests { 124 use super::Decompressor; 125 126 const MESSAGE: [u8; 117] = [ 127 40, 181, 47, 253, 0, 64, 100, 3, 0, 66, 7, 25, 28, 112, 137, 115, 116, 40, 128 208, 203, 85, 255, 167, 74, 75, 126, 203, 222, 231, 255, 151, 18, 211, 212, 129 171, 144, 151, 210, 255, 51, 4, 49, 34, 71, 98, 2, 36, 253, 122, 141, 99, 130 203, 225, 11, 162, 47, 133, 241, 6, 201, 82, 245, 91, 206, 247, 164, 226, 131 156, 92, 108, 130, 123, 11, 95, 199, 15, 61, 179, 117, 157, 28, 37, 65, 64, 132 25, 250, 182, 8, 199, 205, 44, 73, 47, 19, 218, 45, 27, 14, 245, 202, 81, 133 82, 122, 167, 121, 71, 173, 61, 140, 190, 15, 3, 1, 0, 36, 74, 18, 134 ]; 135 136 const OUTPUT: &str = r#"{"t":null,"s":null,"op":10,"d":{"heartbeat_interval":41250,"_trace":["[\"gateway-prd-us-east1-c-7s4x\",{\"micros\":0.0}]"]}}"#; 137 138 #[test] 139 fn message() { 140 let mut decompressor = Decompressor::new(); 141 assert_eq!( 142 decompressor.decompress(&MESSAGE).unwrap(), 143 OUTPUT.as_bytes() 144 ); 145 } 146 147 #[test] 148 fn reset() { 149 let mut decompressor = Decompressor::new(); 150 decompressor 151 .decompress(&MESSAGE[..MESSAGE.len() - 2]) 152 .unwrap(); 153 154 assert!(decompressor.decompress(&MESSAGE).is_err()); 155 decompressor.reset(); 156 assert_eq!( 157 decompressor.decompress(&MESSAGE).unwrap(), 158 OUTPUT.as_bytes() 159 ); 160 } 161 }