/ external / libecc / scripts / expand_libecc.py
expand_libecc.py
   1  #/*
   2  # *  Copyright (C) 2017 - This file is part of libecc project
   3  # *
   4  # *  Authors:
   5  # *      Ryad BENADJILA <ryadbenadjila@gmail.com>
   6  # *      Arnaud EBALARD <arnaud.ebalard@ssi.gouv.fr>
   7  # *      Jean-Pierre FLORI <jean-pierre.flori@ssi.gouv.fr>
   8  # *
   9  # *  Contributors:
  10  # *      Nicolas VIVET <nicolas.vivet@ssi.gouv.fr>
  11  # *      Karim KHALFALLAH <karim.khalfallah@ssi.gouv.fr>
  12  # *
  13  # *  This software is licensed under a dual BSD and GPL v2 license.
  14  # *  See LICENSE file at the root folder of the project.
  15  # */
  16  #! /usr/bin/env python
  17  
  18  import random, sys, re, math, os, getopt, glob, copy, hashlib, binascii, string, signal, base64
  19  
  20  # External dependecy for SHA-3
  21  # It is an independent module, since hashlib has no support
  22  # for SHA-3 functions for now
  23  import sha3
  24  
  25  # Handle Python 2/3 issues
  26  def is_python_2():
  27      if sys.version_info[0] < 3:
  28          return True
  29      else:
  30          return False
  31  
  32  ### Ctrl-C handler
  33  def handler(signal, frame):
  34      print("\nSIGINT caught: exiting ...")
  35      exit(0)
  36  
  37  # Helper to ask the user for something
  38  def get_user_input(prompt):
  39      # Handle the Python 2/3 issue
  40      if is_python_2() == False:
  41          return input(prompt)
  42      else:
  43          return raw_input(prompt)
  44  
  45  ##########################################################
  46  #### Math helpers
  47  def egcd(b, n):
  48      x0, x1, y0, y1 = 1, 0, 0, 1
  49      while n != 0:
  50          q, b, n = b // n, n, b % n
  51          x0, x1 = x1, x0 - q * x1
  52          y0, y1 = y1, y0 - q * y1
  53      return  b, x0, y0
  54  
  55  def modinv(a, m):
  56      g, x, y = egcd(a, m)
  57      if g != 1:
  58          raise Exception("Error: modular inverse does not exist")
  59      else:
  60          return x % m
  61  
  62  def compute_monty_coef(prime, pbitlen, wlen):
  63      """
  64      Compute montgomery coeff r, r^2 and mpinv. pbitlen is the size
  65      of p in bits. It is expected to be a multiple of word
  66      bit size.
  67      """
  68      r = (1 << int(pbitlen)) % prime
  69      r_square = (1 << (2 * int(pbitlen))) % prime
  70      mpinv = 2**wlen - (modinv(prime, 2**wlen))
  71      return r, r_square, mpinv
  72  
  73  def compute_div_coef(prime, pbitlen, wlen):
  74      """
  75      Compute division coeffs p_normalized, p_shift and p_reciprocal.
  76      """
  77      tmp = prime
  78      cnt = 0
  79      while tmp != 0:
  80          tmp = tmp >> 1
  81          cnt += 1
  82      pshift = int(pbitlen - cnt)
  83      primenorm = prime << pshift
  84      B = 2**wlen
  85      prec = B**3 // ((primenorm >> int(pbitlen - 2*wlen)) + 1) - B
  86      return pshift, primenorm, prec
  87  
  88  def is_probprime(n):
  89      # ensure n is odd
  90      if n % 2 == 0:
  91          return False
  92      # write n-1 as 2**s * d
  93      # repeatedly try to divide n-1 by 2
  94      s = 0
  95      d = n-1
  96      while True:
  97          quotient, remainder = divmod(d, 2)
  98          if remainder == 1:
  99              break
 100          s += 1
 101          d = quotient
 102      assert(2**s * d == n-1)
 103      # test the base a to see whether it is a witness for the compositeness of n
 104      def try_composite(a):
 105          if pow(a, d, n) == 1:
 106              return False
 107          for i in range(s):
 108              if pow(a, 2**i * d, n) == n-1:
 109                  return False
 110          return True # n is definitely composite
 111      for i in range(5):
 112          a = random.randrange(2, n)
 113          if try_composite(a):
 114              return False
 115      return True # no base tested showed n as composite
 116  
 117  def legendre_symbol(a, p):
 118      ls = pow(a, (p - 1) // 2, p)
 119      return -1 if ls == p - 1 else ls
 120  
 121  # Tonelli-Shanks algorithm to find square roots
 122  # over prime fields
 123  def mod_sqrt(a, p):
 124      # Square root of 0 is 0
 125      if a == 0:
 126          return 0
 127      # Simple cases
 128      if legendre_symbol(a, p) != 1:
 129          # No square residue
 130          return None
 131      elif p == 2:
 132          return a
 133      elif p % 4 == 3:
 134          return pow(a, (p + 1) // 4, p)
 135      s = p - 1
 136      e = 0
 137      while s % 2 == 0:
 138          s = s // 2
 139          e += 1
 140      n = 2
 141      while legendre_symbol(n, p) != -1:
 142          n += 1
 143      x = pow(a, (s + 1) // 2, p)
 144      b = pow(a, s, p)
 145      g = pow(n, s, p)
 146      r = e
 147      while True:
 148          t = b
 149          m = 0
 150          if is_python_2():
 151              for m in xrange(r):
 152                  if t == 1:
 153                      break
 154                  t = pow(t, 2, p)
 155          else:
 156              for m in range(r):
 157                  if t == 1:
 158                      break
 159                  t = pow(t, 2, p)
 160          if m == 0:
 161              return x
 162          gs = pow(g, 2 ** (r - m - 1), p)
 163          g = (gs * gs) % p
 164          x = (x * gs) % p
 165          b = (b * g) % p
 166          r = m
 167  
 168  ##########################################################
 169  ### Math elliptic curves basic blocks
 170  
 171  # WARNING: these blocks are only here for testing purpose and
 172  # are not intended to be used in a security oriented library!
 173  # This explains the usage of naive affine coordinates fomulas
 174  class Curve(object):
 175      def __init__(self, a, b, prime, order, cofactor, gx, gy, npoints, name, oid):
 176          self.a = a
 177          self.b = b
 178          self.p = prime
 179          self.q = order
 180          self.c = cofactor
 181          self.gx = gx
 182          self.gy = gy
 183          self.n = npoints
 184          self.name = name
 185          self.oid = oid
 186      # Equality testing
 187      def __eq__(self, other):
 188          return self.__dict__ == other.__dict__
 189      # Deep copy is implemented using the ~X operator
 190      def __invert__(self):
 191          return copy.deepcopy(self)
 192  
 193  
 194  class Point(object):
 195      # Affine coordinates (x, y), infinity point is (None, None)
 196      def __init__(self, curve, x, y):
 197          self.curve = curve
 198          if x != None:
 199              self.x = (x % curve.p)
 200          else:
 201              self.x = None
 202          if y != None:
 203              self.y = (y % curve.p)
 204          else:
 205              self.y = None
 206          # Check that the point is indeed on the curve
 207          if (x != None):
 208              if (pow(y, 2, curve.p) != ((pow(x, 3, curve.p) + (curve.a * x) + curve.b ) % curve.p)):
 209                  raise Exception("Error: point is not on curve!")
 210      # Addition
 211      def __add__(self, Q):
 212          x1 = self.x
 213          y1 = self.y
 214          x2 = Q.x
 215          y2 = Q.y
 216          curve = self.curve
 217          # Check that we are on the same curve
 218          if Q.curve != curve:
 219              raise Exception("Point add error: two point don't have the same curve")
 220          # If Q is infinity point, return ourself
 221          if Q.x == None:
 222              return Point(self.curve, self.x, self.y)
 223          # If we are the infinity point return Q
 224          if self.x == None:
 225              return Q
 226          # Infinity point or Doubling
 227          if (x1 == x2):
 228              if (((y1 + y2) % curve.p) == 0):
 229                  # Return infinity point
 230                  return Point(self.curve, None, None)
 231              else:
 232                  # Doubling
 233                  L = ((3*pow(x1, 2, curve.p) + curve.a) * modinv(2*y1, curve.p)) % curve.p
 234          # Addition
 235          else:
 236              L = ((y2 - y1) * modinv((x2 - x1) % curve.p, curve.p)) % curve.p
 237          resx = (pow(L, 2, curve.p) - x1 - x2) % curve.p
 238          resy = ((L * (x1 - resx)) - y1) % curve.p
 239          # Return the point
 240          return Point(self.curve, resx, resy)
 241      # Negation
 242      def __neg__(self):
 243          if (self.x == None):
 244              return Point(self.curve, None, None)
 245          else:
 246              return Point(self.curve, self.x, -self.y)
 247      # Subtraction
 248      def __sub__(self, other):
 249          return self + (-other)
 250      # Scalar mul
 251      def __rmul__(self, scalar):
 252          # Implement simple double and add algorithm
 253          P = self
 254          Q = Point(P.curve, None, None)
 255          for i in range(getbitlen(scalar), 0, -1):
 256              Q = Q + Q
 257              if (scalar >> (i-1)) & 0x1 == 0x1:
 258                  Q = Q + P
 259          return Q
 260      # Equality testing
 261      def __eq__(self, other):
 262          return self.__dict__ == other.__dict__
 263      # Deep copy is implemented using the ~X operator
 264      def __invert__(self):
 265          return copy.deepcopy(self)
 266      def __str__(self):
 267          if self.x == None:
 268              return "Inf"
 269          else:
 270              return ("(x = %s, y = %s)" % (hex(self.x), hex(self.y)))
 271  
 272  ##########################################################
 273  ### Private and public keys structures
 274  class PrivKey(object):
 275      def __init__(self, curve, x):
 276          self.curve = curve
 277          self.x = x
 278  
 279  class PubKey(object):
 280      def __init__(self, curve, Y):
 281          # Sanity check
 282          if Y.curve != curve:
 283              raise Exception("Error: curve and point curve differ in public key!")
 284          self.curve = curve
 285          self.Y = Y
 286  
 287  class KeyPair(object):
 288      def __init__(self, pubkey, privkey):
 289          self.pubkey = pubkey
 290          self.privkey = privkey
 291  
 292  
 293  def fromprivkey(privkey, is_eckcdsa=False):
 294      curve = privkey.curve
 295      q = curve.q
 296      gx = curve.gx
 297      gy = curve.gy
 298      G = Point(curve, gx, gy)
 299      if is_eckcdsa == False:
 300          return PubKey(curve, privkey.x * G)
 301      else:
 302          return PubKey(curve, modinv(privkey.x, q) * G)
 303  
 304  def genKeyPair(curve, is_eckcdsa=False):
 305      p = curve.p
 306      q = curve.q
 307      gx = curve.gx
 308      gy = curve.gy
 309      G = Point(curve, gx, gy)
 310      OK = False
 311      while OK == False:
 312          x = getrandomint(q)
 313          if x == 0:
 314              continue
 315          OK = True
 316      privkey = PrivKey(curve, x)
 317      pubkey = fromprivkey(privkey, is_eckcdsa)
 318      return KeyPair(pubkey, privkey)
 319  
 320  ##########################################################
 321  ### Signature algorithms helpers
 322  def getrandomint(modulo):
 323      return random.randrange(0, modulo+1)
 324  
 325  def getbitlen(bint):
 326      """
 327      Returns the number of bits encoding an integer
 328      """
 329      if bint == None:
 330          return 0
 331      if bint == 0:
 332          # Zero is encoded on one bit
 333          return 1
 334      else:
 335          return int(bint).bit_length()
 336  
 337  def getbytelen(bint):
 338      """
 339      Returns the number of bytes encoding an integer
 340      """
 341      bitsize = getbitlen(bint)
 342      bytesize = int(bitsize // 8)
 343      if bitsize % 8 != 0:
 344          bytesize += 1
 345      return bytesize
 346  
 347  def stringtoint(bitstring):
 348      acc = 0
 349      size = len(bitstring)
 350      for i in range(0, size):
 351          acc = acc + (ord(bitstring[i]) * (2**(8*(size - 1 - i))))
 352      return acc
 353  
 354  def inttostring(a):
 355      size = int(getbytelen(a))
 356      outstr = ""
 357      for i in range(0, size):
 358          outstr = outstr + chr((a >> (8*(size - 1 - i))) & 0xFF)
 359      return outstr
 360  
 361  def expand(bitstring, bitlen, direction):
 362      bytelen = int(math.ceil(bitlen / 8.))
 363      if len(bitstring) >= bytelen:
 364          return bitstring
 365      else:
 366          if direction == "LEFT":
 367              return ((bytelen-len(bitstring))*"\x00") + bitstring
 368          elif direction == "RIGHT":
 369              return bitstring + ((bytelen-len(bitstring))*"\x00")
 370          else:
 371              raise Exception("Error: unknown direction "+direction+" in expand")
 372  
 373  def truncate(bitstring, bitlen, keep):
 374      """
 375      Takes a bit string and truncates it to keep the left
 376      most or the right most bits
 377      """
 378      strbitlen = 8*len(bitstring)
 379      # Check if truncation is needed
 380      if strbitlen > bitlen:
 381          if keep == "LEFT":
 382              return expand(inttostring(stringtoint(bitstring) >> int(strbitlen - bitlen)), bitlen, "LEFT")
 383          elif keep == "RIGHT":
 384              mask = (2**bitlen)-1
 385              return expand(inttostring(stringtoint(bitstring) & mask), bitlen, "LEFT")
 386          else:
 387              raise Exception("Error: unknown direction "+keep+" in truncate")
 388      else:
 389          # No need to truncate!
 390          return bitstring
 391  
 392  ##########################################################
 393  ### Hash algorithms
 394  def sha224(message):
 395      ctx = hashlib.sha224()
 396      if(is_python_2() == True):
 397          ctx.update(message)
 398          digest = ctx.digest()
 399      else:
 400          ctx.update(message.encode('latin-1'))
 401          digest = ctx.digest().decode('latin-1')
 402      return (digest, ctx.digest_size, ctx.block_size)
 403  
 404  def sha256(message):
 405      ctx = hashlib.sha256()
 406      if(is_python_2() == True):
 407          ctx.update(message)
 408          digest = ctx.digest()
 409      else:
 410          ctx.update(message.encode('latin-1'))
 411          digest = ctx.digest().decode('latin-1')
 412      return (digest, ctx.digest_size, ctx.block_size)
 413  
 414  def sha384(message):
 415      ctx = hashlib.sha384()
 416      if(is_python_2() == True):
 417          ctx.update(message)
 418          digest = ctx.digest()
 419      else:
 420          ctx.update(message.encode('latin-1'))
 421          digest = ctx.digest().decode('latin-1')
 422      return (digest, ctx.digest_size, ctx.block_size)
 423  
 424  def sha512(message):
 425      ctx = hashlib.sha512()
 426      if(is_python_2() == True):
 427          ctx.update(message)
 428          digest = ctx.digest()
 429      else:
 430          ctx.update(message.encode('latin-1'))
 431          digest = ctx.digest().decode('latin-1')
 432      return (digest, ctx.digest_size, ctx.block_size)
 433  
 434  def sha3_224(message):
 435      ctx = sha3.Sha3_ctx(224)
 436      if(is_python_2() == True):
 437          ctx.update(message)
 438          digest = ctx.digest()
 439      else:
 440          ctx.update(message.encode('latin-1'))
 441          digest = ctx.digest().decode('latin-1')
 442      return (digest, ctx.digest_size, ctx.block_size)
 443  
 444  def sha3_256(message):
 445      ctx = sha3.Sha3_ctx(256)
 446      if(is_python_2() == True):
 447          ctx.update(message)
 448          digest = ctx.digest()
 449      else:
 450          ctx.update(message.encode('latin-1'))
 451          digest = ctx.digest().decode('latin-1')
 452      return (digest, ctx.digest_size, ctx.block_size)
 453  
 454  def sha3_384(message):
 455      ctx = sha3.Sha3_ctx(384)
 456      if(is_python_2() == True):
 457          ctx.update(message)
 458          digest = ctx.digest()
 459      else:
 460          ctx.update(message.encode('latin-1'))
 461          digest = ctx.digest().decode('latin-1')
 462      return (digest, ctx.digest_size, ctx.block_size)
 463  
 464  def sha3_512(message):
 465      ctx = sha3.Sha3_ctx(512)
 466      if(is_python_2() == True):
 467          ctx.update(message)
 468          digest = ctx.digest()
 469      else:
 470          ctx.update(message.encode('latin-1'))
 471          digest = ctx.digest().decode('latin-1')
 472      return (digest, ctx.digest_size, ctx.block_size)
 473  
 474  ##########################################################
 475  ### Signature algorithms
 476  
 477  # *| IUF  - ECDSA signature
 478  # *|
 479  # *|  UF  1. Compute h = H(m)
 480  # *|   F  2. If |h| > bitlen(q), set h to bitlen(q)
 481  # *|         leftmost (most significant) bits of h
 482  # *|   F  3. e = OS2I(h) mod q
 483  # *|   F  4. Get a random value k in ]0,q[
 484  # *|   F  5. Compute W = (W_x,W_y) = kG
 485  # *|   F  6. Compute r = W_x mod q
 486  # *|   F  7. If r is 0, restart the process at step 4.
 487  # *|   F  8. If e == rx, restart the process at step 4.
 488  # *|   F  9. Compute s = k^-1 * (xr + e) mod q
 489  # *|   F 10. If s is 0, restart the process at step 4.
 490  # *|   F 11. Return (r,s)
 491  def ecdsa_sign(hashfunc, keypair, message, k=None):
 492      privkey = keypair.privkey
 493      # Get important parameters from the curve
 494      p = privkey.curve.p
 495      q = privkey.curve.q
 496      gx = privkey.curve.gx
 497      gy = privkey.curve.gy
 498      G = Point(privkey.curve, gx, gy)
 499      q_limit_len = getbitlen(q)
 500      # Compute the hash
 501      (h, _, _) = hashfunc(message)
 502      # Truncate hash value
 503      h = truncate(h, q_limit_len, "LEFT")
 504      # Convert the hash value to an int
 505      e = stringtoint(h) % q
 506      OK = False
 507      while OK == False:
 508          if k == None:
 509              k = getrandomint(q)
 510          if k == 0:
 511              continue
 512          W = k * G
 513          r = W.x % q
 514          if r == 0:
 515              continue
 516          if e == r * privkey.x:
 517              continue
 518          s = (modinv(k, q) * ((privkey.x * r) + e)) % q
 519          if s == 0:
 520              continue
 521          OK = True
 522      return ((expand(inttostring(r), 8*getbytelen(q), "LEFT") + expand(inttostring(s), 8*getbytelen(q), "LEFT")), k)
 523  
 524  # *| IUF  - ECDSA verification
 525  # *|
 526  # *| I    1. Reject the signature if r or s is 0.
 527  # *|  UF  2. Compute h = H(m)
 528  # *|   F  3. If |h| > bitlen(q), set h to bitlen(q)
 529  # *|         leftmost (most significant) bits of h
 530  # *|   F  4. Compute e = OS2I(h) mod q
 531  # *|   F  5. Compute u = (s^-1)e mod q
 532  # *|   F  6. Compute v = (s^-1)r mod q
 533  # *|   F  7. Compute W' = uG + vY
 534  # *|   F  8. If W' is the point at infinity, reject the signature.
 535  # *|   F  9. Compute r' = W'_x mod q
 536  # *|   F 10. Accept the signature if and only if r equals r'
 537  def ecdsa_verify(hashfunc, keypair, message, sig):
 538      pubkey = keypair.pubkey
 539      # Get important parameters from the curve
 540      p = pubkey.curve.p
 541      q = pubkey.curve.q
 542      gx = pubkey.curve.gx
 543      gy = pubkey.curve.gy
 544      q_limit_len = getbitlen(q)
 545      G = Point(pubkey.curve, gx, gy)
 546      # Extract r and s
 547      if len(sig) != 2*getbytelen(q):
 548          raise Exception("ECDSA verify: bad signature length!")
 549      r = stringtoint(sig[0:int(len(sig)/2)])
 550      s = stringtoint(sig[int(len(sig)/2):])
 551      if r == 0 or s == 0:
 552          return False
 553      # Compute the hash
 554      (h, _, _) = hashfunc(message)
 555      # Truncate hash value
 556      h = truncate(h, q_limit_len, "LEFT")
 557      # Convert the hash value to an int
 558      e = stringtoint(h) % q
 559      u = (modinv(s, q) * e) % q
 560      v = (modinv(s, q) * r) % q
 561      W_ = (u * G) + (v * pubkey.Y)
 562      if W_.x == None:
 563          return False
 564      r_ = W_.x % q
 565      if r == r_:
 566          return True
 567      else:
 568          return False
 569  
 570  def eckcdsa_genKeyPair(curve):
 571      return genKeyPair(curve, True)
 572  
 573  # *| IUF  - ECKCDSA signature
 574  # *|
 575  # *| IUF  1. Compute h = H(z||m)
 576  # *|   F  2. If hsize > bitlen(q), set h to bitlen(q)
 577  # *|         rightmost (less significant) bits of h.
 578  # *|   F  3. Get a random value k in ]0,q[
 579  # *|   F  4. Compute W = (W_x,W_y) = kG
 580  # *|   F  5. Compute r = h(FE2OS(W_x)).
 581  # *|   F  6. If hsize > bitlen(q), set r to bitlen(q)
 582  # *|         rightmost (less significant) bits of r.
 583  # *|   F  7. Compute e = OS2I(r XOR h) mod q
 584  # *|   F  8. Compute s = x(k - e) mod q
 585  # *|   F  9. if s == 0, restart at step 3.
 586  # *|   F 10. return (r,s)
 587  def eckcdsa_sign(hashfunc, keypair, message, k=None):
 588      privkey = keypair.privkey
 589      # Get important parameters from the curve
 590      p = privkey.curve.p
 591      q = privkey.curve.q
 592      gx = privkey.curve.gx
 593      gy = privkey.curve.gy
 594      G = Point(privkey.curve, gx, gy)
 595      q_limit_len = getbitlen(q)
 596      # Compute the certificate data
 597      (_, _, hblocksize) = hashfunc("")
 598      z = expand(inttostring(keypair.pubkey.Y.x), 8*getbytelen(p), "LEFT")
 599      z = z + expand(inttostring(keypair.pubkey.Y.y), 8*getbytelen(p), "LEFT")
 600      if len(z) > hblocksize:
 601          # Truncate
 602          z = truncate(z, 8*hblocksize, "LEFT")
 603      else:
 604          # Expand
 605          z = expand(z, 8*hblocksize, "RIGHT")
 606      # Compute the hash
 607      (h, _, _) = hashfunc(z + message)
 608      # Truncate hash value
 609      h = truncate(h, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT")
 610      OK = False
 611      while OK == False:
 612          if k == None:
 613              k = getrandomint(q)
 614          if k == 0:
 615              continue
 616          W = k * G
 617          (r, _, _) = hashfunc(expand(inttostring(W.x), 8*getbytelen(p), "LEFT"))
 618          r = truncate(r, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT")
 619          e = (stringtoint(r) ^ stringtoint(h)) % q
 620          s = (privkey.x * (k - e)) % q
 621          if s == 0:
 622              continue
 623          OK = True
 624      return (r + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k)
 625  
 626  # *| IUF - ECKCDSA verification
 627  # *|
 628  # *| I   1. Check the length of r:
 629  # *|         - if hsize > bitlen(q), r must be of
 630  # *|           length bitlen(q)
 631  # *|         - if hsize <= bitlen(q), r must be of
 632  # *|           length hsize
 633  # *| I   2. Check that s is in ]0,q[
 634  # *| IUF 3. Compute h = H(z||m)
 635  # *|   F 4. If hsize > bitlen(q), set h to bitlen(q)
 636  # *|        rightmost (less significant) bits of h.
 637  # *|   F 5. Compute e = OS2I(r XOR h) mod q
 638  # *|   F 6. Compute W' = sY + eG, where Y is the public key
 639  # *|   F 7. Compute r' = h(FE2OS(W'x))
 640  # *|   F 8. If hsize > bitlen(q), set r' to bitlen(q)
 641  # *|        rightmost (less significant) bits of r'.
 642  # *|   F 9. Check if r == r'
 643  def eckcdsa_verify(hashfunc, keypair, message, sig):
 644      pubkey = keypair.pubkey
 645      # Get important parameters from the curve
 646      p = pubkey.curve.p
 647      q = pubkey.curve.q
 648      gx = pubkey.curve.gx
 649      gy = pubkey.curve.gy
 650      G = Point(pubkey.curve, gx, gy)
 651      q_limit_len = getbitlen(q)
 652      (_, hsize, hblocksize) = hashfunc("")
 653      # Extract r and s
 654      if (8*hsize) > q_limit_len:
 655          r_len = int(math.ceil(q_limit_len / 8.))
 656      else:
 657          r_len = hsize
 658      r = stringtoint(sig[0:int(r_len)])
 659      s = stringtoint(sig[int(r_len):])
 660      if (s >= q) or (s < 0):
 661          return False
 662      # Compute the certificate data
 663      z = expand(inttostring(keypair.pubkey.Y.x), 8*getbytelen(p), "LEFT")
 664      z = z + expand(inttostring(keypair.pubkey.Y.y), 8*getbytelen(p), "LEFT")
 665      if len(z) > hblocksize:
 666          # Truncate
 667          z = truncate(z, 8*hblocksize, "LEFT")
 668      else:
 669          # Expand
 670          z = expand(z, 8*hblocksize, "RIGHT")
 671      # Compute the hash
 672      (h, _, _) = hashfunc(z + message)
 673      # Truncate hash value
 674      h = truncate(h, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT")
 675      e = (r ^ stringtoint(h)) % q
 676      W_ = (s * pubkey.Y) + (e * G)
 677      (h, _, _) = hashfunc(expand(inttostring(W_.x), 8*getbytelen(p), "LEFT"))
 678      r_ = truncate(h, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT")
 679      if stringtoint(r_) == r:
 680          return True
 681      else:
 682          return False
 683  
 684  # *| IUF - ECFSDSA signature
 685  # *|
 686  # *| I   1. Get a random value k in ]0,q[
 687  # *| I   2. Compute W = (W_x,W_y) = kG
 688  # *| I   3. Compute r = FE2OS(W_x)||FE2OS(W_y)
 689  # *| I   4. If r is an all zero string, restart the process at step 1.
 690  # *| IUF 5. Compute h = H(r||m)
 691  # *|   F 6. Compute e = OS2I(h) mod q
 692  # *|   F 7. Compute s = (k + ex) mod q
 693  # *|   F 8. If s is 0, restart the process at step 1 (see c. below)
 694  # *|   F 9. Return (r,s)
 695  def ecfsdsa_sign(hashfunc, keypair, message, k=None):
 696      privkey = keypair.privkey
 697      # Get important parameters from the curve
 698      p = privkey.curve.p
 699      q = privkey.curve.q
 700      gx = privkey.curve.gx
 701      gy = privkey.curve.gy
 702      G = Point(privkey.curve, gx, gy)
 703      OK = False
 704      while OK == False:
 705          if k == None:
 706              k = getrandomint(q)
 707          if k == 0:
 708              continue
 709          W = k * G
 710          r = expand(inttostring(W.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W.y), 8*getbytelen(p), "LEFT")
 711          if stringtoint(r) == 0:
 712              continue
 713          (h, _, _) = hashfunc(r + message)
 714          e = stringtoint(h) % q
 715          s = (k + e * privkey.x) % q
 716          if s == 0:
 717              continue
 718          OK = True
 719      return (r + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k)
 720  
 721  
 722  # *| IUF - ECFSDSA verification
 723  # *|
 724  # *| I   1. Reject the signature if r is not a valid point on the curve.
 725  # *| I   2. Reject the signature if s is not in ]0,q[
 726  # *| IUF 3. Compute h = H(r||m)
 727  # *|   F 4. Convert h to an integer and then compute e = -h mod q
 728  # *|   F 5. compute W' = sG + eY, where Y is the public key
 729  # *|   F 6. Compute r' = FE2OS(W'_x)||FE2OS(W'_y)
 730  # *|   F 7. Accept the signature if and only if r equals r'
 731  def ecfsdsa_verify(hashfunc, keypair, message, sig):
 732      pubkey = keypair.pubkey
 733      # Get important parameters from the curve
 734      p = pubkey.curve.p
 735      q = pubkey.curve.q
 736      gx = pubkey.curve.gx
 737      gy = pubkey.curve.gy
 738      G = Point(pubkey.curve, gx, gy)
 739      # Extract coordinates from r and s from signature
 740      if len(sig) != (2*getbytelen(p)) + getbytelen(q):
 741          raise Exception("ECFSDSA verify: bad signature length!")
 742      wx = sig[:int(getbytelen(p))]
 743      wy = sig[int(getbytelen(p)):int(2*getbytelen(p))]
 744      r = wx + wy
 745      s = stringtoint(sig[int(2*getbytelen(p)):int((2*getbytelen(p))+getbytelen(q))])
 746      # Check r is on the curve
 747      W = Point(pubkey.curve, stringtoint(wx), stringtoint(wy))
 748      # Check s is in ]0,q[
 749      if s == 0 or s > q:
 750          raise Exception("ECFSDSA verify: s not in ]0,q[")
 751      (h, _, _) = hashfunc(r + message)
 752      e = (-stringtoint(h)) % q
 753      W_ = s * G + e * pubkey.Y
 754      r_ = expand(inttostring(W_.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W_.y), 8*getbytelen(p), "LEFT")
 755      if r == r_:
 756          return True
 757      else:
 758          return False
 759  
 760  
 761  # NOTE: ISO/IEC 14888-3 standard seems to diverge from the existing implementations
 762  # of ECRDSA when treating the message hash, and from the examples of certificates provided
 763  # in RFC 7091 and draft-deremin-rfc4491-bis. While in ISO/IEC 14888-3 it is explicitely asked
 764  # to proceed with the hash of the message as big endian, the RFCs derived from the Russian
 765  # standard expect the hash value to be treated as little endian when importing it as an integer
 766  # (this discrepancy is exhibited and confirmed by test vectors present in ISO/IEC 14888-3, and
 767  # by X.509 certificates present in the RFCs). This seems (to be confirmed) to be a discrepancy of
 768  # ISO/IEC 14888-3 algorithm description that must be fixed there.
 769  #
 770  # In order to be conservative, libecc uses the Russian standard behavior as expected to be in line with
 771  # other implemetations, but keeps the ISO/IEC 14888-3 behavior if forced/asked by the user using
 772  # the USE_ISO14888_3_ECRDSA toggle. This allows to keep backward compatibility with previous versions of the
 773  # library if needed.
 774  
 775  # *| IUF - ECRDSA signature
 776  # *|
 777  # *|  UF  1. Compute h = H(m)
 778  # *|   F  2. Get a random value k in ]0,q[
 779  # *|   F  3. Compute W = (W_x,W_y) = kG
 780  # *|   F  4. Compute r = W_x mod q
 781  # *|   F  5. If r is 0, restart the process at step 2.
 782  # *|   F  6. Compute e = OS2I(h) mod q. If e is 0, set e to 1.
 783  # *|         NOTE: here, ISO/IEC 14888-3 and RFCs differ in the way e treated.
 784  # *|         e = OS2I(h) for ISO/IEC 14888-3, or e = OS2I(reversed(h)) when endianness of h
 785  # *|         is reversed for RFCs.
 786  # *|   F  7. Compute s = (rx + ke) mod q
 787  # *|   F  8. If s is 0, restart the process at step 2.
 788  # *|   F 11. Return (r,s)
 789  def ecrdsa_sign(hashfunc, keypair, message, k=None, use_iso14888_divergence=False):
 790      privkey = keypair.privkey
 791      # Get important parameters from the curve
 792      p = privkey.curve.p
 793      q = privkey.curve.q
 794      gx = privkey.curve.gx
 795      gy = privkey.curve.gy
 796      G = Point(privkey.curve, gx, gy)
 797      (h, _, _) = hashfunc(message)
 798      if use_iso14888_divergence == False:
 799          # Reverse the endianness for Russian standard RFC ECRDSA (contrary to ISO/IEC 14888-3 case)
 800          h = h[::-1]
 801      OK = False
 802      while OK == False:
 803          if k == None:
 804              k = getrandomint(q)
 805          if k == 0:
 806              continue
 807          W = k * G
 808          r = W.x % q
 809          if r == 0:
 810              continue
 811          e = stringtoint(h) % q
 812          if e == 0:
 813              e = 1
 814          s = ((r * privkey.x) + (k * e)) % q
 815          if s == 0:
 816              continue
 817          OK = True
 818      return (expand(inttostring(r), 8*getbytelen(q), "LEFT") + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k)
 819  
 820  # *| IUF - ECRDSA verification
 821  # *|
 822  # *|  UF 1. Check that r and s are both in ]0,q[
 823  # *|   F 2. Compute h = H(m)
 824  # *|   F 3. Compute e = OS2I(h)^-1 mod q
 825  # *|         NOTE: here, ISO/IEC 14888-3 and RFCs differ in the way e treated.
 826  # *|         e = OS2I(h) for ISO/IEC 14888-3, or e = OS2I(reversed(h)) when endianness of h
 827  # *|         is reversed for RFCs.
 828  # *|   F 4. Compute u = es mod q
 829  # *|   F 4. Compute v = -er mod q
 830  # *|   F 5. Compute W' = uG + vY = (W'_x, W'_y)
 831  # *|   F 6. Let's now compute r' = W'_x mod q
 832  # *|   F 7. Check r and r' are the same
 833  def ecrdsa_verify(hashfunc, keypair, message, sig, use_iso14888_divergence=False):
 834      pubkey = keypair.pubkey
 835      # Get important parameters from the curve
 836      p = pubkey.curve.p
 837      q = pubkey.curve.q
 838      gx = pubkey.curve.gx
 839      gy = pubkey.curve.gy
 840      G = Point(pubkey.curve, gx, gy)
 841      # Extract coordinates from r and s from signature
 842      if len(sig) != 2*getbytelen(q):
 843          raise Exception("ECRDSA verify: bad signature length!")
 844      r = stringtoint(sig[:int(getbytelen(q))])
 845      s = stringtoint(sig[int(getbytelen(q)):int(2*getbytelen(q))])
 846      if r == 0 or r > q:
 847          raise Exception("ECRDSA verify: r not in ]0,q[")
 848      if s == 0 or s > q:
 849          raise Exception("ECRDSA verify: s not in ]0,q[")
 850      (h, _, _) = hashfunc(message)
 851      if use_iso14888_divergence == False:
 852          # Reverse the endianness for Russian standard RFC ECRDSA (contrary to ISO/IEC 14888-3 case)
 853          h = h[::-1]
 854      e = modinv(stringtoint(h) % q, q)
 855      u = (e * s) % q
 856      v = (-e * r) % q
 857      W_ = u * G + v * pubkey.Y
 858      r_ = W_.x % q
 859      if r == r_:
 860          return True
 861      else:
 862          return False
 863  
 864  
 865  # *| IUF - ECGDSA signature
 866  # *|
 867  # *|  UF 1. Compute h = H(m). If |h| > bitlen(q), set h to bitlen(q)
 868  # *|         leftmost (most significant) bits of h
 869  # *|   F 2. Convert e = - OS2I(h) mod q
 870  # *|   F 3. Get a random value k in ]0,q[
 871  # *|   F 4. Compute W = (W_x,W_y) = kG
 872  # *|   F 5. Compute r = W_x mod q
 873  # *|   F 6. If r is 0, restart the process at step 4.
 874  # *|   F 7. Compute s = x(kr + e) mod q
 875  # *|   F 8. If s is 0, restart the process at step 4.
 876  # *|   F 9. Return (r,s)
 877  def ecgdsa_sign(hashfunc, keypair, message, k=None):
 878      privkey = keypair.privkey
 879      # Get important parameters from the curve
 880      p = privkey.curve.p
 881      q = privkey.curve.q
 882      gx = privkey.curve.gx
 883      gy = privkey.curve.gy
 884      G = Point(privkey.curve, gx, gy)
 885      (h, _, _) = hashfunc(message)
 886      q_limit_len = getbitlen(q)
 887      # Truncate hash value
 888      h = truncate(h, q_limit_len, "LEFT")
 889      e = (-stringtoint(h)) % q
 890      OK = False
 891      while OK == False:
 892          if k == None:
 893              k = getrandomint(q)
 894          if k == 0:
 895              continue
 896          W = k * G
 897          r = W.x % q
 898          if r == 0:
 899              continue
 900          s = (privkey.x * ((k * r) + e)) % q
 901          if s == 0:
 902              continue
 903          OK = True
 904      return (expand(inttostring(r), 8*getbytelen(q), "LEFT") + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k)
 905  
 906  # *| IUF - ECGDSA verification
 907  # *|
 908  # *| I   1. Reject the signature if r or s is 0.
 909  # *|  UF 2. Compute h = H(m). If |h| > bitlen(q), set h to bitlen(q)
 910  # *|         leftmost (most significant) bits of h
 911  # *|   F 3. Compute e = OS2I(h) mod q
 912  # *|   F 4. Compute u = ((r^-1)e mod q)
 913  # *|   F 5. Compute v = ((r^-1)s mod q)
 914  # *|   F 6. Compute W' = uG + vY
 915  # *|   F 7. Compute r' = W'_x mod q
 916  # *|   F 8. Accept the signature if and only if r equals r'
 917  def ecgdsa_verify(hashfunc, keypair, message, sig):
 918      pubkey = keypair.pubkey
 919      # Get important parameters from the curve
 920      p = pubkey.curve.p
 921      q = pubkey.curve.q
 922      gx = pubkey.curve.gx
 923      gy = pubkey.curve.gy
 924      G = Point(pubkey.curve, gx, gy)
 925      # Extract coordinates from r and s from signature
 926      if len(sig) != 2*getbytelen(q):
 927          raise Exception("ECGDSA verify: bad signature length!")
 928      r = stringtoint(sig[:int(getbytelen(q))])
 929      s = stringtoint(sig[int(getbytelen(q)):int(2*getbytelen(q))])
 930      if r == 0 or r > q:
 931          raise Exception("ECGDSA verify: r not in ]0,q[")
 932      if s == 0 or s > q:
 933          raise Exception("ECGDSA verify: s not in ]0,q[")
 934      (h, _, _) = hashfunc(message)
 935      q_limit_len = getbitlen(q)
 936      # Truncate hash value
 937      h = truncate(h, q_limit_len, "LEFT")
 938      e = stringtoint(h) % q
 939      r_inv = modinv(r, q)
 940      u = (r_inv * e) % q
 941      v = (r_inv * s) % q
 942      W_ = u * G + v * pubkey.Y
 943      r_ = W_.x % q
 944      if r == r_:
 945          return True
 946      else:
 947          return False
 948  
 949  # *| IUF - ECSDSA/ECOSDSA signature
 950  # *|
 951  # *| I   1. Get a random value k in ]0, q[
 952  # *| I   2. Compute W = kG = (Wx, Wy)
 953  # *| IUF 3. Compute r = H(Wx [|| Wy] || m)
 954  # *|        - In the normal version (ECSDSA), r = h(Wx || Wy || m).
 955  # *|        - In the optimized version (ECOSDSA), r = h(Wx || m).
 956  # *|   F 4. Compute e = OS2I(r) mod q
 957  # *|   F 5. if e == 0, restart at step 1.
 958  # *|   F 6. Compute s = (k + ex) mod q.
 959  # *|   F 7. if s == 0, restart at step 1.
 960  # *|   F 8. Return (r, s)
 961  def ecsdsa_common_sign(hashfunc, keypair, message, optimized, k=None):
 962      privkey = keypair.privkey
 963      # Get important parameters from the curve
 964      p = privkey.curve.p
 965      q = privkey.curve.q
 966      gx = privkey.curve.gx
 967      gy = privkey.curve.gy
 968      G = Point(privkey.curve, gx, gy)
 969      OK = False
 970      while OK == False:
 971          if k == None:
 972              k = getrandomint(q)
 973          if k == 0:
 974              continue
 975          W = k * G
 976          if optimized == False:
 977              (r, _, _) = hashfunc(expand(inttostring(W.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W.y), 8*getbytelen(p), "LEFT") + message)
 978          else:
 979              (r, _, _) = hashfunc(expand(inttostring(W.x), 8*getbytelen(p), "LEFT") + message)
 980          e = stringtoint(r) % q
 981          if e == 0:
 982              continue
 983          s = (k + (e * privkey.x)) % q
 984          if s == 0:
 985              continue
 986          OK = True
 987      return (r + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k)
 988  
 989  def ecsdsa_sign(hashfunc, keypair, message, k=None):
 990      return ecsdsa_common_sign(hashfunc, keypair, message, False, k)
 991  
 992  def ecosdsa_sign(hashfunc, keypair, message, k=None):
 993      return ecsdsa_common_sign(hashfunc, keypair, message, True, k)
 994  
 995  # *| IUF - ECSDSA/ECOSDSA verification
 996  # *|
 997  # *| I   1. if s is not in ]0,q[, reject the signature.x
 998  # *| I   2. Compute e = -r mod q
 999  # *| I   3. If e == 0, reject the signature.
1000  # *| I   4. Compute W' = sG + eY
1001  # *| IUF 5. Compute r' = H(W'x [|| W'y] || m)
1002  # *|        - In the normal version (ECSDSA), r = h(W'x || W'y || m).
1003  # *|        - In the optimized version (ECOSDSA), r = h(W'x || m).
1004  # *|   F 6. Accept the signature if and only if r and r' are the same
1005  def ecsdsa_common_verify(hashfunc, keypair, message, sig, optimized):
1006      pubkey = keypair.pubkey
1007      # Get important parameters from the curve
1008      p = pubkey.curve.p
1009      q = pubkey.curve.q
1010      gx = pubkey.curve.gx
1011      gy = pubkey.curve.gy
1012      G = Point(pubkey.curve, gx, gy)
1013      (_, hlen, _) = hashfunc("")
1014      # Extract coordinates from r and s from signature
1015      if len(sig) != hlen + getbytelen(q):
1016          raise Exception("EC[O]SDSA verify: bad signature length!")
1017      r = stringtoint(sig[:int(hlen)])
1018      s = stringtoint(sig[int(hlen):int(hlen+getbytelen(q))])
1019      if s == 0 or s > q:
1020          raise Exception("EC[O]DSA verify: s not in ]0,q[")
1021      e = (-r) % q
1022      if e == 0:
1023          raise Exception("EC[O]DSA verify: e is null")
1024      W_ = s * G + e * pubkey.Y
1025      if optimized == False:
1026          (r_, _, _) = hashfunc(expand(inttostring(W_.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W_.y), 8*getbytelen(p), "LEFT") + message)
1027      else:
1028          (r_, _, _) = hashfunc(expand(inttostring(W_.x), 8*getbytelen(p), "LEFT") + message)
1029      if sig[:int(hlen)] == r_:
1030          return True
1031      else:
1032          return False
1033  
1034  def ecsdsa_verify(hashfunc, keypair, message, sig):
1035      return ecsdsa_common_verify(hashfunc, keypair, message, sig, False)
1036  
1037  def ecosdsa_verify(hashfunc, keypair, message, sig):
1038      return ecsdsa_common_verify(hashfunc, keypair, message, sig, True)
1039  
1040  
1041  ##########################################################
1042  ### Generate self-tests for all the algorithms
1043  
1044  all_hash_funcs = [ (sha224, "SHA224"), (sha256, "SHA256"), (sha384, "SHA384"), (sha512, "SHA512"), (sha3_224, "SHA3_224"), (sha3_256, "SHA3_256"), (sha3_384, "SHA3_384"), (sha3_512, "SHA3_512") ]
1045  
1046  all_sig_algs = [ (ecdsa_sign, ecdsa_verify, genKeyPair, "ECDSA"),
1047           (eckcdsa_sign, eckcdsa_verify, eckcdsa_genKeyPair, "ECKCDSA"),
1048           (ecfsdsa_sign, ecfsdsa_verify, genKeyPair, "ECFSDSA"),
1049           (ecrdsa_sign, ecrdsa_verify, genKeyPair, "ECRDSA"),
1050           (ecgdsa_sign, ecgdsa_verify, eckcdsa_genKeyPair, "ECGDSA"),
1051           (ecsdsa_sign, ecsdsa_verify, genKeyPair, "ECSDSA"),
1052           (ecosdsa_sign, ecosdsa_verify, genKeyPair, "ECOSDSA"), ]
1053  
1054  
1055  curr_test = 0
1056  def pretty_print_curr_test(num_test, total_gen_tests):
1057      num_decimal = int(math.log10(total_gen_tests))+1
1058      format_buf = "%0"+str(num_decimal)+"d/%0"+str(num_decimal)+"d"
1059      sys.stdout.write('\b'*((2*num_decimal)+1))
1060      sys.stdout.flush()
1061      sys.stdout.write(format_buf % (num_test, total_gen_tests))
1062      if num_test == total_gen_tests:
1063          print("")
1064      return
1065  
1066  def gen_self_test(curve, hashfunc, sig_alg_sign, sig_alg_verify, sig_alg_genkeypair, num, hashfunc_name, sig_alg_name, total_gen_tests):
1067      global curr_test
1068      curr_test = curr_test + 1
1069      if num != 0:
1070          pretty_print_curr_test(curr_test, total_gen_tests)
1071      output_list = []
1072      for test_num in range(0, num):
1073          out_vectors = ""
1074          # Generate a random key pair
1075          keypair = sig_alg_genkeypair(curve)
1076          # Generate a random message with a random size
1077          size = getrandomint(256)
1078          if is_python_2():
1079              message = ''.join([random.choice(string.ascii_letters + string.digits) for n in xrange(size)])
1080          else:
1081              message = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(size)])
1082          test_name = sig_alg_name + "_" + hashfunc_name + "_" + curve.name.upper() + "_" + str(test_num)
1083          # Sign the message
1084          (sig, k) = sig_alg_sign(hashfunc, keypair, message)
1085          # Check that everything is OK with a verify
1086          if sig_alg_verify(hashfunc, keypair, message, sig) != True:
1087              raise Exception("Error during self test generation: sig verify failed! "+test_name+ "   /  msg="+message+"   /   sig="+binascii.hexlify(sig)+"    /    k="+hex(k)+"   /   privkey.x="+hex(keypair.privkey.x))
1088          if sig_alg_name == "ECRDSA":
1089              out_vectors += "#ifndef USE_ISO14888_3_ECRDSA\n"
1090          # Now generate the test vector
1091          out_vectors += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"\n"
1092          out_vectors += "#ifdef WITH_CURVE_"+curve.name.upper()+"\n"
1093          out_vectors += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"\n"
1094          out_vectors += "/* "+test_name+" known test vectors */\n"
1095          out_vectors += "static int "+test_name+"_test_vectors_get_random(nn_t out, nn_src_t q)\n{\n"
1096          # k_buf MUST be exported padded to the length of q
1097          out_vectors += "\tconst u8 k_buf[] = "+bigint_to_C_array(k, getbytelen(curve.q))
1098          out_vectors += "\tint ret, cmp;\n\tret = nn_init_from_buf(out, k_buf, sizeof(k_buf)); EG(ret, err);\n\tret = nn_cmp(out, q, &cmp); EG(ret, err);\n\tret = (cmp >= 0) ? -1 : 0;\nerr:\n\treturn ret;\n}\n"
1099          out_vectors += "static const u8 "+test_name+"_test_vectors_priv_key[] = \n"+bigint_to_C_array(keypair.privkey.x, getbytelen(keypair.privkey.x))
1100          out_vectors += "static const u8 "+test_name+"_test_vectors_expected_sig[] = \n"+bigint_to_C_array(stringtoint(sig), len(sig))
1101          out_vectors += "static const ec_test_case "+test_name+"_test_case = {\n"
1102          out_vectors += "\t.name = \""+test_name+"\",\n"
1103          out_vectors += "\t.ec_str_p = &"+curve.name+"_str_params,\n"
1104          out_vectors += "\t.priv_key = "+test_name+"_test_vectors_priv_key,\n"
1105          out_vectors += "\t.priv_key_len = sizeof("+test_name+"_test_vectors_priv_key),\n"
1106          out_vectors += "\t.nn_random = "+test_name+"_test_vectors_get_random,\n"
1107          out_vectors += "\t.hash_type = "+hashfunc_name+",\n"
1108          out_vectors += "\t.msg = \""+message+"\",\n"
1109          out_vectors += "\t.msglen = "+str(len(message))+",\n"
1110          out_vectors += "\t.sig_type = "+sig_alg_name+",\n"
1111          out_vectors += "\t.exp_sig = "+test_name+"_test_vectors_expected_sig,\n"
1112          out_vectors += "\t.exp_siglen = sizeof("+test_name+"_test_vectors_expected_sig),\n};\n"
1113          out_vectors += "#endif /* WITH_HASH_"+hashfunc_name+" */\n"
1114          out_vectors += "#endif /* WITH_CURVE_"+curve.name+" */\n"
1115          out_vectors += "#endif /* WITH_SIG_"+sig_alg_name+" */\n"
1116          if sig_alg_name == "ECRDSA":
1117              out_vectors += "#endif /* !USE_ISO14888_3_ECRDSA */\n"
1118          out_name = ""
1119          if sig_alg_name == "ECRDSA":
1120              out_name += "#ifndef USE_ISO14888_3_ECRDSA"+"/* For "+test_name+" */\n"
1121          out_name += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"/* For "+test_name+" */\n"
1122          out_name += "#ifdef WITH_CURVE_"+curve.name.upper()+"/* For "+test_name+" */\n"
1123          out_name += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"/* For "+test_name+" */\n"
1124          out_name += "\t&"+test_name+"_test_case,\n"
1125          out_name += "#endif /* WITH_HASH_"+hashfunc_name+" for "+test_name+" */\n"
1126          out_name += "#endif /* WITH_CURVE_"+curve.name+" for "+test_name+" */\n"
1127          out_name += "#endif /* WITH_SIG_"+sig_alg_name+" for "+test_name+" */"
1128          if sig_alg_name == "ECRDSA":
1129              out_name += "\n#endif /* !USE_ISO14888_3_ECRDSA */"+"/* For "+test_name+" */"
1130          output_list.append((out_name, out_vectors))
1131          # In the specific case of ECRDSA, we also generate an ISO/IEC compatible test vector
1132          if sig_alg_name == "ECRDSA":
1133              out_vectors = ""
1134              (sig, k) = sig_alg_sign(hashfunc, keypair, message, use_iso14888_divergence=True)
1135              # Check that everything is OK with a verify
1136              if sig_alg_verify(hashfunc, keypair, message, sig, use_iso14888_divergence=True) != True:
1137                  raise Exception("Error during self test generation: sig verify failed! "+test_name+ "   /  msg="+message+"   /   sig="+binascii.hexlify(sig)+"    /    k="+hex(k)+"   /   privkey.x="+hex(keypair.privkey.x))
1138              out_vectors += "#ifdef USE_ISO14888_3_ECRDSA\n"
1139              # Now generate the test vector
1140              out_vectors += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"\n"
1141              out_vectors += "#ifdef WITH_CURVE_"+curve.name.upper()+"\n"
1142              out_vectors += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"\n"
1143              out_vectors += "/* "+test_name+" known test vectors */\n"
1144              out_vectors += "static int "+test_name+"_test_vectors_get_random(nn_t out, nn_src_t q)\n{\n"
1145              # k_buf MUST be exported padded to the length of q
1146              out_vectors += "\tconst u8 k_buf[] = "+bigint_to_C_array(k, getbytelen(curve.q))
1147              out_vectors += "\tint ret, cmp;\n\tret = nn_init_from_buf(out, k_buf, sizeof(k_buf)); EG(ret, err);\n\tret = nn_cmp(out, q, &cmp); EG(ret, err);\n\tret = (cmp >= 0) ? -1 : 0;\nerr:\n\treturn ret;\n}\n"
1148              out_vectors += "static const u8 "+test_name+"_test_vectors_priv_key[] = \n"+bigint_to_C_array(keypair.privkey.x, getbytelen(keypair.privkey.x))
1149              out_vectors += "static const u8 "+test_name+"_test_vectors_expected_sig[] = \n"+bigint_to_C_array(stringtoint(sig), len(sig))
1150              out_vectors += "static const ec_test_case "+test_name+"_test_case = {\n"
1151              out_vectors += "\t.name = \""+test_name+"\",\n"
1152              out_vectors += "\t.ec_str_p = &"+curve.name+"_str_params,\n"
1153              out_vectors += "\t.priv_key = "+test_name+"_test_vectors_priv_key,\n"
1154              out_vectors += "\t.priv_key_len = sizeof("+test_name+"_test_vectors_priv_key),\n"
1155              out_vectors += "\t.nn_random = "+test_name+"_test_vectors_get_random,\n"
1156              out_vectors += "\t.hash_type = "+hashfunc_name+",\n"
1157              out_vectors += "\t.msg = \""+message+"\",\n"
1158              out_vectors += "\t.msglen = "+str(len(message))+",\n"
1159              out_vectors += "\t.sig_type = "+sig_alg_name+",\n"
1160              out_vectors += "\t.exp_sig = "+test_name+"_test_vectors_expected_sig,\n"
1161              out_vectors += "\t.exp_siglen = sizeof("+test_name+"_test_vectors_expected_sig),\n};\n"
1162              out_vectors += "#endif /* WITH_HASH_"+hashfunc_name+" */\n"
1163              out_vectors += "#endif /* WITH_CURVE_"+curve.name+" */\n"
1164              out_vectors += "#endif /* WITH_SIG_"+sig_alg_name+" */\n"
1165              out_vectors += "#endif /* USE_ISO14888_3_ECRDSA */\n"
1166              out_name = ""
1167              out_name += "#ifdef USE_ISO14888_3_ECRDSA"+"/* For "+test_name+" */\n"
1168              out_name += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"/* For "+test_name+" */\n"
1169              out_name += "#ifdef WITH_CURVE_"+curve.name.upper()+"/* For "+test_name+" */\n"
1170              out_name += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"/* For "+test_name+" */\n"
1171              out_name += "\t&"+test_name+"_test_case,\n"
1172              out_name += "#endif /* WITH_HASH_"+hashfunc_name+" for "+test_name+" */\n"
1173              out_name += "#endif /* WITH_CURVE_"+curve.name+" for "+test_name+" */\n"
1174              out_name += "#endif /* WITH_SIG_"+sig_alg_name+" for "+test_name+" */\n"
1175              out_name += "#endif /* USE_ISO14888_3_ECRDSA */"+"/* For "+test_name+" */"
1176              output_list.append((out_name, out_vectors))
1177  
1178      return output_list
1179  
1180  def gen_self_tests(curve, num):
1181      global curr_test
1182      curr_test = 0
1183      total_gen_tests = len(all_hash_funcs) * len(all_sig_algs)
1184      vectors = [[ gen_self_test(curve, hashf, sign, verify, genkp, num, hash_name, sig_alg_name, total_gen_tests)
1185                 for (hashf, hash_name) in all_hash_funcs ] for (sign, verify, genkp, sig_alg_name) in all_sig_algs ]
1186      return vectors
1187  
1188  ##########################################################
1189  ### ASN.1 stuff
1190  def parse_DER_extract_size(derbuf):
1191      # Extract the size
1192      if ord(derbuf[0]) & 0x80 != 0:
1193          encoding_len_bytes = ord(derbuf[0]) & ~0x80
1194          # Skip
1195          base = 1
1196      else:
1197          encoding_len_bytes = 1
1198          base = 0
1199      if len(derbuf) < encoding_len_bytes+1:
1200          return (False, 0, 0)
1201      else:
1202          length = stringtoint(derbuf[base:base+encoding_len_bytes])
1203          if len(derbuf) < length+encoding_len_bytes:
1204              return (False, 0, 0)
1205          else:
1206              return (True, encoding_len_bytes+base, length)
1207  
1208  def extract_DER_object(derbuf, object_tag):
1209      # Check type
1210      if ord(derbuf[0]) != object_tag:
1211          # Not the type we expect ...
1212          return (False, 0, "")
1213      else:
1214          derbuf = derbuf[1:]
1215          # Extract the size
1216          (check, encoding_len, size) = parse_DER_extract_size(derbuf)
1217          if check == False:
1218              return (False, 0, "")
1219          else:
1220              if len(derbuf) < encoding_len + size:
1221                  return (False, 0, "")
1222              else:
1223                  return (True, size+encoding_len+1, derbuf[encoding_len:encoding_len+size])
1224  
1225  def extract_DER_sequence(derbuf):
1226      return extract_DER_object(derbuf, 0x30)
1227  
1228  def extract_DER_integer(derbuf):
1229      return extract_DER_object(derbuf, 0x02)
1230  
1231  def extract_DER_octetstring(derbuf):
1232      return extract_DER_object(derbuf, 0x04)
1233  
1234  def extract_DER_bitstring(derbuf):
1235      return extract_DER_object(derbuf, 0x03)
1236  
1237  def extract_DER_oid(derbuf):
1238      return extract_DER_object(derbuf, 0x06)
1239  
1240  # See ECParameters sequence in RFC 3279
1241  def parse_DER_ECParameters(derbuf):
1242      # XXX: this is a very ugly way of extracting the information
1243      # regarding an EC curve, but since the ASN.1 structure is quite
1244      # "static", this might be sufficient without embedding a full
1245      # ASN.1 parser ...
1246      # Default return (a, b, prime, order, cofactor, gx, gy)
1247      default_ret = (0, 0, 0, 0, 0, 0, 0)
1248      # Get ECParameters wrapping sequence
1249      (check, size_ECParameters, ECParameters) = extract_DER_sequence(derbuf)
1250      if check == False:
1251          return (False, default_ret)
1252      # Get integer
1253      (check, size_ECPVer, ECPVer) = extract_DER_integer(ECParameters)
1254      if check == False:
1255          return (False, default_ret)
1256      # Get sequence
1257      (check, size_FieldID, FieldID) = extract_DER_sequence(ECParameters[size_ECPVer:])
1258      if check == False:
1259          return (False, default_ret)
1260      # Get OID
1261      (check, size_Oid, Oid) = extract_DER_oid(FieldID)
1262      if check == False:
1263          return (False, default_ret)
1264      # Does the OID correspond to a prime field?
1265      if(Oid != "\x2A\x86\x48\xCE\x3D\x01\x01"):
1266          print("DER parse error: only prime fields are supported ...")
1267          return (False, default_ret)
1268      # Get prime p of prime field
1269      (check, size_P, P) = extract_DER_integer(FieldID[size_Oid:])
1270      if check == False:
1271          return (False, default_ret)
1272      # Get curve (sequence)
1273      (check, size_Curve, Curve) = extract_DER_sequence(ECParameters[size_ECPVer+size_FieldID:])
1274      if check == False:
1275          return (False, default_ret)
1276      # Get A in curve
1277      (check, size_A, A) = extract_DER_octetstring(Curve)
1278      if check == False:
1279          return (False, default_ret)
1280      # Get B in curve
1281      (check, size_B, B) = extract_DER_octetstring(Curve[size_A:])
1282      if check == False:
1283          return (False, default_ret)
1284      # Get ECPoint
1285      (check, size_ECPoint, ECPoint) = extract_DER_octetstring(ECParameters[size_ECPVer+size_FieldID+size_Curve:])
1286      if check == False:
1287          return (False, default_ret)
1288      # Get Order
1289      (check, size_Order, Order) = extract_DER_integer(ECParameters[size_ECPVer+size_FieldID+size_Curve+size_ECPoint:])
1290      if check == False:
1291          return (False, default_ret)
1292      # Get Cofactor
1293      (check, size_Cofactor, Cofactor) = extract_DER_integer(ECParameters[size_ECPVer+size_FieldID+size_Curve+size_ECPoint+size_Order:])
1294      if check == False:
1295          return (False, default_ret)
1296      # If we end up here, everything is OK, we can extract all our elements
1297      prime = stringtoint(P)
1298      a = stringtoint(A)
1299      b = stringtoint(B)
1300      order = stringtoint(Order)
1301      cofactor = stringtoint(Cofactor)
1302      # Extract Gx and Gy, see X9.62-1998
1303      if len(ECPoint) < 1:
1304          return (False, default_ret)
1305      ECPoint_type = ord(ECPoint[0])
1306      if (ECPoint_type == 0x04) or (ECPoint_type == 0x06) or (ECPoint_type == 0x07):
1307          # Uncompressed and hybrid points
1308          if len(ECPoint[1:]) % 2 != 0:
1309              return (False, default_ret)
1310          ECPoint = ECPoint[1:]
1311          gx = stringtoint(ECPoint[:int(len(ECPoint)/2)])
1312          gy = stringtoint(ECPoint[int(len(ECPoint)/2):])
1313      elif (ECPoint_type == 0x02) or (ECPoint_type == 0x03):
1314          # Compressed point: uncompress it, see X9.62-1998 section 4.2.1
1315          ECPoint = ECPoint[1:]
1316          gx = stringtoint(ECPoint)
1317          alpha = (pow(gx, 3, prime) + (a * gx) + b) % prime
1318          beta = mod_sqrt(alpha, prime)
1319          if (beta == None) or ((beta == 0) and (alpha != 0)):
1320              return (False, 0)
1321          if (beta & 0x1) == (ECPoint_type & 0x1):
1322              gy = beta
1323          else:
1324              gy = prime - beta
1325      else:
1326          print("DER parse error: hybrid points are unsupported!")
1327          return (False, default_ret)
1328      return (True, (a, b, prime, order, cofactor, gx, gy))
1329  
1330  ##########################################################
1331  ### Text and format helpers
1332  def bigint_to_C_array(bint, size):
1333      """
1334      Format a python big int to a C hex array
1335      """
1336      hexstr = format(int(bint), 'x')
1337      # Left pad to the size!
1338      hexstr = ("0"*int((2*size)-len(hexstr)))+hexstr
1339      hexstr = ("0"*(len(hexstr) % 2))+hexstr
1340      out_str = "{\n"
1341      for i in range(0, len(hexstr) - 1, 2):
1342          if (i%16 == 0):
1343              if(i!=0):
1344                  out_str += "\n"
1345              out_str += "\t"
1346          out_str += "0x"+hexstr[i:i+2]+", "
1347      out_str += "\n};\n"
1348      return out_str
1349  
1350  def check_in_file(fname, pat):
1351      # See if the pattern is in the file.
1352      with open(fname) as f:
1353          if not any(re.search(pat, line) for line in f):
1354              return False # pattern does not occur in file so we are done.
1355          else:
1356              return True
1357  
1358  def num_patterns_in_file(fname, pat):
1359      num_pat = 0
1360      with open(fname) as f:
1361          for line in f:
1362              if re.search(pat, line):
1363                  num_pat = num_pat+1
1364      return num_pat
1365  
1366  def file_replace_pattern(fname, pat, s_after):
1367      # first, see if the pattern is even in the file.
1368      with open(fname) as f:
1369          if not any(re.search(pat, line) for line in f):
1370              return # pattern does not occur in file so we are done.
1371  
1372      # pattern is in the file, so perform replace operation.
1373      with open(fname) as f:
1374          out_fname = fname + ".tmp"
1375          out = open(out_fname, "w")
1376          for line in f:
1377              out.write(re.sub(pat, s_after, line))
1378          out.close()
1379          os.rename(out_fname, fname)
1380  
1381  def file_remove_pattern(fname, pat):
1382      # first, see if the pattern is even in the file.
1383      with open(fname) as f:
1384          if not any(re.search(pat, line) for line in f):
1385              return # pattern does not occur in file so we are done.
1386  
1387      # pattern is in the file, so perform remove operation.
1388      with open(fname) as f:
1389          out_fname = fname + ".tmp"
1390          out = open(out_fname, "w")
1391          for line in f:
1392              if not re.search(pat, line):
1393                  out.write(line)
1394          out.close()
1395  
1396      if os.path.exists(fname):
1397          remove_file(fname)
1398      os.rename(out_fname, fname)
1399  
1400  def remove_file(fname):
1401      # Remove file
1402      os.remove(fname)
1403  
1404  def remove_files_pattern(fpattern):
1405      [remove_file(x) for x in glob.glob(fpattern)]
1406  
1407  def buffer_remove_pattern(buff, pat):
1408      if is_python_2() == False:
1409          buff = buff.decode('latin-1')
1410      if re.search(pat, buff) == None:
1411          return (False, buff) # pattern does not occur in file so we are done.
1412      # Remove the pattern
1413      buff = re.sub(pat, "", buff)
1414      return (True, buff)
1415  
1416  def is_base64(s):
1417      s = ''.join([s.strip() for s in s.split("\n")])
1418      try:
1419          enc = base64.b64encode(base64.b64decode(s)).strip()
1420          if type(enc) is bytes:
1421              return enc == s.encode('latin-1')
1422          else:
1423              return enc == s
1424      except TypeError:
1425          return False
1426  
1427  ### Curve helpers
1428  def export_curve_int(curvename, intname, bigint, size):
1429      if bigint == None:
1430          out  = "static const u8 "+curvename+"_"+intname+"[] = {\n\t0x00,\n};\n"
1431          out += "TO_EC_STR_PARAM_FIXED_SIZE("+curvename+"_"+intname+", 0);\n\n"
1432      else:
1433          out  = "static const u8 "+curvename+"_"+intname+"[] = "+bigint_to_C_array(bigint, size)+"\n"
1434          out += "TO_EC_STR_PARAM("+curvename+"_"+intname+");\n\n"
1435      return out
1436  
1437  def export_curve_string(curvename, stringname, stringvalue):
1438      out  = "static const u8 "+curvename+"_"+stringname+"[] = \""+stringvalue+"\";\n"
1439      out += "TO_EC_STR_PARAM("+curvename+"_"+stringname+");\n\n"
1440      return out
1441  
1442  def export_curve_struct(curvename, paramname, paramnamestr):
1443      return "\t."+paramname+" = &"+curvename+"_"+paramnamestr+"_str_param, \n"
1444  
1445  def curve_params(name, prime, pbitlen, a, b, gx, gy, order, cofactor, oid, alpha_montgomery, gamma_montgomery, alpha_edwards):
1446      """
1447      Take as input some elliptic curve parameters and generate the
1448      C parameters in a string
1449      """
1450      bytesize = int(pbitlen / 8)
1451      if pbitlen % 8 != 0:
1452          bytesize += 1
1453      # Compute the rounded word size for each word size
1454      if bytesize % 8 != 0:
1455          wordsbitsize64 = 8*((int(bytesize/8)+1)*8)
1456      else:
1457          wordsbitsize64 = 8*bytesize
1458      if bytesize % 4 != 0:
1459          wordsbitsize32 = 8*((int(bytesize/4)+1)*4)
1460      else:
1461          wordsbitsize32 = 8*bytesize
1462      if bytesize % 2 != 0:
1463          wordsbitsize16 = 8*((int(bytesize/2)+1)*2)
1464      else:
1465          wordsbitsize16 = 8*bytesize
1466      # Compute some parameters
1467      (r64, r_square64, mpinv64) = compute_monty_coef(prime, wordsbitsize64, 64)
1468      (r32, r_square32, mpinv32) = compute_monty_coef(prime, wordsbitsize32, 32)
1469      (r16, r_square16, mpinv16) = compute_monty_coef(prime, wordsbitsize16, 16)
1470      # Compute p_reciprocal for each word size
1471      (pshift64, primenorm64, p_reciprocal64) = compute_div_coef(prime, wordsbitsize64, 64)
1472      (pshift32, primenorm32, p_reciprocal32) = compute_div_coef(prime, wordsbitsize32, 32)
1473      (pshift16, primenorm16, p_reciprocal16) = compute_div_coef(prime, wordsbitsize16, 16)
1474      # Compute the number of points on the curve
1475      npoints = order * cofactor
1476  
1477      # Now output the parameters
1478      ec_params_string =  "#include <libecc/lib_ecc_config.h>\n"
1479      ec_params_string += "#ifdef WITH_CURVE_"+name.upper()+"\n\n"
1480      ec_params_string += "#ifndef __EC_PARAMS_"+name.upper()+"_H__\n"
1481      ec_params_string += "#define __EC_PARAMS_"+name.upper()+"_H__\n"
1482      ec_params_string += "#include <libecc/curves/known/ec_params_external.h>\n"
1483      ec_params_string += export_curve_int(name, "p", prime, bytesize)
1484  
1485      ec_params_string += "#define CURVE_"+name.upper()+"_P_BITLEN "+str(pbitlen)+"\n"
1486      ec_params_string += export_curve_int(name, "p_bitlen", pbitlen, getbytelen(pbitlen))
1487  
1488      ec_params_string += "#if (WORD_BYTES == 8)     /* 64-bit words */\n"
1489      ec_params_string += export_curve_int(name, "r", r64, getbytelen(r64))
1490      ec_params_string += export_curve_int(name, "r_square", r_square64, getbytelen(r_square64))
1491      ec_params_string += export_curve_int(name, "mpinv", mpinv64, getbytelen(mpinv64))
1492      ec_params_string += export_curve_int(name, "p_shift", pshift64, getbytelen(pshift64))
1493      ec_params_string += export_curve_int(name, "p_normalized", primenorm64, getbytelen(primenorm64))
1494      ec_params_string += export_curve_int(name, "p_reciprocal", p_reciprocal64, getbytelen(p_reciprocal64))
1495      ec_params_string += "#elif (WORD_BYTES == 4)   /* 32-bit words */\n"
1496      ec_params_string += export_curve_int(name, "r", r32, getbytelen(r32))
1497      ec_params_string += export_curve_int(name, "r_square", r_square32, getbytelen(r_square32))
1498      ec_params_string += export_curve_int(name, "mpinv", mpinv32, getbytelen(mpinv32))
1499      ec_params_string += export_curve_int(name, "p_shift", pshift32, getbytelen(pshift32))
1500      ec_params_string += export_curve_int(name, "p_normalized", primenorm32, getbytelen(primenorm32))
1501      ec_params_string += export_curve_int(name, "p_reciprocal", p_reciprocal32, getbytelen(p_reciprocal32))
1502      ec_params_string += "#elif (WORD_BYTES == 2)   /* 16-bit words */\n"
1503      ec_params_string += export_curve_int(name, "r", r16, getbytelen(r16))
1504      ec_params_string += export_curve_int(name, "r_square", r_square16, getbytelen(r_square16))
1505      ec_params_string += export_curve_int(name, "mpinv", mpinv16, getbytelen(mpinv16))
1506      ec_params_string += export_curve_int(name, "p_shift", pshift16, getbytelen(pshift16))
1507      ec_params_string += export_curve_int(name, "p_normalized", primenorm16, getbytelen(primenorm16))
1508      ec_params_string += export_curve_int(name, "p_reciprocal", p_reciprocal16, getbytelen(p_reciprocal16))
1509      ec_params_string += "#else                     /* unknown word size */\n"
1510      ec_params_string += "#error \"Unsupported word size\"\n"
1511      ec_params_string += "#endif\n\n"
1512  
1513      ec_params_string += export_curve_int(name, "a", a, bytesize)
1514      ec_params_string += export_curve_int(name, "b", b, bytesize)
1515  
1516      curve_order_bitlen = getbitlen(npoints)
1517      ec_params_string += "#define CURVE_"+name.upper()+"_CURVE_ORDER_BITLEN "+str(curve_order_bitlen)+"\n"
1518      ec_params_string += export_curve_int(name, "curve_order", npoints, getbytelen(npoints))
1519  
1520      ec_params_string += export_curve_int(name, "gx", gx, bytesize)
1521      ec_params_string += export_curve_int(name, "gy", gy, bytesize)
1522      ec_params_string += export_curve_int(name, "gz", 0x01, bytesize)
1523  
1524      qbitlen = getbitlen(order)
1525  
1526      ec_params_string += export_curve_int(name, "gen_order", order, getbytelen(order))
1527      ec_params_string += "#define CURVE_"+name.upper()+"_Q_BITLEN "+str(qbitlen)+"\n"
1528      ec_params_string += export_curve_int(name, "gen_order_bitlen", qbitlen, getbytelen(qbitlen))
1529  
1530      ec_params_string += export_curve_int(name, "cofactor", cofactor, getbytelen(cofactor))
1531  
1532      ec_params_string += export_curve_int(name, "alpha_montgomery", alpha_montgomery, getbytelen(alpha_montgomery))
1533      ec_params_string += export_curve_int(name, "gamma_montgomery", gamma_montgomery, getbytelen(gamma_montgomery))
1534      ec_params_string += export_curve_int(name, "alpha_edwards", alpha_edwards, getbytelen(alpha_edwards))
1535  
1536      ec_params_string += export_curve_string(name, "name", name.upper());
1537  
1538      if oid == None:
1539          oid = ""
1540      ec_params_string += export_curve_string(name, "oid", oid);
1541  
1542      ec_params_string += "static const ec_str_params "+name+"_str_params = {\n"+\
1543      export_curve_struct(name, "p", "p") +\
1544      export_curve_struct(name, "p_bitlen", "p_bitlen") +\
1545      export_curve_struct(name, "r", "r") +\
1546      export_curve_struct(name, "r_square", "r_square") +\
1547      export_curve_struct(name, "mpinv", "mpinv") +\
1548      export_curve_struct(name, "p_shift", "p_shift") +\
1549      export_curve_struct(name, "p_normalized", "p_normalized") +\
1550      export_curve_struct(name, "p_reciprocal", "p_reciprocal") +\
1551      export_curve_struct(name, "a", "a") +\
1552      export_curve_struct(name, "b", "b") +\
1553      export_curve_struct(name, "curve_order", "curve_order") +\
1554      export_curve_struct(name, "gx", "gx") +\
1555      export_curve_struct(name, "gy", "gy") +\
1556      export_curve_struct(name, "gz", "gz") +\
1557      export_curve_struct(name, "gen_order", "gen_order") +\
1558      export_curve_struct(name, "gen_order_bitlen", "gen_order_bitlen") +\
1559      export_curve_struct(name, "cofactor", "cofactor") +\
1560      export_curve_struct(name, "alpha_montgomery", "alpha_montgomery") +\
1561      export_curve_struct(name, "gamma_montgomery", "gamma_montgomery") +\
1562      export_curve_struct(name, "alpha_edwards", "alpha_edwards") +\
1563      export_curve_struct(name, "oid", "oid") +\
1564      export_curve_struct(name, "name", "name")
1565      ec_params_string += "};\n\n"
1566  
1567      ec_params_string += "/*\n"+\
1568      " * Compute max bit length of all curves for p and q\n"+\
1569      " */\n"+\
1570      "#ifndef CURVES_MAX_P_BIT_LEN\n"+\
1571      "#define CURVES_MAX_P_BIT_LEN    0\n"+\
1572      "#endif\n"+\
1573      "#if (CURVES_MAX_P_BIT_LEN < CURVE_"+name.upper()+"_P_BITLEN)\n"+\
1574      "#undef CURVES_MAX_P_BIT_LEN\n"+\
1575      "#define CURVES_MAX_P_BIT_LEN CURVE_"+name.upper()+"_P_BITLEN\n"+\
1576      "#endif\n"+\
1577      "#ifndef CURVES_MAX_Q_BIT_LEN\n"+\
1578      "#define CURVES_MAX_Q_BIT_LEN    0\n"+\
1579      "#endif\n"+\
1580      "#if (CURVES_MAX_Q_BIT_LEN < CURVE_"+name.upper()+"_Q_BITLEN)\n"+\
1581      "#undef CURVES_MAX_Q_BIT_LEN\n"+\
1582      "#define CURVES_MAX_Q_BIT_LEN CURVE_"+name.upper()+"_Q_BITLEN\n"+\
1583      "#endif\n"+\
1584      "#ifndef CURVES_MAX_CURVE_ORDER_BIT_LEN\n"+\
1585      "#define CURVES_MAX_CURVE_ORDER_BIT_LEN    0\n"+\
1586      "#endif\n"+\
1587      "#if (CURVES_MAX_CURVE_ORDER_BIT_LEN < CURVE_"+name.upper()+"_CURVE_ORDER_BITLEN)\n"+\
1588      "#undef CURVES_MAX_CURVE_ORDER_BIT_LEN\n"+\
1589      "#define CURVES_MAX_CURVE_ORDER_BIT_LEN CURVE_"+name.upper()+"_CURVE_ORDER_BITLEN\n"+\
1590      "#endif\n\n"
1591  
1592      ec_params_string += "/*\n"+\
1593      " * Compute and adapt max name and oid length\n"+\
1594      " */\n"+\
1595      "#ifndef MAX_CURVE_OID_LEN\n"+\
1596      "#define MAX_CURVE_OID_LEN 0\n"+\
1597      "#endif\n"+\
1598      "#ifndef MAX_CURVE_NAME_LEN\n"+\
1599      "#define MAX_CURVE_NAME_LEN 0\n"+\
1600      "#endif\n"+\
1601      "#if (MAX_CURVE_OID_LEN < "+str(len(oid)+1)+")\n"+\
1602      "#undef MAX_CURVE_OID_LEN\n"+\
1603      "#define MAX_CURVE_OID_LEN "+str(len(oid)+1)+"\n"+\
1604      "#endif\n"+\
1605      "#if (MAX_CURVE_NAME_LEN < "+str(len(name.upper())+1)+")\n"+\
1606      "#undef MAX_CURVE_NAME_LEN\n"+\
1607      "#define MAX_CURVE_NAME_LEN "+str(len(name.upper())+1)+"\n"+\
1608      "#endif\n\n"
1609  
1610      ec_params_string += "#endif /* __EC_PARAMS_"+name.upper()+"_H__ */\n\n"+"#endif /* WITH_CURVE_"+name.upper()+" */\n"
1611  
1612      return ec_params_string
1613  
1614  def usage():
1615      print("This script is intented to *statically* expand the ECC library with user defined curves.")
1616      print("By statically we mean that the source code of libecc is expanded with new curves parameters through")
1617      print("automatic code generation filling place holders in the existing code base of the library. Though the")
1618      print("choice of static code generation versus dynamic curves import (such as what OpenSSL does) might be")
1619      print("argued, this choice has been driven by simplicity and security design decisions: we want libecc to have")
1620      print("all its parameters (such as memory consumption) set at compile time and statically adapted to the curves.")
1621      print("Since libecc only supports curves over prime fields, the script can only add this kind of curves.")
1622      print("This script implements elliptic curves and ISO signature algorithms from scratch over Python's multi-precision")
1623      print("big numbers library. Addition and doubling over curves use naive formulas. Please DO NOT use the functions of this")
1624      print("script for production code: they are not securely implemented and are very inefficient. Their only purpose is to expand")
1625      print("libecc and produce test vectors.")
1626      print("")
1627      print("In order to add a curve, there are two ways:")
1628      print("Adding a user defined curve with explicit parameters:")
1629      print("-----------------------------------------------------")
1630      print(sys.argv[0]+" --name=\"YOURCURVENAME\" --prime=... --order=... --a=... --b=... --gx=... --gy=... --cofactor=... --oid=THEOID")
1631      print("\t> name: name of the curve in the form of a string")
1632      print("\t> prime: prime number representing the curve prime field")
1633      print("\t> order: prime number representing the generator order")
1634      print("\t> cofactor: cofactor of the curve")
1635      print("\t> a: 'a' coefficient of the short Weierstrass equation of the curve")
1636      print("\t> b: 'b' coefficient of the short Weierstrass equation of the curve")
1637      print("\t> gx: x coordinate of the generator G")
1638      print("\t> gy: y coordinate of the generator G")
1639      print("\t> oid: optional OID of the curve")
1640      print("  Notes:")
1641      print("  ******")
1642      print("\t1) These elements are verified to indeed satisfy the curve equation.")
1643      print("\t2) All the numbers can be given either in decimal or hexadecimal format with a prepending '0x'.")
1644      print("\t3) The script automatically generates all the necessary files for the curve to be included in the library." )
1645      print("\tYou will find the new curve definition in the usual 'lib_ecc_config.h' file (one can activate it or not at compile time).")
1646      print("")
1647      print("Adding a user defined curve through RFC3279 ASN.1 parameters:")
1648      print("-------------------------------------------------------------")
1649      print(sys.argv[0]+" --name=\"YOURCURVENAME\" --ECfile=... --oid=THEOID")
1650      print("\t> ECfile: the DER or PEM encoded file containing the curve parameters (see RFC3279)")
1651      print("  Notes:")
1652      print("\tCurve parameters encoded in DER or PEM format can be generated with tools like OpenSSL (among others). As an illustrative example,")
1653      print("\tone can list all the supported curves under OpenSSL with:")
1654      print("\t  $ openssl ecparam -list_curves")
1655      print("\tOnly the listed so called \"prime\" curves are supported. Then, one can extract an explicit curve representation in ASN.1")
1656      print("\tas defined in RFC3279, for example for BRAINPOOLP320R1:")
1657      print("\t  $ openssl ecparam -param_enc explicit -outform DER -name brainpoolP320r1 -out brainpoolP320r1.der")
1658      print("")
1659      print("Removing user defined curves:")
1660      print("-----------------------------")
1661      print("\t*All the user defined curves can be removed with the --remove-all toggle.")
1662      print("\t*A specific named user define curve can be removed with the --remove toggle: in this case the --name option is used to ")
1663      print("\tlocate which named curve must be deleted.")
1664      print("")
1665      print("Test vectors:")
1666      print("-------------")
1667      print("\tTest vectors can be automatically generated and added to the library self tests when providing the --add-test-vectors=X toggle.")
1668      print("\tIn this case, X test vectors will be generated for *each* (curve, sign algorithm, hash algorithm) 3-uplet (beware of combinatorial")
1669      print("\tissues when X is big!). These tests are transparently added and compiled with the self tests.")
1670      return
1671  
1672  def get_int(instring):
1673      if len(instring) == 0:
1674          return 0
1675      if len(instring) >= 2:
1676          if instring[:2] == "0x":
1677              return int(instring, 16)
1678      return int(instring)
1679  
1680  def parse_cmd_line(args):
1681      """
1682      Get elliptic curve parameters from command line
1683      """
1684      name = oid = prime = a = b = gx = gy = g = order = cofactor = ECfile = remove = remove_all = add_test_vectors = None
1685      alpha_montgomery = gamma_montgomery = alpha_edwards = None
1686      try:
1687          opts, args = getopt.getopt(sys.argv[1:], ":h", ["help", "remove", "remove-all", "name=", "prime=", "a=", "b=", "generator=", "gx=", "gy=", "order=", "cofactor=", "alpha_montgomery=","gamma_montgomery=", "alpha_edwards=", "ECfile=", "oid=", "add-test-vectors="])
1688      except getopt.GetoptError as err:
1689          # print help information and exit:
1690          print(err) # will print something like "option -a not recognized"
1691          usage()
1692          return False
1693      for o, arg in opts:
1694          if o in ("-h", "--help"):
1695              usage()
1696              return True
1697          elif o in ("--name"):
1698              name = arg
1699              # Prepend the custom string before name to avoid any collision
1700              name = "user_defined_"+name
1701              # Replace any unwanted name char
1702              name = re.sub("\-", "_", name)
1703          elif o in ("--oid="):
1704              oid = arg
1705          elif o in ("--prime"):
1706              prime = get_int(arg.replace(' ', ''))
1707          elif o in ("--a"):
1708              a = get_int(arg.replace(' ', ''))
1709          elif o in ("--b"):
1710              b = get_int(arg.replace(' ', ''))
1711          elif o in ("--gx"):
1712              gx = get_int(arg.replace(' ', ''))
1713          elif o in ("--gy"):
1714              gy = get_int(arg.replace(' ', ''))
1715          elif o in ("--generator"):
1716              g = arg.replace(' ', '')
1717          elif o in ("--order"):
1718              order = get_int(arg.replace(' ', ''))
1719          elif o in ("--cofactor"):
1720              cofactor = get_int(arg.replace(' ', ''))
1721          elif o in ("--alpha_montgomery"):
1722              alpha_montgomery = get_int(arg.replace(' ', ''))
1723          elif o in ("--gamma_montgomery"):
1724              gamma_montgomery = get_int(arg.replace(' ', ''))
1725          elif o in ("--alpha_edwards"):
1726              alpha_edwards = get_int(arg.replace(' ', ''))
1727          elif o in ("--remove"):
1728              remove = True
1729          elif o in ("--remove-all"):
1730              remove_all = True
1731          elif o in ("--add-test-vectors"):
1732              add_test_vectors = get_int(arg.replace(' ', ''))
1733          elif o in ("--ECfile"):
1734              ECfile = arg
1735          else:
1736              print("unhandled option")
1737              usage()
1738              return False
1739  
1740      # File paths
1741      script_path = os.path.abspath(os.path.dirname(sys.argv[0])) + "/"
1742      ec_params_path = script_path + "../include/libecc/curves/user_defined/"
1743      curves_list_path = script_path + "../include/libecc/curves/"
1744      lib_ecc_types_path = script_path + "../include/libecc/"
1745      lib_ecc_config_path = script_path + "../include/libecc/"
1746      ec_self_tests_path = script_path + "../src/tests/"
1747      meson_options_path = script_path + "../"
1748  
1749      # If remove is True, we have been asked to remove already existing user defined curves
1750      if remove == True:
1751          if name == None:
1752              print("--remove option expects a curve name provided with --name")
1753              return False
1754          asked = ""
1755          while asked != "y" and asked != "n":
1756              asked = get_user_input("You asked to remove everything related to user defined "+name.replace("user_defined_", "")+" curve. Enter y to confirm, n to cancel [y/n]. ")
1757          if asked == "n":
1758              print("NOT removing curve "+name.replace("user_defined_", "")+" (cancelled).")
1759              return True
1760          # Remove any user defined stuff with given name
1761          print("Removing user defined curve "+name.replace("user_defined_", "")+" ...")
1762          if name == None:
1763              print("Error: you must provide a curve name with --remove")
1764              return False
1765          file_remove_pattern(curves_list_path + "curves_list.h", ".*"+name+".*")
1766          file_remove_pattern(curves_list_path + "curves_list.h", ".*"+name.upper()+".*")
1767          file_remove_pattern(lib_ecc_types_path + "lib_ecc_types.h", ".*"+name.upper()+".*")
1768          file_remove_pattern(lib_ecc_config_path + "lib_ecc_config.h", ".*"+name.upper()+".*")
1769          file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*"+name+".*")
1770          file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*"+name.upper()+".*")
1771          file_remove_pattern(meson_options_path + "meson.options", ".*"+name.lower()+".*")
1772          try:
1773              remove_file(ec_params_path + "ec_params_"+name+".h")
1774          except:
1775              print("Error: curve name "+name+" does not seem to be present in the sources!")
1776              return False
1777          try:
1778              remove_file(ec_self_tests_path + "ec_self_tests_core_"+name+".h")
1779          except:
1780              print("Warning: curve name "+name+" self tests do not seem to be present ...")
1781              return True
1782          return True
1783      if remove_all == True:
1784          asked = ""
1785          while asked != "y" and asked != "n":
1786              asked = get_user_input("You asked to remove everything related to ALL user defined curves. Enter y to confirm, n to cancel [y/n]. ")
1787          if asked == "n":
1788              print("NOT removing user defined curves (cancelled).")
1789              return True
1790          # Remove any user defined stuff with given name
1791          print("Removing ALL user defined curves ...")
1792          # Remove any user defined stuff (whatever name)
1793          file_remove_pattern(curves_list_path + "curves_list.h", ".*user_defined.*")
1794          file_remove_pattern(curves_list_path + "curves_list.h", ".*USER_DEFINED.*")
1795          file_remove_pattern(lib_ecc_types_path + "lib_ecc_types.h", ".*USER_DEFINED.*")
1796          file_remove_pattern(lib_ecc_config_path + "lib_ecc_config.h", ".*USER_DEFINED.*")
1797          file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*USER_DEFINED.*")
1798          file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*user_defined.*")
1799          file_remove_pattern(meson_options_path + "meson.options", ".*user_defined.*")
1800          remove_files_pattern(ec_params_path + "ec_params_user_defined_*.h")
1801          remove_files_pattern(ec_self_tests_path + "ec_self_tests_core_user_defined_*.h")
1802          return True
1803  
1804      # If a g is provided, split it in two gx and gy
1805      if g != None:
1806          if (len(g)/2)%2 == 0:
1807              gx = get_int(g[:len(g)/2])
1808              gy = get_int(g[len(g)/2:])
1809          else:
1810              # This is probably a generator encapsulated in a bit string
1811              if g[0:2] != "04":
1812                  print("Error: provided generator g is not conforming!")
1813                  return False
1814              else:
1815                  g = g[2:]
1816                  gx = get_int(g[:len(g)/2])
1817                  gy = get_int(g[len(g)/2:])
1818      if ECfile != None:
1819          # ASN.1 DER input incompatible with other options
1820          if (prime != None) or (a != None) or (b != None) or (gx != None) or (gy != None) or (order != None) or (cofactor != None):
1821              print("Error: option ECfile incompatible with explicit (prime, a, b, gx, gy, order, cofactor) options!")
1822              return False
1823          # We need at least a name
1824          if (name == None):
1825              print("Error: option ECfile needs a curve name!")
1826              return False
1827          # Open the file
1828          try:
1829              buf = open(ECfile, 'rb').read()
1830          except:
1831              print("Error: cannot open ECfile file "+ECfile)
1832              return False
1833          # Check if we have a PEM or a DER file
1834          (check, derbuf) = buffer_remove_pattern(buf, "-----.*-----")
1835          if (check == True):
1836              # This a PEM file, proceed with base64 decoding
1837              if(is_base64(derbuf) == False):
1838                  print("Error: error when decoding ECfile file "+ECfile+" (seems to be PEM, but failed to decode)")
1839                  return False
1840              derbuf = base64.b64decode(derbuf)
1841          (check, (a, b, prime, order, cofactor, gx, gy)) = parse_DER_ECParameters(derbuf)
1842          if (check == False):
1843              print("Error: error when parsing ECfile file "+ECfile+" (malformed or unsupported ASN.1)")
1844              return False
1845  
1846      else:
1847          if (prime == None) or (a == None) or (b == None) or (gx == None) or (gy == None) or (order == None) or (cofactor == None) or (name == None):
1848              err_string = (prime == None)*"prime "+(a == None)*"a "+(b == None)*"b "+(gx == None)*"gx "+(gy == None)*"gy "+(order == None)*"order "+(cofactor == None)*"cofactor "+(name == None)*"name "
1849              print("Error: missing "+err_string+" in explicit curve definition (name, prime, a, b, gx, gy, order, cofactor)!")
1850              print("See the help with -h or --help")
1851              return False
1852  
1853      # Some sanity checks here
1854      # Check that prime is indeed a prime
1855      if is_probprime(prime) == False:
1856          print("Error: given prime is *NOT* prime!")
1857          return False
1858      if is_probprime(order) == False:
1859          print("Error: given order is *NOT* prime!")
1860          return False
1861      if (a > prime) or (b > prime) or (gx > prime) or (gy > prime):
1862          err_string = (a > prime)*"a "+(b > prime)*"b "+(gx > prime)*"gx "+(gy > prime)*"gy "
1863          print("Error: "+err_string+"is > prime")
1864          return False
1865      # Check that the provided generator is on the curve
1866      if pow(gy, 2, prime) != ((pow(gx, 3, prime) + (a*gx) + b) % prime):
1867          print("Error: the given parameters (prime, a, b, gx, gy) do not verify the elliptic curve equation!")
1868          return False
1869  
1870      # Check Montgomery and Edwards transfer coefficients
1871      if ((alpha_montgomery != None) and (gamma_montgomery == None)) or ((alpha_montgomery == None) and (gamma_montgomery != None)):
1872          print("Error: alpha_montgomery and gamma_montgomery must be both defined if used!")
1873          return False
1874      if (alpha_edwards != None):
1875          if (alpha_montgomery == None) or (gamma_montgomery == None):
1876              print("Error: alpha_edwards needs alpha_montgomery and gamma_montgomery to be both defined if used!")
1877              return False
1878  
1879      # Now that we have our parameters, call the function to get bitlen
1880      pbitlen = getbitlen(prime)
1881      ec_params = curve_params(name, prime, pbitlen, a, b, gx, gy, order, cofactor, oid, alpha_montgomery, gamma_montgomery, alpha_edwards)
1882      # Check if there is a name collision somewhere
1883      if os.path.exists(ec_params_path + "ec_params_"+name+".h") == True :
1884          print("Error: file %s already exists!" % (ec_params_path + "ec_params_"+name+".h"))
1885          return False
1886      if (check_in_file(curves_list_path + "curves_list.h", "ec_params_"+name+"_str_params") == True) or (check_in_file(curves_list_path + "curves_list.h", "WITH_CURVE_"+name.upper()+"\n") == True) or (check_in_file(lib_ecc_types_path + "lib_ecc_types.h", "WITH_CURVE_"+name.upper()+"\n") == True):
1887          print("Error: name %s already exists in files" % ("ec_params_"+name))
1888          return False
1889      # Create a new file with the parameters
1890      if not os.path.exists(ec_params_path):
1891          # Create the "user_defined" folder if it does not exist
1892          os.mkdir(ec_params_path)
1893      f = open(ec_params_path + "ec_params_"+name+".h", 'w')
1894      f.write(ec_params)
1895      f.close()
1896      # Include the file in curves_list.h
1897      magic = "ADD curves header here"
1898      magic_re = "\/\* "+magic+" \*\/"
1899      magic_back = "/* "+magic+" */"
1900      file_replace_pattern(curves_list_path + "curves_list.h", magic_re, "#include <libecc/curves/user_defined/ec_params_"+name+".h>\n"+magic_back)
1901      # Add the curve mapping
1902      magic = "ADD curves mapping here"
1903      magic_re = "\/\* "+magic+" \*\/"
1904      magic_back = "/* "+magic+" */"
1905      file_replace_pattern(curves_list_path + "curves_list.h", magic_re, "#ifdef WITH_CURVE_"+name.upper()+"\n\t{ .type = "+name.upper()+", .params = &"+name+"_str_params },\n#endif /* WITH_CURVE_"+name.upper()+" */\n"+magic_back)
1906      # Add the new curve type in the enum
1907      # First we get the number of already defined curves so that we increment the enum counter
1908      num_with_curve = num_patterns_in_file(lib_ecc_types_path + "lib_ecc_types.h", "#ifdef WITH_CURVE_")
1909      magic = "ADD curves type here"
1910      magic_re = "\/\* "+magic+" \*\/"
1911      magic_back = "/* "+magic+" */"
1912      file_replace_pattern(lib_ecc_types_path + "lib_ecc_types.h", magic_re, "#ifdef WITH_CURVE_"+name.upper()+"\n\t"+name.upper()+" = "+str(num_with_curve+1)+",\n#endif /* WITH_CURVE_"+name.upper()+" */\n"+magic_back)
1913      # Add the new curve define in the config
1914      magic = "ADD curves define here"
1915      magic_re = "\/\* "+magic+" \*\/"
1916      magic_back = "/* "+magic+" */"
1917      file_replace_pattern(lib_ecc_config_path + "lib_ecc_config.h", magic_re, "#define WITH_CURVE_"+name.upper()+"\n"+magic_back)
1918      # Add the new curve meson option in the meson.options file
1919      magic = "ADD curves meson option here"
1920      magic_re = "# " + magic
1921      magic_back = "# " + magic
1922      file_replace_pattern(meson_options_path + "meson.options", magic_re, "\t'"+name.lower()+"',\n"+magic_back)
1923  
1924      # Do we need to add some test vectors?
1925      if add_test_vectors != None:
1926          print("Test vectors generation asked: this can take some time! Please wait ...")
1927          # Create curve
1928          c = Curve(a, b, prime, order, cofactor, gx, gy, cofactor * order, name, oid)
1929          # Generate key pair for the algorithm
1930          vectors = gen_self_tests(c, add_test_vectors)
1931          # Iterate through all the tests
1932          f = open(ec_self_tests_path + "ec_self_tests_core_"+name+".h", 'w')
1933          for l in vectors:
1934              for v in l:
1935                  for case in v:
1936                      (case_name, case_vector) = case
1937                      # Add the new test case
1938                      magic = "ADD curve test case here"
1939                      magic_re = "\/\* "+magic+" \*\/"
1940                      magic_back = "/* "+magic+" */"
1941                      file_replace_pattern(ec_self_tests_path + "ec_self_tests_core.h", magic_re, case_name+"\n"+magic_back)
1942                      # Create/Increment the header file
1943                      f.write(case_vector)
1944          f.close()
1945          # Add the new test cases header
1946          magic = "ADD curve test vectors header here"
1947          magic_re = "\/\* "+magic+" \*\/"
1948          magic_back = "/* "+magic+" */"
1949          file_replace_pattern(ec_self_tests_path + "ec_self_tests_core.h", magic_re, "#include \"ec_self_tests_core_"+name+".h\"\n"+magic_back)
1950      return True
1951  
1952  
1953  #### Main
1954  if __name__ == "__main__":
1955      signal.signal(signal.SIGINT, handler)
1956      parse_cmd_line(sys.argv[1:])