/ tests / math_arbitrary_precision / t_bigints_powmod_vs_gmp.nim
t_bigints_powmod_vs_gmp.nim
  1  # Constantine
  2  # Copyright (c) 2018-2019    Status Research & Development GmbH
  3  # Copyright (c) 2020-Present Mamy André-Ratsimbazafy
  4  # Licensed and distributed under either of
  5  #   * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
  6  #   * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
  7  # at your option. This file may not be copied, modified, or distributed except according to those terms.
  8  
  9  import
 10    ../../constantine/math_arbitrary_precision/arithmetic/[bigints_views, limbs_views],
 11    ../../constantine/platforms/abstractions,
 12    ../../constantine/serialization/codecs,
 13    ../../helpers/prng_unsafe,
 14  
 15    std/[times, strformat],
 16    gmp
 17  
 18  const # https://gmplib.org/manual/Integer-Import-and-Export.html
 19    GMP_WordLittleEndian = -1'i32
 20    GMP_WordNativeEndian = 0'i32
 21    GMP_WordBigEndian = 1'i32
 22  
 23    GMP_MostSignificantWordFirst = 1'i32
 24    GMP_LeastSignificantWordFirst = -1'i32
 25  
 26  const
 27    moduleName = "t_powmod_vs_gmp"
 28    Iters = 100
 29  
 30  var rng: RngState
 31  let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32
 32  rng.seed(seed)
 33  echo "\n------------------------------------------------------\n"
 34  echo moduleName, " xoshiro512** seed: ", seed
 35  
 36  proc fromHex(T: typedesc, hex: string): T =
 37    result.unmarshal(array[sizeof(T), byte].fromHex(hex), WordBitWidth, bigEndian)
 38  
 39  proc toHex(a: mpz_t): string =
 40    let size = mpz_sizeinbase(a, 16)
 41    result.setLen(size+2)
 42  
 43    result[0] = '0'
 44    result[1] = 'x'
 45    discard mpz_get_str(cast[cstring](result[2].addr), 16, a)
 46  
 47  proc test(rng: var RngState) =
 48    let
 49      aLen = rng.random_unsafe(1..100)
 50      eLen = rng.random_unsafe(1..400)
 51      mLen = rng.random_unsafe(1..100)
 52  
 53    var
 54      a = newSeq[SecretWord](aLen)
 55      e = newSeq[byte](eLen)
 56      M = newSeq[SecretWord](mLen)
 57  
 58      rGMP = newSeq[SecretWord](mLen)
 59      rCtt = newSeq[SecretWord](mLen)
 60  
 61    for word in a.mitems():
 62      word = SecretWord rng.next()
 63    for octet in e.mitems():
 64      octet = byte rng.next()
 65    for word in M.mitems():
 66      word = SecretWord rng.next()
 67  
 68    var aa, ee, mm, rr: mpz_t
 69    mpz_init(aa)
 70    mpz_init(ee)
 71    mpz_init(mm)
 72    mpz_init(rr)
 73  
 74    aa.mpz_import(aLen, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, a[0].addr)
 75    ee.mpz_import(eLen, GMP_MostSignificantWordFirst, sizeof(byte), GMP_WordNativeEndian, 0, e[0].addr)
 76    mm.mpz_import(mLen, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, M[0].addr)
 77  
 78    rr.mpz_powm(aa, ee, mm)
 79  
 80    var rWritten: csize
 81    discard rGMP[0].addr.mpz_export(rWritten.addr, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, rr)
 82  
 83    mpz_clear(aa)
 84    mpz_clear(ee)
 85    mpz_clear(mm)
 86    mpz_clear(rr)
 87  
 88    let
 89      aBits = a.getBits_LE_vartime()
 90      eBits = e.getBits_BE_vartime()
 91      mBits = M.getBits_LE_vartime()
 92  
 93    rCtt.powMod_vartime(a, e, M, window = 4)
 94  
 95    doAssert (seq[BaseType])(rGMP) == (seq[BaseType])(rCtt), block:
 96      "\nModular exponentiation failure:\n" &
 97      &"  a.len (word): {a.len:>3}, a.bits: {aBits:>4}\n" &
 98      &"  e.len (byte): {e.len:>3}, e.bits: {eBits:>4}\n" &
 99      &"  M.len (word): {M.len:>3}, M.bits: {mBits:>4}\n" &
100      "  ------------------------------------------------\n" &
101      &"  a: {aa.toHex()}\n" &
102      &"  e: {ee.toHex()}\n" &
103      &"  M: {mm.toHex()}\n" &
104      "  ------------------------------------------------\n" &
105      &"  r (GMP): {rGMP.toString()}\n" &
106      &"  r (Ctt): {rCtt.toString()}\n"
107  
108  
109  for _ in 0 ..< Iters:
110    rng.test()
111    stdout.write'.'
112  stdout.write'\n'