/ src / minisketch / tests / pyminisketch.py
pyminisketch.py
  1  #!/usr/bin/env python3
  2  # Copyright (c) 2020 Pieter Wuille
  3  # Distributed under the MIT software license, see the accompanying
  4  # file LICENSE or http://www.opensource.org/licenses/mit-license.php.
  5  
  6  """Native Python (slow) reimplementation of libminisketch' algorithms."""
  7  
  8  import random
  9  import unittest
 10  
 11  # Irreducible polynomials over GF(2) to use (represented as integers).
 12  #
 13  # Most fields can be defined by multiple such polynomials. Minisketch uses the one with the minimal
 14  # number of nonzero coefficients, and tie-breaking by picking the lexicographically first among
 15  # those.
 16  #
 17  # All polynomials for degrees 2 through 64 (inclusive) are given.
 18  GF2_MODULI = [
 19      None, None,
 20      2**2 + 2**1 + 1,
 21      2**3 + 2**1 + 1,
 22      2**4 + 2**1 + 1,
 23      2**5 + 2**2 + 1,
 24      2**6 + 2**1 + 1,
 25      2**7 + 2**1 + 1,
 26      2**8 + 2**4 + 2**3 + 2**1 + 1,
 27      2**9 + 2**1 + 1,
 28      2**10 + 2**3 + 1,
 29      2**11 + 2**2 + 1,
 30      2**12 + 2**3 + 1,
 31      2**13 + 2**4 + 2**3 + 2**1 + 1,
 32      2**14 + 2**5 + 1,
 33      2**15 + 2**1 + 1,
 34      2**16 + 2**5 + 2**3 + 2**1 + 1,
 35      2**17 + 2**3 + 1,
 36      2**18 + 2**3 + 1,
 37      2**19 + 2**5 + 2**2 + 2**1 + 1,
 38      2**20 + 2**3 + 1,
 39      2**21 + 2**2 + 1,
 40      2**22 + 2**1 + 1,
 41      2**23 + 2**5 + 1,
 42      2**24 + 2**4 + 2**3 + 2**1 + 1,
 43      2**25 + 2**3 + 1,
 44      2**26 + 2**4 + 2**3 + 2**1 + 1,
 45      2**27 + 2**5 + 2**2 + 2**1 + 1,
 46      2**28 + 2**1 + 1,
 47      2**29 + 2**2 + 1,
 48      2**30 + 2**1 + 1,
 49      2**31 + 2**3 + 1,
 50      2**32 + 2**7 + 2**3 + 2**2 + 1,
 51      2**33 + 2**10 + 1,
 52      2**34 + 2**7 + 1,
 53      2**35 + 2**2 + 1,
 54      2**36 + 2**9 + 1,
 55      2**37 + 2**6 + 2**4 + 2**1 + 1,
 56      2**38 + 2**6 + 2**5 + 2**1 + 1,
 57      2**39 + 2**4 + 1,
 58      2**40 + 2**5 + 2**4 + 2**3 + 1,
 59      2**41 + 2**3 + 1,
 60      2**42 + 2**7 + 1,
 61      2**43 + 2**6 + 2**4 + 2**3 + 1,
 62      2**44 + 2**5 + 1,
 63      2**45 + 2**4 + 2**3 + 2**1 + 1,
 64      2**46 + 2**1 + 1,
 65      2**47 + 2**5 + 1,
 66      2**48 + 2**5 + 2**3 + 2**2 + 1,
 67      2**49 + 2**9 + 1,
 68      2**50 + 2**4 + 2**3 + 2**2 + 1,
 69      2**51 + 2**6 + 2**3 + 2**1 + 1,
 70      2**52 + 2**3 + 1,
 71      2**53 + 2**6 + 2**2 + 2**1 + 1,
 72      2**54 + 2**9 + 1,
 73      2**55 + 2**7 + 1,
 74      2**56 + 2**7 + 2**4 + 2**2 + 1,
 75      2**57 + 2**4 + 1,
 76      2**58 + 2**19 + 1,
 77      2**59 + 2**7 + 2**4 + 2**2 + 1,
 78      2**60 + 2**1 + 1,
 79      2**61 + 2**5 + 2**2 + 2**1 + 1,
 80      2**62 + 2**29 + 1,
 81      2**63 + 2**1 + 1,
 82      2**64 + 2**4 + 2**3 + 2**1 + 1
 83  ]
 84  
 85  class GF2Ops:
 86      """Class to perform GF(2^field_size) operations on elements represented as integers.
 87  
 88      Given that elements are represented as integers, addition is simply xor, and not
 89      exposed here.
 90      """
 91  
 92      def __init__(self, field_size):
 93          """Construct a GF2Ops object for the specified field size."""
 94          self.field_size = field_size
 95          self._modulus = GF2_MODULI[field_size]
 96          assert self._modulus is not None
 97  
 98      def mul2(self, x):
 99          """Multiply x by 2 in GF(2^field_size)."""
100          x <<= 1
101          if x >> self.field_size:
102              x ^= self._modulus
103          return x
104  
105      def mul(self, x, y):
106          """Multiply x by y in GF(2^field_size)."""
107          ret = 0
108          while y:
109              if y & 1:
110                  ret ^= x
111              y >>= 1
112              x = self.mul2(x)
113          return ret
114  
115      def sqr(self, x):
116          """Square x in GF(2^field_size)."""
117          return self.mul(x, x)
118  
119      def inv(self, x):
120          """Compute the inverse of x in GF(2^field_size)."""
121          assert x != 0
122          # Use the extended polynomial Euclidean GCD algorithm on (modulus, x), over GF(2).
123          # See https://en.wikipedia.org/wiki/Polynomial_greatest_common_divisor.
124          t1, t2 = 0, 1
125          r1, r2 = self._modulus, x
126          r1l, r2l = self.field_size + 1, r2.bit_length()
127          while r2:
128              q = r1l - r2l
129              r1 ^= r2 << q
130              t1 ^= t2 << q
131              r1l = r1.bit_length()
132              if r1 < r2:
133                  t1, t2 = t2, t1
134                  r1, r2 = r2, r1
135                  r1l, r2l = r2l, r1l
136          assert r1 == 1
137          return t1
138  
139  class TestGF2Ops(unittest.TestCase):
140      """Test class for basic arithmetic properties of GF2Ops."""
141  
142      def field_size_test(self, field_size):
143          """Test operations for given field_size."""
144  
145          gf = GF2Ops(field_size)
146          for i in range(100):
147              x = random.randrange(1 << field_size)
148              y = random.randrange(1 << field_size)
149              x2 = gf.mul2(x)
150              xy = gf.mul(x, y)
151              self.assertEqual(x2, gf.mul(x, 2)) # mul2(x) == x*2
152              self.assertEqual(x2, gf.mul(2, x)) # mul2(x) == 2*x
153              self.assertEqual(xy == 0, x == 0 or y == 0)
154              self.assertEqual(xy == x, y == 1 or x == 0)
155              self.assertEqual(xy == y, x == 1 or y == 0)
156              self.assertEqual(xy, gf.mul(y, x)) # x*y == y*x
157              if i < 10:
158                  xp = x
159                  for _ in range(field_size):
160                      xp = gf.sqr(xp)
161                  self.assertEqual(xp, x) # x^(2^field_size) == x
162              if y != 0:
163                  yi = gf.inv(y)
164                  self.assertEqual(y == yi, y == 1) # y==1/x iff y==1
165                  self.assertEqual(gf.mul(y, yi), 1) # y*(1/y) == 1
166                  yii = gf.inv(yi)
167                  self.assertEqual(y, yii) # 1/(1/y) == y
168                  if x != 0:
169                      xi = gf.inv(x)
170                      xyi = gf.inv(xy)
171                      self.assertEqual(xyi, gf.mul(xi, yi)) # (1/x)*(1/y) == 1/(x*y)
172  
173      def test(self):
174          """Run tests."""
175          for field_size in range(2, 65):
176              self.field_size_test(field_size)
177  
178  # The operations below operate on polynomials over GF(2^field_size), represented as lists of
179  # integers:
180  #
181  #   [a, b, c, ...] = a + b*x + c*x^2 + ...
182  #
183  # As an invariant, there are never any trailing zeroes in the list representation.
184  #
185  # Examples:
186  # * [] = 0
187  # * [3] = 3
188  # * [0, 1] = x
189  # * [2, 0, 5] = 5*x^2 + 2
190  
191  def poly_monic(poly, gf):
192      """Return a monic version of the polynomial poly."""
193      # Multiply every coefficient with the inverse of the top coefficient.
194      inv = gf.inv(poly[-1])
195      return [gf.mul(inv, v) for v in poly]
196  
197  def poly_divmod(poly, mod, gf):
198      """Return the polynomial (quotient, remainder) of poly divided by mod."""
199      assert len(mod) > 0 and mod[-1] == 1 # Require monic mod.
200      if len(poly) < len(mod):
201          return ([], poly)
202      val = list(poly)
203      div = [0 for _ in range(len(val) - len(mod) + 1)]
204      while len(val) >= len(mod):
205          term = val[-1]
206          div[len(val) - len(mod)] = term
207          # If the highest coefficient in val is nonzero, subtract a multiple of mod from it.
208          val.pop()
209          if term != 0:
210              for x in range(len(mod) - 1):
211                  val[1 + x - len(mod)] ^= gf.mul(term, mod[x])
212      # Prune trailing zero coefficients.
213      while len(val) > 0 and val[-1] == 0:
214          val.pop()
215      return div, val
216  
217  def poly_gcd(a, b, gf):
218      """Return the polynomial GCD of a and b."""
219      if len(a) < len(b):
220          a, b = b, a
221      # Use Euclid's algorithm to find the GCD of a and b.
222      # see https://en.wikipedia.org/wiki/Polynomial_greatest_common_divisor#Euclid's_algorithm.
223      while len(b) > 0:
224          b = poly_monic(b, gf)
225          (_, b), a = poly_divmod(a, b, gf), b
226      return a
227  
228  def poly_sqr(poly, gf):
229      """Return the square of polynomial poly."""
230      if len(poly) == 0:
231          return []
232      # In characteristic-2 fields, thanks to Frobenius' endomorphism ((a + b)^2 = a^2 + b^2),
233      # squaring a polynomial is easy: square all the coefficients and interleave with zeroes.
234      # E.g., (3 + 5*x + 17*x^2)^2 = 3^2 + (5*x)^2 + (17*x^2)^2.
235      # See https://en.wikipedia.org/wiki/Frobenius_endomorphism.
236      return [0 if i & 1 else gf.sqr(poly[i // 2]) for i in range(2 * len(poly) - 1)]
237  
238  def poly_tracemod(poly, param, gf):
239      """Compute y + y^2 + y^4 + ... + y^(2^(field_size-1)) mod poly, where y = param*x."""
240      out = [0, param]
241      for _ in range(gf.field_size - 1):
242          # In each loop iteration i, we start with out = y + y^2 + ... + y^(2^i). By squaring that we
243          # transform it into out = y^2 + y^4 + ... + y^(2^(i+1)).
244          out = poly_sqr(out, gf)
245          # Thus, we just need to add y again to it to get out = y + ... + y^(2^(i+1)).
246          while len(out) < 2:
247              out.append(0)
248          out[1] = param
249          # Finally take a modulus to keep the intermediary polynomials small.
250          _, out = poly_divmod(out, poly, gf)
251      return out
252  
253  def poly_frobeniusmod(poly, gf):
254      """Compute x^(2^field_size) mod poly."""
255      out = [0, 1]
256      for _ in range(gf.field_size):
257          _, out = poly_divmod(poly_sqr(out, gf), poly, gf)
258      return out
259  
260  def poly_find_roots(poly, gf):
261      """Find the roots of poly if fully factorizable with unique roots, [] otherwise."""
262      assert len(poly) > 0
263      # If the polynomial is constant (and nonzero), it has no roots.
264      if len(poly) == 1:
265          return []
266      # Make the polynomial monic (which doesn't change its roots).
267      poly = poly_monic(poly, gf)
268      # If the polynomial is of the form x+a, return a.
269      if len(poly) == 2:
270          return [poly[0]]
271      # Otherwise, first test that poly can be completely factored into unique roots. The polynomial
272      # x^(2^fieldsize)-x has every field element once as root. Thus we want to know that that is a
273      # multiple of poly. Compute x^(field_size) mod poly, which needs to equal x if that is the case
274      # (unless poly has degree <= 1, but that case is handled above).
275      if poly_frobeniusmod(poly, gf) != [0, 1]:
276          return []
277  
278      def rec_split(poly, randv):
279          """Recursively split poly using the Berlekamp trace algorithm."""
280          # See https://hal.archives-ouvertes.fr/hal-00626997/document.
281          assert len(poly) > 1 and poly[-1] == 1 # Require a monic poly.
282          # If poly is of the form x+a, its root is a.
283          if len(poly) == 2:
284              return [poly[0]]
285          # Try consecutive randomization factors randv, until one is found that factors poly.
286          while True:
287              # Compute the trace of (randv*x) mod poly. This is a polynomial that maps half of the
288              # domain to 0, and the other half to 1. Which half that is is controlled by randv.
289              # By taking it modulo poly, we only add a multiple of poly. Thus the result has at least
290              # the shared roots of the trace polynomial and poly still, but may have others.
291              trace = poly_tracemod(poly, randv, gf)
292              # Using the set {2^i*a for i=0..fieldsize-1} gives optimally independent randv values
293              # (no more than fieldsize are ever needed).
294              randv = gf.mul2(randv)
295              # Now take the GCD of this trace polynomial with poly. The result is a polynomial
296              # that only has the shared roots of the trace polynomial and poly as roots.
297              gcd = poly_gcd(trace, poly, gf)
298              # If the result has a degree higher than 1, and lower than that of poly, we found a
299              # useful factorization.
300              if len(gcd) != len(poly) and len(gcd) > 1:
301                  break
302              # Otherwise, continue with another randv.
303          # Find the actual factors: the monic version of the GCD above, and poly divided by it.
304          factor1 = poly_monic(gcd, gf)
305          factor2, _ = poly_divmod(poly, gcd, gf)
306          # Recurse.
307          return rec_split(factor1, randv) + rec_split(factor2, randv)
308  
309      # Invoke the recursive splitting with a random initial factor, and sort the results.
310      return sorted(rec_split(poly, random.randrange(1, 1 << gf.field_size)))
311  
312  class TestPolyFindRoots(unittest.TestCase):
313      """Test class for poly_find_roots."""
314  
315      def field_size_test(self, field_size):
316          """Run tests for given field_size."""
317          gf = GF2Ops(field_size)
318          for test_size in [0, 1, 2, 3, 10]:
319              roots = [random.randrange(1 << field_size) for _ in range(test_size)]
320              roots_set = set(roots)
321              # Construct a polynomial with all elements of roots as roots (with multiplicity).
322              poly = [1]
323              for root in roots:
324                  new_poly = [0] + poly
325                  for n, c in enumerate(poly):
326                      new_poly[n] ^= gf.mul(c, root)
327                  poly = new_poly
328              # Invoke the root finding algorithm.
329              found_roots = poly_find_roots(poly, gf)
330              # The result must match the input, unless any roots were repeated.
331              if len(roots) == len(roots_set):
332                  self.assertEqual(found_roots, sorted(roots))
333              else:
334                  self.assertEqual(found_roots, [])
335  
336      def test(self):
337          """Run tests."""
338          for field_size in range(2, 65):
339              self.field_size_test(field_size)
340  
341  def berlekamp_massey(syndromes, gf):
342      """Implement the Berlekamp-Massey algorithm.
343  
344      Takes as input a sequence of GF(2^field_size) elements, and returns the shortest LSFR
345      that generates it, represented as a polynomial.
346      """
347      # See https://en.wikipedia.org/wiki/Berlekamp%E2%80%93Massey_algorithm.
348      current = [1]
349      prev = [1]
350      b_inv = 1
351      for n, discrepancy in enumerate(syndromes):
352          # Compute discrepancy
353          for i in range(1, len(current)):
354              discrepancy ^= gf.mul(syndromes[n - i], current[i])
355  
356          # Correct if discrepancy is nonzero.
357          if discrepancy:
358              x = n + 1 - (len(current) - 1) - (len(prev) - 1)
359              if 2 * (len(current) - 1) <= n:
360                  tmp = list(current)
361                  current.extend(0 for _ in range(len(prev) + x - len(current)))
362                  mul = gf.mul(discrepancy, b_inv)
363                  for i, v in enumerate(prev):
364                      current[i + x] ^= gf.mul(mul, v)
365                  prev = tmp
366                  b_inv = gf.inv(discrepancy)
367              else:
368                  mul = gf.mul(discrepancy, b_inv)
369                  for i, v in enumerate(prev):
370                      current[i + x] ^= gf.mul(mul, v)
371      return current
372  
373  class Minisketch:
374      """A Minisketch sketch.
375  
376      This represents a sketch of a certain capacity, with elements of a certain bit size.
377      """
378  
379      def __init__(self, field_size, capacity):
380          """Initialize an empty sketch with the specified field_size size and capacity."""
381          self.field_size = field_size
382          self.capacity = capacity
383          self.odd_syndromes = [0] * capacity
384          self.gf = GF2Ops(field_size)
385  
386      def add(self, element):
387          """Add an element to this sketch. 1 <= element < 2**field_size."""
388          sqr = self.gf.sqr(element)
389          for pos in range(self.capacity):
390              self.odd_syndromes[pos] ^= element
391              element = self.gf.mul(sqr, element)
392  
393      def serialized_size(self):
394          """Compute how many bytes a serialization of this sketch will be in size."""
395          return (self.capacity * self.field_size + 7) // 8
396  
397      def serialize(self):
398          """Serialize this sketch to bytes."""
399          val = 0
400          for i in range(self.capacity):
401              val |= self.odd_syndromes[i] << (self.field_size * i)
402          return val.to_bytes(self.serialized_size(), 'little')
403  
404      def deserialize(self, byte_data):
405          """Deserialize a byte array into this sketch, overwriting its contents."""
406          assert len(byte_data) == self.serialized_size()
407          val = int.from_bytes(byte_data, 'little')
408          for i in range(self.capacity):
409              self.odd_syndromes[i] = (val >> (self.field_size * i)) & ((1 << self.field_size) - 1)
410  
411      def clone(self):
412          """Return a clone of this sketch."""
413          ret = Minisketch(self.field_size, self.capacity)
414          ret.odd_syndromes = list(self.odd_syndromes)
415          ret.gf = self.gf
416          return ret
417  
418      def merge(self, other):
419          """Merge a sketch with another sketch. Corresponds to XOR'ing their serializations."""
420          assert self.capacity == other.capacity
421          assert self.field_size == other.field_size
422          for i in range(self.capacity):
423              self.odd_syndromes[i] ^= other.odd_syndromes[i]
424  
425      def decode(self, max_count=None):
426          """Decode the contents of this sketch.
427  
428          Returns either a list of elements or None if undecodable.
429          """
430          # We know the odd syndromes s1=x+y+..., s3=x^3+y^3+..., s5=..., and reconstruct the even
431          # syndromes from this:
432          #  * s2 = x^2+y^2+.... = (x+y+...)^2 = s1^2
433          #  * s4 = x^4+y^4+.... = (x^2+y^2+...)^2 = s2^2
434          #  * s6 = x^6+y^6+.... = (x^3+y^3+...)^2 = s3^2
435          all_syndromes = [0 for _ in range(2 * len(self.odd_syndromes))]
436          for i in range(len(self.odd_syndromes)):
437              all_syndromes[i * 2] = self.odd_syndromes[i]
438              all_syndromes[i * 2 + 1] = self.gf.sqr(all_syndromes[i])
439          # Given the syndromes, find the polynomial that generates them.
440          poly = berlekamp_massey(all_syndromes, self.gf)
441          # Deal with failure and trivial cases.
442          if len(poly) == 0:
443              return None
444          if len(poly) == 1:
445              return []
446          if max_count is not None and len(poly) > 1 + max_count:
447              return None
448          # If the polynomial can be factored into (1-m1*x)*(1-m2*x)*...*(1-mn*x), then {m1,m2,...,mn}
449          # is our set. As each factor (1-m*x) has 1/m as root, we're really just looking for the
450          # inverses of the roots. We find these by reversing the order of the coefficients, and
451          # finding the roots.
452          roots = poly_find_roots(list(reversed(poly)), self.gf)
453          if len(roots) == 0:
454              return None
455          return roots
456  
457  class TestMinisketch(unittest.TestCase):
458      """Test class for Minisketch."""
459  
460      @classmethod
461      def construct_data(cls, field_size, num_a_only, num_b_only, num_both):
462          """Construct two random lists of elements in [1..2**field_size-1].
463  
464          Each list will have unique elements that don't appear in the other (num_a_only in the first
465          and num_b_only in the second), and num_both elements will appear in both."""
466          sample = []
467          # Simulate random.sample here (which doesn't work with ranges over 2**63).
468          for _ in range(num_a_only + num_b_only + num_both):
469              while True:
470                  r = random.randrange(1, 1 << field_size)
471                  if r not in sample:
472                      sample.append(r)
473                      break
474          full_a = sample[:num_a_only + num_both]
475          full_b = sample[num_a_only:]
476          random.shuffle(full_a)
477          random.shuffle(full_b)
478          return full_a, full_b
479  
480      def field_size_capacity_test(self, field_size, capacity):
481          """Test Minisketch methods for a specific field and capacity."""
482          used_capacity = random.randrange(capacity + 1)
483          num_a = random.randrange(used_capacity + 1)
484          num_both = random.randrange(min(2 * capacity, (1 << field_size) - 1 - used_capacity) + 1)
485          full_a, full_b = self.construct_data(field_size, num_a, used_capacity - num_a, num_both)
486          sketch_a = Minisketch(field_size, capacity)
487          sketch_b = Minisketch(field_size, capacity)
488          for v in full_a:
489              sketch_a.add(v)
490          for v in full_b:
491              sketch_b.add(v)
492          sketch_combined = sketch_a.clone()
493          sketch_b_ser = sketch_b.serialize()
494          sketch_b_received = Minisketch(field_size, capacity)
495          sketch_b_received.deserialize(sketch_b_ser)
496          sketch_combined.merge(sketch_b_received)
497          decode = sketch_combined.decode()
498          self.assertEqual(decode, sorted(set(full_a) ^ set(full_b)))
499  
500      def test(self):
501          """Run tests."""
502          for field_size in range(2, 65):
503              for capacity in [0, 1, 2, 5, 10, field_size]:
504                  self.field_size_capacity_test(field_size, min(capacity, (1 << field_size) - 1))
505  
506  if __name__ == '__main__':
507      unittest.main()