/ src / noise / cipher.rs
cipher.rs
  1  //! ChaCha20-Poly1305 cipher state for Noise
  2  
  3  use alloc::vec::Vec;
  4  use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce, aead::{Aead, KeyInit}};
  5  
  6  use super::{NoiseError, Result};
  7  
  8  /// Cipher state for encryption/decryption
  9  #[derive(Clone)]
 10  pub struct CipherState {
 11      key: [u8; 32],
 12      nonce: u64,
 13  }
 14  
 15  impl CipherState {
 16      /// Create new cipher state with key
 17      pub fn new(key: [u8; 32]) -> Self {
 18          Self { key, nonce: 0 }
 19      }
 20  
 21      /// Get current nonce value
 22      pub fn nonce(&self) -> u64 {
 23          self.nonce
 24      }
 25  
 26      /// Check if cipher has a key set
 27      pub fn has_key(&self) -> bool {
 28          self.key != [0u8; 32]
 29      }
 30  
 31      /// Rekey the cipher (for long-lived connections)
 32      pub fn rekey(&mut self) {
 33          let cipher = ChaCha20Poly1305::new(Key::from_slice(&self.key));
 34          let nonce = Self::encode_nonce(u64::MAX);
 35  
 36          // Encrypt 32 zeros to get new key
 37          let zeros = [0u8; 32];
 38          if let Ok(new_key_material) = cipher.encrypt(&nonce, zeros.as_ref()) {
 39              // Take first 32 bytes as new key
 40              self.key.copy_from_slice(&new_key_material[..32]);
 41              self.nonce = 0;
 42          }
 43      }
 44  
 45      /// Encode nonce as 12 bytes (4 zeros + 8 byte LE counter)
 46      fn encode_nonce(n: u64) -> Nonce {
 47          let mut nonce_bytes = [0u8; 12];
 48          nonce_bytes[4..].copy_from_slice(&n.to_le_bytes());
 49          *Nonce::from_slice(&nonce_bytes)
 50      }
 51  
 52      /// Encrypt plaintext with associated data
 53      pub fn encrypt_with_ad(&mut self, ad: &[u8], plaintext: &[u8]) -> Result<Vec<u8>> {
 54          let cipher = ChaCha20Poly1305::new(Key::from_slice(&self.key));
 55          let nonce = Self::encode_nonce(self.nonce);
 56  
 57          let ciphertext = cipher
 58              .encrypt(&nonce, chacha20poly1305::aead::Payload { msg: plaintext, aad: ad })
 59              .map_err(|_| NoiseError::EncryptionFailed)?;
 60  
 61          self.nonce += 1;
 62          Ok(ciphertext)
 63      }
 64  
 65      /// Decrypt ciphertext with associated data
 66      pub fn decrypt_with_ad(&mut self, ad: &[u8], ciphertext: &[u8]) -> Result<Vec<u8>> {
 67          if ciphertext.len() < 16 {
 68              return Err(NoiseError::InvalidMessage);
 69          }
 70  
 71          let cipher = ChaCha20Poly1305::new(Key::from_slice(&self.key));
 72          let nonce = Self::encode_nonce(self.nonce);
 73  
 74          let plaintext = cipher
 75              .decrypt(&nonce, chacha20poly1305::aead::Payload { msg: ciphertext, aad: ad })
 76              .map_err(|_| NoiseError::DecryptionFailed)?;
 77  
 78          self.nonce += 1;
 79          Ok(plaintext)
 80      }
 81  
 82      /// Encrypt without associated data (for transport)
 83      pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>> {
 84          self.encrypt_with_ad(&[], plaintext)
 85      }
 86  
 87      /// Decrypt without associated data (for transport)
 88      pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>> {
 89          self.decrypt_with_ad(&[], ciphertext)
 90      }
 91  }
 92  
 93  #[cfg(test)]
 94  mod tests {
 95      use super::*;
 96  
 97      #[test]
 98      fn test_encrypt_decrypt() {
 99          let key = [0x42u8; 32];
100          let mut cipher1 = CipherState::new(key);
101          let mut cipher2 = CipherState::new(key);
102  
103          let plaintext = b"hello world";
104          let ciphertext = cipher1.encrypt(plaintext).unwrap();
105          let decrypted = cipher2.decrypt(&ciphertext).unwrap();
106  
107          assert_eq!(plaintext.as_slice(), decrypted.as_slice());
108      }
109  
110      #[test]
111      fn test_nonce_increment() {
112          let key = [0x42u8; 32];
113          let mut cipher = CipherState::new(key);
114  
115          assert_eq!(cipher.nonce(), 0);
116          let _ = cipher.encrypt(b"test").unwrap();
117          assert_eq!(cipher.nonce(), 1);
118          let _ = cipher.encrypt(b"test").unwrap();
119          assert_eq!(cipher.nonce(), 2);
120      }
121  }