/ erasure_code / share.py
share.py
  1  import copy
  2  
  3  
  4  # Galois field class and logtable
  5  #
  6  # See: https://en.wikipedia.org/wiki/Finite_field
  7  #
  8  # Note that you can substitute "Galois" with "float" in the code, and
  9  # the code will then magically start using the plain old field of rationals
 10  # instead of this spooky modulo polynomial thing. If you are not an expert in
 11  # finite field theory and want to dig deep into how this code works, I
 12  # recommend adding the line "Galois = float" immediately after this class (and
 13  # not using the methods that require serialization)
 14  #
 15  # As a quick intro to finite field theory, the idea is that there exist these
 16  # things called fields, which are basically sets of objects together with
 17  # rules for addition, subtraction, multiplication, division, such that algebra
 18  # within this field is consistent, even if the results look nonsensical from
 19  # a "normal numbers" perspective. For instance, consider the field of integers
 20  # modulo 7. Here, for example, 2 * 5 = 3, 3 * 4 = 5, 6 * 6 = 1, 6 + 6 = 5.
 21  # However, all algebra still works; for example, (a^2 - b^2) = (a + b)(a - b)
 22  # works for all a,b. For this reason, we can do secret sharing arithmetic
 23  # "over" any field. The reason why Galois fields are preferable is that all
 24  # elements in the Galois field are values in [0 ... 255] (at least using the
 25  # canonical serialization that we use here); no amount of addition,
 26  # multiplication, subtraction or division will ever get you anything else.
 27  # This guarantees that our secret shares will always be serializable as byte
 28  # arrays. The way the Galois field we use here works is that the elements are
 29  # polynomials of elements in the field of integers mod 2, so addition and
 30  # subtraction are xor, and multiplication is modulo x^8 + x^4 + x^3 + x + 1,
 31  # and division is defined by a/b = c iff bc = a and b != 0. In practice, we
 32  # do multiplication and division via a precomputed log table using x+1 as a
 33  # base
 34  
 35  # per-byte 2^8 Galois field
 36  # Note that this imposes a hard limit that the number of extended chunks can
 37  # be at most 256 along each dimension
 38  
 39  
 40  def galoistpl(a):
 41      # 2 is not a primitive root, so we have to use 3 as our logarithm base
 42      unrolla = [a/(2**k) % 2 for k in range(8)]
 43      res = [0] + unrolla
 44      for i in range(8):
 45          res[i] = (res[i] + unrolla[i]) % 2
 46      if res[-1] == 0:
 47          res.pop()
 48      else:
 49          # AES Polynomial
 50          for i in range(9):
 51              res[i] = (res[i] - [1, 1, 0, 1, 1, 0, 0, 0, 1][i]) % 2
 52          res.pop()
 53      return sum([res[k] * 2**k for k in range(8)])
 54  
 55  # Precomputing a multiplication and XOR table for increased speed
 56  glogtable = [0] * 256
 57  gexptable = []
 58  v = 1
 59  for i in range(255):
 60      glogtable[v] = i
 61      gexptable.append(v)
 62      v = galoistpl(v)
 63  
 64  
 65  class Galois:
 66      val = 0
 67  
 68      def __init__(self, val):
 69          self.val = val.val if isinstance(self.val, Galois) else val
 70  
 71      def __add__(self, other):
 72          return Galois(self.val ^ other.val)
 73  
 74      def __mul__(self, other):
 75          if self.val == 0 or other.val == 0:
 76              return Galois(0)
 77          return Galois(gexptable[(glogtable[self.val] +
 78                                   glogtable[other.val]) % 255])
 79  
 80      def __sub__(self, other):
 81          return Galois(self.val ^ other.val)
 82  
 83      def __div__(self, other):
 84          if other.val == 0:
 85              raise ZeroDivisionError
 86          if self.val == 0:
 87              return Galois(0)
 88          return Galois(gexptable[(glogtable[self.val] -
 89                                   glogtable[other.val]) % 255])
 90  
 91      def __int__(self):
 92          return self.val
 93  
 94      def __repr__(self):
 95          return repr(self.val)
 96  
 97  
 98  # Modular division class
 99  
100  def mkModuloClass(n):
101  
102      if pow(2, n, n) != 2:
103          raise Exception("n must be prime!")
104  
105      class Mod:
106          val = 0
107  
108          def __init__(self, val):
109              self.val = val.val if isinstance(
110                  self.val, self.__class__) else val
111  
112          def __add__(self, other):
113              return self.__class__((self.val + other.val) % n)
114  
115          def __mul__(self, other):
116              return self.__class__((self.val * other.val) % n)
117  
118          def __sub__(self, other):
119              return self.__class__((self.val - other.val) % n)
120  
121          def __div__(self, other):
122              return self.__class__((self.val * other.val ** (n-2)) % n)
123  
124          def __int__(self):
125              return self.val
126  
127          def __repr__(self):
128              return repr(self.val)
129      return Mod
130  
131  # Evaluates a polynomial in little-endian form, eg. x^2 + 3x + 2 = [2, 3, 1]
132  # (normally I hate little-endian, but in this case dealing with polynomials
133  # it's justified, since you get the nice property that p[n] is the nth degree
134  # term of p) at coordinate x, eg. eval_poly_at([2, 3, 1], 5) = 42 if you are
135  # using float as your arithmetic
136  
137  
138  def eval_poly_at(p, x):
139      arithmetic = p[0].__class__
140      y = arithmetic(0)
141      x_to_the_i = arithmetic(1)
142      for i in range(len(p)):
143          y += x_to_the_i * p[i]
144          x_to_the_i *= x
145      return y
146  
147  
148  # Given p+1 y values and x values with no errors, recovers the original
149  # p+1 degree polynomial. For example,
150  # lagrange_interp([51.0, 59.0, 66.0], [1, 3, 4]) = [50.0, 0, 1.0]
151  # if you are using float as your arithmetic
152  
153  
154  def lagrange_interp(pieces, xs):
155      arithmetic = pieces[0].__class__
156      zero, one = arithmetic(0), arithmetic(1)
157      # Generate master numerator polynomial
158      root = [one]
159      for i in range(len(xs)):
160          root.insert(0, zero)
161          for j in range(len(root)-1):
162              root[j] = root[j] - root[j+1] * xs[i]
163      # Generate per-value numerator polynomials by dividing the master
164      # polynomial back by each x coordinate
165      nums = []
166      for i in range(len(xs)):
167          output = []
168          last = one
169          for j in range(2, len(root)+1):
170              output.insert(0, last)
171              if j != len(root):
172                  last = root[-j] + last * xs[i]
173          nums.append(output)
174      # Generate denominators by evaluating numerator polys at their x
175      denoms = []
176      for i in range(len(xs)):
177          denom = zero
178          x_to_the_j = one
179          for j in range(len(nums[i])):
180              denom += x_to_the_j * nums[i][j]
181              x_to_the_j *= xs[i]
182          denoms.append(denom)
183      # Generate output polynomial
184      b = [zero for i in range(len(pieces))]
185      for i in range(len(xs)):
186          yslice = pieces[int(i)] / denoms[int(i)]
187          for j in range(len(pieces)):
188              b[j] += nums[i][j] * yslice
189      return b
190  
191  
192  # Compresses two linear equations of length n into one
193  # equation of length n-1
194  # Format:
195  # 3x + 4y = 80 (ie. 3x + 4y - 80 = 0) -> a = [3,4,-80]
196  # 5x + 2y = 70 (ie. 5x + 2y - 70 = 0) -> b = [5,2,-70]
197  
198  
199  def elim(a, b):
200      aprime = [x*b[0] for x in a]
201      bprime = [x*a[0] for x in b]
202      c = [aprime[i] - bprime[i] for i in range(1, len(a))]
203      return c
204  
205  
206  # Linear equation solver
207  # Format:
208  # 3x + 4y = 80, y = 5 (ie. 3x + 4y - 80z = 0, y = 5, z = 1)
209  #      -> coeffs = [3,4,-80], vals = [5,1]
210  
211  
212  def evaluate(coeffs, vals):
213      arithmetic = coeffs[0].__class__
214      tot = arithmetic(0)
215      for i in range(len(vals)):
216          tot -= coeffs[i+1] * vals[i]
217      if int(coeffs[0]) == 0:
218          raise ZeroDivisionError
219      return tot / coeffs[0]
220  
221  
222  # Linear equation system solver
223  # Format:
224  # ax + by + c = 0, dx + ey + f = 0
225  # -> [[a, b, c], [d, e, f]]
226  # eg.
227  # [[3.0, 5.0, -13.0], [9.0, 1.0, -11.0]] -> [1.0, 2.0]
228  
229  
230  def sys_solve(eqs):
231      arithmetic = eqs[0][0].__class__
232      one = arithmetic(1)
233      back_eqs = [eqs[0]]
234      while len(eqs) > 1:
235          neweqs = []
236          for i in range(len(eqs)-1):
237              neweqs.append(elim(eqs[i], eqs[i+1]))
238          eqs = neweqs
239          i = 0
240          while i < len(eqs) - 1 and int(eqs[i][0]) == 0:
241              i += 1
242          back_eqs.insert(0, eqs[i])
243      kvals = [one]
244      for i in range(len(back_eqs)):
245          kvals.insert(0, evaluate(back_eqs[i], kvals))
246      return kvals[:-1]
247  
248  
249  def polydiv(Q, E):
250      qpoly = copy.deepcopy(Q)
251      epoly = copy.deepcopy(E)
252      div = []
253      while len(qpoly) >= len(epoly):
254          div.insert(0, qpoly[-1] / epoly[-1])
255          for i in range(2, len(epoly)+1):
256              qpoly[-i] -= div[0] * epoly[-i]
257          qpoly.pop()
258      return div
259  
260  
261  # Given a set of y coordinates and x coordinates, and the degree of the
262  # original polynomial, determines the original polynomial even if some of
263  # the y coordinates are wrong. If m is the minimal number of pieces (ie.
264  # degree + 1), t is the total number of pieces provided, then the algo can
265  # handle up to (t-m)/2 errors. See:
266  # http://en.wikipedia.org/wiki/Berlekamp%E2%80%93Welch_algorithm#Example
267  # (just skip to my example, the rest of the article sucks imo)
268  
269  
270  def berlekamp_welch_attempt(pieces, xs, master_degree):
271      error_locator_degree = (len(pieces) - master_degree - 1) / 2
272      arithmetic = pieces[0].__class__
273      zero, one = arithmetic(0), arithmetic(1)
274      # Set up the equations for y[i]E(x[i]) = Q(x[i])
275      # degree(E) = error_locator_degree
276      # degree(Q) = master_degree + error_locator_degree - 1
277      eqs = []
278      for i in range(2 * error_locator_degree + master_degree + 1):
279          eqs.append([])
280      for i in range(2 * error_locator_degree + master_degree + 1):
281          neg_x_to_the_j = zero - one
282          for j in range(error_locator_degree + master_degree + 1):
283              eqs[i].append(neg_x_to_the_j)
284              neg_x_to_the_j *= xs[i]
285          x_to_the_j = one
286          for j in range(error_locator_degree + 1):
287              eqs[i].append(x_to_the_j * pieces[i])
288              x_to_the_j *= xs[i]
289      # Solve 'em
290      # Assume the top error polynomial term to be one
291      errors = error_locator_degree
292      ones = 1
293      while errors >= 0:
294          try:
295              polys = sys_solve(eqs) + [one] * ones
296              qpoly = polys[:errors + master_degree + 1]
297              epoly = polys[errors + master_degree + 1:]
298              break
299          except ZeroDivisionError:
300              for eq in eqs:
301                  eq[-2] += eq[-1]
302                  eq.pop()
303              eqs.pop()
304              errors -= 1
305              ones += 1
306      if errors < 0:
307          raise Exception("Not enough data!")
308      # Divide the polynomials
309      qpoly = polys[:error_locator_degree + master_degree + 1]
310      epoly = polys[error_locator_degree + master_degree + 1:]
311      div = []
312      while len(qpoly) >= len(epoly):
313          div.insert(0, qpoly[-1] / epoly[-1])
314          for i in range(2, len(epoly)+1):
315              qpoly[-i] -= div[0] * epoly[-i]
316          qpoly.pop()
317      # Check
318      corrects = 0
319      for i, x in enumerate(xs):
320          if int(eval_poly_at(div, x)) == int(pieces[i]):
321              corrects += 1
322      if corrects < master_degree + errors:
323          raise Exception("Answer doesn't match (too many errors)!")
324      return div
325  
326  
327  # Extends a list of integers in [0 ... 255] (if using Galois arithmetic) by
328  # adding n redundant error-correction values
329  
330  
331  def extend(data, n, arithmetic=Galois):
332      data2 = map(arithmetic, data)
333      data3 = data[:]
334      poly = berlekamp_welch_attempt(data2,
335                                     map(arithmetic, range(len(data))),
336                                     len(data) - 1)
337      for i in range(n):
338          data3.append(int(eval_poly_at(poly, arithmetic(len(data) + i))))
339      return data3
340  
341  
342  # Repairs a list of integers in [0 ... 255]. Some integers can be erroneous,
343  # and you can put None in place of an integer if you know that a certain
344  # value is defective or missing. Uses the Berlekamp-Welch algorithm to
345  # do error-correction
346  
347  
348  def repair(data, datasize, arithmetic=Galois):
349      vs, xs = [], []
350      for i, v in enumerate(data):
351          if v is not None:
352              vs.append(arithmetic(v))
353              xs.append(arithmetic(i))
354      poly = berlekamp_welch_attempt(vs, xs, datasize - 1)
355      return [int(eval_poly_at(poly, arithmetic(i))) for i in range(len(data))]
356  
357  
358  # Extends a list of bytearrays
359  # eg. extend_chunks([map(ord, 'hello'), map(ord, 'world')], 2)
360  # n is the number of redundant error-correction chunks to add
361  
362  
363  def extend_chunks(data, n, arithmetic=Galois):
364      o = []
365      for i in range(len(data[0])):
366          o.append(extend(map(lambda x: x[i], data), n, arithmetic))
367      return map(list, zip(*o))
368  
369  
370  # Repairs a list of bytearrays. Use None in place of a missing array.
371  # Individual arrays can contain some missing or erroneous data.
372  
373  
374  def repair_chunks(data, datasize, arithmetic=Galois):
375      first_nonzero = 0
376      while not data[first_nonzero]:
377          first_nonzero += 1
378      for i in range(len(data)):
379          if data[i] is None:
380              data[i] = [None] * len(data[first_nonzero])
381      o = []
382      for i in range(len(data[0])):
383          o.append(repair(map(lambda x: x[i], data), datasize, arithmetic))
384      return map(list, zip(*o))
385  
386  
387  # Extends either a bytearray or a list of bytearrays or a list of lists...
388  # Used in the cubify method to expand a cube in all dimensions
389  
390  
391  def deep_extend_chunks(data, n, arithmetic=Galois):
392      if not isinstance(data[0], list):
393          return extend(data, n, arithmetic)
394      else:
395          o = []
396          for i in range(len(data[0])):
397              o.append(
398                  deep_extend_chunks(map(lambda x: x[i], data), n, arithmetic))
399          return map(list, zip(*o))
400  
401  
402  # ISO/IEC 7816-4 padding
403  
404  
405  def pad(data, size):
406      data = data[:]
407      data.append(128)
408      while len(data) % size != 0:
409          data.append(0)
410      return data
411  
412  
413  # Removes ISO/IEC 7816-4 padding
414  
415  
416  def unpad(data):
417      data = data[:]
418      while data[-1] != 128:
419          data.pop()
420      data.pop()
421      return data
422  
423  
424  # Splits a bytearray into a given number of chunks with some
425  # redundant chunks
426  
427  
428  def split(data, numchunks, redund):
429      chunksize = len(data) / numchunks + 1
430      data = pad(data, chunksize)
431      chunks = []
432      for i in range(0, len(data), chunksize):
433          chunks.append(data[i: i+chunksize])
434      o = extend_chunks(chunks, redund)
435      return o
436  
437  
438  # Recombines chunks into the original bytearray
439  
440  
441  def recombine(chunks, datalength):
442      datasize = datalength / len(chunks[0]) + 1
443      c = repair_chunks(chunks, datasize)
444      return unpad(sum(c[:datasize], []))
445  
446  
447  h = '0123456789abcdef'
448  hexfy = lambda x: h[x//16]+h[x % 16]
449  unhexfy = lambda x: h.find(x[0]) * 16 + h.find(x[1])
450  split2 = lambda x: map(lambda a: ''.join(a), zip(x[::2], x[1::2]))
451  
452  
453  # Canonical serialization. First argument is a bytearray, remaining
454  # arguments are strings to prepend
455  
456  
457  def serialize_chunk(*args):
458      chunk = args[0]
459      if not chunk or chunk[0] is None:
460          return None
461      metadata = args[1:]
462      return '-'.join(map(str, metadata) + [''.join(map(hexfy, chunk))])
463  
464  
465  def deserialize_chunk(chunk):
466      data = chunk.split('-')
467      metadata, main = data[:-1], data[-1]
468      return metadata, map(unhexfy, split2(main))
469  
470  
471  # Splits a string into a given number of chunks with some redundant chunks
472  
473  
474  def split_file(f, numchunks=5, redund=5):
475      f = map(ord, f)
476      ec = split(f, numchunks, redund)
477      o = []
478      for i, c in enumerate(ec):
479          o.append(
480              serialize_chunk(c, *[i, numchunks, numchunks + redund, len(f)]))
481      return o
482  
483  
484  def recombine_file(chunks):
485      chunks2 = map(deserialize_chunk, chunks)
486      metadata = map(int, chunks2[0][0])
487      o = [None] * metadata[2]
488      for chunk in chunks2:
489          o[int(chunk[0][0])] = chunk[1]
490      return ''.join(map(chr, recombine(o, metadata[3])))
491  
492  outersplitn = lambda x, k: map(lambda i: x[i:i+k], range(len(x)))
493  
494  
495  # Splits a bytearray into a hypercube with `dim` dimensions with the original
496  # data being in a sub-cube of width `width` and the expanded cube being of
497  # width `width+redund`. The cube is self-healing; if any edge in any dimension
498  # has missing or erroneous pieces, we can use the Berlekamp-Welch algorithm
499  # to fix this
500  
501  
502  def cubify(f, width, dim, redund):
503      chunksize = len(f) / width**dim + 1
504      data = pad(f, width**dim)
505      chunks = []
506      for i in range(0, len(data), chunksize * width):
507          for j in range(width):
508              chunks.append(data[i+j*chunksize: i+j*chunksize+chunksize])
509  
510      for i in range(dim):
511          o = []
512          for j in range(0, len(chunks), width):
513              e = chunks[j: j + width]
514              o.append(
515                  deep_extend_chunks(e, redund))
516          chunks = o
517  
518      return chunks[0]
519  
520  
521  # `pos` is an array of coordinates. Go deep into a nested list
522  
523  
524  def descend(obj, pos):
525      for p in pos:
526          obj = obj[p]
527      return obj
528  
529  
530  # Go deep into a nested list and modify the value
531  
532  
533  def descend_and_set(obj, pos, val):
534      immed = descend(obj, pos[:-1])
535      immed[pos[-1]] = val
536  
537  
538  # Use the Berlekamp-Welch algorithm to try to "heal" a particular missing
539  # or damaged coordinate
540  
541  
542  def heal_cube(cube, width, dim, pos, datalen):
543      for d in range(len(pos)):
544          o = []
545          for i in range(len(cube)):
546              o.append(descend(cube, pos[:d] + [i] + pos[d+1:]))
547          try:
548              o = repair_chunks(o, width)
549              for i in range(len(cube)):
550                  path = pos[:d] + [i] + pos[d+1:]
551                  descend_and_set(cube, path, o[i])
552          except:
553              pass
554  
555  
556  def pack_metadata(meta):
557      return map(str, meta['coords']) + [
558          str(meta['base_width']),
559          str(meta['extended_width']),
560          str(meta['filesize'])
561      ]
562  
563  
564  def unpack_metadata(meta):
565      return {
566          'coords': map(int, meta[:-3]),
567          'base_width': int(meta[-3]),
568          'extended_width': int(meta[-2]),
569          'filesize': int(meta[-1])
570      }
571  
572  
573  # Helper to serialize the contents of a cube of byte arrays
574  
575  
576  def _ser(chunk, meta):
577      if chunk is None or (not isinstance(chunk[0], list) and
578                           chunk[0] is not None):
579          u = serialize_chunk(chunk, *pack_metadata(meta))
580          return u
581      else:
582          o = []
583          for i, c in enumerate(chunk):
584              meta2 = copy.deepcopy(meta)
585              meta2['coords'] += [i]
586              o.append(_ser(c, meta2))
587          return o
588  
589  
590  # Converts a deep list into a shallow list
591  
592  
593  def flatten(chunks):
594      if not isinstance(chunks, list):
595          return [chunks]
596      else:
597          o = []
598          for c in chunks:
599              o.extend(flatten(c))
600          return o
601  
602  
603  # Converts a file into a multidimensional set of chunks with
604  # the desired parameters
605  
606  
607  def serialize_cubify(f, width, dim, redund):
608      f = map(ord, f)
609      cube = cubify(f, width, dim, redund)
610      metadata = {
611          'base_width': width,
612          'extended_width': width + redund,
613          'coords': [],
614          'filesize': len(f)
615      }
616      cube_of_serialized_chunks = _ser(cube, metadata)
617      return flatten(cube_of_serialized_chunks)
618  
619  
620  # Converts a set of serialized chunks into a partially filled cube
621  
622  
623  def construct_cube(pieces):
624      pieces = map(deserialize_chunk, pieces)
625      metadata = unpack_metadata(pieces[0][0])
626      dim = len(metadata['coords'])
627      cube = None
628      for i in range(dim):
629          cube = [copy.deepcopy(cube) for i in range(metadata['extended_width'])]
630      for p in pieces:
631          descend_and_set(cube, unpack_metadata(p[0])['coords'], p[1])
632      return cube
633  
634  
635  # Tries to recreate the chunk at a particular coordinate given a set of
636  # other chunks
637  
638  
639  def heal_set(pieces, coords):
640      c = construct_cube(pieces)
641      metadata, piecezzz = deserialize_chunk(pieces[0])
642      metadata = unpack_metadata(metadata)
643      heal_cube(c,
644                metadata['base_width'],
645                len(metadata['coords']),
646                coords,
647                metadata['filesize'])
648      metadata2 = copy.deepcopy(metadata)
649      metadata2["coords"] = []
650      return filter(lambda x: x, flatten(_ser(c, metadata2)))
651  
652  
653  def number_to_coords(n, w, dim):
654      c = [0] * dim
655      for i in range(dim):
656          c[i] = n / w**(dim - i - 1)
657          n %= w**(dim - i - 1)
658      return c
659  
660  
661  def full_heal_set(pieces):
662      c = construct_cube(pieces)
663      metadata, piecezzz = deserialize_chunk(pieces[0])
664      metadata = unpack_metadata(metadata)
665      while 1:
666          done = True
667          unfilled = False
668          i = 0
669          while i < metadata['extended_width'] ** len(metadata['coords']):
670              coords = number_to_coords(i,
671                                        metadata['extended_width'],
672                                        len(metadata['coords']))
673              v = descend(c, coords)
674              heal_cube(c,
675                        metadata['base_width'],
676                        len(metadata['coords']),
677                        coords,
678                        metadata['filesize'])
679              v2 = descend(c, coords)
680              if v != v2:
681                  done = False
682              if v is None and v2 is None:
683                  unfilled = True
684              i += 1
685          if done and not unfilled:
686              break
687          elif done and unfilled:
688              raise Exception("not enough data or too much corrupted data")
689      o = []
690      for i in range(metadata['base_width'] ** len(metadata['coords'])):
691          coords = number_to_coords(i,
692                                    metadata['base_width'],
693                                    len(metadata['coords']))
694          o.extend(descend(c, coords))
695      return ''.join(map(chr, unpad(o)))