/ src / minisketch / src / fields / clmul_common_impl.h
clmul_common_impl.h
  1  /**********************************************************************
  2   * Copyright (c) 2018 Pieter Wuille, Greg Maxwell, Gleb Naumenko      *
  3   * Distributed under the MIT software license, see the accompanying   *
  4   * file LICENSE or http://www.opensource.org/licenses/mit-license.php.*
  5   **********************************************************************/
  6  
  7  #ifndef _MINISKETCH_FIELDS_CLMUL_COMMON_IMPL_H_
  8  #define _MINISKETCH_FIELDS_CLMUL_COMMON_IMPL_H_ 1
  9  
 10  #include <stdint.h>
 11  #include <immintrin.h>
 12  
 13  #include "../int_utils.h"
 14  #include "../lintrans.h"
 15  
 16  namespace {
 17  
 18  // The memory sanitizer in clang < 11 cannot reason through _mm_clmulepi64_si128 calls.
 19  // Disable memory sanitization in the functions using them for those compilers.
 20  #if defined(__clang__) && (__clang_major__ < 11)
 21  #  if defined(__has_feature)
 22  #    if __has_feature(memory_sanitizer)
 23  #      define NO_SANITIZE_MEMORY __attribute__((no_sanitize("memory")))
 24  #    endif
 25  #  endif
 26  #endif
 27  #ifndef NO_SANITIZE_MEMORY
 28  #  define NO_SANITIZE_MEMORY
 29  #endif
 30  
 31  template<typename I, int BITS, I MOD> NO_SANITIZE_MEMORY I MulWithClMulReduce(I a, I b)
 32  {
 33      static constexpr I MASK = Mask<BITS, I>();
 34  
 35      const __m128i MOD128 = _mm_cvtsi64_si128(MOD);
 36      __m128i product = _mm_clmulepi64_si128(_mm_cvtsi64_si128((uint64_t)a), _mm_cvtsi64_si128((uint64_t)b), 0x00);
 37      if (BITS <= 32) {
 38          __m128i high1 = _mm_srli_epi64(product, BITS);
 39          __m128i red1 = _mm_clmulepi64_si128(high1, MOD128, 0x00);
 40          __m128i high2 = _mm_srli_epi64(red1, BITS);
 41          __m128i red2 = _mm_clmulepi64_si128(high2, MOD128, 0x00);
 42          return _mm_cvtsi128_si64(_mm_xor_si128(_mm_xor_si128(product, red1), red2)) & MASK;
 43      } else if (BITS == 64) {
 44          __m128i red1 = _mm_clmulepi64_si128(product, MOD128, 0x01);
 45          __m128i red2 = _mm_clmulepi64_si128(red1, MOD128, 0x01);
 46          return _mm_cvtsi128_si64(_mm_xor_si128(_mm_xor_si128(product, red1), red2));
 47      } else if ((BITS % 8) == 0) {
 48          __m128i high1 = _mm_srli_si128(product, BITS / 8);
 49          __m128i red1 = _mm_clmulepi64_si128(high1, MOD128, 0x00);
 50          __m128i high2 = _mm_srli_si128(red1, BITS / 8);
 51          __m128i red2 = _mm_clmulepi64_si128(high2, MOD128, 0x00);
 52          return _mm_cvtsi128_si64(_mm_xor_si128(_mm_xor_si128(product, red1), red2)) & MASK;
 53      } else {
 54          __m128i high1 = _mm_or_si128(_mm_srli_epi64(product, BITS), _mm_srli_si128(_mm_slli_epi64(product, 64 - BITS), 8));
 55          __m128i red1 = _mm_clmulepi64_si128(high1, MOD128, 0x00);
 56          if ((uint64_t(MOD) >> (66 - BITS)) == 0) {
 57              __m128i high2 = _mm_srli_epi64(red1, BITS);
 58              __m128i red2 = _mm_clmulepi64_si128(high2, MOD128, 0x00);
 59              return _mm_cvtsi128_si64(_mm_xor_si128(_mm_xor_si128(product, red1), red2)) & MASK;
 60          } else {
 61              __m128i high2 = _mm_or_si128(_mm_srli_epi64(red1, BITS), _mm_srli_si128(_mm_slli_epi64(red1, 64 - BITS), 8));
 62              __m128i red2 = _mm_clmulepi64_si128(high2, MOD128, 0x00);
 63              return _mm_cvtsi128_si64(_mm_xor_si128(_mm_xor_si128(product, red1), red2)) & MASK;
 64          }
 65      }
 66  }
 67  
 68  template<typename I, int BITS, int POS> NO_SANITIZE_MEMORY I MulTrinomial(I a, I b)
 69  {
 70      static constexpr I MASK = Mask<BITS, I>();
 71  
 72      __m128i product = _mm_clmulepi64_si128(_mm_cvtsi64_si128((uint64_t)a), _mm_cvtsi64_si128((uint64_t)b), 0x00);
 73      if (BITS <= 32) {
 74          __m128i high1 = _mm_srli_epi64(product, BITS);
 75          __m128i red1 = _mm_xor_si128(high1, _mm_slli_epi64(high1, POS));
 76          if (POS == 1) {
 77              return _mm_cvtsi128_si64(_mm_xor_si128(product, red1)) & MASK;
 78          } else {
 79              __m128i high2 = _mm_srli_epi64(red1, BITS);
 80              __m128i red2 = _mm_xor_si128(high2, _mm_slli_epi64(high2, POS));
 81              return _mm_cvtsi128_si64(_mm_xor_si128(_mm_xor_si128(product, red1), red2)) & MASK;
 82          }
 83      } else {
 84          __m128i high1 = _mm_or_si128(_mm_srli_epi64(product, BITS), _mm_srli_si128(_mm_slli_epi64(product, 64 - BITS), 8));
 85          if (BITS + POS <= 66) {
 86              __m128i red1 = _mm_xor_si128(high1, _mm_slli_epi64(high1, POS));
 87              if (POS == 1) {
 88                  return _mm_cvtsi128_si64(_mm_xor_si128(product, red1)) & MASK;
 89              } else if (BITS + POS <= 66) {
 90                  __m128i high2 = _mm_srli_epi64(red1, BITS);
 91                  __m128i red2 = _mm_xor_si128(high2, _mm_slli_epi64(high2, POS));
 92                  return _mm_cvtsi128_si64(_mm_xor_si128(_mm_xor_si128(product, red1), red2)) & MASK;
 93              }
 94          } else {
 95              const __m128i MOD128 = _mm_cvtsi64_si128(1 + (((uint64_t)1) << POS));
 96              __m128i red1 = _mm_clmulepi64_si128(high1, MOD128, 0x00);
 97              __m128i high2 = _mm_or_si128(_mm_srli_epi64(red1, BITS), _mm_srli_si128(_mm_slli_epi64(red1, 64 - BITS), 8));
 98              __m128i red2 = _mm_xor_si128(high2, _mm_slli_epi64(high2, POS));
 99              return _mm_cvtsi128_si64(_mm_xor_si128(_mm_xor_si128(product, red1), red2)) & MASK;
100          }
101      }
102  }
103  
104  /** Implementation of fields that use the SSE clmul intrinsic for multiplication. */
105  template<typename I, int B, I MOD, I (*MUL)(I, I), typename F, const F* SQR, const F* SQR2, const F* SQR4, const F* SQR8, const F* SQR16, const F* QRT, typename T, const T* LOAD, const T* SAVE> struct GenField
106  {
107      typedef BitsInt<I, B> O;
108      typedef LFSR<O, MOD> L;
109  
110      static inline constexpr I Sqr1(I a) { return SQR->template Map<O>(a); }
111      static inline constexpr I Sqr2(I a) { return SQR2->template Map<O>(a); }
112      static inline constexpr I Sqr4(I a) { return SQR4->template Map<O>(a); }
113      static inline constexpr I Sqr8(I a) { return SQR8->template Map<O>(a); }
114      static inline constexpr I Sqr16(I a) { return SQR16->template Map<O>(a); }
115  
116  public:
117      typedef I Elem;
118  
119      inline constexpr int Bits() const { return B; }
120  
121      inline constexpr Elem Mul2(Elem val) const { return L::Call(val); }
122  
123      inline Elem Mul(Elem a, Elem b) const { return MUL(a, b); }
124  
125      class Multiplier
126      {
127          Elem m_val;
128      public:
129          inline constexpr explicit Multiplier(const GenField&, Elem a) : m_val(a) {}
130          constexpr Elem operator()(Elem a) const { return MUL(m_val, a); }
131      };
132  
133      /** Compute the square of a. */
134      inline constexpr Elem Sqr(Elem val) const { return SQR->template Map<O>(val); }
135  
136      /** Compute x such that x^2 + x = a (undefined result if no solution exists). */
137      inline constexpr Elem Qrt(Elem val) const { return QRT->template Map<O>(val); }
138  
139      /** Compute the inverse of x1. */
140      inline Elem Inv(Elem val) const { return InvLadder<I, O, B, MUL, Sqr1, Sqr2, Sqr4, Sqr8, Sqr16>(val); }
141  
142      /** Generate a random field element. */
143      Elem FromSeed(uint64_t seed) const {
144          uint64_t k0 = 0x434c4d554c466c64ull; // "CLMULFld"
145          uint64_t k1 = seed;
146          uint64_t count = ((uint64_t)B) << 32;
147          I ret;
148          do {
149              ret = O::Mask(I(SipHash(k0, k1, count++)));
150          } while(ret == 0);
151          return LOAD->template Map<O>(ret);
152      }
153  
154      Elem Deserialize(BitReader& in) const { return LOAD->template Map<O>(in.Read<B, I>()); }
155  
156      void Serialize(BitWriter& out, Elem val) const { out.Write<B, I>(SAVE->template Map<O>(val)); }
157  
158      constexpr Elem FromUint64(uint64_t x) const { return LOAD->template Map<O>(O::Mask(I(x))); }
159      constexpr uint64_t ToUint64(Elem val) const { return uint64_t(SAVE->template Map<O>(val)); }
160  };
161  
162  template<typename I, int B, I MOD, typename F, const F* SQR, const F* SQR2, const F* SQR4, const F* SQR8, const F* SQR16, const F* QRT, typename T, const T* LOAD, const T* SAVE>
163  using Field = GenField<I, B, MOD, MulWithClMulReduce<I, B, MOD>, F, SQR, SQR2, SQR4, SQR8, SQR16, QRT, T, LOAD, SAVE>;
164  
165  template<typename I, int B, int POS, typename F, const F* SQR, const F* SQR2, const F* SQR4, const F* SQR8, const F* SQR16, const F* QRT, typename T, const T* LOAD, const T* SAVE>
166  using FieldTri = GenField<I, B, I(1) + (I(1) << POS), MulTrinomial<I, B, POS>, F, SQR, SQR2, SQR4, SQR8, SQR16, QRT, T, LOAD, SAVE>;
167  
168  }
169  
170  #endif