int_utils.h
1 /********************************************************************** 2 * Copyright (c) 2018 Pieter Wuille, Greg Maxwell, Gleb Naumenko * 3 * Distributed under the MIT software license, see the accompanying * 4 * file LICENSE or http://www.opensource.org/licenses/mit-license.php.* 5 **********************************************************************/ 6 7 #ifndef _MINISKETCH_INT_UTILS_H_ 8 #define _MINISKETCH_INT_UTILS_H_ 9 10 #include <stdint.h> 11 #include <stdlib.h> 12 13 #include <limits> 14 #include <algorithm> 15 #include <type_traits> 16 17 #if defined(__cpp_lib_int_pow2) && __cpp_lib_int_pow2 >= 202002L 18 # include <bit> 19 #elif defined(_MSC_VER) 20 # include <intrin.h> 21 #endif 22 23 template<int bits> 24 static constexpr inline uint64_t Rot(uint64_t x) { return (x << bits) | (x >> (64 - bits)); } 25 26 static inline void SipHashRound(uint64_t& v0, uint64_t& v1, uint64_t& v2, uint64_t& v3) { 27 v0 += v1; v1 = Rot<13>(v1); v1 ^= v0; 28 v0 = Rot<32>(v0); 29 v2 += v3; v3 = Rot<16>(v3); v3 ^= v2; 30 v0 += v3; v3 = Rot<21>(v3); v3 ^= v0; 31 v2 += v1; v1 = Rot<17>(v1); v1 ^= v2; 32 v2 = Rot<32>(v2); 33 } 34 35 inline uint64_t SipHash(uint64_t k0, uint64_t k1, uint64_t data) { 36 uint64_t v0 = 0x736f6d6570736575ULL ^ k0; 37 uint64_t v1 = 0x646f72616e646f6dULL ^ k1; 38 uint64_t v2 = 0x6c7967656e657261ULL ^ k0; 39 uint64_t v3 = 0x7465646279746573ULL ^ k1 ^ data; 40 SipHashRound(v0, v1, v2, v3); 41 SipHashRound(v0, v1, v2, v3); 42 v0 ^= data; 43 v3 ^= 0x800000000000000ULL; 44 SipHashRound(v0, v1, v2, v3); 45 SipHashRound(v0, v1, v2, v3); 46 v0 ^= 0x800000000000000ULL; 47 v2 ^= 0xFF; 48 SipHashRound(v0, v1, v2, v3); 49 SipHashRound(v0, v1, v2, v3); 50 SipHashRound(v0, v1, v2, v3); 51 SipHashRound(v0, v1, v2, v3); 52 return v0 ^ v1 ^ v2 ^ v3; 53 } 54 55 class BitWriter { 56 unsigned char state = 0; 57 int offset = 0; 58 unsigned char* out; 59 60 template<int BITS, typename I> 61 inline void WriteInner(I val) { 62 // We right shift by up to 8 bits below. Verify that's well defined for the type I. 63 static_assert(std::numeric_limits<I>::digits > 8, "BitWriter::WriteInner needs I > 8 bits"); 64 int bits = BITS; 65 if (bits + offset >= 8) { 66 state |= ((val & ((I(1) << (8 - offset)) - 1)) << offset); 67 *(out++) = state; 68 val >>= (8 - offset); 69 bits -= 8 - offset; 70 offset = 0; 71 state = 0; 72 } 73 while (bits >= 8) { 74 *(out++) = val & 255; 75 val >>= 8; 76 bits -= 8; 77 } 78 state |= ((val & ((I(1) << bits) - 1)) << offset); 79 offset += bits; 80 } 81 82 83 public: 84 BitWriter(unsigned char* output) : out(output) {} 85 86 template<int BITS, typename I> 87 inline void Write(I val) { 88 // If I is smaller than an unsigned int, invoke WriteInner with argument converted to unsigned. 89 using compute_type = typename std::conditional< 90 (std::numeric_limits<I>::digits < std::numeric_limits<unsigned>::digits), 91 unsigned, I>::type; 92 return WriteInner<BITS, compute_type>(val); 93 } 94 95 inline void Flush() { 96 if (offset) { 97 *(out++) = state; 98 state = 0; 99 offset = 0; 100 } 101 } 102 }; 103 104 class BitReader { 105 unsigned char state = 0; 106 int offset = 0; 107 const unsigned char* in; 108 109 public: 110 BitReader(const unsigned char* input) : in(input) {} 111 112 template<int BITS, typename I> 113 inline I Read() { 114 int bits = BITS; 115 if (offset >= bits) { 116 I ret = state & ((1 << bits) - 1); 117 state >>= bits; 118 offset -= bits; 119 return ret; 120 } 121 I val = state; 122 int out = offset; 123 while (out + 8 <= bits) { 124 val |= ((I(*(in++))) << out); 125 out += 8; 126 } 127 if (out < bits) { 128 unsigned char c = *(in++); 129 val |= (c & ((I(1) << (bits - out)) - 1)) << out; 130 state = c >> (bits - out); 131 offset = 8 - (bits - out); 132 } else { 133 state = 0; 134 offset = 0; 135 } 136 return val; 137 } 138 }; 139 140 /** Return a value of type I with its `bits` lowest bits set (bits must be > 0). */ 141 template<int BITS, typename I> 142 constexpr inline I Mask() { return ((I((I(-1)) << (std::numeric_limits<I>::digits - BITS))) >> (std::numeric_limits<I>::digits - BITS)); } 143 144 /** Compute the smallest power of two that is larger than val. */ 145 template<typename I> 146 static inline int CountBits(I val, int max) { 147 #if defined(__cpp_lib_int_pow2) && __cpp_lib_int_pow2 >= 202002L 148 // c++20 impl 149 (void)max; 150 return std::bit_width(val); 151 #elif defined(_MSC_VER) 152 (void)max; 153 unsigned long index; 154 unsigned char ret; 155 if (std::numeric_limits<I>::digits <= 32) { 156 ret = _BitScanReverse(&index, val); 157 } else { 158 ret = _BitScanReverse64(&index, val); 159 } 160 if (!ret) return 0; 161 return index + 1; 162 #elif defined(HAVE_CLZ) 163 (void)max; 164 if (val == 0) return 0; 165 if (std::numeric_limits<unsigned>::digits >= std::numeric_limits<I>::digits) { 166 return std::numeric_limits<unsigned>::digits - __builtin_clz(val); 167 } else if (std::numeric_limits<unsigned long>::digits >= std::numeric_limits<I>::digits) { 168 return std::numeric_limits<unsigned long>::digits - __builtin_clzl(val); 169 } else { 170 return std::numeric_limits<unsigned long long>::digits - __builtin_clzll(val); 171 } 172 #else 173 while (max && (val >> (max - 1) == 0)) --max; 174 return max; 175 #endif 176 } 177 178 template<typename I, int BITS> 179 class BitsInt { 180 private: 181 static_assert(std::is_integral<I>::value && std::is_unsigned<I>::value, "BitsInt requires an unsigned integer type"); 182 static_assert(BITS > 0 && BITS <= std::numeric_limits<I>::digits, "BitsInt requires 1 <= Bits <= representation type size"); 183 184 static constexpr I MASK = Mask<BITS, I>(); 185 186 public: 187 188 typedef I Repr; 189 190 static constexpr int SIZE = BITS; 191 192 static void inline Swap(I& a, I& b) { 193 std::swap(a, b); 194 } 195 196 static constexpr inline bool IsZero(I a) { return a == 0; } 197 static constexpr inline bool IsOne(I a) { return a == 1; } 198 static constexpr inline I Mask(I val) { return val & MASK; } 199 static constexpr inline I Shift(I val, int bits) { return ((val << bits) & MASK); } 200 static constexpr inline I UnsafeShift(I val, int bits) { return (val << bits); } 201 202 template<int Offset, int Count> 203 static constexpr inline int MidBits(I val) { 204 static_assert(Count > 0, "BITSInt::MidBits needs Count > 0"); 205 static_assert(Count + Offset <= BITS, "BitsInt::MidBits overflow of Count+Offset"); 206 return (val >> Offset) & ((I(1) << Count) - 1); 207 } 208 209 template<int Count> 210 static constexpr inline int TopBits(I val) { 211 static_assert(Count > 0, "BitsInt::TopBits needs Count > 0"); 212 static_assert(Count <= BITS, "BitsInt::TopBits needs Offset <= BITS"); 213 return static_cast<int>(val >> (BITS - Count)); 214 } 215 216 static inline constexpr I CondXorWith(I val, bool cond, I v) { 217 return val ^ (-I(cond) & v); 218 } 219 220 template<I MOD> 221 static inline constexpr I CondXorWith(I val, bool cond) { 222 return val ^ (-I(cond) & MOD); 223 } 224 225 static inline int Bits(I val, int max) { return CountBits<I>(val, max); } 226 }; 227 228 /** Class which implements a stateless LFSR for generic moduli. */ 229 template<typename F, uint32_t MOD> 230 struct LFSR { 231 typedef typename F::Repr I; 232 /** Shift a value `a` up once, treating it as an `N`-bit LFSR, with pattern `MOD`. */ 233 static inline constexpr I Call(const I& a) { 234 return F::template CondXorWith<MOD>(F::Shift(a, 1), F::template TopBits<1>(a)); 235 } 236 }; 237 238 /** Helper class for carryless multiplications. */ 239 template<typename I, int N, typename L, typename F, int K> struct GFMulHelper; 240 template<typename I, int N, typename L, typename F> struct GFMulHelper<I, N, L, F, 0> 241 { 242 static inline constexpr I Run(const I& a, const I& b) { return I(0); } 243 }; 244 template<typename I, int N, typename L, typename F, int K> struct GFMulHelper 245 { 246 static inline constexpr I Run(const I& a, const I& b) { return F::CondXorWith(GFMulHelper<I, N, L, F, K - 1>::Run(L::Call(a), b), F::template MidBits<N - K, 1>(b), a); } 247 }; 248 249 /** Compute the carry-less multiplication of a and b, with N bits, using L as LFSR type. */ 250 template<typename I, int N, typename L, typename F> inline constexpr I GFMul(const I& a, const I& b) { return GFMulHelper<I, N, L, F, N>::Run(a, b); } 251 252 /** Compute the inverse of x using an extgcd algorithm. */ 253 template<typename I, typename F, int BITS, uint32_t MOD> 254 inline I InvExtGCD(I x) 255 { 256 if (F::IsZero(x) || F::IsOne(x)) return x; 257 I t(0), newt(1); 258 I r(MOD), newr = x; 259 int rlen = BITS + 1, newrlen = F::Bits(newr, BITS); 260 while (newr) { 261 int q = rlen - newrlen; 262 r ^= F::Shift(newr, q); 263 t ^= F::UnsafeShift(newt, q); 264 rlen = F::Bits(r, rlen - 1); 265 if (r < newr) { 266 F::Swap(t, newt); 267 F::Swap(r, newr); 268 std::swap(rlen, newrlen); 269 } 270 } 271 return t; 272 } 273 274 /** Compute the inverse of x1 using an exponentiation ladder. 275 * 276 * The `MUL` argument is a multiplication function, `SQR` is a squaring function, and the `SQRi` arguments 277 * compute x**(2**i). 278 */ 279 template<typename I, typename F, int BITS, I (*MUL)(I, I), I (*SQR)(I), I (*SQR2)(I), I(*SQR4)(I), I(*SQR8)(I), I(*SQR16)(I)> 280 inline I InvLadder(I x1) 281 { 282 static constexpr int INV_EXP = BITS - 1; 283 I x2 = (INV_EXP >= 2) ? MUL(SQR(x1), x1) : I(); 284 I x4 = (INV_EXP >= 4) ? MUL(SQR2(x2), x2) : I(); 285 I x8 = (INV_EXP >= 8) ? MUL(SQR4(x4), x4) : I(); 286 I x16 = (INV_EXP >= 16) ? MUL(SQR8(x8), x8) : I(); 287 I x32 = (INV_EXP >= 32) ? MUL(SQR16(x16), x16) : I(); 288 I r; 289 if (INV_EXP >= 32) { 290 r = x32; 291 } else if (INV_EXP >= 16) { 292 r = x16; 293 } else if (INV_EXP >= 8) { 294 r = x8; 295 } else if (INV_EXP >= 4) { 296 r = x4; 297 } else if (INV_EXP >= 2) { 298 r = x2; 299 } else { 300 r = x1; 301 } 302 if (INV_EXP >= 32 && (INV_EXP & 16)) r = MUL(SQR16(r), x16); 303 if (INV_EXP >= 16 && (INV_EXP & 8)) r = MUL(SQR8(r), x8); 304 if (INV_EXP >= 8 && (INV_EXP & 4)) r = MUL(SQR4(r), x4); 305 if (INV_EXP >= 4 && (INV_EXP & 2)) r = MUL(SQR2(r), x2); 306 if (INV_EXP >= 2 && (INV_EXP & 1)) r = MUL(SQR(r), x1); 307 return SQR(r); 308 } 309 310 #endif