/ erasure_code / share.py
share.py
1 import copy 2 3 4 # Galois field class and logtable 5 # 6 # See: https://en.wikipedia.org/wiki/Finite_field 7 # 8 # Note that you can substitute "Galois" with "float" in the code, and 9 # the code will then magically start using the plain old field of rationals 10 # instead of this spooky modulo polynomial thing. If you are not an expert in 11 # finite field theory and want to dig deep into how this code works, I 12 # recommend adding the line "Galois = float" immediately after this class (and 13 # not using the methods that require serialization) 14 # 15 # As a quick intro to finite field theory, the idea is that there exist these 16 # things called fields, which are basically sets of objects together with 17 # rules for addition, subtraction, multiplication, division, such that algebra 18 # within this field is consistent, even if the results look nonsensical from 19 # a "normal numbers" perspective. For instance, consider the field of integers 20 # modulo 7. Here, for example, 2 * 5 = 3, 3 * 4 = 5, 6 * 6 = 1, 6 + 6 = 5. 21 # However, all algebra still works; for example, (a^2 - b^2) = (a + b)(a - b) 22 # works for all a,b. For this reason, we can do secret sharing arithmetic 23 # "over" any field. The reason why Galois fields are preferable is that all 24 # elements in the Galois field are values in [0 ... 255] (at least using the 25 # canonical serialization that we use here); no amount of addition, 26 # multiplication, subtraction or division will ever get you anything else. 27 # This guarantees that our secret shares will always be serializable as byte 28 # arrays. The way the Galois field we use here works is that the elements are 29 # polynomials of elements in the field of integers mod 2, so addition and 30 # subtraction are xor, and multiplication is modulo x^8 + x^4 + x^3 + x + 1, 31 # and division is defined by a/b = c iff bc = a and b != 0. In practice, we 32 # do multiplication and division via a precomputed log table using x+1 as a 33 # base 34 35 # per-byte 2^8 Galois field 36 # Note that this imposes a hard limit that the number of extended chunks can 37 # be at most 256 along each dimension 38 39 40 def galoistpl(a): 41 # 2 is not a primitive root, so we have to use 3 as our logarithm base 42 unrolla = [a/(2**k) % 2 for k in range(8)] 43 res = [0] + unrolla 44 for i in range(8): 45 res[i] = (res[i] + unrolla[i]) % 2 46 if res[-1] == 0: 47 res.pop() 48 else: 49 # AES Polynomial 50 for i in range(9): 51 res[i] = (res[i] - [1, 1, 0, 1, 1, 0, 0, 0, 1][i]) % 2 52 res.pop() 53 return sum([res[k] * 2**k for k in range(8)]) 54 55 # Precomputing a multiplication and XOR table for increased speed 56 glogtable = [0] * 256 57 gexptable = [] 58 v = 1 59 for i in range(255): 60 glogtable[v] = i 61 gexptable.append(v) 62 v = galoistpl(v) 63 64 65 class Galois: 66 val = 0 67 68 def __init__(self, val): 69 self.val = val.val if isinstance(self.val, Galois) else val 70 71 def __add__(self, other): 72 return Galois(self.val ^ other.val) 73 74 def __mul__(self, other): 75 if self.val == 0 or other.val == 0: 76 return Galois(0) 77 return Galois(gexptable[(glogtable[self.val] + 78 glogtable[other.val]) % 255]) 79 80 def __sub__(self, other): 81 return Galois(self.val ^ other.val) 82 83 def __div__(self, other): 84 if other.val == 0: 85 raise ZeroDivisionError 86 if self.val == 0: 87 return Galois(0) 88 return Galois(gexptable[(glogtable[self.val] - 89 glogtable[other.val]) % 255]) 90 91 def __int__(self): 92 return self.val 93 94 def __repr__(self): 95 return repr(self.val) 96 97 98 # Modular division class 99 100 def mkModuloClass(n): 101 102 if pow(2, n, n) != 2: 103 raise Exception("n must be prime!") 104 105 class Mod: 106 val = 0 107 108 def __init__(self, val): 109 self.val = val.val if isinstance( 110 self.val, self.__class__) else val 111 112 def __add__(self, other): 113 return self.__class__((self.val + other.val) % n) 114 115 def __mul__(self, other): 116 return self.__class__((self.val * other.val) % n) 117 118 def __sub__(self, other): 119 return self.__class__((self.val - other.val) % n) 120 121 def __div__(self, other): 122 return self.__class__((self.val * other.val ** (n-2)) % n) 123 124 def __int__(self): 125 return self.val 126 127 def __repr__(self): 128 return repr(self.val) 129 return Mod 130 131 # Evaluates a polynomial in little-endian form, eg. x^2 + 3x + 2 = [2, 3, 1] 132 # (normally I hate little-endian, but in this case dealing with polynomials 133 # it's justified, since you get the nice property that p[n] is the nth degree 134 # term of p) at coordinate x, eg. eval_poly_at([2, 3, 1], 5) = 42 if you are 135 # using float as your arithmetic 136 137 138 def eval_poly_at(p, x): 139 arithmetic = p[0].__class__ 140 y = arithmetic(0) 141 x_to_the_i = arithmetic(1) 142 for i in range(len(p)): 143 y += x_to_the_i * p[i] 144 x_to_the_i *= x 145 return y 146 147 148 # Given p+1 y values and x values with no errors, recovers the original 149 # p+1 degree polynomial. For example, 150 # lagrange_interp([51.0, 59.0, 66.0], [1, 3, 4]) = [50.0, 0, 1.0] 151 # if you are using float as your arithmetic 152 153 154 def lagrange_interp(pieces, xs): 155 arithmetic = pieces[0].__class__ 156 zero, one = arithmetic(0), arithmetic(1) 157 # Generate master numerator polynomial 158 root = [one] 159 for i in range(len(xs)): 160 root.insert(0, zero) 161 for j in range(len(root)-1): 162 root[j] = root[j] - root[j+1] * xs[i] 163 # Generate per-value numerator polynomials by dividing the master 164 # polynomial back by each x coordinate 165 nums = [] 166 for i in range(len(xs)): 167 output = [] 168 last = one 169 for j in range(2, len(root)+1): 170 output.insert(0, last) 171 if j != len(root): 172 last = root[-j] + last * xs[i] 173 nums.append(output) 174 # Generate denominators by evaluating numerator polys at their x 175 denoms = [] 176 for i in range(len(xs)): 177 denom = zero 178 x_to_the_j = one 179 for j in range(len(nums[i])): 180 denom += x_to_the_j * nums[i][j] 181 x_to_the_j *= xs[i] 182 denoms.append(denom) 183 # Generate output polynomial 184 b = [zero for i in range(len(pieces))] 185 for i in range(len(xs)): 186 yslice = pieces[int(i)] / denoms[int(i)] 187 for j in range(len(pieces)): 188 b[j] += nums[i][j] * yslice 189 return b 190 191 192 # Compresses two linear equations of length n into one 193 # equation of length n-1 194 # Format: 195 # 3x + 4y = 80 (ie. 3x + 4y - 80 = 0) -> a = [3,4,-80] 196 # 5x + 2y = 70 (ie. 5x + 2y - 70 = 0) -> b = [5,2,-70] 197 198 199 def elim(a, b): 200 aprime = [x*b[0] for x in a] 201 bprime = [x*a[0] for x in b] 202 c = [aprime[i] - bprime[i] for i in range(1, len(a))] 203 return c 204 205 206 # Linear equation solver 207 # Format: 208 # 3x + 4y = 80, y = 5 (ie. 3x + 4y - 80z = 0, y = 5, z = 1) 209 # -> coeffs = [3,4,-80], vals = [5,1] 210 211 212 def evaluate(coeffs, vals): 213 arithmetic = coeffs[0].__class__ 214 tot = arithmetic(0) 215 for i in range(len(vals)): 216 tot -= coeffs[i+1] * vals[i] 217 if int(coeffs[0]) == 0: 218 raise ZeroDivisionError 219 return tot / coeffs[0] 220 221 222 # Linear equation system solver 223 # Format: 224 # ax + by + c = 0, dx + ey + f = 0 225 # -> [[a, b, c], [d, e, f]] 226 # eg. 227 # [[3.0, 5.0, -13.0], [9.0, 1.0, -11.0]] -> [1.0, 2.0] 228 229 230 def sys_solve(eqs): 231 arithmetic = eqs[0][0].__class__ 232 one = arithmetic(1) 233 back_eqs = [eqs[0]] 234 while len(eqs) > 1: 235 neweqs = [] 236 for i in range(len(eqs)-1): 237 neweqs.append(elim(eqs[i], eqs[i+1])) 238 eqs = neweqs 239 i = 0 240 while i < len(eqs) - 1 and int(eqs[i][0]) == 0: 241 i += 1 242 back_eqs.insert(0, eqs[i]) 243 kvals = [one] 244 for i in range(len(back_eqs)): 245 kvals.insert(0, evaluate(back_eqs[i], kvals)) 246 return kvals[:-1] 247 248 249 def polydiv(Q, E): 250 qpoly = copy.deepcopy(Q) 251 epoly = copy.deepcopy(E) 252 div = [] 253 while len(qpoly) >= len(epoly): 254 div.insert(0, qpoly[-1] / epoly[-1]) 255 for i in range(2, len(epoly)+1): 256 qpoly[-i] -= div[0] * epoly[-i] 257 qpoly.pop() 258 return div 259 260 261 # Given a set of y coordinates and x coordinates, and the degree of the 262 # original polynomial, determines the original polynomial even if some of 263 # the y coordinates are wrong. If m is the minimal number of pieces (ie. 264 # degree + 1), t is the total number of pieces provided, then the algo can 265 # handle up to (t-m)/2 errors. See: 266 # http://en.wikipedia.org/wiki/Berlekamp%E2%80%93Welch_algorithm#Example 267 # (just skip to my example, the rest of the article sucks imo) 268 269 270 def berlekamp_welch_attempt(pieces, xs, master_degree): 271 error_locator_degree = (len(pieces) - master_degree - 1) / 2 272 arithmetic = pieces[0].__class__ 273 zero, one = arithmetic(0), arithmetic(1) 274 # Set up the equations for y[i]E(x[i]) = Q(x[i]) 275 # degree(E) = error_locator_degree 276 # degree(Q) = master_degree + error_locator_degree - 1 277 eqs = [] 278 for i in range(2 * error_locator_degree + master_degree + 1): 279 eqs.append([]) 280 for i in range(2 * error_locator_degree + master_degree + 1): 281 neg_x_to_the_j = zero - one 282 for j in range(error_locator_degree + master_degree + 1): 283 eqs[i].append(neg_x_to_the_j) 284 neg_x_to_the_j *= xs[i] 285 x_to_the_j = one 286 for j in range(error_locator_degree + 1): 287 eqs[i].append(x_to_the_j * pieces[i]) 288 x_to_the_j *= xs[i] 289 # Solve 'em 290 # Assume the top error polynomial term to be one 291 errors = error_locator_degree 292 ones = 1 293 while errors >= 0: 294 try: 295 polys = sys_solve(eqs) + [one] * ones 296 qpoly = polys[:errors + master_degree + 1] 297 epoly = polys[errors + master_degree + 1:] 298 break 299 except ZeroDivisionError: 300 for eq in eqs: 301 eq[-2] += eq[-1] 302 eq.pop() 303 eqs.pop() 304 errors -= 1 305 ones += 1 306 if errors < 0: 307 raise Exception("Not enough data!") 308 # Divide the polynomials 309 qpoly = polys[:error_locator_degree + master_degree + 1] 310 epoly = polys[error_locator_degree + master_degree + 1:] 311 div = [] 312 while len(qpoly) >= len(epoly): 313 div.insert(0, qpoly[-1] / epoly[-1]) 314 for i in range(2, len(epoly)+1): 315 qpoly[-i] -= div[0] * epoly[-i] 316 qpoly.pop() 317 # Check 318 corrects = 0 319 for i, x in enumerate(xs): 320 if int(eval_poly_at(div, x)) == int(pieces[i]): 321 corrects += 1 322 if corrects < master_degree + errors: 323 raise Exception("Answer doesn't match (too many errors)!") 324 return div 325 326 327 # Extends a list of integers in [0 ... 255] (if using Galois arithmetic) by 328 # adding n redundant error-correction values 329 330 331 def extend(data, n, arithmetic=Galois): 332 data2 = map(arithmetic, data) 333 data3 = data[:] 334 poly = berlekamp_welch_attempt(data2, 335 map(arithmetic, range(len(data))), 336 len(data) - 1) 337 for i in range(n): 338 data3.append(int(eval_poly_at(poly, arithmetic(len(data) + i)))) 339 return data3 340 341 342 # Repairs a list of integers in [0 ... 255]. Some integers can be erroneous, 343 # and you can put None in place of an integer if you know that a certain 344 # value is defective or missing. Uses the Berlekamp-Welch algorithm to 345 # do error-correction 346 347 348 def repair(data, datasize, arithmetic=Galois): 349 vs, xs = [], [] 350 for i, v in enumerate(data): 351 if v is not None: 352 vs.append(arithmetic(v)) 353 xs.append(arithmetic(i)) 354 poly = berlekamp_welch_attempt(vs, xs, datasize - 1) 355 return [int(eval_poly_at(poly, arithmetic(i))) for i in range(len(data))] 356 357 358 # Extends a list of bytearrays 359 # eg. extend_chunks([map(ord, 'hello'), map(ord, 'world')], 2) 360 # n is the number of redundant error-correction chunks to add 361 362 363 def extend_chunks(data, n, arithmetic=Galois): 364 o = [] 365 for i in range(len(data[0])): 366 o.append(extend(map(lambda x: x[i], data), n, arithmetic)) 367 return map(list, zip(*o)) 368 369 370 # Repairs a list of bytearrays. Use None in place of a missing array. 371 # Individual arrays can contain some missing or erroneous data. 372 373 374 def repair_chunks(data, datasize, arithmetic=Galois): 375 first_nonzero = 0 376 while not data[first_nonzero]: 377 first_nonzero += 1 378 for i in range(len(data)): 379 if data[i] is None: 380 data[i] = [None] * len(data[first_nonzero]) 381 o = [] 382 for i in range(len(data[0])): 383 o.append(repair(map(lambda x: x[i], data), datasize, arithmetic)) 384 return map(list, zip(*o)) 385 386 387 # Extends either a bytearray or a list of bytearrays or a list of lists... 388 # Used in the cubify method to expand a cube in all dimensions 389 390 391 def deep_extend_chunks(data, n, arithmetic=Galois): 392 if not isinstance(data[0], list): 393 return extend(data, n, arithmetic) 394 else: 395 o = [] 396 for i in range(len(data[0])): 397 o.append( 398 deep_extend_chunks(map(lambda x: x[i], data), n, arithmetic)) 399 return map(list, zip(*o)) 400 401 402 # ISO/IEC 7816-4 padding 403 404 405 def pad(data, size): 406 data = data[:] 407 data.append(128) 408 while len(data) % size != 0: 409 data.append(0) 410 return data 411 412 413 # Removes ISO/IEC 7816-4 padding 414 415 416 def unpad(data): 417 data = data[:] 418 while data[-1] != 128: 419 data.pop() 420 data.pop() 421 return data 422 423 424 # Splits a bytearray into a given number of chunks with some 425 # redundant chunks 426 427 428 def split(data, numchunks, redund): 429 chunksize = len(data) / numchunks + 1 430 data = pad(data, chunksize) 431 chunks = [] 432 for i in range(0, len(data), chunksize): 433 chunks.append(data[i: i+chunksize]) 434 o = extend_chunks(chunks, redund) 435 return o 436 437 438 # Recombines chunks into the original bytearray 439 440 441 def recombine(chunks, datalength): 442 datasize = datalength / len(chunks[0]) + 1 443 c = repair_chunks(chunks, datasize) 444 return unpad(sum(c[:datasize], [])) 445 446 447 h = '0123456789abcdef' 448 hexfy = lambda x: h[x//16]+h[x % 16] 449 unhexfy = lambda x: h.find(x[0]) * 16 + h.find(x[1]) 450 split2 = lambda x: map(lambda a: ''.join(a), zip(x[::2], x[1::2])) 451 452 453 # Canonical serialization. First argument is a bytearray, remaining 454 # arguments are strings to prepend 455 456 457 def serialize_chunk(*args): 458 chunk = args[0] 459 if not chunk or chunk[0] is None: 460 return None 461 metadata = args[1:] 462 return '-'.join(map(str, metadata) + [''.join(map(hexfy, chunk))]) 463 464 465 def deserialize_chunk(chunk): 466 data = chunk.split('-') 467 metadata, main = data[:-1], data[-1] 468 return metadata, map(unhexfy, split2(main)) 469 470 471 # Splits a string into a given number of chunks with some redundant chunks 472 473 474 def split_file(f, numchunks=5, redund=5): 475 f = map(ord, f) 476 ec = split(f, numchunks, redund) 477 o = [] 478 for i, c in enumerate(ec): 479 o.append( 480 serialize_chunk(c, *[i, numchunks, numchunks + redund, len(f)])) 481 return o 482 483 484 def recombine_file(chunks): 485 chunks2 = map(deserialize_chunk, chunks) 486 metadata = map(int, chunks2[0][0]) 487 o = [None] * metadata[2] 488 for chunk in chunks2: 489 o[int(chunk[0][0])] = chunk[1] 490 return ''.join(map(chr, recombine(o, metadata[3]))) 491 492 outersplitn = lambda x, k: map(lambda i: x[i:i+k], range(len(x))) 493 494 495 # Splits a bytearray into a hypercube with `dim` dimensions with the original 496 # data being in a sub-cube of width `width` and the expanded cube being of 497 # width `width+redund`. The cube is self-healing; if any edge in any dimension 498 # has missing or erroneous pieces, we can use the Berlekamp-Welch algorithm 499 # to fix this 500 501 502 def cubify(f, width, dim, redund): 503 chunksize = len(f) / width**dim + 1 504 data = pad(f, width**dim) 505 chunks = [] 506 for i in range(0, len(data), chunksize * width): 507 for j in range(width): 508 chunks.append(data[i+j*chunksize: i+j*chunksize+chunksize]) 509 510 for i in range(dim): 511 o = [] 512 for j in range(0, len(chunks), width): 513 e = chunks[j: j + width] 514 o.append( 515 deep_extend_chunks(e, redund)) 516 chunks = o 517 518 return chunks[0] 519 520 521 # `pos` is an array of coordinates. Go deep into a nested list 522 523 524 def descend(obj, pos): 525 for p in pos: 526 obj = obj[p] 527 return obj 528 529 530 # Go deep into a nested list and modify the value 531 532 533 def descend_and_set(obj, pos, val): 534 immed = descend(obj, pos[:-1]) 535 immed[pos[-1]] = val 536 537 538 # Use the Berlekamp-Welch algorithm to try to "heal" a particular missing 539 # or damaged coordinate 540 541 542 def heal_cube(cube, width, dim, pos, datalen): 543 for d in range(len(pos)): 544 o = [] 545 for i in range(len(cube)): 546 o.append(descend(cube, pos[:d] + [i] + pos[d+1:])) 547 try: 548 o = repair_chunks(o, width) 549 for i in range(len(cube)): 550 path = pos[:d] + [i] + pos[d+1:] 551 descend_and_set(cube, path, o[i]) 552 except: 553 pass 554 555 556 def pack_metadata(meta): 557 return map(str, meta['coords']) + [ 558 str(meta['base_width']), 559 str(meta['extended_width']), 560 str(meta['filesize']) 561 ] 562 563 564 def unpack_metadata(meta): 565 return { 566 'coords': map(int, meta[:-3]), 567 'base_width': int(meta[-3]), 568 'extended_width': int(meta[-2]), 569 'filesize': int(meta[-1]) 570 } 571 572 573 # Helper to serialize the contents of a cube of byte arrays 574 575 576 def _ser(chunk, meta): 577 if chunk is None or (not isinstance(chunk[0], list) and 578 chunk[0] is not None): 579 u = serialize_chunk(chunk, *pack_metadata(meta)) 580 return u 581 else: 582 o = [] 583 for i, c in enumerate(chunk): 584 meta2 = copy.deepcopy(meta) 585 meta2['coords'] += [i] 586 o.append(_ser(c, meta2)) 587 return o 588 589 590 # Converts a deep list into a shallow list 591 592 593 def flatten(chunks): 594 if not isinstance(chunks, list): 595 return [chunks] 596 else: 597 o = [] 598 for c in chunks: 599 o.extend(flatten(c)) 600 return o 601 602 603 # Converts a file into a multidimensional set of chunks with 604 # the desired parameters 605 606 607 def serialize_cubify(f, width, dim, redund): 608 f = map(ord, f) 609 cube = cubify(f, width, dim, redund) 610 metadata = { 611 'base_width': width, 612 'extended_width': width + redund, 613 'coords': [], 614 'filesize': len(f) 615 } 616 cube_of_serialized_chunks = _ser(cube, metadata) 617 return flatten(cube_of_serialized_chunks) 618 619 620 # Converts a set of serialized chunks into a partially filled cube 621 622 623 def construct_cube(pieces): 624 pieces = map(deserialize_chunk, pieces) 625 metadata = unpack_metadata(pieces[0][0]) 626 dim = len(metadata['coords']) 627 cube = None 628 for i in range(dim): 629 cube = [copy.deepcopy(cube) for i in range(metadata['extended_width'])] 630 for p in pieces: 631 descend_and_set(cube, unpack_metadata(p[0])['coords'], p[1]) 632 return cube 633 634 635 # Tries to recreate the chunk at a particular coordinate given a set of 636 # other chunks 637 638 639 def heal_set(pieces, coords): 640 c = construct_cube(pieces) 641 metadata, piecezzz = deserialize_chunk(pieces[0]) 642 metadata = unpack_metadata(metadata) 643 heal_cube(c, 644 metadata['base_width'], 645 len(metadata['coords']), 646 coords, 647 metadata['filesize']) 648 metadata2 = copy.deepcopy(metadata) 649 metadata2["coords"] = [] 650 return filter(lambda x: x, flatten(_ser(c, metadata2))) 651 652 653 def number_to_coords(n, w, dim): 654 c = [0] * dim 655 for i in range(dim): 656 c[i] = n / w**(dim - i - 1) 657 n %= w**(dim - i - 1) 658 return c 659 660 661 def full_heal_set(pieces): 662 c = construct_cube(pieces) 663 metadata, piecezzz = deserialize_chunk(pieces[0]) 664 metadata = unpack_metadata(metadata) 665 while 1: 666 done = True 667 unfilled = False 668 i = 0 669 while i < metadata['extended_width'] ** len(metadata['coords']): 670 coords = number_to_coords(i, 671 metadata['extended_width'], 672 len(metadata['coords'])) 673 v = descend(c, coords) 674 heal_cube(c, 675 metadata['base_width'], 676 len(metadata['coords']), 677 coords, 678 metadata['filesize']) 679 v2 = descend(c, coords) 680 if v != v2: 681 done = False 682 if v is None and v2 is None: 683 unfilled = True 684 i += 1 685 if done and not unfilled: 686 break 687 elif done and unfilled: 688 raise Exception("not enough data or too much corrupted data") 689 o = [] 690 for i in range(metadata['base_width'] ** len(metadata['coords'])): 691 coords = number_to_coords(i, 692 metadata['base_width'], 693 len(metadata['coords'])) 694 o.extend(descend(c, coords)) 695 return ''.join(map(chr, unpad(o)))