testgen_scalar_mul.sage
1 #!/usr/bin/sage 2 # vim: syntax=python 3 # vim: set ts=2 sw=2 et: 4 5 # Constantine 6 # Copyright (c) 2018-2019 Status Research & Development GmbH 7 # Copyright (c) 2020-Present Mamy André-Ratsimbazafy 8 # Licensed and distributed under either of 9 # * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). 10 # * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). 11 # at your option. This file may not be copied, modified, or distributed except according to those terms. 12 13 # ############################################################ 14 # 15 # Scalar multiplication test generator 16 # 17 # ############################################################ 18 19 # Imports 20 # --------------------------------------------------------- 21 22 import os, json 23 import inspect, textwrap 24 25 # Working directory 26 # --------------------------------------------------------- 27 28 os.chdir(os.path.dirname(__file__)) 29 30 # Sage imports 31 # --------------------------------------------------------- 32 # Accelerate arithmetic by accepting probabilistic proofs 33 from sage.structure.proof.all import arithmetic 34 arithmetic(False) 35 36 load('curves.sage') 37 38 # Utilities 39 # --------------------------------------------------------- 40 41 def progressbar(it, prefix="", size=60, file=sys.stdout): 42 count = len(it) 43 def show(j): 44 x = int(size*j/count) 45 file.write("%s[%s%s] %i/%i\r" % (prefix, "#"*x, "."*(size-x), j, count)) 46 file.flush() 47 show(0) 48 for i, item in enumerate(it): 49 yield item 50 show(i+1) 51 file.write("\n") 52 file.flush() 53 54 def serialize_bigint(x): 55 return '0x' + Integer(x).hex() 56 57 def serialize_EC_Fp(P): 58 (Px, Py, Pz) = P 59 coords = { 60 'x': serialize_bigint(Px), 61 'y': serialize_bigint(Py) 62 } 63 return coords 64 65 def serialize_EC_Fp2(P): 66 (Px, Py, Pz) = P 67 Px = vector(Px) 68 Py = vector(Py) 69 coords = { 70 'x': { 71 'c0': serialize_bigint(Px[0]), 72 'c1': serialize_bigint(Px[1]) 73 }, 74 'y': { 75 'c0': serialize_bigint(Py[0]), 76 'c1': serialize_bigint(Py[1]) 77 } 78 } 79 return coords 80 81 # Generator 82 # --------------------------------------------------------- 83 84 def genScalarMulG1(curve_name, curve_config, count, seed, scalarBits = None): 85 p = curve_config[curve_name]['field']['modulus'] 86 r = curve_config[curve_name]['field']['order'] 87 form = curve_config[curve_name]['curve']['form'] 88 a = curve_config[curve_name]['curve']['a'] 89 b = curve_config[curve_name]['curve']['b'] 90 91 Fp = GF(p) 92 G1 = EllipticCurve(Fp, [0, b]) 93 cofactor = G1.order() // r 94 95 out = { 96 'curve': curve_name, 97 'group': 'G1', 98 'modulus': serialize_bigint(p), 99 'order': serialize_bigint(r), 100 'cofactor': serialize_bigint(cofactor), 101 'form': form 102 } 103 if form == 'short_weierstrass': 104 out['a'] = serialize_bigint(a) 105 out['b'] = serialize_bigint(b) 106 107 vectors = [] 108 set_random_seed(seed) 109 for i in progressbar(range(count)): 110 v = {} 111 P = G1.random_point() 112 scalar = randrange(1 << scalarBits) if scalarBits else randrange(r) 113 114 P *= cofactor # clear cofactor 115 Q = scalar * P 116 117 v['id'] = i 118 v['P'] = serialize_EC_Fp(P) 119 v['scalarBits'] = scalarBits if scalarBits else r.bit_length() 120 v['scalar'] = serialize_bigint(scalar) 121 v['Q'] = serialize_EC_Fp(Q) 122 vectors.append(v) 123 124 out['vectors'] = vectors 125 return out 126 127 def genScalarMulG2(curve_name, curve_config, count, seed, scalarBits = None): 128 p = curve_config[curve_name]['field']['modulus'] 129 r = curve_config[curve_name]['field']['order'] 130 form = curve_config[curve_name]['curve']['form'] 131 a = curve_config[curve_name]['curve']['a'] 132 b = curve_config[curve_name]['curve']['b'] 133 embedding_degree = curve_config[curve_name]['tower']['embedding_degree'] 134 twist_degree = curve_config[curve_name]['tower']['twist_degree'] 135 twist = curve_config[curve_name]['tower']['twist'] 136 137 G2_field_degree = embedding_degree // twist_degree 138 G2_field = f'Fp{G2_field_degree}' if G2_field_degree > 1 else 'Fp' 139 140 if G2_field_degree == 2: 141 non_residue_fp = curve_config[curve_name]['tower']['QNR_Fp'] 142 elif G2_field_degree == 1: 143 if twist_degree == 6: 144 # Only for complete serialization 145 non_residue_fp = curve_config[curve_name]['tower']['SNR_Fp'] 146 else: 147 raise NotImplementedError() 148 else: 149 raise NotImplementedError() 150 151 Fp = GF(p) 152 K.<u> = PolynomialRing(Fp) 153 154 if G2_field == 'Fp2': 155 Fp2.<beta> = Fp.extension(u^2 - non_residue_fp) 156 G2F = Fp2 157 if twist_degree == 6: 158 non_residue_twist = curve_config[curve_name]['tower']['SNR_Fp2'] 159 else: 160 raise NotImplementedError() 161 elif G2_field == 'Fp': 162 G2F = Fp 163 if twist_degree == 6: 164 non_residue_twist = curve_config[curve_name]['tower']['SNR_Fp'] 165 else: 166 raise NotImplementedError() 167 else: 168 raise NotImplementedError() 169 170 if twist == 'D_Twist': 171 G2 = EllipticCurve(G2F, [0, b/G2F(non_residue_twist)]) 172 elif twist == 'M_Twist': 173 G2 = EllipticCurve(G2F, [0, b*G2F(non_residue_twist)]) 174 else: 175 raise ValueError('G2 must be a D_Twist or M_Twist but found ' + twist) 176 177 cofactor = G2.order() // r 178 179 out = { 180 'curve': curve_name, 181 'group': 'G2', 182 'modulus': serialize_bigint(p), 183 'order': serialize_bigint(r), 184 'cofactor': serialize_bigint(cofactor), 185 'form': form, 186 'twist_degree': int(twist_degree), 187 'twist': twist, 188 'non_residue_fp': int(non_residue_fp), 189 'G2_field': G2_field, 190 'non_residue_twist': [int(coord) for coord in non_residue_twist] if isinstance(non_residue_twist, list) else int(non_residue_twist) 191 } 192 if form == 'short_weierstrass': 193 out['a'] = serialize_bigint(a) 194 out['b'] = serialize_bigint(b) 195 196 vectors = [] 197 set_random_seed(seed) 198 for i in progressbar(range(count)): 199 v = {} 200 P = G2.random_point() 201 scalar = randrange(1 << scalarBits) if scalarBits else randrange(r) 202 203 P *= cofactor # clear cofactor 204 Q = scalar * P 205 206 v['id'] = i 207 if G2_field == 'Fp2': 208 v['P'] = serialize_EC_Fp2(P) 209 v['scalarBits'] = scalarBits if scalarBits else r.bit_length() 210 v['scalar'] = serialize_bigint(scalar) 211 v['Q'] = serialize_EC_Fp2(Q) 212 elif G2_field == 'Fp': 213 v['P'] = serialize_EC_Fp(P) 214 v['scalarBits'] = scalarBits if scalarBits else r.bit_length() 215 v['scalar'] = serialize_bigint(scalar) 216 v['Q'] = serialize_EC_Fp(Q) 217 vectors.append(v) 218 219 out['vectors'] = vectors 220 return out 221 222 # CLI 223 # --------------------------------------------------------- 224 225 if __name__ == "__main__": 226 # Usage 227 # BLS12-381 228 # sage sage/testgen_scalar_mul.sage BLS12_381 G1 {scalarBits: optional int} 229 230 from argparse import ArgumentParser 231 232 parser = ArgumentParser() 233 parser.add_argument("curve",nargs="+") 234 args = parser.parse_args() 235 236 curve = args.curve[0] 237 group = args.curve[1] 238 scalarBits = None 239 if len(args.curve) > 2: 240 scalarBits = int(args.curve[2]) 241 242 if curve not in Curves: 243 raise ValueError( 244 curve + 245 ' is not one of the available curves: ' + 246 str(Curves.keys()) 247 ) 248 elif group not in ['G1', 'G2']: 249 raise ValueError( 250 group + 251 ' is not a valid group, expected G1 or G2 instead' 252 ) 253 else: 254 bits = scalarBits if scalarBits else Curves[curve]['field']['order'].bit_length() 255 print(f'\nGenerating test vectors tv_{curve}_scalar_mul_{group}_{bits}bit.json') 256 print('----------------------------------------------------\n') 257 258 count = 40 259 seed = 1337 260 261 if group == 'G1': 262 out = genScalarMulG1(curve, Curves, count, seed, scalarBits) 263 elif group == 'G2': 264 out = genScalarMulG2(curve, Curves, count, seed, scalarBits) 265 266 with open(f'tv_{curve}_scalar_mul_{group}_{bits}bits.json', 'w') as f: 267 json.dump(out, f, indent=2)