codec.rs
1 use super::error::ProtocolError; 2 use super::wire::{from_u32, into_u32, NameList}; 3 4 use core::str::from_utf8; 5 use sha2::{digest::Output, Digest}; 6 7 #[derive(Debug)] 8 pub struct ObjectWriter<'a> { 9 buffer: &'a mut [u8], 10 offset: usize, 11 } 12 13 impl<'a> ObjectWriter<'a> { 14 pub fn new(buffer: &'a mut [u8]) -> Self { 15 Self { buffer, offset: 0 } 16 } 17 18 pub fn write_byte(&mut self, value: u8) -> Result<(), ProtocolError> { 19 self.write_byte_array(&[value]) 20 } 21 22 pub fn write_byte_array(&mut self, value: &[u8]) -> Result<(), ProtocolError> { 23 self.consume(value.len())?.copy_from_slice(value); 24 25 Ok(()) 26 } 27 28 pub fn write_boolean(&mut self, value: bool) -> Result<(), ProtocolError> { 29 self.write_byte(if value { 1 } else { 0 }) 30 } 31 32 pub fn write_uint32(&mut self, value: u32) -> Result<(), ProtocolError> { 33 self.consume(4)?.copy_from_slice(&value.to_be_bytes()); 34 35 Ok(()) 36 } 37 38 pub fn write_uint64(&mut self, value: u64) -> Result<(), ProtocolError> { 39 self.consume(8)?.copy_from_slice(&value.to_be_bytes()); 40 41 Ok(()) 42 } 43 44 pub fn write_string_len(&mut self, value: u32) -> Result<(), ProtocolError> { 45 self.write_uint32(value) 46 } 47 48 pub fn write_string(&mut self, value: &[u8]) -> Result<(), ProtocolError> { 49 self.write_string_len(into_u32(value.len()))?; 50 self.write_byte_array(value) 51 } 52 53 pub fn write_name_list(&mut self, value: NameList) -> Result<(), ProtocolError> { 54 self.write_string_utf8(value.as_str()) 55 } 56 57 // Not an SSH data type, this is a "string" SSH type that has been validated 58 // to be UTF-8 for convenience when interoperating with human-readable text. 59 60 pub fn write_string_utf8(&mut self, value: &str) -> Result<(), ProtocolError> { 61 self.write_string(value.as_bytes()) 62 } 63 64 // Some objects contain nested objects which appear as ordinary string types 65 // which themselves opaquely contain an (SSH) encoded object representation. 66 67 pub fn write_nested<F>(&mut self, write_fn: F) -> Result<(), ProtocolError> 68 where 69 F: FnOnce(&mut ObjectWriter) -> Result<(), ProtocolError>, 70 { 71 if self.offset + 4 > self.buffer.len() { 72 return Err(ProtocolError::BufferExhausted); 73 } 74 75 let mut writer = ObjectWriter::new(&mut self.buffer[self.offset + 4..]); 76 write_fn(&mut writer)?; // let the lambda populate the buffer as desired 77 78 let object_len = writer.into_written().len(); 79 80 self.write_string_len(into_u32(object_len))?; 81 self.skip(object_len)?; 82 83 Ok(()) 84 } 85 86 pub fn skip(&mut self, len: usize) -> Result<(), ProtocolError> { 87 let _ = self.consume(len)?; 88 89 Ok(()) 90 } 91 92 pub fn into_written(self) -> &'a [u8] { 93 &self.buffer[..self.offset] 94 } 95 96 fn consume(&mut self, len: usize) -> Result<&mut [u8], ProtocolError> { 97 if self.offset + len <= self.buffer.len() { 98 let consumed = &mut self.buffer[self.offset..][..len]; 99 self.offset += len; // consider this slice as consumed 100 101 Ok(consumed) 102 } else { 103 Err(ProtocolError::BufferExhausted) 104 } 105 } 106 } 107 108 #[derive(Debug)] 109 pub struct ObjectReader<'a> { 110 buffer: &'a [u8], 111 } 112 113 impl<'a> ObjectReader<'a> { 114 pub fn new(buffer: &'a [u8]) -> Self { 115 Self { buffer } 116 } 117 118 pub fn read_byte(&mut self) -> Result<u8, ProtocolError> { 119 Ok(self.consume(1)?[0]) 120 } 121 122 pub fn read_byte_array<const N: usize>(&mut self) -> Result<&'a [u8; N], ProtocolError> { 123 Ok(super::unwrap_unreachable(self.consume(N)?.try_into().ok())) 124 } 125 126 pub fn read_boolean(&mut self) -> Result<bool, ProtocolError> { 127 Ok(self.read_byte()? != 0) 128 } 129 130 pub fn read_uint32(&mut self) -> Result<u32, ProtocolError> { 131 Ok(u32::from_be_bytes(*self.read_byte_array::<4>()?)) 132 } 133 134 pub fn read_uint64(&mut self) -> Result<u64, ProtocolError> { 135 Ok(u64::from_be_bytes(*self.read_byte_array::<8>()?)) 136 } 137 138 pub fn read_string(&mut self) -> Result<&'a [u8], ProtocolError> { 139 let len = from_u32(self.read_uint32()?); 140 self.consume(len) // read the bytes 141 } 142 143 pub fn read_string_fixed<const N: usize>(&mut self) -> Result<&'a [u8; N], ProtocolError> { 144 let string = self.read_string()?; // and length check 145 string 146 .try_into() 147 .map_err(|_| ProtocolError::BadStringLength) 148 } 149 150 // Not an SSH data type, this is a "string" SSH type that has been validated 151 // to be UTF-8 for convenience when interoperating with human-readable text. 152 153 pub fn read_string_utf8(&mut self) -> Result<&'a str, ProtocolError> { 154 Ok(from_utf8(self.read_string()?)?) 155 } 156 157 // Not an SSH data type, this is a "string" SSH type that has been validated 158 // to be US-ASCII with only printable characters, for use by internal names. 159 160 pub fn read_internal_name(&mut self) -> Result<&'a str, ProtocolError> { 161 let string = self.read_string_utf8()?; 162 163 for &byte in string.as_bytes() { 164 if byte >= 0x7F || byte <= 0x1F { 165 return Err(ProtocolError::BadStringEncoding); 166 } 167 } 168 169 Ok(string) 170 } 171 172 pub fn read_remaining(&mut self) -> &'a [u8] { 173 core::mem::take(&mut self.buffer) 174 } 175 176 fn consume(&mut self, len: usize) -> Result<&'a [u8], ProtocolError> { 177 if self.buffer.len() >= len { 178 let (consumed, remaining) = self.buffer.split_at(len); 179 self.buffer = remaining; // advance to remaining bytes 180 181 Ok(consumed) 182 } else { 183 Err(ProtocolError::BufferExhausted) 184 } 185 } 186 } 187 188 #[derive(Clone, Copy, Debug)] 189 pub struct ObjectHasher<H> { 190 hasher: H, 191 } 192 193 impl<H> ObjectHasher<H> { 194 pub fn new(hasher: H) -> Self { 195 Self { hasher } 196 } 197 } 198 199 impl<H: Digest> ObjectHasher<H> { 200 pub fn hash_byte(&mut self, value: u8) { 201 self.hasher.update([value]); 202 } 203 204 pub fn hash_byte_array(&mut self, value: &[u8]) { 205 self.hasher.update(value); 206 } 207 208 #[allow(dead_code)] 209 pub fn hash_boolean(&mut self, value: bool) { 210 self.hash_byte(if value { 1 } else { 0 }) 211 } 212 213 pub fn hash_uint32(&mut self, value: u32) { 214 self.hash_byte_array(&value.to_be_bytes()); 215 } 216 217 #[allow(dead_code)] 218 pub fn hash_uint64(&mut self, value: u64) { 219 self.hash_byte_array(&value.to_be_bytes()); 220 } 221 222 pub fn hash_string(&mut self, value: &[u8]) { 223 self.hash_uint32(into_u32(value.len())); 224 self.hash_byte_array(value); 225 } 226 227 pub fn hash_mpint(&mut self, value: &[u8]) { 228 if value.is_empty() { 229 self.hash_uint32(0); 230 } else if value[0] & 0x80 != 0 { 231 self.hash_uint32(into_u32(1 + value.len())); 232 self.hash_byte(0x00); 233 self.hash_byte_array(value); 234 } else { 235 let offset = value.iter().position(|&b| b != 0).unwrap_or(0); 236 self.hash_uint32(into_u32(value.len() - offset)); 237 self.hash_byte_array(&value[offset..]); 238 } 239 } 240 241 #[allow(dead_code)] 242 pub fn hash_name_list(&mut self, value: NameList) { 243 self.hash_string_utf8(value.as_str()); 244 } 245 246 // Not an SSH data type, this is a "string" SSH type that has been validated 247 // to be UTF-8 for convenience when interoperating with human-readable text. 248 249 pub fn hash_string_utf8(&mut self, value: &str) { 250 self.hash_string(value.as_bytes()) 251 } 252 253 pub fn into_digest(self) -> Output<H> { 254 self.hasher.finalize() 255 } 256 }