gen_params.sage
1 #!/usr/bin/env sage 2 r""" 3 Generate finite field parameters for minisketch. 4 5 This script selects the finite fields used by minisketch 6 for various sizes and generates the required tables for 7 the implementation. 8 9 The output (after formatting) can be found in src/fields/*.cpp. 10 11 """ 12 B.<b> = GF(2) 13 P.<p> = B[] 14 15 def apply_map(m, v): 16 r = 0 17 i = 0 18 while v != 0: 19 if (v & 1): 20 r ^^= m[i] 21 i += 1 22 v >>= 1 23 return r 24 25 def recurse_moduli(acc, maxweight, maxdegree): 26 for pos in range(maxweight, maxdegree + 1, 1): 27 poly = acc + p^pos 28 if maxweight == 1: 29 if poly.is_irreducible(): 30 return (pos, poly) 31 else: 32 (deg, ret) = recurse_moduli(poly, maxweight - 1, pos - 1) 33 if ret is not None: 34 return (pos, ret) 35 return (None, None) 36 37 def compute_moduli(bits): 38 # Return all optimal irreducible polynomials for GF(2^bits) 39 # The result is a list of tuples (weight, degree of second-highest nonzero coefficient, polynomial) 40 maxdegree = bits - 1 41 result = [] 42 for weight in range(1, bits, 2): 43 deg, res = None, None 44 while True: 45 ret = recurse_moduli(p^bits + 1, weight, maxdegree) 46 if ret[0] is not None: 47 (deg, res) = ret 48 maxdegree = deg - 1 49 else: 50 break 51 if res is not None: 52 result.append((weight + 2, deg, res)) 53 return result 54 55 def bits_to_int(vals): 56 ret = 0 57 base = 1 58 for val in vals: 59 ret += Integer(val) * base 60 base *= 2 61 return ret 62 63 def sqr_table(f, bits, n=1): 64 ret = [] 65 for i in range(bits): 66 ret.append((f^(2^n*i)).integer_representation()) 67 return ret 68 69 # Compute x**(2**n) 70 def pow2(x, n): 71 for i in range(n): 72 x = x**2 73 return x 74 75 def qrt_table(F, f, bits): 76 # Table for solving x2 + x = a 77 # This implements the technique from https://www.raco.cat/index.php/PublicacionsMatematiques/article/viewFile/37927/40412, Lemma 1 78 for i in range(bits): 79 if (f**i).trace() != 0: 80 u = f**i 81 ret = [] 82 for i in range(0, bits): 83 d = f^i 84 y = sum(pow2(d, j) * sum(pow2(u, k) for k in range(j)) for j in range(1, bits)) 85 ret.append(y.integer_representation() ^^ (y.integer_representation() & 1)) 86 return ret 87 88 def conv_tables(F, NF, bits): 89 # Generate a F(2) linear projection that maps elements from one field 90 # to an isomorphic field with a different modulus. 91 f = F.gen() 92 fp = f.minimal_polynomial() 93 assert(fp == F.modulus()) 94 nfp = fp.change_ring(NF) 95 nf = sorted(nfp.roots(multiplicities=False))[0] 96 ret = [] 97 matrepr = [[B(0) for x in range(bits)] for y in range(bits)] 98 for i in range(bits): 99 val = (nf**i).integer_representation() 100 ret.append(val) 101 for j in range(bits): 102 matrepr[j][i] = B((val >> j) & 1) 103 mat = Matrix(matrepr).inverse().transpose() 104 ret2 = [] 105 for i in range(bits): 106 ret2.append(bits_to_int(mat[i])) 107 108 for t in range(100): 109 f1a = F.random_element() 110 f1b = F.random_element() 111 f1r = f1a * f1b 112 f2a = NF.fetch_int(apply_map(ret, f1a.integer_representation())) 113 f2b = NF.fetch_int(apply_map(ret, f1b.integer_representation())) 114 f2r = NF.fetch_int(apply_map(ret, f1r.integer_representation())) 115 f2s = f2a * f2b 116 assert(f2r == f2s) 117 118 for t in range(100): 119 f2a = NF.random_element() 120 f2b = NF.random_element() 121 f2r = f2a * f2b 122 f1a = F.fetch_int(apply_map(ret2, f2a.integer_representation())) 123 f1b = F.fetch_int(apply_map(ret2, f2b.integer_representation())) 124 f1r = F.fetch_int(apply_map(ret2, f2r.integer_representation())) 125 f1s = f1a * f1b 126 assert(f1r == f1s) 127 128 return (ret, ret2) 129 130 def fmt(i,typ): 131 if i == 0: 132 return "0" 133 else: 134 return "0x%x" % i 135 136 def lintranstype(typ, bits, maxtbl): 137 gsize = min(maxtbl, bits) 138 array_size = (bits + gsize - 1) // gsize 139 bits_list = [] 140 total = 0 141 for i in range(array_size): 142 rsize = (bits - total + array_size - i - 1) // (array_size - i) 143 total += rsize 144 bits_list.append(rsize) 145 return "RecLinTrans<%s, %s>" % (typ, ", ".join("%i" % x for x in bits_list)) 146 147 INT=0 148 CLMUL=1 149 CLMUL_TRI=2 150 MD=3 151 152 def print_modulus_md(mod): 153 ret = "" 154 pos = mod.degree() 155 for c in reversed(list(mod)): 156 if c: 157 if ret: 158 ret += " + " 159 if pos == 0: 160 ret += "1" 161 elif pos == 1: 162 ret += "x" 163 else: 164 ret += "x<sup>%i</sup>" % pos 165 pos -= 1 166 return ret 167 168 def pick_modulus(bits, style): 169 # Choose the lexicographicly-first lowest-weight modulus 170 # optionally subject to implementation specific constraints. 171 moduli = compute_moduli(bits) 172 if style == INT or style == MD: 173 multi_sqr = False 174 need_trans = False 175 elif style == CLMUL: 176 # Fast CLMUL reduction requires that bits + the highest 177 # set bit are less than 66. 178 moduli = list(filter((lambda x: bits+x[1] <= 66), moduli)) + moduli 179 multi_sqr = True 180 need_trans = True 181 if not moduli or moduli[0][2].change_ring(ZZ)(2) == 3 + 2**bits: 182 # For modulus 3, CLMUL_TRI is obviously better. 183 return None 184 elif style == CLMUL_TRI: 185 moduli = list(filter(lambda x: bits+x[1] <= 66, moduli)) + moduli 186 moduli = list(filter(lambda x: x[0] == 3, moduli)) 187 multi_sqr = True 188 need_trans = True 189 else: 190 assert(False) 191 if not moduli: 192 return None 193 return moduli[0][2] 194 195 def print_result(bits, style): 196 if style == INT: 197 multi_sqr = False 198 need_trans = False 199 table_id = "%i" % bits 200 elif style == MD: 201 pass 202 elif style == CLMUL: 203 multi_sqr = True 204 need_trans = True 205 table_id = "%i" % bits 206 elif style == CLMUL_TRI: 207 multi_sqr = True 208 need_trans = True 209 table_id = "TRI%i" % bits 210 else: 211 assert(False) 212 213 nmodulus = pick_modulus(bits, INT) 214 modulus = pick_modulus(bits, style) 215 if modulus is None: 216 return 217 218 if style == MD: 219 print("* *%s*" % print_modulus_md(modulus)) 220 return 221 222 if bits > 32: 223 typ = "uint64_t" 224 elif bits > 16: 225 typ = "uint32_t" 226 elif bits > 8: 227 typ = "uint16_t" 228 else: 229 typ = "uint8_t" 230 231 ttyp = lintranstype(typ, bits, 4) 232 rtyp = lintranstype(typ, bits, 6) 233 234 F.<f> = GF(2**bits, modulus=modulus) 235 236 include_table = True 237 if style != INT and style != CLMUL: 238 cmodulus = pick_modulus(bits, CLMUL) 239 if cmodulus == modulus: 240 include_table = False 241 table_id = "%i" % bits 242 243 if include_table: 244 print("typedef %s StatTable%s;" % (rtyp, table_id)) 245 rtyp = "StatTable%s" % table_id 246 if (style == INT): 247 print("typedef %s DynTable%s;" % (ttyp, table_id)) 248 ttyp = "DynTable%s" % table_id 249 250 if need_trans: 251 if modulus != nmodulus: 252 # If the bitstream modulus is not the best modulus for 253 # this implementation a conversion table will be needed. 254 ctyp = rtyp 255 NF.<nf> = GF(2**bits, modulus=nmodulus) 256 ctables = conv_tables(NF, F, bits) 257 loadtbl = "&LOAD_TABLE_%s" % table_id 258 savetbl = "&SAVE_TABLE_%s" % table_id 259 if include_table: 260 print("constexpr %s LOAD_TABLE_%s({%s});" % (ctyp, table_id, ", ".join([fmt(x,typ) for x in ctables[0]]))) 261 print("constexpr %s SAVE_TABLE_%s({%s});" % (ctyp, table_id, ", ".join([fmt(x,typ) for x in ctables[1]]))) 262 else: 263 ctyp = "IdTrans" 264 loadtbl = "&ID_TRANS" 265 savetbl = "&ID_TRANS" 266 else: 267 assert(modulus == nmodulus) 268 269 if include_table: 270 print("constexpr %s SQR_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 1)]))) 271 if multi_sqr: 272 # Repeated squaring is a linearised polynomial so in F(2^n) it is 273 # F(2) linear and can be computed by a simple bit-matrix. 274 # Repeated squaring is especially useful in powering ladders such as 275 # for inversion. 276 # When certain repeated squaring tables are not in use, use the QRT 277 # table instead to make the C++ compiler happy (it always has the 278 # same type). 279 sqr2 = "&QRT_TABLE_%s" % table_id 280 sqr4 = "&QRT_TABLE_%s" % table_id 281 sqr8 = "&QRT_TABLE_%s" % table_id 282 sqr16 = "&QRT_TABLE_%s" % table_id 283 if ((bits - 1) >= 4): 284 if include_table: 285 print("constexpr %s SQR2_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 2)]))) 286 sqr2 = "&SQR2_TABLE_%s" % table_id 287 if ((bits - 1) >= 8): 288 if include_table: 289 print("constexpr %s SQR4_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 4)]))) 290 sqr4 = "&SQR4_TABLE_%s" % table_id 291 if ((bits - 1) >= 16): 292 if include_table: 293 print("constexpr %s SQR8_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 8)]))) 294 sqr8 = "&SQR8_TABLE_%s" % table_id 295 if ((bits - 1) >= 32): 296 if include_table: 297 print("constexpr %s SQR16_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 16)]))) 298 sqr16 = "&SQR16_TABLE_%s" % table_id 299 if include_table: 300 print("constexpr %s QRT_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in qrt_table(F, f, bits)]))) 301 302 modulus_weight = modulus.hamming_weight() 303 modulus_degree = (modulus - p**bits).degree() 304 modulus_int = (modulus - p**bits).change_ring(ZZ)(2) 305 306 lfsr = "" 307 308 if style == INT: 309 print("typedef Field<%s, %i, %i, %s, %s, &SQR_TABLE_%s, &QRT_TABLE_%s%s> Field%i;" % (typ, bits, modulus_int, rtyp, ttyp, table_id, table_id, lfsr, bits)) 310 elif style == CLMUL: 311 print("typedef Field<%s, %i, %i, %s, &SQR_TABLE_%s, %s, %s, %s, %s, &QRT_TABLE_%s, %s, %s, %s%s> Field%i;" % (typ, bits, modulus_int, rtyp, table_id, sqr2, sqr4, sqr8, sqr16, table_id, ctyp, loadtbl, savetbl, lfsr, bits)) 312 elif style == CLMUL_TRI: 313 print("typedef FieldTri<%s, %i, %i, %s, &SQR_TABLE_%s, %s, %s, %s, %s, &QRT_TABLE_%s, %s, %s, %s> FieldTri%i;" % (typ, bits, modulus_degree, rtyp, table_id, sqr2, sqr4, sqr8, sqr16, table_id, ctyp, loadtbl, savetbl, bits)) 314 else: 315 assert(False) 316 317 for bits in range(2, 65): 318 print("#ifdef ENABLE_FIELD_INT_%i" % bits) 319 print("// %i bit field" % bits) 320 print_result(bits, INT) 321 print("#endif") 322 print("") 323 324 for bits in range(2, 65): 325 print("#ifdef ENABLE_FIELD_INT_%i" % bits) 326 print("// %i bit field" % bits) 327 print_result(bits, CLMUL) 328 print_result(bits, CLMUL_TRI) 329 print("#endif") 330 print("") 331 332 for bits in range(2, 65): 333 print_result(bits, MD)