cipher.rs
1 //! ChaCha20-Poly1305 cipher state for Noise 2 3 use alloc::vec::Vec; 4 use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce, aead::{Aead, KeyInit}}; 5 6 use super::{NoiseError, Result}; 7 8 /// Cipher state for encryption/decryption 9 #[derive(Clone)] 10 pub struct CipherState { 11 key: [u8; 32], 12 nonce: u64, 13 } 14 15 impl CipherState { 16 /// Create new cipher state with key 17 pub fn new(key: [u8; 32]) -> Self { 18 Self { key, nonce: 0 } 19 } 20 21 /// Get current nonce value 22 pub fn nonce(&self) -> u64 { 23 self.nonce 24 } 25 26 /// Check if cipher has a key set 27 pub fn has_key(&self) -> bool { 28 self.key != [0u8; 32] 29 } 30 31 /// Rekey the cipher (for long-lived connections) 32 pub fn rekey(&mut self) { 33 let cipher = ChaCha20Poly1305::new(Key::from_slice(&self.key)); 34 let nonce = Self::encode_nonce(u64::MAX); 35 36 // Encrypt 32 zeros to get new key 37 let zeros = [0u8; 32]; 38 if let Ok(new_key_material) = cipher.encrypt(&nonce, zeros.as_ref()) { 39 // Take first 32 bytes as new key 40 self.key.copy_from_slice(&new_key_material[..32]); 41 self.nonce = 0; 42 } 43 } 44 45 /// Encode nonce as 12 bytes (4 zeros + 8 byte LE counter) 46 fn encode_nonce(n: u64) -> Nonce { 47 let mut nonce_bytes = [0u8; 12]; 48 nonce_bytes[4..].copy_from_slice(&n.to_le_bytes()); 49 *Nonce::from_slice(&nonce_bytes) 50 } 51 52 /// Encrypt plaintext with associated data 53 pub fn encrypt_with_ad(&mut self, ad: &[u8], plaintext: &[u8]) -> Result<Vec<u8>> { 54 let cipher = ChaCha20Poly1305::new(Key::from_slice(&self.key)); 55 let nonce = Self::encode_nonce(self.nonce); 56 57 let ciphertext = cipher 58 .encrypt(&nonce, chacha20poly1305::aead::Payload { msg: plaintext, aad: ad }) 59 .map_err(|_| NoiseError::EncryptionFailed)?; 60 61 self.nonce += 1; 62 Ok(ciphertext) 63 } 64 65 /// Decrypt ciphertext with associated data 66 pub fn decrypt_with_ad(&mut self, ad: &[u8], ciphertext: &[u8]) -> Result<Vec<u8>> { 67 if ciphertext.len() < 16 { 68 return Err(NoiseError::InvalidMessage); 69 } 70 71 let cipher = ChaCha20Poly1305::new(Key::from_slice(&self.key)); 72 let nonce = Self::encode_nonce(self.nonce); 73 74 let plaintext = cipher 75 .decrypt(&nonce, chacha20poly1305::aead::Payload { msg: ciphertext, aad: ad }) 76 .map_err(|_| NoiseError::DecryptionFailed)?; 77 78 self.nonce += 1; 79 Ok(plaintext) 80 } 81 82 /// Encrypt without associated data (for transport) 83 pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>> { 84 self.encrypt_with_ad(&[], plaintext) 85 } 86 87 /// Decrypt without associated data (for transport) 88 pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>> { 89 self.decrypt_with_ad(&[], ciphertext) 90 } 91 } 92 93 #[cfg(test)] 94 mod tests { 95 use super::*; 96 97 #[test] 98 fn test_encrypt_decrypt() { 99 let key = [0x42u8; 32]; 100 let mut cipher1 = CipherState::new(key); 101 let mut cipher2 = CipherState::new(key); 102 103 let plaintext = b"hello world"; 104 let ciphertext = cipher1.encrypt(plaintext).unwrap(); 105 let decrypted = cipher2.decrypt(&ciphertext).unwrap(); 106 107 assert_eq!(plaintext.as_slice(), decrypted.as_slice()); 108 } 109 110 #[test] 111 fn test_nonce_increment() { 112 let key = [0x42u8; 32]; 113 let mut cipher = CipherState::new(key); 114 115 assert_eq!(cipher.nonce(), 0); 116 let _ = cipher.encrypt(b"test").unwrap(); 117 assert_eq!(cipher.nonce(), 1); 118 let _ = cipher.encrypt(b"test").unwrap(); 119 assert_eq!(cipher.nonce(), 2); 120 } 121 }