/ sage / testgen_scalar_mul.sage
testgen_scalar_mul.sage
  1  #!/usr/bin/sage
  2  # vim: syntax=python
  3  # vim: set ts=2 sw=2 et:
  4  
  5  # Constantine
  6  # Copyright (c) 2018-2019    Status Research & Development GmbH
  7  # Copyright (c) 2020-Present Mamy André-Ratsimbazafy
  8  # Licensed and distributed under either of
  9  #   * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
 10  #   * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
 11  # at your option. This file may not be copied, modified, or distributed except according to those terms.
 12  
 13  # ############################################################
 14  #
 15  #            Scalar multiplication test generator
 16  #
 17  # ############################################################
 18  
 19  # Imports
 20  # ---------------------------------------------------------
 21  
 22  import os, json
 23  import inspect, textwrap
 24  
 25  # Working directory
 26  # ---------------------------------------------------------
 27  
 28  os.chdir(os.path.dirname(__file__))
 29  
 30  # Sage imports
 31  # ---------------------------------------------------------
 32  # Accelerate arithmetic by accepting probabilistic proofs
 33  from sage.structure.proof.all import arithmetic
 34  arithmetic(False)
 35  
 36  load('curves.sage')
 37  
 38  # Utilities
 39  # ---------------------------------------------------------
 40  
 41  def progressbar(it, prefix="", size=60, file=sys.stdout):
 42    count = len(it)
 43    def show(j):
 44      x = int(size*j/count)
 45      file.write("%s[%s%s] %i/%i\r" % (prefix, "#"*x, "."*(size-x), j, count))
 46      file.flush()
 47    show(0)
 48    for i, item in enumerate(it):
 49      yield item
 50      show(i+1)
 51    file.write("\n")
 52    file.flush()
 53  
 54  def serialize_bigint(x):
 55    return '0x' + Integer(x).hex()
 56  
 57  def serialize_EC_Fp(P):
 58    (Px, Py, Pz) = P
 59    coords = {
 60      'x': serialize_bigint(Px),
 61      'y': serialize_bigint(Py)
 62    }
 63    return coords
 64  
 65  def serialize_EC_Fp2(P):
 66    (Px, Py, Pz) = P
 67    Px = vector(Px)
 68    Py = vector(Py)
 69    coords = {
 70      'x': {
 71        'c0': serialize_bigint(Px[0]),
 72        'c1': serialize_bigint(Px[1])
 73      },
 74      'y': {
 75        'c0': serialize_bigint(Py[0]),
 76        'c1': serialize_bigint(Py[1])
 77      }
 78    }
 79    return coords
 80  
 81  # Generator
 82  # ---------------------------------------------------------
 83  
 84  def genScalarMulG1(curve_name, curve_config, count, seed, scalarBits = None):
 85    p = curve_config[curve_name]['field']['modulus']
 86    r = curve_config[curve_name]['field']['order']
 87    form = curve_config[curve_name]['curve']['form']
 88    a = curve_config[curve_name]['curve']['a']
 89    b = curve_config[curve_name]['curve']['b']
 90  
 91    Fp = GF(p)
 92    G1 = EllipticCurve(Fp, [0, b])
 93    cofactor = G1.order() // r
 94  
 95    out = {
 96      'curve': curve_name,
 97      'group': 'G1',
 98      'modulus': serialize_bigint(p),
 99      'order': serialize_bigint(r),
100      'cofactor': serialize_bigint(cofactor),
101      'form': form
102    }
103    if form == 'short_weierstrass':
104      out['a'] = serialize_bigint(a)
105      out['b'] = serialize_bigint(b)
106  
107    vectors = []
108    set_random_seed(seed)
109    for i in progressbar(range(count)):
110      v = {}
111      P = G1.random_point()
112      scalar = randrange(1 << scalarBits) if scalarBits else randrange(r)
113  
114      P *= cofactor # clear cofactor
115      Q = scalar * P
116  
117      v['id'] = i
118      v['P'] = serialize_EC_Fp(P)
119      v['scalarBits'] = scalarBits if scalarBits else r.bit_length()
120      v['scalar'] = serialize_bigint(scalar)
121      v['Q'] = serialize_EC_Fp(Q)
122      vectors.append(v)
123  
124    out['vectors'] = vectors
125    return out
126  
127  def genScalarMulG2(curve_name, curve_config, count, seed, scalarBits = None):
128    p = curve_config[curve_name]['field']['modulus']
129    r = curve_config[curve_name]['field']['order']
130    form = curve_config[curve_name]['curve']['form']
131    a = curve_config[curve_name]['curve']['a']
132    b = curve_config[curve_name]['curve']['b']
133    embedding_degree = curve_config[curve_name]['tower']['embedding_degree']
134    twist_degree = curve_config[curve_name]['tower']['twist_degree']
135    twist = curve_config[curve_name]['tower']['twist']
136  
137    G2_field_degree = embedding_degree // twist_degree
138    G2_field = f'Fp{G2_field_degree}' if G2_field_degree > 1 else 'Fp'
139  
140    if G2_field_degree == 2:
141      non_residue_fp = curve_config[curve_name]['tower']['QNR_Fp']
142    elif G2_field_degree == 1:
143      if twist_degree == 6:
144        # Only for complete serialization
145        non_residue_fp = curve_config[curve_name]['tower']['SNR_Fp']
146      else:
147        raise NotImplementedError()
148    else:
149      raise NotImplementedError()
150  
151    Fp = GF(p)
152    K.<u> = PolynomialRing(Fp)
153  
154    if G2_field == 'Fp2':
155      Fp2.<beta> = Fp.extension(u^2 - non_residue_fp)
156      G2F = Fp2
157      if twist_degree == 6:
158        non_residue_twist = curve_config[curve_name]['tower']['SNR_Fp2']
159      else:
160        raise NotImplementedError()
161    elif G2_field == 'Fp':
162      G2F = Fp
163      if twist_degree == 6:
164        non_residue_twist = curve_config[curve_name]['tower']['SNR_Fp']
165      else:
166        raise NotImplementedError()
167    else:
168      raise NotImplementedError()
169  
170    if twist == 'D_Twist':
171      G2 = EllipticCurve(G2F, [0, b/G2F(non_residue_twist)])
172    elif twist == 'M_Twist':
173      G2 = EllipticCurve(G2F, [0, b*G2F(non_residue_twist)])
174    else:
175      raise ValueError('G2 must be a D_Twist or M_Twist but found ' + twist)
176  
177    cofactor = G2.order() // r
178  
179    out = {
180      'curve': curve_name,
181      'group': 'G2',
182      'modulus': serialize_bigint(p),
183      'order': serialize_bigint(r),
184      'cofactor': serialize_bigint(cofactor),
185      'form': form,
186      'twist_degree': int(twist_degree),
187      'twist': twist,
188      'non_residue_fp': int(non_residue_fp),
189      'G2_field': G2_field,
190      'non_residue_twist': [int(coord) for coord in non_residue_twist] if isinstance(non_residue_twist, list) else int(non_residue_twist)
191    }
192    if form == 'short_weierstrass':
193      out['a'] = serialize_bigint(a)
194      out['b'] = serialize_bigint(b)
195  
196    vectors = []
197    set_random_seed(seed)
198    for i in progressbar(range(count)):
199        v = {}
200        P = G2.random_point()
201        scalar = randrange(1 << scalarBits) if scalarBits else randrange(r)
202  
203        P *= cofactor # clear cofactor
204        Q = scalar * P
205  
206        v['id'] = i
207        if G2_field == 'Fp2':
208          v['P'] = serialize_EC_Fp2(P)
209          v['scalarBits'] = scalarBits if scalarBits else r.bit_length()
210          v['scalar'] = serialize_bigint(scalar)
211          v['Q'] = serialize_EC_Fp2(Q)
212        elif G2_field == 'Fp':
213          v['P'] = serialize_EC_Fp(P)
214          v['scalarBits'] = scalarBits if scalarBits else r.bit_length()
215          v['scalar'] = serialize_bigint(scalar)
216          v['Q'] = serialize_EC_Fp(Q)
217        vectors.append(v)
218  
219    out['vectors'] = vectors
220    return out
221  
222  # CLI
223  # ---------------------------------------------------------
224  
225  if __name__ == "__main__":
226    # Usage
227    # BLS12-381
228    # sage sage/testgen_scalar_mul.sage BLS12_381 G1 {scalarBits: optional int}
229  
230    from argparse import ArgumentParser
231  
232    parser = ArgumentParser()
233    parser.add_argument("curve",nargs="+")
234    args = parser.parse_args()
235  
236    curve = args.curve[0]
237    group = args.curve[1]
238    scalarBits = None
239    if len(args.curve) > 2:
240      scalarBits = int(args.curve[2])
241  
242    if curve not in Curves:
243      raise ValueError(
244        curve +
245        ' is not one of the available curves: ' +
246        str(Curves.keys())
247      )
248    elif group not in ['G1', 'G2']:
249      raise ValueError(
250        group +
251        ' is not a valid group, expected G1 or G2 instead'
252      )
253    else:
254      bits = scalarBits if scalarBits else Curves[curve]['field']['order'].bit_length()
255      print(f'\nGenerating test vectors tv_{curve}_scalar_mul_{group}_{bits}bit.json')
256      print('----------------------------------------------------\n')
257  
258      count = 40
259      seed = 1337
260  
261      if group == 'G1':
262        out = genScalarMulG1(curve, Curves, count, seed, scalarBits)
263      elif group == 'G2':
264        out = genScalarMulG2(curve, Curves, count, seed, scalarBits)
265  
266      with open(f'tv_{curve}_scalar_mul_{group}_{bits}bits.json', 'w') as f:
267        json.dump(out, f, indent=2)