/ firmware / src / services / ssh / codec.rs
codec.rs
  1  use super::error::ProtocolError;
  2  use super::wire::{from_u32, into_u32, NameList};
  3  
  4  use core::str::from_utf8;
  5  use sha2::{digest::Output, Digest};
  6  
  7  #[derive(Debug)]
  8  pub struct ObjectWriter<'a> {
  9      buffer: &'a mut [u8],
 10      offset: usize,
 11  }
 12  
 13  impl<'a> ObjectWriter<'a> {
 14      pub fn new(buffer: &'a mut [u8]) -> Self {
 15          Self { buffer, offset: 0 }
 16      }
 17  
 18      pub fn write_byte(&mut self, value: u8) -> Result<(), ProtocolError> {
 19          self.write_byte_array(&[value])
 20      }
 21  
 22      pub fn write_byte_array(&mut self, value: &[u8]) -> Result<(), ProtocolError> {
 23          self.consume(value.len())?.copy_from_slice(value);
 24  
 25          Ok(())
 26      }
 27  
 28      pub fn write_boolean(&mut self, value: bool) -> Result<(), ProtocolError> {
 29          self.write_byte(if value { 1 } else { 0 })
 30      }
 31  
 32      pub fn write_uint32(&mut self, value: u32) -> Result<(), ProtocolError> {
 33          self.consume(4)?.copy_from_slice(&value.to_be_bytes());
 34  
 35          Ok(())
 36      }
 37  
 38      pub fn write_uint64(&mut self, value: u64) -> Result<(), ProtocolError> {
 39          self.consume(8)?.copy_from_slice(&value.to_be_bytes());
 40  
 41          Ok(())
 42      }
 43  
 44      pub fn write_string_len(&mut self, value: u32) -> Result<(), ProtocolError> {
 45          self.write_uint32(value)
 46      }
 47  
 48      pub fn write_string(&mut self, value: &[u8]) -> Result<(), ProtocolError> {
 49          self.write_string_len(into_u32(value.len()))?;
 50          self.write_byte_array(value)
 51      }
 52  
 53      pub fn write_name_list(&mut self, value: NameList) -> Result<(), ProtocolError> {
 54          self.write_string_utf8(value.as_str())
 55      }
 56  
 57      // Not an SSH data type, this is a "string" SSH type that has been validated
 58      // to be UTF-8 for convenience when interoperating with human-readable text.
 59  
 60      pub fn write_string_utf8(&mut self, value: &str) -> Result<(), ProtocolError> {
 61          self.write_string(value.as_bytes())
 62      }
 63  
 64      // Some objects contain nested objects which appear as ordinary string types
 65      // which themselves opaquely contain an (SSH) encoded object representation.
 66  
 67      pub fn write_nested<F>(&mut self, write_fn: F) -> Result<(), ProtocolError>
 68      where
 69          F: FnOnce(&mut ObjectWriter) -> Result<(), ProtocolError>,
 70      {
 71          if self.offset + 4 > self.buffer.len() {
 72              return Err(ProtocolError::BufferExhausted);
 73          }
 74  
 75          let mut writer = ObjectWriter::new(&mut self.buffer[self.offset + 4..]);
 76          write_fn(&mut writer)?; // let the lambda populate the buffer as desired
 77  
 78          let object_len = writer.into_written().len();
 79  
 80          self.write_string_len(into_u32(object_len))?;
 81          self.skip(object_len)?;
 82  
 83          Ok(())
 84      }
 85  
 86      pub fn skip(&mut self, len: usize) -> Result<(), ProtocolError> {
 87          let _ = self.consume(len)?;
 88  
 89          Ok(())
 90      }
 91  
 92      pub fn into_written(self) -> &'a [u8] {
 93          &self.buffer[..self.offset]
 94      }
 95  
 96      fn consume(&mut self, len: usize) -> Result<&mut [u8], ProtocolError> {
 97          if self.offset + len <= self.buffer.len() {
 98              let consumed = &mut self.buffer[self.offset..][..len];
 99              self.offset += len; // consider this slice as consumed
100  
101              Ok(consumed)
102          } else {
103              Err(ProtocolError::BufferExhausted)
104          }
105      }
106  }
107  
108  #[derive(Debug)]
109  pub struct ObjectReader<'a> {
110      buffer: &'a [u8],
111  }
112  
113  impl<'a> ObjectReader<'a> {
114      pub fn new(buffer: &'a [u8]) -> Self {
115          Self { buffer }
116      }
117  
118      pub fn read_byte(&mut self) -> Result<u8, ProtocolError> {
119          Ok(self.consume(1)?[0])
120      }
121  
122      pub fn read_byte_array<const N: usize>(&mut self) -> Result<&'a [u8; N], ProtocolError> {
123          Ok(super::unwrap_unreachable(self.consume(N)?.try_into().ok()))
124      }
125  
126      pub fn read_boolean(&mut self) -> Result<bool, ProtocolError> {
127          Ok(self.read_byte()? != 0)
128      }
129  
130      pub fn read_uint32(&mut self) -> Result<u32, ProtocolError> {
131          Ok(u32::from_be_bytes(*self.read_byte_array::<4>()?))
132      }
133  
134      pub fn read_uint64(&mut self) -> Result<u64, ProtocolError> {
135          Ok(u64::from_be_bytes(*self.read_byte_array::<8>()?))
136      }
137  
138      pub fn read_string(&mut self) -> Result<&'a [u8], ProtocolError> {
139          let len = from_u32(self.read_uint32()?);
140          self.consume(len) // read the bytes
141      }
142  
143      pub fn read_string_fixed<const N: usize>(&mut self) -> Result<&'a [u8; N], ProtocolError> {
144          let string = self.read_string()?; // and length check
145          string
146              .try_into()
147              .map_err(|_| ProtocolError::BadStringLength)
148      }
149  
150      // Not an SSH data type, this is a "string" SSH type that has been validated
151      // to be UTF-8 for convenience when interoperating with human-readable text.
152  
153      pub fn read_string_utf8(&mut self) -> Result<&'a str, ProtocolError> {
154          Ok(from_utf8(self.read_string()?)?)
155      }
156  
157      // Not an SSH data type, this is a "string" SSH type that has been validated
158      // to be US-ASCII with only printable characters, for use by internal names.
159  
160      pub fn read_internal_name(&mut self) -> Result<&'a str, ProtocolError> {
161          let string = self.read_string_utf8()?;
162  
163          for &byte in string.as_bytes() {
164              if byte >= 0x7F || byte <= 0x1F {
165                  return Err(ProtocolError::BadStringEncoding);
166              }
167          }
168  
169          Ok(string)
170      }
171  
172      pub fn read_remaining(&mut self) -> &'a [u8] {
173          core::mem::take(&mut self.buffer)
174      }
175  
176      fn consume(&mut self, len: usize) -> Result<&'a [u8], ProtocolError> {
177          if self.buffer.len() >= len {
178              let (consumed, remaining) = self.buffer.split_at(len);
179              self.buffer = remaining; // advance to remaining bytes
180  
181              Ok(consumed)
182          } else {
183              Err(ProtocolError::BufferExhausted)
184          }
185      }
186  }
187  
188  #[derive(Clone, Copy, Debug)]
189  pub struct ObjectHasher<H> {
190      hasher: H,
191  }
192  
193  impl<H> ObjectHasher<H> {
194      pub fn new(hasher: H) -> Self {
195          Self { hasher }
196      }
197  }
198  
199  impl<H: Digest> ObjectHasher<H> {
200      pub fn hash_byte(&mut self, value: u8) {
201          self.hasher.update([value]);
202      }
203  
204      pub fn hash_byte_array(&mut self, value: &[u8]) {
205          self.hasher.update(value);
206      }
207  
208      #[allow(dead_code)]
209      pub fn hash_boolean(&mut self, value: bool) {
210          self.hash_byte(if value { 1 } else { 0 })
211      }
212  
213      pub fn hash_uint32(&mut self, value: u32) {
214          self.hash_byte_array(&value.to_be_bytes());
215      }
216  
217      #[allow(dead_code)]
218      pub fn hash_uint64(&mut self, value: u64) {
219          self.hash_byte_array(&value.to_be_bytes());
220      }
221  
222      pub fn hash_string(&mut self, value: &[u8]) {
223          self.hash_uint32(into_u32(value.len()));
224          self.hash_byte_array(value);
225      }
226  
227      pub fn hash_mpint(&mut self, value: &[u8]) {
228          if value.is_empty() {
229              self.hash_uint32(0);
230          } else if value[0] & 0x80 != 0 {
231              self.hash_uint32(into_u32(1 + value.len()));
232              self.hash_byte(0x00);
233              self.hash_byte_array(value);
234          } else {
235              let offset = value.iter().position(|&b| b != 0).unwrap_or(0);
236              self.hash_uint32(into_u32(value.len() - offset));
237              self.hash_byte_array(&value[offset..]);
238          }
239      }
240  
241      #[allow(dead_code)]
242      pub fn hash_name_list(&mut self, value: NameList) {
243          self.hash_string_utf8(value.as_str());
244      }
245  
246      // Not an SSH data type, this is a "string" SSH type that has been validated
247      // to be UTF-8 for convenience when interoperating with human-readable text.
248  
249      pub fn hash_string_utf8(&mut self, value: &str) {
250          self.hash_string(value.as_bytes())
251      }
252  
253      pub fn into_digest(self) -> Output<H> {
254          self.hasher.finalize()
255      }
256  }