derive_endomorphisms.sage
1 #!/usr/bin/sage 2 # vim: syntax=python 3 # vim: set ts=2 sw=2 et: 4 5 # Constantine 6 # Copyright (c) 2018-2019 Status Research & Development GmbH 7 # Copyright (c) 2020-Present Mamy André-Ratsimbazafy 8 # Licensed and distributed under either of 9 # * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). 10 # * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). 11 # at your option. This file may not be copied, modified, or distributed except according to those terms. 12 13 # ############################################################ 14 # 15 # Endomorphism acceleration constants 16 # 17 # ############################################################ 18 19 # Imports 20 # --------------------------------------------------------- 21 22 import os 23 import inspect, textwrap 24 25 # Working directory 26 # --------------------------------------------------------- 27 28 os.chdir(os.path.dirname(__file__)) 29 30 # Sage imports 31 # --------------------------------------------------------- 32 # Accelerate arithmetic by accepting probabilistic proofs 33 from sage.structure.proof.all import arithmetic 34 arithmetic(False) 35 36 load('curves.sage') 37 38 # Utilities 39 # --------------------------------------------------------- 40 41 def fp2_to_hex(a): 42 v = vector(a) 43 return '0x' + Integer(v[0]).hex() + ' + β * ' + '0x' + Integer(v[1]).hex() 44 45 def pretty_print_lattice(Lat): 46 print('Lattice:') 47 latHex = [['0x' + x.hex() if x >= 0 else '-0x' + (-x).hex() for x in vec] for vec in Lat] 48 maxlen = max([len(cell) for row in latHex for cell in row]) 49 for row in latHex: 50 row = ' '.join(cell.rjust(maxlen + 2) for cell in row) 51 print(row) 52 53 def pretty_print_babai(Basis): 54 print('Babai:') 55 for i, v in enumerate(Basis): 56 if v < 0: 57 print(f' 𝛼\u0305{i}: -0x{Integer(int(-v)).hex()}') 58 else: 59 print(f' 𝛼\u0305{i}: 0x{Integer(int(v)).hex()}') 60 61 def derive_lattice(r, lambdaR, m): 62 lat = Matrix(matrix.identity(m)) 63 lat[0, 0] = r 64 for i in range(1, m): 65 lat[i, 0] = -lambdaR^i 66 67 return lat.LLL() 68 69 def derive_babai(r, lattice, m): 70 basis = m * [0] 71 basis[0] = r 72 73 ahat = vector(basis) * lattice.inverse() 74 v = int(r).bit_length() 75 v = int(((v + 64 - 1) // 64) * 64) 76 77 return [(a << v) // r for a in ahat] 78 79 # TODO: maximum infinity norm 80 81 # G1 Endomorphism 82 # --------------------------------------------------------- 83 84 def check_cubic_root_endo(G1, Fp, r, cofactor, lambdaR, phiP): 85 ## Check the Endomorphism for p mod 3 == 1 86 ## Endomorphism can be field multiplication by one of the non-trivial cube root of unity 𝜑 87 ## Rationale: 88 ## curve equation is y² = x³ + b, and y² = (x𝜑)³ + b <=> y² = x³ + b (with 𝜑³ == 1) so we are still on the curve 89 ## this means that multiplying by 𝜑 the x-coordinate is equivalent to a scalar multiplication by some λᵩ 90 ## with λᵩ² + λᵩ + 1 ≡ 0 (mod r) and 𝜑² + 𝜑 + 1 ≡ 0 (mod p), see below. 91 ## Hence we have a 2 dimensional decomposition of the scalar multiplication 92 ## i.e. For any [s]P, we can find a corresponding [k1]P + [k2][λᵩ]P with [λᵩ]P being a simple field multiplication by 𝜑 93 ## Finding cube roots: 94 ## x³−1=0 <=> (x−1)(x²+x+1) = 0, if x != 1, x solves (x²+x+1) = 0 <=> x = (-1±√3)/2 95 96 assert phiP^3 == Fp(1) 97 assert lambdaR^3 % r == 1 98 99 Prand = G1.random_point() 100 P = Prand * cofactor 101 assert P != G1([0, 1, 0]) 102 103 (Px, Py, Pz) = P 104 105 Qendo = G1([Px*phiP, Py, Pz]) 106 Qlambda = lambdaR * P 107 108 assert P != Qendo 109 assert P != Qlambda 110 111 assert Qendo == Qlambda 112 print('Endomorphism OK') 113 114 def genCubicRootEndo(curve_name, curve_config): 115 p = curve_config[curve_name]['field']['modulus'] 116 r = curve_config[curve_name]['field']['order'] 117 b = curve_config[curve_name]['curve']['b'] 118 119 print('Constructing G1') 120 Fp = GF(p) 121 G1 = EllipticCurve(Fp, [0, b]) 122 print('Computing cofactor') 123 cofactor = G1.order() // r 124 print('cofactor: 0x' + Integer(cofactor).hex()) 125 126 # slow for large inputs - https://pari.math.u-bordeaux.fr/archives/pari-dev-0412/msg00020.html 127 if curve_name != 'BW6_761': 128 print('Finding cube roots') 129 (phi1, phi2) = (Fp(root) for root in Fp(1).nth_root(3, all=True) if root != 1) 130 (lambda1, lambda2) = (GF(r)(root) for root in GF(r)(1).nth_root(3, all=True) if root != 1) 131 else: 132 print('Skip finding cube roots for BW6_761, too slow, use values from paper https://eprint.iacr.org/2020/351') 133 phi1 = Integer('0x531dc16c6ecd27aa846c61024e4cca6c1f31e53bd9603c2d17be416c5e4426ee4a737f73b6f952ab5e57926fa701848e0a235a0a398300c65759fc45183151f2f082d4dcb5e37cb6290012d96f8819c547ba8a4000002f962140000000002a') 134 phi2 = Integer('0xcfca638f1500e327035cdf02acb2744d06e68545f7e64c256ab7ae14297a1a823132b971cdefc65870636cb60d217ff87fa59308c07a8fab8579e02ed3cddca5b093ed79b1c57b5fe3f89c11811c1e214983de300000535e7bc00000000060') 135 lambda1 = Integer('0x9b3af05dd14f6ec619aaf7d34594aabc5ed1347970dec00452217cc900000008508c00000000001') 136 lambda2 = Integer('-0x9b3af05dd14f6ec619aaf7d34594aabc5ed1347970dec00452217cc900000008508c00000000002') 137 138 print('𝜑1 (mod p): 0x' + Integer(phi1).hex()) 139 print('λᵩ1 (mod r): 0x' + Integer(lambda1).hex()) 140 print('𝜑2 (mod p): 0x' + Integer(phi2).hex()) 141 print('λᵩ2 (mod r): 0x' + Integer(lambda2).hex()) 142 143 # TODO: is there a better way than spray-and-pray? 144 # TODO: Should we maximize or minimize lambda 145 # to maximize/minimize the scalar norm? 146 # TODO: Or is there a way to ensure 147 # that the Babai basis is mostly positive? 148 if lambda1 < lambda2: 149 lambda1, lambda2 = lambda2, lambda1 150 151 try: 152 check_cubic_root_endo(G1, Fp, r, cofactor, int(lambda1), phi1) 153 except: 154 print('Failure with:') 155 print(' 𝜑 (mod p): 0x' + Integer(phi1).hex()) 156 print(' λᵩ (mod r): 0x' + Integer(lambda1).hex()) 157 phi1, phi2 = phi2, phi1 158 check_cubic_root_endo(G1, Fp, r, cofactor, int(lambda1), phi1) 159 finally: 160 print('Success with:') 161 print(' 𝜑 (mod p): 0x' + Integer(phi1).hex()) 162 print(' λᵩ (mod r): 0x' + Integer(lambda1).hex()) 163 164 print('Deriving Lattice') 165 lattice = derive_lattice(r, lambda1, 2) 166 pretty_print_lattice(lattice) 167 168 print('Deriving Babai basis') 169 babai = derive_babai(r, lattice, 2) 170 pretty_print_babai(babai) 171 172 return phi1, lattice, babai 173 174 # G2 Endomorphism 175 # --------------------------------------------------------- 176 177 def genPsiEndo(curve_name, curve_config): 178 t = curve_config[curve_name]['field']['trace'] 179 r = curve_config[curve_name]['field']['order'] 180 k = curve_config[curve_name]['tower']['embedding_degree'] 181 182 # Decomposition factor depends on the embedding degree 183 m = CyclotomicField(k).degree() 184 # λψ is the trace of Frobenius - 1 185 lambda_psi = t - 1 186 187 print('Deriving Lattice') 188 lattice = derive_lattice(r, lambda_psi, m) 189 pretty_print_lattice(lattice) 190 191 print('Deriving Babai basis') 192 babai = derive_babai(r, lattice, m) 193 pretty_print_babai(babai) 194 195 return lattice, babai 196 197 # Dump 198 # --------------------------------------------------------- 199 200 def dumpLattice(lattice): 201 result = ' # (BigInt, isNeg)\n' 202 lastRow = lattice.nrows() - 1 203 lastCol = lattice.ncols() - 1 204 205 for rowID, row in enumerate(lattice): 206 for colID, val in enumerate(row): 207 result += ' ' 208 result += '(' if colID == 0 else ' ' 209 result += f'(BigInt[{max(1, int(abs(val)).bit_length())}].fromHex"0x{Integer(int(abs(val))).hex()}", ' 210 result += ('false' if val >= 0 else 'true') + ')' 211 result += ')' if colID == lastCol else '' 212 result += ',\n' if (rowID != lastRow or colID != lastCol) else '\n' 213 214 return result 215 216 def dumpBabai(vec): 217 result = ' # (BigInt, isNeg)\n' 218 lastRow = len(vec) - 1 219 220 for rowID, val in enumerate(vec): 221 result += ' ' 222 result += f'(BigInt[{max(1, int(abs(val)).bit_length())}].fromHex"0x{Integer(int(abs(val))).hex()}", ' 223 result += ('false' if val >= 0 else 'true') + ')' 224 result += ',\n' if rowID != lastRow else '\n' 225 226 return result 227 228 def dumpConst(name, inner): 229 result = f'const {name}* = (\n' 230 result += inner 231 result += ')\n' 232 233 return result 234 235 # CLI 236 # --------------------------------------------------------- 237 238 if __name__ == "__main__": 239 # Usage 240 # BLS12-381 241 # sage sage/derive_endomorphisms.sage BLS12_381 242 243 from argparse import ArgumentParser 244 245 parser = ArgumentParser() 246 parser.add_argument("curve",nargs="+") 247 args = parser.parse_args() 248 249 curve = args.curve[0] 250 251 if curve not in Curves: 252 raise ValueError( 253 curve + 254 ' is not one of the available curves: ' + 255 str(Curves.keys()) 256 ) 257 else: 258 print('\nPrecomputing G1 - 𝜑 (phi) cubic root endomorphism') 259 print('----------------------------------------------------\n') 260 cubeRootModP, g1lat, g1babai = genCubicRootEndo(curve, Curves) 261 print('\n\nPrecomputing G2 - ψ (Psi) - untwist-Frobenius-twist endomorphism') 262 print('----------------------------------------------------\n') 263 g2lat, g2babai = genPsiEndo(curve, Curves) 264 265 with open(f'{curve.lower()}_endomorphisms.nim', 'w') as f: 266 f.write(copyright()) 267 f.write('\n\n') 268 f.write(inspect.cleandoc(f""" 269 import 270 ../config/curves, 271 ../io/[io_bigints, io_fields] 272 273 # {curve} G1 274 # ------------------------------------------------------------ 275 """)) 276 f.write('\n\n') 277 f.write(inspect.cleandoc(f""" 278 const {curve}_cubicRootOfUnity_mod_p* = 279 Fp[{curve}].fromHex"0x{Integer(cubeRootModP).hex()}" 280 """)) 281 f.write('\n\n') 282 f.write(dumpConst( 283 f'{curve}_Lattice_G1', 284 dumpLattice(g1lat) 285 )) 286 f.write('\n') 287 f.write(dumpConst( 288 f'{curve}_Babai_G1', 289 dumpBabai(g1babai) 290 )) 291 f.write('\n\n') 292 f.write(inspect.cleandoc(f""" 293 # {curve} G2 294 # ------------------------------------------------------------ 295 """)) 296 f.write('\n\n') 297 f.write(dumpConst( 298 f'{curve}_Lattice_G2', 299 dumpLattice(g2lat) 300 )) 301 f.write('\n') 302 f.write(dumpConst( 303 f'{curve}_Babai_G2', 304 dumpBabai(g2babai) 305 ))