/ sage / derive_endomorphisms.sage
derive_endomorphisms.sage
  1  #!/usr/bin/sage
  2  # vim: syntax=python
  3  # vim: set ts=2 sw=2 et:
  4  
  5  # Constantine
  6  # Copyright (c) 2018-2019    Status Research & Development GmbH
  7  # Copyright (c) 2020-Present Mamy André-Ratsimbazafy
  8  # Licensed and distributed under either of
  9  #   * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
 10  #   * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
 11  # at your option. This file may not be copied, modified, or distributed except according to those terms.
 12  
 13  # ############################################################
 14  #
 15  #             Endomorphism acceleration constants
 16  #
 17  # ############################################################
 18  
 19  # Imports
 20  # ---------------------------------------------------------
 21  
 22  import os
 23  import inspect, textwrap
 24  
 25  # Working directory
 26  # ---------------------------------------------------------
 27  
 28  os.chdir(os.path.dirname(__file__))
 29  
 30  # Sage imports
 31  # ---------------------------------------------------------
 32  # Accelerate arithmetic by accepting probabilistic proofs
 33  from sage.structure.proof.all import arithmetic
 34  arithmetic(False)
 35  
 36  load('curves.sage')
 37  
 38  # Utilities
 39  # ---------------------------------------------------------
 40  
 41  def fp2_to_hex(a):
 42    v = vector(a)
 43    return '0x' + Integer(v[0]).hex() + ' + β * ' + '0x' + Integer(v[1]).hex()
 44  
 45  def pretty_print_lattice(Lat):
 46    print('Lattice:')
 47    latHex = [['0x' + x.hex() if x >= 0 else '-0x' + (-x).hex() for x in vec] for vec in Lat]
 48    maxlen = max([len(cell) for row in latHex for cell in row])
 49    for row in latHex:
 50      row = ' '.join(cell.rjust(maxlen + 2) for cell in row)
 51      print(row)
 52  
 53  def pretty_print_babai(Basis):
 54    print('Babai:')
 55    for i, v in enumerate(Basis):
 56      if v < 0:
 57        print(f'  𝛼\u0305{i}: -0x{Integer(int(-v)).hex()}')
 58      else:
 59        print(f'  𝛼\u0305{i}:  0x{Integer(int(v)).hex()}')
 60  
 61  def derive_lattice(r, lambdaR, m):
 62    lat = Matrix(matrix.identity(m))
 63    lat[0, 0] = r
 64    for i in range(1, m):
 65       lat[i, 0] = -lambdaR^i
 66  
 67    return lat.LLL()
 68  
 69  def derive_babai(r, lattice, m):
 70    basis = m * [0]
 71    basis[0] = r
 72  
 73    ahat = vector(basis) * lattice.inverse()
 74    v = int(r).bit_length()
 75    v = int(((v + 64 - 1) // 64) * 64)
 76  
 77    return [(a << v) // r for a in ahat]
 78  
 79  # TODO: maximum infinity norm
 80  
 81  # G1 Endomorphism
 82  # ---------------------------------------------------------
 83  
 84  def check_cubic_root_endo(G1, Fp, r, cofactor, lambdaR, phiP):
 85    ## Check the Endomorphism for p mod 3 == 1
 86    ## Endomorphism can be field multiplication by one of the non-trivial cube root of unity 𝜑
 87    ##   Rationale:
 88    ##     curve equation is y² = x³ + b, and y² = (x𝜑)³ + b <=> y² = x³ + b (with 𝜑³ == 1) so we are still on the curve
 89    ##     this means that multiplying by 𝜑 the x-coordinate is equivalent to a scalar multiplication by some λᵩ
 90    ##     with λᵩ² + λᵩ + 1 ≡ 0 (mod r) and 𝜑² + 𝜑 + 1 ≡ 0 (mod p), see below.
 91    ##     Hence we have a 2 dimensional decomposition of the scalar multiplication
 92    ##     i.e. For any [s]P, we can find a corresponding [k1]P + [k2][λᵩ]P with [λᵩ]P being a simple field multiplication by 𝜑
 93    ##   Finding cube roots:
 94    ##      x³−1=0 <=> (x−1)(x²+x+1) = 0, if x != 1, x solves (x²+x+1) = 0 <=> x = (-1±√3)/2
 95  
 96    assert phiP^3 == Fp(1)
 97    assert lambdaR^3 % r == 1
 98  
 99    Prand = G1.random_point()
100    P = Prand * cofactor
101    assert P != G1([0, 1, 0])
102  
103    (Px, Py, Pz) = P
104  
105    Qendo = G1([Px*phiP, Py, Pz])
106    Qlambda = lambdaR * P
107  
108    assert P != Qendo
109    assert P != Qlambda
110  
111    assert Qendo == Qlambda
112    print('Endomorphism OK')
113  
114  def genCubicRootEndo(curve_name, curve_config):
115    p = curve_config[curve_name]['field']['modulus']
116    r = curve_config[curve_name]['field']['order']
117    b = curve_config[curve_name]['curve']['b']
118  
119    print('Constructing G1')
120    Fp = GF(p)
121    G1 = EllipticCurve(Fp, [0, b])
122    print('Computing cofactor')
123    cofactor = G1.order() // r
124    print('cofactor: 0x' + Integer(cofactor).hex())
125  
126    # slow for large inputs - https://pari.math.u-bordeaux.fr/archives/pari-dev-0412/msg00020.html
127    if curve_name != 'BW6_761':
128      print('Finding cube roots')
129      (phi1, phi2) = (Fp(root) for root in Fp(1).nth_root(3, all=True) if root != 1)
130      (lambda1, lambda2) = (GF(r)(root) for root in GF(r)(1).nth_root(3, all=True) if root != 1)
131    else:
132      print('Skip finding cube roots for BW6_761, too slow, use values from paper https://eprint.iacr.org/2020/351')
133      phi1 = Integer('0x531dc16c6ecd27aa846c61024e4cca6c1f31e53bd9603c2d17be416c5e4426ee4a737f73b6f952ab5e57926fa701848e0a235a0a398300c65759fc45183151f2f082d4dcb5e37cb6290012d96f8819c547ba8a4000002f962140000000002a')
134      phi2 = Integer('0xcfca638f1500e327035cdf02acb2744d06e68545f7e64c256ab7ae14297a1a823132b971cdefc65870636cb60d217ff87fa59308c07a8fab8579e02ed3cddca5b093ed79b1c57b5fe3f89c11811c1e214983de300000535e7bc00000000060')
135      lambda1 = Integer('0x9b3af05dd14f6ec619aaf7d34594aabc5ed1347970dec00452217cc900000008508c00000000001')
136      lambda2 = Integer('-0x9b3af05dd14f6ec619aaf7d34594aabc5ed1347970dec00452217cc900000008508c00000000002')
137  
138    print('𝜑1 (mod p):  0x' + Integer(phi1).hex())
139    print('λᵩ1 (mod r): 0x' + Integer(lambda1).hex())
140    print('𝜑2 (mod p):  0x' + Integer(phi2).hex())
141    print('λᵩ2 (mod r): 0x' + Integer(lambda2).hex())
142  
143    # TODO: is there a better way than spray-and-pray?
144    # TODO: Should we maximize or minimize lambda
145    #       to maximize/minimize the scalar norm?
146    # TODO: Or is there a way to ensure
147    #       that the Babai basis is mostly positive?
148    if lambda1 < lambda2:
149      lambda1, lambda2 = lambda2, lambda1
150  
151    try:
152      check_cubic_root_endo(G1, Fp, r, cofactor, int(lambda1), phi1)
153    except:
154      print('Failure with:')
155      print('  𝜑 (mod p): 0x' + Integer(phi1).hex())
156      print('  λᵩ (mod r): 0x' + Integer(lambda1).hex())
157      phi1, phi2 = phi2, phi1
158      check_cubic_root_endo(G1, Fp, r, cofactor, int(lambda1), phi1)
159    finally:
160      print('Success with:')
161      print('  𝜑 (mod p):  0x' + Integer(phi1).hex())
162      print('  λᵩ (mod r): 0x' + Integer(lambda1).hex())
163  
164    print('Deriving Lattice')
165    lattice = derive_lattice(r, lambda1, 2)
166    pretty_print_lattice(lattice)
167  
168    print('Deriving Babai basis')
169    babai = derive_babai(r, lattice, 2)
170    pretty_print_babai(babai)
171  
172    return phi1, lattice, babai
173  
174  # G2 Endomorphism
175  # ---------------------------------------------------------
176  
177  def genPsiEndo(curve_name, curve_config):
178    t = curve_config[curve_name]['field']['trace']
179    r = curve_config[curve_name]['field']['order']
180    k = curve_config[curve_name]['tower']['embedding_degree']
181  
182    # Decomposition factor depends on the embedding degree
183    m = CyclotomicField(k).degree()
184    # λψ is the trace of Frobenius - 1
185    lambda_psi = t - 1
186  
187    print('Deriving Lattice')
188    lattice = derive_lattice(r, lambda_psi, m)
189    pretty_print_lattice(lattice)
190  
191    print('Deriving Babai basis')
192    babai = derive_babai(r, lattice, m)
193    pretty_print_babai(babai)
194  
195    return lattice, babai
196  
197  # Dump
198  # ---------------------------------------------------------
199  
200  def dumpLattice(lattice):
201    result = '  # (BigInt, isNeg)\n'
202    lastRow = lattice.nrows() - 1
203    lastCol = lattice.ncols() - 1
204  
205    for rowID, row in enumerate(lattice):
206      for colID, val in enumerate(row):
207        result += '  '
208        result += '(' if colID == 0 else ' '
209        result += f'(BigInt[{max(1, int(abs(val)).bit_length())}].fromHex"0x{Integer(int(abs(val))).hex()}", '
210        result += ('false' if val >= 0 else 'true') + ')'
211        result += ')' if colID == lastCol else ''
212        result += ',\n' if (rowID != lastRow or colID != lastCol) else '\n'
213  
214    return result
215  
216  def dumpBabai(vec):
217    result = '  # (BigInt, isNeg)\n'
218    lastRow = len(vec) - 1
219  
220    for rowID, val in enumerate(vec):
221      result += '  '
222      result += f'(BigInt[{max(1, int(abs(val)).bit_length())}].fromHex"0x{Integer(int(abs(val))).hex()}", '
223      result += ('false' if val >= 0 else 'true') + ')'
224      result += ',\n' if rowID != lastRow else '\n'
225  
226    return result
227  
228  def dumpConst(name, inner):
229    result = f'const {name}* = (\n'
230    result += inner
231    result += ')\n'
232  
233    return result
234  
235  # CLI
236  # ---------------------------------------------------------
237  
238  if __name__ == "__main__":
239    # Usage
240    # BLS12-381
241    # sage sage/derive_endomorphisms.sage BLS12_381
242  
243    from argparse import ArgumentParser
244  
245    parser = ArgumentParser()
246    parser.add_argument("curve",nargs="+")
247    args = parser.parse_args()
248  
249    curve = args.curve[0]
250  
251    if curve not in Curves:
252      raise ValueError(
253        curve +
254        ' is not one of the available curves: ' +
255        str(Curves.keys())
256      )
257    else:
258      print('\nPrecomputing G1 - 𝜑 (phi) cubic root endomorphism')
259      print('----------------------------------------------------\n')
260      cubeRootModP, g1lat, g1babai = genCubicRootEndo(curve, Curves)
261      print('\n\nPrecomputing G2 - ψ (Psi) - untwist-Frobenius-twist endomorphism')
262      print('----------------------------------------------------\n')
263      g2lat, g2babai = genPsiEndo(curve, Curves)
264  
265      with open(f'{curve.lower()}_endomorphisms.nim', 'w') as f:
266        f.write(copyright())
267        f.write('\n\n')
268        f.write(inspect.cleandoc(f"""
269          import
270            ../config/curves,
271            ../io/[io_bigints, io_fields]
272  
273          # {curve} G1
274          # ------------------------------------------------------------
275        """))
276        f.write('\n\n')
277        f.write(inspect.cleandoc(f"""
278          const {curve}_cubicRootOfUnity_mod_p* =
279            Fp[{curve}].fromHex"0x{Integer(cubeRootModP).hex()}"
280        """))
281        f.write('\n\n')
282        f.write(dumpConst(
283          f'{curve}_Lattice_G1',
284          dumpLattice(g1lat)
285        ))
286        f.write('\n')
287        f.write(dumpConst(
288          f'{curve}_Babai_G1',
289          dumpBabai(g1babai)
290        ))
291        f.write('\n\n')
292        f.write(inspect.cleandoc(f"""
293          # {curve} G2
294          # ------------------------------------------------------------
295        """))
296        f.write('\n\n')
297        f.write(dumpConst(
298          f'{curve}_Lattice_G2',
299          dumpLattice(g2lat)
300        ))
301        f.write('\n')
302        f.write(dumpConst(
303          f'{curve}_Babai_G2',
304          dumpBabai(g2babai)
305        ))