/ test / test_math.c
test_math.c
  1  /*
  2   * Copyright (c) 2016 Thomas Pornin <pornin@bolet.org>
  3   *
  4   * Permission is hereby granted, free of charge, to any person obtaining 
  5   * a copy of this software and associated documentation files (the
  6   * "Software"), to deal in the Software without restriction, including
  7   * without limitation the rights to use, copy, modify, merge, publish,
  8   * distribute, sublicense, and/or sell copies of the Software, and to
  9   * permit persons to whom the Software is furnished to do so, subject to
 10   * the following conditions:
 11   *
 12   * The above copyright notice and this permission notice shall be 
 13   * included in all copies or substantial portions of the Software.
 14   *
 15   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 
 16   * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 17   * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
 18   * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
 19   * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
 20   * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
 21   * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 22   * SOFTWARE.
 23   */
 24  
 25  #include <stdio.h>
 26  #include <stdlib.h>
 27  #include <string.h>
 28  #include <stdarg.h>
 29  #include <time.h>
 30  
 31  #include <gmp.h>
 32  
 33  #include "bearssl.h"
 34  #include "inner.h"
 35  
 36  /*
 37   * Pointers to implementations.
 38   */
 39  typedef struct {
 40  	uint32_t word_size;
 41  	void (*zero)(uint32_t *x, uint32_t bit_len);
 42  	void (*decode)(uint32_t *x, const void *src, size_t len);
 43  	uint32_t (*decode_mod)(uint32_t *x,
 44  		const void *src, size_t len, const uint32_t *m);
 45  	void (*reduce)(uint32_t *x, const uint32_t *a, const uint32_t *m);
 46  	void (*decode_reduce)(uint32_t *x,
 47  		const void *src, size_t len, const uint32_t *m);
 48  	void (*encode)(void *dst, size_t len, const uint32_t *x);
 49  	uint32_t (*add)(uint32_t *a, const uint32_t *b, uint32_t ctl);
 50  	uint32_t (*sub)(uint32_t *a, const uint32_t *b, uint32_t ctl);
 51  	uint32_t (*ninv)(uint32_t x);
 52  	void (*montymul)(uint32_t *d, const uint32_t *x, const uint32_t *y,
 53  		const uint32_t *m, uint32_t m0i);
 54  	void (*to_monty)(uint32_t *x, const uint32_t *m);
 55  	void (*from_monty)(uint32_t *x, const uint32_t *m, uint32_t m0i);
 56  	void (*modpow)(uint32_t *x, const unsigned char *e, size_t elen,
 57  		const uint32_t *m, uint32_t m0i, uint32_t *t1, uint32_t *t2);
 58  } int_impl;
 59  
 60  static const int_impl i31_impl = {
 61  	31,
 62  	&br_i31_zero,
 63  	&br_i31_decode,
 64  	&br_i31_decode_mod,
 65  	&br_i31_reduce,
 66  	&br_i31_decode_reduce,
 67  	&br_i31_encode,
 68  	&br_i31_add,
 69  	&br_i31_sub,
 70  	&br_i31_ninv31,
 71  	&br_i31_montymul,
 72  	&br_i31_to_monty,
 73  	&br_i31_from_monty,
 74  	&br_i31_modpow
 75  };
 76  static const int_impl i32_impl = {
 77  	32,
 78  	&br_i32_zero,
 79  	&br_i32_decode,
 80  	&br_i32_decode_mod,
 81  	&br_i32_reduce,
 82  	&br_i32_decode_reduce,
 83  	&br_i32_encode,
 84  	&br_i32_add,
 85  	&br_i32_sub,
 86  	&br_i32_ninv32,
 87  	&br_i32_montymul,
 88  	&br_i32_to_monty,
 89  	&br_i32_from_monty,
 90  	&br_i32_modpow
 91  };
 92  
 93  static const int_impl *impl;
 94  
 95  static gmp_randstate_t RNG;
 96  
 97  /*
 98   * Get a random prime of length 'size' bits. This function also guarantees
 99   * that x-1 is not a multiple of 65537.
100   */
101  static void
102  rand_prime(mpz_t x, int size)
103  {
104  	for (;;) {
105  		mpz_urandomb(x, RNG, size - 1);
106  		mpz_setbit(x, 0);
107  		mpz_setbit(x, size - 1);
108  		if (mpz_probab_prime_p(x, 50)) {
109  			mpz_sub_ui(x, x, 1);
110  			if (mpz_divisible_ui_p(x, 65537)) {
111  				continue;
112  			}
113  			mpz_add_ui(x, x, 1);
114  			return;
115  		}
116  	}
117  }
118  
119  /*
120   * Print out a GMP integer (for debug).
121   */
122  static void
123  print_z(mpz_t z)
124  {
125  	unsigned char zb[1000];
126  	size_t zlen, k;
127  
128  	mpz_export(zb, &zlen, 1, 1, 0, 0, z);
129  	if (zlen == 0) {
130  		printf(" 00");
131  		return;
132  	}
133  	if ((zlen & 3) != 0) {
134  		k = 4 - (zlen & 3);
135  		memmove(zb + k, zb, zlen);
136  		memset(zb, 0, k);
137  		zlen += k;
138  	}
139  	for (k = 0; k < zlen; k += 4) {
140  		printf(" %02X%02X%02X%02X",
141  			zb[k], zb[k + 1], zb[k + 2], zb[k + 3]);
142  	}
143  }
144  
145  /*
146   * Print out an i31 or i32 integer (for debug).
147   */
148  static void
149  print_u(uint32_t *x)
150  {
151  	size_t k;
152  
153  	if (x[0] == 0) {
154  		printf(" 00000000 (0, 0)");
155  		return;
156  	}
157  	for (k = (x[0] + 31) >> 5; k > 0; k --) {
158  		printf(" %08lX", (unsigned long)x[k]);
159  	}
160  	printf(" (%u, %u)", (unsigned)(x[0] >> 5), (unsigned)(x[0] & 31));
161  }
162  
163  /*
164   * Check that an i31/i32 number and a GMP number are equal.
165   */
166  static void
167  check_eqz(uint32_t *x, mpz_t z)
168  {
169  	unsigned char xb[1000];
170  	unsigned char zb[1000];
171  	size_t xlen, zlen;
172  	int good;
173  
174  	xlen = ((x[0] + 31) & ~(uint32_t)31) >> 3;
175  	impl->encode(xb, xlen, x);
176  	mpz_export(zb, &zlen, 1, 1, 0, 0, z);
177  	good = 1;
178  	if (xlen < zlen) {
179  		good = 0;
180  	} else if (xlen > zlen) {
181  		size_t u;
182  
183  		for (u = xlen; u > zlen; u --) {
184  			if (xb[xlen - u] != 0) {
185  				good = 0;
186  				break;
187  			}
188  		}
189  	}
190  	good = good && memcmp(xb + xlen - zlen, zb, zlen) == 0;
191  	if (!good) {
192  		size_t u;
193  
194  		printf("Mismatch:\n");
195  		printf("  x = ");
196  		print_u(x);
197  		printf("\n");
198  		printf("  ex = ");
199  		for (u = 0; u < xlen; u ++) {
200  			printf("%02X", xb[u]);
201  		}
202  		printf("\n");
203  		printf("  z = ");
204  		print_z(z);
205  		printf("\n");
206  		exit(EXIT_FAILURE);
207  	}
208  }
209  
210  /* obsolete
211  static void
212  mp_to_br(uint32_t *mx, uint32_t x_bitlen, mpz_t x)
213  {
214  	uint32_t x_ebitlen;
215  	size_t xlen;
216  
217  	if (mpz_sizeinbase(x, 2) > x_bitlen) {
218  		abort();
219  	}
220  	x_ebitlen = ((x_bitlen / 31) << 5) + (x_bitlen % 31);
221  	br_i31_zero(mx, x_ebitlen);
222  	mpz_export(mx + 1, &xlen, -1, sizeof *mx, 0, 1, x);
223  }
224  */
225  
226  static void
227  test_modint(void)
228  {
229  	int i, j, k;
230  	mpz_t p, a, b, v, t1;
231  
232  	printf("Test modular integers: ");
233  	fflush(stdout);
234  
235  	gmp_randinit_mt(RNG);
236  	mpz_init(p);
237  	mpz_init(a);
238  	mpz_init(b);
239  	mpz_init(v);
240  	mpz_init(t1);
241  	mpz_set_ui(t1, (unsigned long)time(NULL));
242  	gmp_randseed(RNG, t1);
243  	for (k = 2; k <= 128; k ++) {
244  		for (i = 0; i < 10; i ++) {
245  			unsigned char ep[100], ea[100], eb[100], ev[100];
246  			size_t plen, alen, blen, vlen;
247  			uint32_t mp[40], ma[40], mb[40], mv[60], mx[100];
248  			uint32_t mt1[40], mt2[40], mt3[40];
249  			uint32_t ctl;
250  			uint32_t mp0i;
251  
252  			rand_prime(p, k);
253  			mpz_urandomm(a, RNG, p);
254  			mpz_urandomm(b, RNG, p);
255  			mpz_urandomb(v, RNG, k + 60);
256  			if (mpz_sgn(b) == 0) {
257  				mpz_set_ui(b, 1);
258  			}
259  			mpz_export(ep, &plen, 1, 1, 0, 0, p);
260  			mpz_export(ea, &alen, 1, 1, 0, 0, a);
261  			mpz_export(eb, &blen, 1, 1, 0, 0, b);
262  			mpz_export(ev, &vlen, 1, 1, 0, 0, v);
263  
264  			impl->decode(mp, ep, plen);
265  			if (impl->decode_mod(ma, ea, alen, mp) != 1) {
266  				printf("Decode error\n");
267  				printf("  ea = ");
268  				print_z(a);
269  				printf("\n");
270  				printf("  p = ");
271  				print_u(mp);
272  				printf("\n");
273  				exit(EXIT_FAILURE);
274  			}
275  			mp0i = impl->ninv(mp[1]);
276  			if (impl->decode_mod(mb, eb, blen, mp) != 1) {
277  				printf("Decode error\n");
278  				printf("  eb = ");
279  				print_z(b);
280  				printf("\n");
281  				printf("  p = ");
282  				print_u(mp);
283  				printf("\n");
284  				exit(EXIT_FAILURE);
285  			}
286  			impl->decode(mv, ev, vlen);
287  			check_eqz(mp, p);
288  			check_eqz(ma, a);
289  			check_eqz(mb, b);
290  			check_eqz(mv, v);
291  
292  			impl->decode_mod(ma, ea, alen, mp);
293  			impl->decode_mod(mb, eb, blen, mp);
294  			ctl = impl->add(ma, mb, 1);
295  			ctl |= impl->sub(ma, mp, 0) ^ (uint32_t)1;
296  			impl->sub(ma, mp, ctl);
297  			mpz_add(t1, a, b);
298  			mpz_mod(t1, t1, p);
299  			check_eqz(ma, t1);
300  
301  			impl->decode_mod(ma, ea, alen, mp);
302  			impl->decode_mod(mb, eb, blen, mp);
303  			impl->add(ma, mp, impl->sub(ma, mb, 1));
304  			mpz_sub(t1, a, b);
305  			mpz_mod(t1, t1, p);
306  			check_eqz(ma, t1);
307  
308  			impl->decode_reduce(ma, ev, vlen, mp);
309  			mpz_mod(t1, v, p);
310  			check_eqz(ma, t1);
311  
312  			impl->decode(mv, ev, vlen);
313  			impl->reduce(ma, mv, mp);
314  			mpz_mod(t1, v, p);
315  			check_eqz(ma, t1);
316  
317  			impl->decode_mod(ma, ea, alen, mp);
318  			impl->to_monty(ma, mp);
319  			mpz_mul_2exp(t1, a, ((k + impl->word_size - 1)
320  				/ impl->word_size) * impl->word_size);
321  			mpz_mod(t1, t1, p);
322  			check_eqz(ma, t1);
323  			impl->from_monty(ma, mp, mp0i);
324  			check_eqz(ma, a);
325  
326  			impl->decode_mod(ma, ea, alen, mp);
327  			impl->decode_mod(mb, eb, blen, mp);
328  			impl->to_monty(ma, mp);
329  			impl->montymul(mt1, ma, mb, mp, mp0i);
330  			mpz_mul(t1, a, b);
331  			mpz_mod(t1, t1, p);
332  			check_eqz(mt1, t1);
333  
334  			impl->decode_mod(ma, ea, alen, mp);
335  			impl->modpow(ma, ev, vlen, mp, mp0i, mt1, mt2);
336  			mpz_powm(t1, a, v, p);
337  			check_eqz(ma, t1);
338  
339  			/*
340  			br_modint_decode(ma, mp, ea, alen);
341  			br_modint_decode(mb, mp, eb, blen);
342  			if (!br_modint_div(ma, mb, mp, mt1, mt2, mt3)) {
343  				fprintf(stderr, "division failed\n");
344  				exit(EXIT_FAILURE);
345  			}
346  			mpz_sub_ui(t1, p, 2);
347  			mpz_powm(t1, b, t1, p);
348  			mpz_mul(t1, a, t1);
349  			mpz_mod(t1, t1, p);
350  			check_eqz(ma, t1);
351  
352  			br_modint_decode(ma, mp, ea, alen);
353  			br_modint_decode(mb, mp, eb, blen);
354  			for (j = 0; j <= (2 * k + 5); j ++) {
355  				br_int_add(mx, j, ma, mb);
356  				mpz_add(t1, a, b);
357  				mpz_tdiv_r_2exp(t1, t1, j);
358  				check_eqz(mx, t1);
359  
360  				br_int_mul(mx, j, ma, mb);
361  				mpz_mul(t1, a, b);
362  				mpz_tdiv_r_2exp(t1, t1, j);
363  				check_eqz(mx, t1);
364  			}
365  			*/
366  		}
367  		printf(".");
368  		fflush(stdout);
369  	}
370  	mpz_clear(p);
371  	mpz_clear(a);
372  	mpz_clear(b);
373  	mpz_clear(v);
374  	mpz_clear(t1);
375  
376  	printf(" done.\n");
377  	fflush(stdout);
378  }
379  
380  #if 0
381  static void
382  test_RSA_core(void)
383  {
384  	int i, j, k;
385  	mpz_t n, e, d, p, q, dp, dq, iq, t1, t2, phi;
386  
387  	printf("Test RSA core: ");
388  	fflush(stdout);
389  
390  	gmp_randinit_mt(RNG);
391  	mpz_init(n);
392  	mpz_init(e);
393  	mpz_init(d);
394  	mpz_init(p);
395  	mpz_init(q);
396  	mpz_init(dp);
397  	mpz_init(dq);
398  	mpz_init(iq);
399  	mpz_init(t1);
400  	mpz_init(t2);
401  	mpz_init(phi);
402  	mpz_set_ui(t1, (unsigned long)time(NULL));
403  	gmp_randseed(RNG, t1);
404  
405  	/*
406  	 * To test corner cases, we want to try RSA keys such that the
407  	 * lengths of both factors can be arbitrary modulo 2^32. Factors
408  	 * p and q need not be of the same length; p can be greater than
409  	 * q and q can be greater than p.
410  	 *
411  	 * To keep computation time reasonable, we use p and q factors of
412  	 * less than 128 bits; this is way too small for secure RSA,
413  	 * but enough to exercise all code paths (since we work only with
414  	 * 32-bit words).
415  	 */
416  	for (i = 64; i <= 96; i ++) {
417  		rand_prime(p, i);
418  		for (j = i - 33; j <= i + 33; j ++) {
419  			uint32_t mp[40], mq[40], mdp[40], mdq[40], miq[40];
420  
421  			/*
422  			 * Generate a RSA key pair, with p of length i bits,
423  			 * and q of length j bits.
424  			 */
425  			do {
426  				rand_prime(q, j);
427  			} while (mpz_cmp(p, q) == 0);
428  			mpz_mul(n, p, q);
429  			mpz_set_ui(e, 65537);
430  			mpz_sub_ui(t1, p, 1);
431  			mpz_sub_ui(t2, q, 1);
432  			mpz_mul(phi, t1, t2);
433  			mpz_invert(d, e, phi);
434  			mpz_mod(dp, d, t1);
435  			mpz_mod(dq, d, t2);
436  			mpz_invert(iq, q, p);
437  
438  			/*
439  			 * Convert the key pair elements to BearSSL arrays.
440  			 */
441  			mp_to_br(mp, mpz_sizeinbase(p, 2), p);
442  			mp_to_br(mq, mpz_sizeinbase(q, 2), q);
443  			mp_to_br(mdp, mpz_sizeinbase(dp, 2), dp);
444  			mp_to_br(mdq, mpz_sizeinbase(dq, 2), dq);
445  			mp_to_br(miq, mp[0], iq);
446  
447  			/*
448  			 * Compute and check ten public/private operations.
449  			 */
450  			for (k = 0; k < 10; k ++) {
451  				uint32_t mx[40];
452  
453  				mpz_urandomm(t1, RNG, n);
454  				mpz_powm(t2, t1, e, n);
455  				mp_to_br(mx, mpz_sizeinbase(n, 2), t2);
456  				br_rsa_private_core(mx, mp, mq, mdp, mdq, miq);
457  				check_eqz(mx, t1);
458  			}
459  		}
460  		printf(".");
461  		fflush(stdout);
462  	}
463  
464  	printf(" done.\n");
465  	fflush(stdout);
466  }
467  #endif
468  
469  int
470  main(void)
471  {
472  	printf("===== i32 ======\n");
473  	impl = &i32_impl;
474  	test_modint();
475  	printf("===== i31 ======\n");
476  	impl = &i31_impl;
477  	test_modint();
478  	/*
479  	test_RSA_core();
480  	*/
481  	return 0;
482  }