gen_basefpbits.sage
1 # Require exact values up to 2 FPBITS = 256 3 4 # Overkill accuracy 5 F = RealField(400) 6 7 def BaseFPBits(bits, capacity): 8 return bits * capacity - int(ceil(F(log(sum(binomial(2**bits - 1, i) for i in range(capacity+1)), 2)))) 9 10 def Log2Factorial(capacity): 11 return int(floor(log(factorial(capacity), 2))) 12 13 print("uint64_t BaseFPBits(uint32_t bits, uint32_t capacity) {") 14 print(" // Correction table for low bits/capacities") 15 TBLS={} 16 FARS={} 17 SKIPS={} 18 for bits in range(1, 32): 19 TBL = [] 20 for capacity in range(1, min(2**bits, FPBITS)): 21 exact = BaseFPBits(bits, capacity) 22 approx = Log2Factorial(capacity) 23 TBL.append((exact, approx)) 24 MIN = 10000000000 25 while len(TBL) and ((TBL[-1][0] == TBL[-1][1]) or (TBL[-1][0] >= FPBITS and TBL[-1][1] >= FPBITS)): 26 MIN = min(MIN, TBL[-1][0] - TBL[-1][1]) 27 TBL.pop() 28 while len(TBL) and (TBL[-1][0] - TBL[-1][1] == MIN): 29 TBL.pop() 30 SKIP = 0 31 while SKIP < len(TBL) and TBL[SKIP][0] == TBL[SKIP][1]: 32 SKIP += 1 33 DIFFS = [TBL[i][0] - TBL[i][1] for i in range(SKIP, len(TBL))] 34 if len(DIFFS) > 0 and len(DIFFS) * Integer(max(DIFFS)).nbits() > 64: 35 print(" static constexpr uint8_t ADD%i[] = {%s};" % (bits, ", ".join(("%i" % (TBL[i][0] - TBL[i][1])) for i in range(SKIP, len(TBL))))) 36 TBLS[bits] = DIFFS 37 FARS[bits] = MIN 38 SKIPS[bits] = SKIP 39 print("") 40 print(" if (capacity == 0) return 0;") 41 print(" uint64_t ret = 0;") 42 print(" if (bits < 32 && capacity >= (1U << bits)) {") 43 print(" ret = uint64_t{bits} * (capacity - (1U << bits) + 1);") 44 print(" capacity = (1U << bits) - 1;") 45 print(" }") 46 print(" ret += Log2Factorial(capacity);") 47 print(" switch (bits) {") 48 for bits in sorted(TBLS.keys()): 49 if len(TBLS[bits]) == 0: 50 continue 51 width = Integer(max(TBLS[bits])).nbits() 52 if len(TBLS[bits]) == 1: 53 add = "%i" % TBLS[bits][0] 54 elif len(TBLS[bits]) * width <= 64: 55 code = sum((2**(width*i) * TBLS[bits][i]) for i in range(len(TBLS[bits]))) 56 if width == 1: 57 add = "(0x%x >> (capacity - %i)) & 1" % (code, 1 + SKIPS[bits]) 58 else: 59 add = "(0x%x >> %i * (capacity - %i)) & %i" % (code, width, 1 + SKIPS[bits], 2**width - 1) 60 else: 61 add = "ADD%i[capacity - %i]" % (bits, 1 + SKIPS[bits]) 62 if len(TBLS[bits]) + SKIPS[bits] == 2**bits - 1: 63 print(" case %i: return ret + (capacity <= %i ? 0 : %s);" % (bits, SKIPS[bits], add)) 64 else: 65 print(" case %i: return ret + (capacity <= %i ? 0 : capacity > %i ? %i : %s);" % (bits, SKIPS[bits], len(TBLS[bits]) + SKIPS[bits], FARS[bits], add)) 66 print(" default: return ret;") 67 print(" }") 68 print("}") 69 70 print("void TestBaseFPBits() {") 71 print(" static constexpr uint16_t TBL[20][100] = {%s};" % (", ".join("{" + ", ".join(("%i" % BaseFPBits(bits, capacity)) for capacity in range(0, 100)) + "}" for bits in range(1, 21)))) 72 print(" for (int bits = 1; bits <= 20; ++bits) {") 73 print(" for (int capacity = 0; capacity < 100; ++capacity) {") 74 print(" uint64_t computed = BaseFPBits(bits, capacity), exact = TBL[bits - 1][capacity];") 75 print(" CHECK(exact == computed || (exact >= 256 && computed >= 256));") 76 print(" }") 77 print(" }") 78 print("}")