/ src / test / fuzz / crypto_diff_fuzz_chacha20.cpp
crypto_diff_fuzz_chacha20.cpp
  1  // Copyright (c) 2020-present The Bitcoin Core developers
  2  // Distributed under the MIT software license, see the accompanying
  3  // file COPYING or http://www.opensource.org/licenses/mit-license.php.
  4  
  5  #include <crypto/chacha20.h>
  6  #include <test/fuzz/FuzzedDataProvider.h>
  7  #include <test/fuzz/fuzz.h>
  8  #include <test/fuzz/util.h>
  9  
 10  #include <cstdint>
 11  #include <vector>
 12  
 13  /*
 14  From https://cr.yp.to/chacha.html
 15  chacha-merged.c version 20080118
 16  D. J. Bernstein
 17  Public domain.
 18  */
 19  
 20  typedef unsigned int u32;
 21  typedef unsigned char u8;
 22  
 23  #define U8C(v) (v##U)
 24  #define U32C(v) (v##U)
 25  
 26  #define U8V(v) ((u8)(v)&U8C(0xFF))
 27  #define U32V(v) ((u32)(v)&U32C(0xFFFFFFFF))
 28  
 29  #define ROTL32(v, n) (U32V((v) << (n)) | ((v) >> (32 - (n))))
 30  
 31  #define U8TO32_LITTLE(p)                                              \
 32      (((u32)((p)[0])) | ((u32)((p)[1]) << 8) | ((u32)((p)[2]) << 16) | \
 33       ((u32)((p)[3]) << 24))
 34  
 35  #define U32TO8_LITTLE(p, v)      \
 36      do {                         \
 37          (p)[0] = U8V((v));       \
 38          (p)[1] = U8V((v) >> 8);  \
 39          (p)[2] = U8V((v) >> 16); \
 40          (p)[3] = U8V((v) >> 24); \
 41      } while (0)
 42  
 43  /* ------------------------------------------------------------------------- */
 44  /* Data structures */
 45  
 46  typedef struct
 47  {
 48      u32 input[16];
 49  } ECRYPT_ctx;
 50  
 51  /* ------------------------------------------------------------------------- */
 52  /* Mandatory functions */
 53  
 54  void ECRYPT_keysetup(
 55      ECRYPT_ctx* ctx,
 56      const u8* key,
 57      u32 keysize, /* Key size in bits. */
 58      u32 ivsize); /* IV size in bits. */
 59  
 60  void ECRYPT_ivsetup(
 61      ECRYPT_ctx* ctx,
 62      const u8* iv);
 63  
 64  void ECRYPT_encrypt_bytes(
 65      ECRYPT_ctx* ctx,
 66      const u8* plaintext,
 67      u8* ciphertext,
 68      u32 msglen); /* Message length in bytes. */
 69  
 70  /* ------------------------------------------------------------------------- */
 71  
 72  /* Optional features */
 73  
 74  void ECRYPT_keystream_bytes(
 75      ECRYPT_ctx* ctx,
 76      u8* keystream,
 77      u32 length); /* Length of keystream in bytes. */
 78  
 79  /* ------------------------------------------------------------------------- */
 80  
 81  #define ROTATE(v, c) (ROTL32(v, c))
 82  #define XOR(v, w) ((v) ^ (w))
 83  #define PLUS(v, w) (U32V((v) + (w)))
 84  #define PLUSONE(v) (PLUS((v), 1))
 85  
 86  #define QUARTERROUND(a, b, c, d) \
 87      a = PLUS(a, b); d = ROTATE(XOR(d, a), 16);   \
 88      c = PLUS(c, d); b = ROTATE(XOR(b, c), 12);   \
 89      a = PLUS(a, b); d = ROTATE(XOR(d, a), 8);    \
 90      c = PLUS(c, d); b = ROTATE(XOR(b, c), 7);
 91  
 92  static const char sigma[] = "expand 32-byte k";
 93  static const char tau[] = "expand 16-byte k";
 94  
 95  void ECRYPT_keysetup(ECRYPT_ctx* x, const u8* k, u32 kbits, u32 ivbits)
 96  {
 97      const char* constants;
 98  
 99      x->input[4] = U8TO32_LITTLE(k + 0);
100      x->input[5] = U8TO32_LITTLE(k + 4);
101      x->input[6] = U8TO32_LITTLE(k + 8);
102      x->input[7] = U8TO32_LITTLE(k + 12);
103      if (kbits == 256) { /* recommended */
104          k += 16;
105          constants = sigma;
106      } else { /* kbits == 128 */
107          constants = tau;
108      }
109      x->input[8] = U8TO32_LITTLE(k + 0);
110      x->input[9] = U8TO32_LITTLE(k + 4);
111      x->input[10] = U8TO32_LITTLE(k + 8);
112      x->input[11] = U8TO32_LITTLE(k + 12);
113      x->input[0] = U8TO32_LITTLE(constants + 0);
114      x->input[1] = U8TO32_LITTLE(constants + 4);
115      x->input[2] = U8TO32_LITTLE(constants + 8);
116      x->input[3] = U8TO32_LITTLE(constants + 12);
117  }
118  
119  void ECRYPT_ivsetup(ECRYPT_ctx* x, const u8* iv)
120  {
121      x->input[12] = 0;
122      x->input[13] = 0;
123      x->input[14] = U8TO32_LITTLE(iv + 0);
124      x->input[15] = U8TO32_LITTLE(iv + 4);
125  }
126  
127  void ECRYPT_encrypt_bytes(ECRYPT_ctx* x, const u8* m, u8* c, u32 bytes)
128  {
129      u32 x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15;
130      u32 j0, j1, j2, j3, j4, j5, j6, j7, j8, j9, j10, j11, j12, j13, j14, j15;
131      u8* ctarget = nullptr;
132      u8 tmp[64];
133      uint32_t i;
134  
135      if (!bytes) return;
136  
137      j0 = x->input[0];
138      j1 = x->input[1];
139      j2 = x->input[2];
140      j3 = x->input[3];
141      j4 = x->input[4];
142      j5 = x->input[5];
143      j6 = x->input[6];
144      j7 = x->input[7];
145      j8 = x->input[8];
146      j9 = x->input[9];
147      j10 = x->input[10];
148      j11 = x->input[11];
149      j12 = x->input[12];
150      j13 = x->input[13];
151      j14 = x->input[14];
152      j15 = x->input[15];
153  
154      for (;;) {
155          if (bytes < 64) {
156              for (i = 0; i < bytes; ++i)
157                  tmp[i] = m[i];
158              m = tmp;
159              ctarget = c;
160              c = tmp;
161          }
162          x0 = j0;
163          x1 = j1;
164          x2 = j2;
165          x3 = j3;
166          x4 = j4;
167          x5 = j5;
168          x6 = j6;
169          x7 = j7;
170          x8 = j8;
171          x9 = j9;
172          x10 = j10;
173          x11 = j11;
174          x12 = j12;
175          x13 = j13;
176          x14 = j14;
177          x15 = j15;
178          for (i = 20; i > 0; i -= 2) {
179              QUARTERROUND(x0, x4, x8, x12)
180              QUARTERROUND(x1, x5, x9, x13)
181              QUARTERROUND(x2, x6, x10, x14)
182              QUARTERROUND(x3, x7, x11, x15)
183              QUARTERROUND(x0, x5, x10, x15)
184              QUARTERROUND(x1, x6, x11, x12)
185              QUARTERROUND(x2, x7, x8, x13)
186              QUARTERROUND(x3, x4, x9, x14)
187          }
188          x0 = PLUS(x0, j0);
189          x1 = PLUS(x1, j1);
190          x2 = PLUS(x2, j2);
191          x3 = PLUS(x3, j3);
192          x4 = PLUS(x4, j4);
193          x5 = PLUS(x5, j5);
194          x6 = PLUS(x6, j6);
195          x7 = PLUS(x7, j7);
196          x8 = PLUS(x8, j8);
197          x9 = PLUS(x9, j9);
198          x10 = PLUS(x10, j10);
199          x11 = PLUS(x11, j11);
200          x12 = PLUS(x12, j12);
201          x13 = PLUS(x13, j13);
202          x14 = PLUS(x14, j14);
203          x15 = PLUS(x15, j15);
204  
205          x0 = XOR(x0, U8TO32_LITTLE(m + 0));
206          x1 = XOR(x1, U8TO32_LITTLE(m + 4));
207          x2 = XOR(x2, U8TO32_LITTLE(m + 8));
208          x3 = XOR(x3, U8TO32_LITTLE(m + 12));
209          x4 = XOR(x4, U8TO32_LITTLE(m + 16));
210          x5 = XOR(x5, U8TO32_LITTLE(m + 20));
211          x6 = XOR(x6, U8TO32_LITTLE(m + 24));
212          x7 = XOR(x7, U8TO32_LITTLE(m + 28));
213          x8 = XOR(x8, U8TO32_LITTLE(m + 32));
214          x9 = XOR(x9, U8TO32_LITTLE(m + 36));
215          x10 = XOR(x10, U8TO32_LITTLE(m + 40));
216          x11 = XOR(x11, U8TO32_LITTLE(m + 44));
217          x12 = XOR(x12, U8TO32_LITTLE(m + 48));
218          x13 = XOR(x13, U8TO32_LITTLE(m + 52));
219          x14 = XOR(x14, U8TO32_LITTLE(m + 56));
220          x15 = XOR(x15, U8TO32_LITTLE(m + 60));
221  
222          j12 = PLUSONE(j12);
223          if (!j12) {
224              j13 = PLUSONE(j13);
225              /* stopping at 2^70 bytes per nonce is user's responsibility */
226          }
227  
228          U32TO8_LITTLE(c + 0, x0);
229          U32TO8_LITTLE(c + 4, x1);
230          U32TO8_LITTLE(c + 8, x2);
231          U32TO8_LITTLE(c + 12, x3);
232          U32TO8_LITTLE(c + 16, x4);
233          U32TO8_LITTLE(c + 20, x5);
234          U32TO8_LITTLE(c + 24, x6);
235          U32TO8_LITTLE(c + 28, x7);
236          U32TO8_LITTLE(c + 32, x8);
237          U32TO8_LITTLE(c + 36, x9);
238          U32TO8_LITTLE(c + 40, x10);
239          U32TO8_LITTLE(c + 44, x11);
240          U32TO8_LITTLE(c + 48, x12);
241          U32TO8_LITTLE(c + 52, x13);
242          U32TO8_LITTLE(c + 56, x14);
243          U32TO8_LITTLE(c + 60, x15);
244  
245          if (bytes <= 64) {
246              if (bytes < 64) {
247                  for (i = 0; i < bytes; ++i)
248                      ctarget[i] = c[i];
249              }
250              x->input[12] = j12;
251              x->input[13] = j13;
252              return;
253          }
254          bytes -= 64;
255          c += 64;
256          m += 64;
257      }
258  }
259  
260  void ECRYPT_keystream_bytes(ECRYPT_ctx* x, u8* stream, u32 bytes)
261  {
262      u32 i;
263      for (i = 0; i < bytes; ++i)
264          stream[i] = 0;
265      ECRYPT_encrypt_bytes(x, stream, stream, bytes);
266  }
267  
268  FUZZ_TARGET(crypto_diff_fuzz_chacha20)
269  {
270      FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()};
271  
272      ECRYPT_ctx ctx;
273  
274      const std::vector<unsigned char> key = ConsumeFixedLengthByteVector(fuzzed_data_provider, 32);
275      ChaCha20 chacha20{MakeByteSpan(key)};
276      ECRYPT_keysetup(&ctx, key.data(), key.size() * 8, 0);
277  
278      // ECRYPT_keysetup() doesn't set the counter and nonce to 0 while SetKey() does
279      static const uint8_t iv[8] = {0, 0, 0, 0, 0, 0, 0, 0};
280      ChaCha20::Nonce96 nonce{0, 0};
281      uint32_t counter{0};
282      ECRYPT_ivsetup(&ctx, iv);
283  
284      LIMITED_WHILE (fuzzed_data_provider.ConsumeBool(), 3000) {
285          CallOneOf(
286              fuzzed_data_provider,
287              [&] {
288                  const std::vector<unsigned char> key = ConsumeFixedLengthByteVector(fuzzed_data_provider, 32);
289                  chacha20.SetKey(MakeByteSpan(key));
290                  nonce = {0, 0};
291                  counter = 0;
292                  ECRYPT_keysetup(&ctx, key.data(), key.size() * 8, 0);
293                  // ECRYPT_keysetup() doesn't set the counter and nonce to 0 while SetKey() does
294                  uint8_t iv[8] = {0, 0, 0, 0, 0, 0, 0, 0};
295                  ECRYPT_ivsetup(&ctx, iv);
296              },
297              [&] {
298                  uint32_t iv_prefix = fuzzed_data_provider.ConsumeIntegral<uint32_t>();
299                  uint64_t iv = fuzzed_data_provider.ConsumeIntegral<uint64_t>();
300                  nonce = {iv_prefix, iv};
301                  counter = fuzzed_data_provider.ConsumeIntegral<uint32_t>();
302                  chacha20.Seek(nonce, counter);
303                  ctx.input[12] = counter;
304                  ctx.input[13] = iv_prefix;
305                  ctx.input[14] = iv;
306                  ctx.input[15] = iv >> 32;
307              },
308              [&] {
309                  uint32_t integralInRange = fuzzed_data_provider.ConsumeIntegralInRange<size_t>(0, 4096);
310                  std::vector<uint8_t> output(integralInRange);
311                  chacha20.Keystream(MakeWritableByteSpan(output));
312                  std::vector<uint8_t> djb_output(integralInRange);
313                  ECRYPT_keystream_bytes(&ctx, djb_output.data(), djb_output.size());
314                  assert(output == djb_output);
315                  // DJB's version seeks forward to a multiple of 64 bytes after every operation. Correct for that.
316                  uint32_t old_counter = counter;
317                  counter += (integralInRange + 63) >> 6;
318                  if (counter < old_counter) ++nonce.first;
319                  if (integralInRange & 63) {
320                      chacha20.Seek(nonce, counter);
321                  }
322                  assert(counter == ctx.input[12]);
323              },
324              [&] {
325                  uint32_t integralInRange = fuzzed_data_provider.ConsumeIntegralInRange<size_t>(0, 4096);
326                  std::vector<uint8_t> output(integralInRange);
327                  const std::vector<uint8_t> input = ConsumeFixedLengthByteVector(fuzzed_data_provider, output.size());
328                  chacha20.Crypt(MakeByteSpan(input), MakeWritableByteSpan(output));
329                  std::vector<uint8_t> djb_output(integralInRange);
330                  ECRYPT_encrypt_bytes(&ctx, input.data(), djb_output.data(), input.size());
331                  assert(output == djb_output);
332                  // DJB's version seeks forward to a multiple of 64 bytes after every operation. Correct for that.
333                  uint32_t old_counter = counter;
334                  counter += (integralInRange + 63) >> 6;
335                  if (counter < old_counter) ++nonce.first;
336                  if (integralInRange & 63) {
337                      chacha20.Seek(nonce, counter);
338                  }
339                  assert(counter == ctx.input[12]);
340              });
341      }
342  }