/ src / minisketch / doc / gen_params.sage
gen_params.sage
  1  #!/usr/bin/env sage
  2  r"""
  3  Generate finite field parameters for minisketch.
  4  
  5  This script selects the finite fields used by minisketch
  6   for various sizes and generates the required tables for
  7   the implementation.
  8  
  9  The output (after formatting) can be found in src/fields/*.cpp.
 10  
 11  """
 12  B.<b> = GF(2)
 13  P.<p> = B[]
 14  
 15  def apply_map(m, v):
 16      r = 0
 17      i = 0
 18      while v != 0:
 19          if (v & 1):
 20              r ^^= m[i]
 21          i += 1
 22          v >>= 1
 23      return r
 24  
 25  def recurse_moduli(acc, maxweight, maxdegree):
 26      for pos in range(maxweight, maxdegree + 1, 1):
 27          poly = acc + p^pos
 28          if maxweight == 1:
 29              if poly.is_irreducible():
 30                  return (pos, poly)
 31          else:
 32              (deg, ret) = recurse_moduli(poly, maxweight - 1, pos - 1)
 33              if ret is not None:
 34                  return (pos, ret)
 35      return (None, None)
 36  
 37  def compute_moduli(bits):
 38      # Return all optimal irreducible polynomials for GF(2^bits)
 39      # The result is a list of tuples (weight, degree of second-highest nonzero coefficient, polynomial)
 40      maxdegree = bits - 1
 41      result = []
 42      for weight in range(1, bits, 2):
 43          deg, res = None, None
 44          while True:
 45              ret = recurse_moduli(p^bits + 1, weight, maxdegree)
 46              if ret[0] is not None:
 47                  (deg, res) = ret
 48                  maxdegree = deg - 1
 49              else:
 50                  break
 51          if res is not None:
 52              result.append((weight + 2, deg, res))
 53      return result
 54  
 55  def bits_to_int(vals):
 56      ret = 0
 57      base = 1
 58      for val in vals:
 59          ret += Integer(val) * base
 60          base *= 2
 61      return ret
 62  
 63  def sqr_table(f, bits, n=1):
 64      ret = []
 65      for i in range(bits):
 66          ret.append((f^(2^n*i)).integer_representation())
 67      return ret
 68  
 69  # Compute x**(2**n)
 70  def pow2(x, n):
 71      for i in range(n):
 72          x = x**2
 73      return x
 74  
 75  def qrt_table(F, f, bits):
 76      # Table for solving x2 + x = a
 77      # This implements the technique from https://www.raco.cat/index.php/PublicacionsMatematiques/article/viewFile/37927/40412, Lemma 1
 78      for i in range(bits):
 79          if (f**i).trace() != 0:
 80              u = f**i
 81      ret = []
 82      for i in range(0, bits):
 83          d = f^i
 84          y = sum(pow2(d, j) * sum(pow2(u, k) for k in range(j)) for j in range(1, bits))
 85          ret.append(y.integer_representation() ^^ (y.integer_representation() & 1))
 86      return ret
 87  
 88  def conv_tables(F, NF, bits):
 89      # Generate a F(2) linear projection that maps elements from one field
 90      #  to an isomorphic field with a different modulus.
 91      f = F.gen()
 92      fp = f.minimal_polynomial()
 93      assert(fp == F.modulus())
 94      nfp = fp.change_ring(NF)
 95      nf = sorted(nfp.roots(multiplicities=False))[0]
 96      ret = []
 97      matrepr = [[B(0) for x in range(bits)] for y in range(bits)]
 98      for i in range(bits):
 99          val = (nf**i).integer_representation()
100          ret.append(val)
101          for j in range(bits):
102              matrepr[j][i] = B((val >> j) & 1)
103      mat = Matrix(matrepr).inverse().transpose()
104      ret2 = []
105      for i in range(bits):
106          ret2.append(bits_to_int(mat[i]))
107  
108      for t in range(100):
109          f1a = F.random_element()
110          f1b = F.random_element()
111          f1r = f1a * f1b
112          f2a = NF.fetch_int(apply_map(ret, f1a.integer_representation()))
113          f2b = NF.fetch_int(apply_map(ret, f1b.integer_representation()))
114          f2r = NF.fetch_int(apply_map(ret, f1r.integer_representation()))
115          f2s = f2a * f2b
116          assert(f2r == f2s)
117  
118      for t in range(100):
119          f2a = NF.random_element()
120          f2b = NF.random_element()
121          f2r = f2a * f2b
122          f1a = F.fetch_int(apply_map(ret2, f2a.integer_representation()))
123          f1b = F.fetch_int(apply_map(ret2, f2b.integer_representation()))
124          f1r = F.fetch_int(apply_map(ret2, f2r.integer_representation()))
125          f1s = f1a * f1b
126          assert(f1r == f1s)
127  
128      return (ret, ret2)
129  
130  def fmt(i,typ):
131      if i == 0:
132          return "0"
133      else:
134          return "0x%x" % i
135  
136  def lintranstype(typ, bits, maxtbl):
137      gsize = min(maxtbl, bits)
138      array_size = (bits + gsize - 1) // gsize
139      bits_list = []
140      total = 0
141      for i in range(array_size):
142          rsize = (bits - total + array_size - i - 1) // (array_size - i)
143          total += rsize
144          bits_list.append(rsize)
145      return "RecLinTrans<%s, %s>" % (typ, ", ".join("%i" % x for x in bits_list))
146  
147  INT=0
148  CLMUL=1
149  CLMUL_TRI=2
150  MD=3
151  
152  def print_modulus_md(mod):
153      ret = ""
154      pos = mod.degree()
155      for c in reversed(list(mod)):
156          if c:
157              if ret:
158                  ret += " + "
159              if pos == 0:
160                  ret += "1"
161              elif pos == 1:
162                  ret += "x"
163              else:
164                  ret += "x<sup>%i</sup>" % pos
165          pos -= 1
166      return ret
167  
168  def pick_modulus(bits, style):
169      # Choose the lexicographicly-first lowest-weight modulus
170      #  optionally subject to implementation specific constraints.
171      moduli = compute_moduli(bits)
172      if style == INT or style == MD:
173          multi_sqr = False
174          need_trans = False
175      elif style == CLMUL:
176          # Fast CLMUL reduction requires that bits + the highest
177          #  set bit are less than 66.
178          moduli = list(filter((lambda x: bits+x[1] <= 66), moduli)) + moduli
179          multi_sqr = True
180          need_trans = True
181          if not moduli or moduli[0][2].change_ring(ZZ)(2) == 3 + 2**bits:
182              # For modulus 3, CLMUL_TRI is obviously better.
183              return None
184      elif style == CLMUL_TRI:
185          moduli = list(filter(lambda x: bits+x[1] <= 66, moduli)) + moduli
186          moduli = list(filter(lambda x: x[0] == 3, moduli))
187          multi_sqr = True
188          need_trans = True
189      else:
190          assert(False)
191      if not moduli:
192          return None
193      return moduli[0][2]
194  
195  def print_result(bits, style):
196      if style == INT:
197          multi_sqr = False
198          need_trans = False
199          table_id = "%i" % bits
200      elif style == MD:
201          pass
202      elif style == CLMUL:
203          multi_sqr = True
204          need_trans = True
205          table_id = "%i" % bits
206      elif style == CLMUL_TRI:
207          multi_sqr = True
208          need_trans = True
209          table_id = "TRI%i" % bits
210      else:
211          assert(False)
212  
213      nmodulus = pick_modulus(bits, INT)
214      modulus = pick_modulus(bits, style)
215      if modulus is None:
216          return
217  
218      if style == MD:
219          print("* *%s*" % print_modulus_md(modulus))
220          return
221  
222      if bits > 32:
223          typ = "uint64_t"
224      elif bits > 16:
225          typ = "uint32_t"
226      elif bits > 8:
227          typ = "uint16_t"
228      else:
229          typ = "uint8_t"
230  
231      ttyp = lintranstype(typ, bits, 4)
232      rtyp = lintranstype(typ, bits, 6)
233  
234      F.<f> = GF(2**bits, modulus=modulus)
235  
236      include_table = True
237      if style != INT and style != CLMUL:
238          cmodulus = pick_modulus(bits, CLMUL)
239          if cmodulus == modulus:
240              include_table = False
241              table_id = "%i" % bits
242  
243      if include_table:
244          print("typedef %s StatTable%s;" % (rtyp, table_id))
245          rtyp = "StatTable%s" % table_id
246          if (style == INT):
247              print("typedef %s DynTable%s;" % (ttyp, table_id))
248              ttyp = "DynTable%s" % table_id
249  
250      if need_trans:
251          if modulus != nmodulus:
252              # If the bitstream modulus is not the best modulus for
253              #  this implementation a conversion table will be needed.
254              ctyp = rtyp
255              NF.<nf> = GF(2**bits, modulus=nmodulus)
256              ctables = conv_tables(NF, F, bits)
257              loadtbl = "&LOAD_TABLE_%s" % table_id
258              savetbl = "&SAVE_TABLE_%s" % table_id
259              if include_table:
260                  print("constexpr %s LOAD_TABLE_%s({%s});" % (ctyp, table_id, ", ".join([fmt(x,typ) for x in ctables[0]])))
261                  print("constexpr %s SAVE_TABLE_%s({%s});" % (ctyp, table_id, ", ".join([fmt(x,typ) for x in ctables[1]])))
262          else:
263              ctyp = "IdTrans"
264              loadtbl = "&ID_TRANS"
265              savetbl = "&ID_TRANS"
266      else:
267          assert(modulus == nmodulus)
268  
269      if include_table:
270          print("constexpr %s SQR_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 1)])))
271      if multi_sqr:
272          # Repeated squaring is a linearised polynomial so in F(2^n) it is
273          #  F(2) linear and can be computed by a simple bit-matrix.
274          # Repeated squaring is especially useful in powering ladders such as
275          #  for inversion.
276          # When certain repeated squaring tables are not in use, use the QRT
277          # table instead to make the C++ compiler happy (it always has the
278          # same type).
279          sqr2 = "&QRT_TABLE_%s" % table_id
280          sqr4 = "&QRT_TABLE_%s" % table_id
281          sqr8 = "&QRT_TABLE_%s" % table_id
282          sqr16 = "&QRT_TABLE_%s" % table_id
283          if ((bits - 1) >= 4):
284              if include_table:
285                  print("constexpr %s SQR2_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 2)])))
286              sqr2 = "&SQR2_TABLE_%s" % table_id
287          if ((bits - 1) >= 8):
288              if include_table:
289                  print("constexpr %s SQR4_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 4)])))
290              sqr4 = "&SQR4_TABLE_%s" % table_id
291          if ((bits - 1) >= 16):
292              if include_table:
293                  print("constexpr %s SQR8_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 8)])))
294              sqr8 = "&SQR8_TABLE_%s" % table_id
295          if ((bits - 1) >= 32):
296              if include_table:
297                  print("constexpr %s SQR16_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 16)])))
298              sqr16 = "&SQR16_TABLE_%s" % table_id
299      if include_table:
300          print("constexpr %s QRT_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in qrt_table(F, f, bits)])))
301  
302      modulus_weight = modulus.hamming_weight()
303      modulus_degree = (modulus - p**bits).degree()
304      modulus_int = (modulus - p**bits).change_ring(ZZ)(2)
305  
306      lfsr = ""
307  
308      if style == INT:
309          print("typedef Field<%s, %i, %i, %s, %s, &SQR_TABLE_%s, &QRT_TABLE_%s%s> Field%i;" % (typ, bits, modulus_int, rtyp, ttyp, table_id, table_id, lfsr, bits))
310      elif style == CLMUL:
311          print("typedef Field<%s, %i, %i, %s, &SQR_TABLE_%s, %s, %s, %s, %s, &QRT_TABLE_%s, %s, %s, %s%s> Field%i;" % (typ, bits, modulus_int, rtyp, table_id, sqr2, sqr4, sqr8, sqr16, table_id, ctyp, loadtbl, savetbl, lfsr, bits))
312      elif style == CLMUL_TRI:
313          print("typedef FieldTri<%s, %i, %i, %s, &SQR_TABLE_%s, %s, %s, %s, %s, &QRT_TABLE_%s, %s, %s, %s> FieldTri%i;" % (typ, bits, modulus_degree, rtyp, table_id, sqr2, sqr4, sqr8, sqr16, table_id, ctyp, loadtbl, savetbl, bits))
314      else:
315          assert(False)
316  
317  for bits in range(2, 65):
318      print("#ifdef ENABLE_FIELD_INT_%i" % bits)
319      print("// %i bit field" % bits)
320      print_result(bits, INT)
321      print("#endif")
322      print("")
323  
324  for bits in range(2, 65):
325      print("#ifdef ENABLE_FIELD_INT_%i" % bits)
326      print("// %i bit field" % bits)
327      print_result(bits, CLMUL)
328      print_result(bits, CLMUL_TRI)
329      print("#endif")
330      print("")
331  
332  for bits in range(2, 65):
333      print_result(bits, MD)