/ src / chacha20.rs
chacha20.rs
  1  extern crate quickcheck;
  2  
  3  use core::arch::aarch64::*;
  4  use quickcheck::{Arbitrary, Gen};
  5  
  6  /// ChaCha20 state consists of 16 words (u32)
  7  const STATE_LEN: usize = 16;
  8  
  9  /// Number of rounds in the ChaCha20 algorithm
 10  const NUM_ROUNDS: usize = 10; // ChaCha20 uses 20 rounds, each function call here represents 2 rounds
 11  
 12  /// The ChaCha20 state.
 13  pub struct ChaCha20State {
 14      state: [u32; STATE_LEN],
 15  }
 16  
 17  // Newtype structs for different array sizes
 18  #[derive(Clone, Debug)]
 19  pub struct Key(pub [u8; 32]);
 20  #[derive(Clone, Debug)]
 21  pub struct Nonce(pub [u8; 12]);
 22  #[derive(Clone, Debug)]
 23  pub struct Block(pub [u8; 64]);
 24  
 25  // Implement Arbitrary for each new type
 26  impl Arbitrary for Key {
 27      fn arbitrary(g: &mut Gen) -> Self {
 28          let mut arr = [0u8; 32];
 29          for byte in arr.iter_mut() {
 30              *byte = u8::arbitrary(g);
 31          }
 32          Key(arr)
 33      }
 34  }
 35  
 36  impl Arbitrary for Nonce {
 37      fn arbitrary(g: &mut Gen) -> Self {
 38          let mut arr = [0u8; 12];
 39          for byte in arr.iter_mut() {
 40              *byte = u8::arbitrary(g);
 41          }
 42          Nonce(arr)
 43      }
 44  }
 45  
 46  impl Arbitrary for Block {
 47      fn arbitrary(g: &mut Gen) -> Self {
 48          let mut arr = [0u8; 64];
 49          for byte in arr.iter_mut() {
 50              *byte = u8::arbitrary(g);
 51          }
 52          Block(arr)
 53      }
 54  }
 55  
 56  impl ChaCha20State {
 57      /// Creates a new ChaCha20 state initialized with the key, nonce, and block counter.
 58      pub fn new(key: &[u8; 32], nonce: &[u8; 12], counter: u32) -> Self {
 59          let mut state = [
 60              0x6170_7865,
 61              0x3320_646e,
 62              0x7962_2d32,
 63              0x6b20_6574, // Constants
 64              0,
 65              0,
 66              0,
 67              0, // 256-bit key
 68              0,
 69              0,
 70              0,
 71              0,
 72              counter,                                                      // Block counter
 73              u32::from_le_bytes([nonce[0], nonce[1], nonce[2], nonce[3]]), // Nonce
 74              u32::from_le_bytes([nonce[4], nonce[5], nonce[6], nonce[7]]),
 75              u32::from_le_bytes([nonce[8], nonce[9], nonce[10], nonce[11]]),
 76          ];
 77  
 78          for i in 0..8 {
 79              state[4 + i] =
 80                  u32::from_le_bytes([key[4 * i], key[4 * i + 1], key[4 * i + 2], key[4 * i + 3]]);
 81          }
 82  
 83          ChaCha20State { state }
 84      }
 85  
 86      /// Resets the state with a new key, nonce, and block counter.
 87      pub fn reset(&mut self, key: &[u8; 32], nonce: &[u8; 12], counter: u32) {
 88          self.state = [
 89              0x6170_7865,
 90              0x3320_646e,
 91              0x7962_2d32,
 92              0x6b20_6574, // Constants
 93              0,
 94              0,
 95              0,
 96              0, // 256-bit key
 97              0,
 98              0,
 99              0,
100              0,
101              counter,                                                      // Block counter
102              u32::from_le_bytes([nonce[0], nonce[1], nonce[2], nonce[3]]), // Nonce
103              u32::from_le_bytes([nonce[4], nonce[5], nonce[6], nonce[7]]),
104              u32::from_le_bytes([nonce[8], nonce[9], nonce[10], nonce[11]]),
105          ];
106  
107          for i in 0..8 {
108              self.state[4 + i] =
109                  u32::from_le_bytes([key[4 * i], key[4 * i + 1], key[4 * i + 2], key[4 * i + 3]]);
110          }
111      }
112  
113      /// Encrypt or decrypt data using the ChaCha20 block function.
114      pub unsafe fn process(&mut self, input: &[u8], output: &mut [u8]) {
115          assert_eq!(
116              input.len(),
117              output.len(),
118              "Input and output must be the same length"
119          );
120  
121          let mut x = [
122              vld1q_u32(&self.state[0]),
123              vld1q_u32(&self.state[4]),
124              vld1q_u32(&self.state[8]),
125              vld1q_u32(&self.state[12]),
126          ];
127  
128          // Perform rounds
129          for _ in 0..NUM_ROUNDS {
130              // Column rounds
131              self.quarter_round(&mut x, 0, 1, 2, 3);
132              // Diagonal rounds
133              self.diagonal_round(&mut x);
134          }
135  
136          // Add back original state and serialize the state to the output
137          for i in 0..4 {
138              x[i] = vaddq_u32(x[i], vld1q_u32(&self.state[i * 4]));
139              let output_bytes =
140                  core::slice::from_raw_parts((&x[i] as *const uint32x4_t) as *const u8, 16);
141              for j in 0..16 {
142                  output[i * 16 + j] = input[i * 16 + j] ^ output_bytes[j]; // XOR to produce output
143              }
144          }
145      }
146  
147      /// Perform the ChaCha20 quarter round operation
148      fn quarter_round(&self, x: &mut [uint32x4_t; 4], a: usize, b: usize, c: usize, d: usize) {
149          unsafe {
150              x[a] = vaddq_u32(x[a], x[b]);
151              x[d] = veorq_u32(x[d], x[a]);
152              x[d] = vorrq_u32(vshlq_n_u32(x[d], 16), vshrq_n_u32(x[d], 16)); // Rotate by 16 bits
153  
154              x[c] = vaddq_u32(x[c], x[d]);
155              x[b] = veorq_u32(x[b], x[c]);
156              x[b] = vorrq_u32(vshlq_n_u32(x[b], 12), vshrq_n_u32(x[b], 20)); // Rotate by 12 bits
157  
158              x[a] = vaddq_u32(x[a], x[b]);
159              x[d] = veorq_u32(x[d], x[a]);
160              x[d] = vorrq_u32(vshlq_n_u32(x[d], 8), vshrq_n_u32(x[d], 24)); // Rotate by 8 bits
161  
162              x[c] = vaddq_u32(x[c], x[d]);
163              x[b] = veorq_u32(x[b], x[c]);
164              x[b] = vorrq_u32(vshlq_n_u32(x[b], 7), vshrq_n_u32(x[b], 25)); // Rotate by 7 bits
165          }
166      }
167  
168      /// Perform the ChaCha20 diagonal round operation
169      fn diagonal_round(&mut self, x: &mut [uint32x4_t; 4]) {
170          self.quarter_round(x, 0, 1, 2, 3);
171  
172          let temp = x[1];
173          x[1] = x[2];
174          x[2] = x[3];
175          x[3] = temp;
176      }
177  }