/ src / minisketch / src / int_utils.h
int_utils.h
  1  /**********************************************************************
  2   * Copyright (c) 2018 Pieter Wuille, Greg Maxwell, Gleb Naumenko      *
  3   * Distributed under the MIT software license, see the accompanying   *
  4   * file LICENSE or http://www.opensource.org/licenses/mit-license.php.*
  5   **********************************************************************/
  6  
  7  #ifndef _MINISKETCH_INT_UTILS_H_
  8  #define _MINISKETCH_INT_UTILS_H_
  9  
 10  #include <stdint.h>
 11  #include <stdlib.h>
 12  
 13  #include <limits>
 14  #include <algorithm>
 15  #include <type_traits>
 16  
 17  #if defined(__cpp_lib_int_pow2) && __cpp_lib_int_pow2 >= 202002L
 18  #  include <bit>
 19  #elif defined(_MSC_VER)
 20  #  include <intrin.h>
 21  #endif
 22  
 23  template<int bits>
 24  static constexpr inline uint64_t Rot(uint64_t x) { return (x << bits) | (x >> (64 - bits)); }
 25  
 26  static inline void SipHashRound(uint64_t& v0, uint64_t& v1, uint64_t& v2, uint64_t& v3) {
 27      v0 += v1; v1 = Rot<13>(v1); v1 ^= v0;
 28      v0 = Rot<32>(v0);
 29      v2 += v3; v3 = Rot<16>(v3); v3 ^= v2;
 30      v0 += v3; v3 = Rot<21>(v3); v3 ^= v0;
 31      v2 += v1; v1 = Rot<17>(v1); v1 ^= v2;
 32      v2 = Rot<32>(v2);
 33  }
 34  
 35  inline uint64_t SipHash(uint64_t k0, uint64_t k1, uint64_t data) {
 36      uint64_t v0 = 0x736f6d6570736575ULL ^ k0;
 37      uint64_t v1 = 0x646f72616e646f6dULL ^ k1;
 38      uint64_t v2 = 0x6c7967656e657261ULL ^ k0;
 39      uint64_t v3 = 0x7465646279746573ULL ^ k1 ^ data;
 40      SipHashRound(v0, v1, v2, v3);
 41      SipHashRound(v0, v1, v2, v3);
 42      v0 ^= data;
 43      v3 ^= 0x800000000000000ULL;
 44      SipHashRound(v0, v1, v2, v3);
 45      SipHashRound(v0, v1, v2, v3);
 46      v0 ^= 0x800000000000000ULL;
 47      v2 ^= 0xFF;
 48      SipHashRound(v0, v1, v2, v3);
 49      SipHashRound(v0, v1, v2, v3);
 50      SipHashRound(v0, v1, v2, v3);
 51      SipHashRound(v0, v1, v2, v3);
 52      return v0 ^ v1 ^ v2 ^ v3;
 53  }
 54  
 55  class BitWriter {
 56      unsigned char state = 0;
 57      int offset = 0;
 58      unsigned char* out;
 59  
 60      template<int BITS, typename I>
 61      inline void WriteInner(I val) {
 62          // We right shift by up to 8 bits below. Verify that's well defined for the type I.
 63          static_assert(std::numeric_limits<I>::digits > 8, "BitWriter::WriteInner needs I > 8 bits");
 64          int bits = BITS;
 65          if (bits + offset >= 8) {
 66              state |= ((val & ((I(1) << (8 - offset)) - 1)) << offset);
 67              *(out++) = state;
 68              val >>= (8 - offset);
 69              bits -= 8 - offset;
 70              offset = 0;
 71              state = 0;
 72          }
 73          while (bits >= 8) {
 74              *(out++) = val & 255;
 75              val >>= 8;
 76              bits -= 8;
 77          }
 78          state |= ((val & ((I(1) << bits) - 1)) << offset);
 79          offset += bits;
 80      }
 81  
 82  
 83  public:
 84      BitWriter(unsigned char* output) : out(output) {}
 85  
 86      template<int BITS, typename I>
 87      inline void Write(I val) {
 88          // If I is smaller than an unsigned int, invoke WriteInner with argument converted to unsigned.
 89          using compute_type = typename std::conditional<
 90              (std::numeric_limits<I>::digits < std::numeric_limits<unsigned>::digits),
 91              unsigned, I>::type;
 92          return WriteInner<BITS, compute_type>(val);
 93      }
 94  
 95      inline void Flush() {
 96          if (offset) {
 97              *(out++) = state;
 98              state = 0;
 99              offset = 0;
100          }
101      }
102  };
103  
104  class BitReader {
105      unsigned char state = 0;
106      int offset = 0;
107      const unsigned char* in;
108  
109  public:
110      BitReader(const unsigned char* input) : in(input) {}
111  
112      template<int BITS, typename I>
113      inline I Read() {
114          int bits = BITS;
115          if (offset >= bits) {
116              I ret = state & ((1 << bits) - 1);
117              state >>= bits;
118              offset -= bits;
119              return ret;
120          }
121          I val = state;
122          int out = offset;
123          while (out + 8 <= bits) {
124              val |= ((I(*(in++))) << out);
125              out += 8;
126          }
127          if (out < bits) {
128              unsigned char c = *(in++);
129              val |= (c & ((I(1) << (bits - out)) - 1)) << out;
130              state = c >> (bits - out);
131              offset = 8 - (bits - out);
132          } else {
133              state = 0;
134              offset = 0;
135          }
136          return val;
137      }
138  };
139  
140  /** Return a value of type I with its `bits` lowest bits set (bits must be > 0). */
141  template<int BITS, typename I>
142  constexpr inline I Mask() { return ((I((I(-1)) << (std::numeric_limits<I>::digits - BITS))) >> (std::numeric_limits<I>::digits - BITS)); }
143  
144  /** Compute the smallest power of two that is larger than val. */
145  template<typename I>
146  static inline int CountBits(I val, int max) {
147  #if defined(__cpp_lib_int_pow2) && __cpp_lib_int_pow2 >= 202002L
148      // c++20 impl
149      (void)max;
150      return std::bit_width(val);
151  #elif defined(_MSC_VER)
152      (void)max;
153      unsigned long index;
154      unsigned char ret;
155      if (std::numeric_limits<I>::digits <= 32) {
156          ret = _BitScanReverse(&index, val);
157      } else {
158          ret = _BitScanReverse64(&index, val);
159      }
160      if (!ret) return 0;
161      return index + 1;
162  #elif defined(HAVE_CLZ)
163      (void)max;
164      if (val == 0) return 0;
165      if (std::numeric_limits<unsigned>::digits >= std::numeric_limits<I>::digits) {
166          return std::numeric_limits<unsigned>::digits - __builtin_clz(val);
167      } else if (std::numeric_limits<unsigned long>::digits >= std::numeric_limits<I>::digits) {
168          return std::numeric_limits<unsigned long>::digits - __builtin_clzl(val);
169      } else {
170          return std::numeric_limits<unsigned long long>::digits - __builtin_clzll(val);
171      }
172  #else
173      while (max && (val >> (max - 1) == 0)) --max;
174      return max;
175  #endif
176  }
177  
178  template<typename I, int BITS>
179  class BitsInt {
180  private:
181      static_assert(std::is_integral<I>::value && std::is_unsigned<I>::value, "BitsInt requires an unsigned integer type");
182      static_assert(BITS > 0 && BITS <= std::numeric_limits<I>::digits, "BitsInt requires 1 <= Bits <= representation type size");
183  
184      static constexpr I MASK = Mask<BITS, I>();
185  
186  public:
187  
188      typedef I Repr;
189  
190      static constexpr int SIZE = BITS;
191  
192      static void inline Swap(I& a, I& b) {
193          std::swap(a, b);
194      }
195  
196      static constexpr inline bool IsZero(I a) { return a == 0; }
197      static constexpr inline bool IsOne(I a) { return a == 1; }
198      static constexpr inline I Mask(I val) { return val & MASK; }
199      static constexpr inline I Shift(I val, int bits) { return ((val << bits) & MASK); }
200      static constexpr inline I UnsafeShift(I val, int bits) { return (val << bits); }
201  
202      template<int Offset, int Count>
203      static constexpr inline int MidBits(I val) {
204          static_assert(Count > 0, "BITSInt::MidBits needs Count > 0");
205          static_assert(Count + Offset <= BITS, "BitsInt::MidBits overflow of Count+Offset");
206          return (val >> Offset) & ((I(1) << Count) - 1);
207      }
208  
209      template<int Count>
210      static constexpr inline int TopBits(I val) {
211          static_assert(Count > 0, "BitsInt::TopBits needs Count > 0");
212          static_assert(Count <= BITS, "BitsInt::TopBits needs Offset <= BITS");
213          return static_cast<int>(val >> (BITS - Count));
214      }
215  
216      static inline constexpr I CondXorWith(I val, bool cond, I v) {
217          return val ^ (-I(cond) & v);
218      }
219  
220      template<I MOD>
221      static inline constexpr I CondXorWith(I val, bool cond) {
222          return val ^ (-I(cond) & MOD);
223      }
224  
225      static inline int Bits(I val, int max) { return CountBits<I>(val, max); }
226  };
227  
228  /** Class which implements a stateless LFSR for generic moduli. */
229  template<typename F, uint32_t MOD>
230  struct LFSR {
231      typedef typename F::Repr I;
232      /** Shift a value `a` up once, treating it as an `N`-bit LFSR, with pattern `MOD`. */
233      static inline constexpr I Call(const I& a) {
234          return F::template CondXorWith<MOD>(F::Shift(a, 1), F::template TopBits<1>(a));
235      }
236  };
237  
238  /** Helper class for carryless multiplications. */
239  template<typename I, int N, typename L, typename F, int K> struct GFMulHelper;
240  template<typename I, int N, typename L, typename F> struct GFMulHelper<I, N, L, F, 0>
241  {
242      static inline constexpr I Run(const I& a, const I& b) { return I(0); }
243  };
244  template<typename I, int N, typename L, typename F, int K> struct GFMulHelper
245  {
246      static inline constexpr I Run(const I& a, const I& b) { return F::CondXorWith(GFMulHelper<I, N, L, F, K - 1>::Run(L::Call(a), b), F::template MidBits<N - K, 1>(b), a); }
247  };
248  
249  /** Compute the carry-less multiplication of a and b, with N bits, using L as LFSR type. */
250  template<typename I, int N, typename L, typename F> inline constexpr I GFMul(const I& a, const I& b) { return GFMulHelper<I, N, L, F, N>::Run(a, b); }
251  
252  /** Compute the inverse of x using an extgcd algorithm. */
253  template<typename I, typename F, int BITS, uint32_t MOD>
254  inline I InvExtGCD(I x)
255  {
256      if (F::IsZero(x) || F::IsOne(x)) return x;
257      I t(0), newt(1);
258      I r(MOD), newr = x;
259      int rlen = BITS + 1, newrlen = F::Bits(newr, BITS);
260      while (newr) {
261          int q = rlen - newrlen;
262          r ^= F::Shift(newr, q);
263          t ^= F::UnsafeShift(newt, q);
264          rlen = F::Bits(r, rlen - 1);
265          if (r < newr) {
266              F::Swap(t, newt);
267              F::Swap(r, newr);
268              std::swap(rlen, newrlen);
269          }
270      }
271      return t;
272  }
273  
274  /** Compute the inverse of x1 using an exponentiation ladder.
275   *
276   * The `MUL` argument is a multiplication function, `SQR` is a squaring function, and the `SQRi` arguments
277   * compute x**(2**i).
278   */
279  template<typename I, typename F, int BITS, I (*MUL)(I, I), I (*SQR)(I), I (*SQR2)(I), I(*SQR4)(I), I(*SQR8)(I), I(*SQR16)(I)>
280  inline I InvLadder(I x1)
281  {
282      static constexpr int INV_EXP = BITS - 1;
283      I x2 = (INV_EXP >= 2) ? MUL(SQR(x1), x1) : I();
284      I x4 = (INV_EXP >= 4) ? MUL(SQR2(x2), x2) : I();
285      I x8 = (INV_EXP >= 8) ? MUL(SQR4(x4), x4) : I();
286      I x16 = (INV_EXP >= 16) ? MUL(SQR8(x8), x8) : I();
287      I x32 = (INV_EXP >= 32) ? MUL(SQR16(x16), x16) : I();
288      I r;
289      if (INV_EXP >= 32) {
290          r = x32;
291      } else if (INV_EXP >= 16) {
292          r = x16;
293      } else if (INV_EXP >= 8) {
294          r = x8;
295      } else if (INV_EXP >= 4) {
296          r = x4;
297      } else if (INV_EXP >= 2) {
298          r = x2;
299      } else {
300          r = x1;
301      }
302      if (INV_EXP >= 32 && (INV_EXP & 16)) r = MUL(SQR16(r), x16);
303      if (INV_EXP >= 16 && (INV_EXP & 8)) r = MUL(SQR8(r), x8);
304      if (INV_EXP >= 8 && (INV_EXP & 4)) r = MUL(SQR4(r), x4);
305      if (INV_EXP >= 4 && (INV_EXP & 2)) r = MUL(SQR2(r), x2);
306      if (INV_EXP >= 2 && (INV_EXP & 1)) r = MUL(SQR(r), x1);
307      return SQR(r);
308  }
309  
310  #endif