/ src / crypto / muhash.cpp
muhash.cpp
  1  // Copyright (c) 2017-present The Bitcoin Core developers
  2  // Distributed under the MIT software license, see the accompanying
  3  // file COPYING or http://www.opensource.org/licenses/mit-license.php.
  4  
  5  #include <crypto/muhash.h>
  6  
  7  #include <crypto/chacha20.h>
  8  #include <crypto/common.h>
  9  #include <hash.h>
 10  #include <span.h>
 11  #include <uint256.h>
 12  #include <util/check.h>
 13  
 14  #include <bit>
 15  #include <cstring>
 16  #include <limits>
 17  
 18  namespace {
 19  
 20  using limb_t = Num3072::limb_t;
 21  using signed_limb_t = Num3072::signed_limb_t;
 22  using double_limb_t = Num3072::double_limb_t;
 23  using signed_double_limb_t = Num3072::signed_double_limb_t;
 24  constexpr int LIMB_SIZE = Num3072::LIMB_SIZE;
 25  constexpr int SIGNED_LIMB_SIZE = Num3072::SIGNED_LIMB_SIZE;
 26  constexpr int LIMBS = Num3072::LIMBS;
 27  constexpr int SIGNED_LIMBS = Num3072::SIGNED_LIMBS;
 28  constexpr int FINAL_LIMB_POSITION = 3072 / SIGNED_LIMB_SIZE;
 29  constexpr int FINAL_LIMB_MODULUS_BITS = 3072 % SIGNED_LIMB_SIZE;
 30  constexpr limb_t MAX_LIMB = (limb_t)(-1);
 31  constexpr limb_t MAX_SIGNED_LIMB = (((limb_t)1) << SIGNED_LIMB_SIZE) - 1;
 32  /** 2^3072 - 1103717, the largest 3072-bit safe prime number, is used as the modulus. */
 33  constexpr limb_t MAX_PRIME_DIFF = 1103717;
 34  /** The modular inverse of (2**3072 - MAX_PRIME_DIFF) mod (MAX_SIGNED_LIMB + 1). */
 35  constexpr limb_t MODULUS_INVERSE = limb_t(0x70a1421da087d93);
 36  
 37  
 38  /** Extract the lowest limb of [c0,c1,c2] into n, and left shift the number by 1 limb. */
 39  inline void extract3(limb_t& c0, limb_t& c1, limb_t& c2, limb_t& n)
 40  {
 41      n = c0;
 42      c0 = c1;
 43      c1 = c2;
 44      c2 = 0;
 45  }
 46  
 47  /** [c0,c1] = a * b */
 48  inline void mul(limb_t& c0, limb_t& c1, const limb_t& a, const limb_t& b)
 49  {
 50      double_limb_t t = (double_limb_t)a * b;
 51      c1 = t >> LIMB_SIZE;
 52      c0 = t;
 53  }
 54  
 55  /* [c0,c1,c2] += n * [d0,d1,d2]. c2 is 0 initially */
 56  inline void mulnadd3(limb_t& c0, limb_t& c1, limb_t& c2, limb_t& d0, limb_t& d1, limb_t& d2, const limb_t& n)
 57  {
 58      double_limb_t t = (double_limb_t)d0 * n + c0;
 59      c0 = t;
 60      t >>= LIMB_SIZE;
 61      t += (double_limb_t)d1 * n + c1;
 62      c1 = t;
 63      t >>= LIMB_SIZE;
 64      c2 = t + d2 * n;
 65  }
 66  
 67  /* [c0,c1] *= n */
 68  inline void muln2(limb_t& c0, limb_t& c1, const limb_t& n)
 69  {
 70      double_limb_t t = (double_limb_t)c0 * n;
 71      c0 = t;
 72      t >>= LIMB_SIZE;
 73      t += (double_limb_t)c1 * n;
 74      c1 = t;
 75  }
 76  
 77  /** [c0,c1,c2] += a * b */
 78  inline void muladd3(limb_t& c0, limb_t& c1, limb_t& c2, const limb_t& a, const limb_t& b)
 79  {
 80      double_limb_t t = (double_limb_t)a * b;
 81      limb_t th = t >> LIMB_SIZE;
 82      limb_t tl = t;
 83  
 84      c0 += tl;
 85      th += (c0 < tl) ? 1 : 0;
 86      c1 += th;
 87      c2 += (c1 < th) ? 1 : 0;
 88  }
 89  
 90  /**
 91   * Add limb a to [c0,c1]: [c0,c1] += a. Then extract the lowest
 92   * limb of [c0,c1] into n, and left shift the number by 1 limb.
 93   * */
 94  inline void addnextract2(limb_t& c0, limb_t& c1, const limb_t& a, limb_t& n)
 95  {
 96      limb_t c2 = 0;
 97  
 98      // add
 99      c0 += a;
100      if (c0 < a) {
101          c1 += 1;
102  
103          // Handle case when c1 has overflown
104          if (c1 == 0) c2 = 1;
105      }
106  
107      // extract
108      n = c0;
109      c0 = c1;
110      c1 = c2;
111  }
112  
113  } // namespace
114  
115  /** Indicates whether d is larger than the modulus. */
116  bool Num3072::IsOverflow() const
117  {
118      if (this->limbs[0] <= std::numeric_limits<limb_t>::max() - MAX_PRIME_DIFF) return false;
119      for (int i = 1; i < LIMBS; ++i) {
120          if (this->limbs[i] != std::numeric_limits<limb_t>::max()) return false;
121      }
122      return true;
123  }
124  
125  void Num3072::FullReduce()
126  {
127      limb_t c0 = MAX_PRIME_DIFF;
128      limb_t c1 = 0;
129      for (int i = 0; i < LIMBS; ++i) {
130          addnextract2(c0, c1, this->limbs[i], this->limbs[i]);
131      }
132  }
133  
134  namespace {
135  /** A type representing a number in signed limb representation. */
136  struct Num3072Signed
137  {
138      /** The represented value is sum(limbs[i]*2^(SIGNED_LIMB_SIZE*i), i=0..SIGNED_LIMBS-1).
139       *  Note that limbs may be negative, or exceed 2^SIGNED_LIMB_SIZE-1. */
140      signed_limb_t limbs[SIGNED_LIMBS];
141  
142      /** Construct a Num3072Signed with value 0. */
143      Num3072Signed()
144      {
145          memset(limbs, 0, sizeof(limbs));
146      }
147  
148      /** Convert a Num3072 to a Num3072Signed. Output will be normalized and in
149       *  range 0..2^3072-1. */
150      void FromNum3072(const Num3072& in)
151      {
152          double_limb_t c = 0;
153          int b = 0, outpos = 0;
154          for (int i = 0; i < LIMBS; ++i) {
155              c += double_limb_t{in.limbs[i]} << b;
156              b += LIMB_SIZE;
157              while (b >= SIGNED_LIMB_SIZE) {
158                  limbs[outpos++] = limb_t(c) & MAX_SIGNED_LIMB;
159                  c >>= SIGNED_LIMB_SIZE;
160                  b -= SIGNED_LIMB_SIZE;
161              }
162          }
163          Assume(outpos == SIGNED_LIMBS - 1);
164          limbs[SIGNED_LIMBS - 1] = c;
165          c >>= SIGNED_LIMB_SIZE;
166          Assume(c == 0);
167      }
168  
169      /** Convert a Num3072Signed to a Num3072. Input must be in range 0..modulus-1. */
170      void ToNum3072(Num3072& out) const
171      {
172          double_limb_t c = 0;
173          int b = 0, outpos = 0;
174          for (int i = 0; i < SIGNED_LIMBS; ++i) {
175              c += double_limb_t(limbs[i]) << b;
176              b += SIGNED_LIMB_SIZE;
177              if (b >= LIMB_SIZE) {
178                  out.limbs[outpos++] = c;
179                  c >>= LIMB_SIZE;
180                  b -= LIMB_SIZE;
181              }
182          }
183          Assume(outpos == LIMBS);
184          Assume(c == 0);
185      }
186  
187      /** Take a Num3072Signed in range 1-2*2^3072..2^3072-1, and:
188       *  - optionally negate it (if negate is true)
189       *  - reduce it modulo the modulus (2^3072 - MAX_PRIME_DIFF)
190       *  - produce output with all limbs in range 0..2^SIGNED_LIMB_SIZE-1
191       */
192      void Normalize(bool negate)
193      {
194          // Add modulus if this was negative. This brings the range of *this to 1-2^3072..2^3072-1.
195          signed_limb_t cond_add = limbs[SIGNED_LIMBS-1] >> (LIMB_SIZE-1); // -1 if this is negative; 0 otherwise
196          limbs[0] += signed_limb_t(-MAX_PRIME_DIFF) & cond_add;
197          limbs[FINAL_LIMB_POSITION] += (signed_limb_t(1) << FINAL_LIMB_MODULUS_BITS) & cond_add;
198          // Next negate all limbs if negate was set. This does not change the range of *this.
199          signed_limb_t cond_negate = -signed_limb_t(negate); // -1 if this negate is true; 0 otherwise
200          for (int i = 0; i < SIGNED_LIMBS; ++i) {
201              limbs[i] = (limbs[i] ^ cond_negate) - cond_negate;
202          }
203          // Perform carry (make all limbs except the top one be in range 0..2^SIGNED_LIMB_SIZE-1).
204          for (int i = 0; i < SIGNED_LIMBS - 1; ++i) {
205              limbs[i + 1] += limbs[i] >> SIGNED_LIMB_SIZE;
206              limbs[i] &= MAX_SIGNED_LIMB;
207          }
208          // Again add modulus if *this was negative. This brings the range of *this to 0..2^3072-1.
209          cond_add = limbs[SIGNED_LIMBS-1] >> (LIMB_SIZE-1); // -1 if this is negative; 0 otherwise
210          limbs[0] += signed_limb_t(-MAX_PRIME_DIFF) & cond_add;
211          limbs[FINAL_LIMB_POSITION] += (signed_limb_t(1) << FINAL_LIMB_MODULUS_BITS) & cond_add;
212          // Perform another carry. Now all limbs are in range 0..2^SIGNED_LIMB_SIZE-1.
213          for (int i = 0; i < SIGNED_LIMBS - 1; ++i) {
214              limbs[i + 1] += limbs[i] >> SIGNED_LIMB_SIZE;
215              limbs[i] &= MAX_SIGNED_LIMB;
216          }
217      }
218  };
219  
220  /** 2x2 transformation matrix with signed_limb_t elements. */
221  struct SignedMatrix
222  {
223      signed_limb_t u, v, q, r;
224  };
225  
226  /** Compute the transformation matrix for SIGNED_LIMB_SIZE divsteps.
227   *
228   * eta: initial eta value
229   * f:   bottom SIGNED_LIMB_SIZE bits of initial f value
230   * g:   bottom SIGNED_LIMB_SIZE bits of initial g value
231   * out: resulting transformation matrix, scaled by 2^SIGNED_LIMB_SIZE
232   * return: eta value after SIGNED_LIMB_SIZE divsteps
233   */
234  inline limb_t ComputeDivstepMatrix(signed_limb_t eta, limb_t f, limb_t g, SignedMatrix& out)
235  {
236      /** inv256[i] = -1/(2*i+1) (mod 256) */
237      static const uint8_t NEGINV256[128] = {
238          0xFF, 0x55, 0x33, 0x49, 0xC7, 0x5D, 0x3B, 0x11, 0x0F, 0xE5, 0xC3, 0x59,
239          0xD7, 0xED, 0xCB, 0x21, 0x1F, 0x75, 0x53, 0x69, 0xE7, 0x7D, 0x5B, 0x31,
240          0x2F, 0x05, 0xE3, 0x79, 0xF7, 0x0D, 0xEB, 0x41, 0x3F, 0x95, 0x73, 0x89,
241          0x07, 0x9D, 0x7B, 0x51, 0x4F, 0x25, 0x03, 0x99, 0x17, 0x2D, 0x0B, 0x61,
242          0x5F, 0xB5, 0x93, 0xA9, 0x27, 0xBD, 0x9B, 0x71, 0x6F, 0x45, 0x23, 0xB9,
243          0x37, 0x4D, 0x2B, 0x81, 0x7F, 0xD5, 0xB3, 0xC9, 0x47, 0xDD, 0xBB, 0x91,
244          0x8F, 0x65, 0x43, 0xD9, 0x57, 0x6D, 0x4B, 0xA1, 0x9F, 0xF5, 0xD3, 0xE9,
245          0x67, 0xFD, 0xDB, 0xB1, 0xAF, 0x85, 0x63, 0xF9, 0x77, 0x8D, 0x6B, 0xC1,
246          0xBF, 0x15, 0xF3, 0x09, 0x87, 0x1D, 0xFB, 0xD1, 0xCF, 0xA5, 0x83, 0x19,
247          0x97, 0xAD, 0x8B, 0xE1, 0xDF, 0x35, 0x13, 0x29, 0xA7, 0x3D, 0x1B, 0xF1,
248          0xEF, 0xC5, 0xA3, 0x39, 0xB7, 0xCD, 0xAB, 0x01
249      };
250      // Coefficients of returned SignedMatrix; starts off as identity matrix. */
251      limb_t u = 1, v = 0, q = 0, r = 1;
252      // The number of divsteps still left.
253      int i = SIGNED_LIMB_SIZE;
254      while (true) {
255          /* Use a sentinel bit to count zeros only up to i. */
256          int zeros = std::countr_zero(g | (MAX_LIMB << i));
257          /* Perform zeros divsteps at once; they all just divide g by two. */
258          g >>= zeros;
259          u <<= zeros;
260          v <<= zeros;
261          eta -= zeros;
262          i -= zeros;
263           /* We're done once we've performed SIGNED_LIMB_SIZE divsteps. */
264          if (i == 0) break;
265          /* If eta is negative, negate it and replace f,g with g,-f. */
266          if (eta < 0) {
267              limb_t tmp;
268              eta = -eta;
269              tmp = f; f = g; g = -tmp;
270              tmp = u; u = q; q = -tmp;
271              tmp = v; v = r; r = -tmp;
272          }
273          /* eta is now >= 0. In what follows we're going to cancel out the bottom bits of g. No more
274           * than i can be cancelled out (as we'd be done before that point), and no more than eta+1
275           * can be done as its sign will flip once that happens. */
276          int limit = ((int)eta + 1) > i ? i : ((int)eta + 1);
277          /* m is a mask for the bottom min(limit, 8) bits (our table only supports 8 bits). */
278          limb_t m = (MAX_LIMB >> (LIMB_SIZE - limit)) & 255U;
279          /* Find what multiple of f must be added to g to cancel its bottom min(limit, 8) bits. */
280          limb_t w = (g * NEGINV256[(f >> 1) & 127]) & m;
281          /* Do so. */
282          g += f * w;
283          q += u * w;
284          r += v * w;
285      }
286      out.u = (signed_limb_t)u;
287      out.v = (signed_limb_t)v;
288      out.q = (signed_limb_t)q;
289      out.r = (signed_limb_t)r;
290      return eta;
291  }
292  
293  /** Apply matrix t/2^SIGNED_LIMB_SIZE to vector [d,e], modulo modulus.
294   *
295   * On input and output, d and e are in range 1-2*modulus..modulus-1.
296   */
297  inline void UpdateDE(Num3072Signed& d, Num3072Signed& e, const SignedMatrix& t)
298  {
299      const signed_limb_t u = t.u, v=t.v, q=t.q, r=t.r;
300  
301      /* [md,me] start as zero; plus [u,q] if d is negative; plus [v,r] if e is negative. */
302      signed_limb_t sd = d.limbs[SIGNED_LIMBS - 1] >> (LIMB_SIZE - 1);
303      signed_limb_t se = e.limbs[SIGNED_LIMBS - 1] >> (LIMB_SIZE - 1);
304      signed_limb_t md = (u & sd) + (v & se);
305      signed_limb_t me = (q & sd) + (r & se);
306      /* Begin computing t*[d,e]. */
307      signed_limb_t di = d.limbs[0], ei = e.limbs[0];
308      signed_double_limb_t cd = (signed_double_limb_t)u * di + (signed_double_limb_t)v * ei;
309      signed_double_limb_t ce = (signed_double_limb_t)q * di + (signed_double_limb_t)r * ei;
310      /* Correct md,me so that t*[d,e]+modulus*[md,me] has SIGNED_LIMB_SIZE zero bottom bits. */
311      md -= (MODULUS_INVERSE * limb_t(cd) + md) & MAX_SIGNED_LIMB;
312      me -= (MODULUS_INVERSE * limb_t(ce) + me) & MAX_SIGNED_LIMB;
313      /* Update the beginning of computation for t*[d,e]+modulus*[md,me] now md,me are known. */
314      cd -= (signed_double_limb_t)1103717 * md;
315      ce -= (signed_double_limb_t)1103717 * me;
316      /* Verify that the low SIGNED_LIMB_SIZE bits of the computation are indeed zero, and then throw them away. */
317      Assume((cd & MAX_SIGNED_LIMB) == 0);
318      Assume((ce & MAX_SIGNED_LIMB) == 0);
319      cd >>= SIGNED_LIMB_SIZE;
320      ce >>= SIGNED_LIMB_SIZE;
321      /* Now iteratively compute limb i=1..SIGNED_LIMBS-2 of t*[d,e]+modulus*[md,me], and store them in output
322       * limb i-1 (shifting down by SIGNED_LIMB_SIZE bits). The corresponding limbs in modulus are all zero,
323       * so modulus/md/me are not actually involved here. */
324      for (int i = 1; i < SIGNED_LIMBS - 1; ++i) {
325          di = d.limbs[i];
326          ei = e.limbs[i];
327          cd += (signed_double_limb_t)u * di + (signed_double_limb_t)v * ei;
328          ce += (signed_double_limb_t)q * di + (signed_double_limb_t)r * ei;
329          d.limbs[i - 1] = (signed_limb_t)cd & MAX_SIGNED_LIMB; cd >>= SIGNED_LIMB_SIZE;
330          e.limbs[i - 1] = (signed_limb_t)ce & MAX_SIGNED_LIMB; ce >>= SIGNED_LIMB_SIZE;
331      }
332      /* Compute limb SIGNED_LIMBS-1 of t*[d,e]+modulus*[md,me], and store it in output limb SIGNED_LIMBS-2. */
333      di = d.limbs[SIGNED_LIMBS - 1];
334      ei = e.limbs[SIGNED_LIMBS - 1];
335      cd += (signed_double_limb_t)u * di + (signed_double_limb_t)v * ei;
336      ce += (signed_double_limb_t)q * di + (signed_double_limb_t)r * ei;
337      cd += (signed_double_limb_t)md << FINAL_LIMB_MODULUS_BITS;
338      ce += (signed_double_limb_t)me << FINAL_LIMB_MODULUS_BITS;
339      d.limbs[SIGNED_LIMBS - 2] = (signed_limb_t)cd & MAX_SIGNED_LIMB; cd >>= SIGNED_LIMB_SIZE;
340      e.limbs[SIGNED_LIMBS - 2] = (signed_limb_t)ce & MAX_SIGNED_LIMB; ce >>= SIGNED_LIMB_SIZE;
341      /* What remains goes into output limb SINGED_LIMBS-1 */
342      d.limbs[SIGNED_LIMBS - 1] = (signed_limb_t)cd;
343      e.limbs[SIGNED_LIMBS - 1] = (signed_limb_t)ce;
344  }
345  
346  /** Apply matrix t/2^SIGNED_LIMB_SIZE to vector (f,g).
347   *
348   * The matrix t must be chosen such that t*(f,g) results in multiples of 2^SIGNED_LIMB_SIZE.
349   * This is the case for matrices computed by ComputeDivstepMatrix().
350   */
351  inline void UpdateFG(Num3072Signed& f, Num3072Signed& g, const SignedMatrix& t, int len)
352  {
353      const signed_limb_t u = t.u, v=t.v, q=t.q, r=t.r;
354  
355      signed_limb_t fi, gi;
356      signed_double_limb_t cf, cg;
357      /* Start computing t*[f,g]. */
358      fi = f.limbs[0];
359      gi = g.limbs[0];
360      cf = (signed_double_limb_t)u * fi + (signed_double_limb_t)v * gi;
361      cg = (signed_double_limb_t)q * fi + (signed_double_limb_t)r * gi;
362      /* Verify that the bottom SIGNED_LIMB_BITS bits of the result are zero, and then throw them away. */
363      Assume((cf & MAX_SIGNED_LIMB) == 0);
364      Assume((cg & MAX_SIGNED_LIMB) == 0);
365      cf >>= SIGNED_LIMB_SIZE;
366      cg >>= SIGNED_LIMB_SIZE;
367      /* Now iteratively compute limb i=1..SIGNED_LIMBS-1 of t*[f,g], and store them in output limb i-1 (shifting
368       * down by SIGNED_LIMB_BITS bits). */
369      for (int i = 1; i < len; ++i) {
370          fi = f.limbs[i];
371          gi = g.limbs[i];
372          cf += (signed_double_limb_t)u * fi + (signed_double_limb_t)v * gi;
373          cg += (signed_double_limb_t)q * fi + (signed_double_limb_t)r * gi;
374          f.limbs[i - 1] = (signed_limb_t)cf & MAX_SIGNED_LIMB; cf >>= SIGNED_LIMB_SIZE;
375          g.limbs[i - 1] = (signed_limb_t)cg & MAX_SIGNED_LIMB; cg >>= SIGNED_LIMB_SIZE;
376      }
377      /* What remains is limb SIGNED_LIMBS of t*[f,g]; store it as output limb SIGNED_LIMBS-1. */
378      f.limbs[len - 1] = (signed_limb_t)cf;
379      g.limbs[len - 1] = (signed_limb_t)cg;
380  
381  }
382  } // namespace
383  
384  Num3072 Num3072::GetInverse() const
385  {
386      // Compute a modular inverse based on a variant of the safegcd algorithm:
387      // - Paper: https://gcd.cr.yp.to/papers.html
388      // - Inspired by this code in libsecp256k1:
389      //   https://github.com/bitcoin-core/secp256k1/blob/master/src/modinv32_impl.h
390      // - Explanation of the algorithm:
391      //   https://github.com/bitcoin-core/secp256k1/blob/master/doc/safegcd_implementation.md
392  
393      // Local variables d, e, f, g:
394      // - f and g are the variables whose gcd we compute (despite knowing the answer is 1):
395      //   - f is always odd, and initialized as modulus
396      //   - g is initialized as *this (called x in what follows)
397      // - d and e are the numbers for which at every step it is the case that:
398      //   - f = d * x mod modulus; d is initialized as 0
399      //   - g = e * x mod modulus; e is initialized as 1
400      Num3072Signed d, e, f, g;
401      e.limbs[0] = 1;
402      // F is initialized as modulus, which in signed limb representation can be expressed
403      // simply as 2^3072 + -MAX_PRIME_DIFF.
404      f.limbs[0] = -MAX_PRIME_DIFF;
405      f.limbs[FINAL_LIMB_POSITION] = ((limb_t)1) << FINAL_LIMB_MODULUS_BITS;
406      g.FromNum3072(*this);
407      int len = SIGNED_LIMBS; //!< The number of significant limbs in f and g
408      signed_limb_t eta = -1; //!< State to track knowledge about ratio of f and g
409      // Perform divsteps on [f,g] until g=0 is reached, keeping (d,e) synchronized with them.
410      while (true) {
411          // Compute transformation matrix t that represents the next SIGNED_LIMB_SIZE divsteps
412          // to apply. This can be computed from just the bottom limb of f and g, and eta.
413          SignedMatrix t;
414          eta = ComputeDivstepMatrix(eta, f.limbs[0], g.limbs[0], t);
415          // Apply that transformation matrix to the full [f,g] vector.
416          UpdateFG(f, g, t, len);
417          // Apply that transformation matrix to the full [d,e] vector (mod modulus).
418          UpdateDE(d, e, t);
419  
420          // Check if g is zero.
421          if (g.limbs[0] == 0) {
422              signed_limb_t cond = 0;
423              for (int j = 1; j < len; ++j) {
424                  cond |= g.limbs[j];
425              }
426              // If so, we're done.
427              if (cond == 0) break;
428          }
429  
430          // Check if the top limbs of both f and g are both 0 or -1.
431          signed_limb_t fn = f.limbs[len - 1], gn = g.limbs[len - 1];
432          signed_limb_t cond = ((signed_limb_t)len - 2) >> (LIMB_SIZE - 1);
433          cond |= fn ^ (fn >> (LIMB_SIZE - 1));
434          cond |= gn ^ (gn >> (LIMB_SIZE - 1));
435          if (cond == 0) {
436              // If so, drop the top limb, shrinking the size of f and g, by
437              // propagating the sign to the previous limb.
438              f.limbs[len - 2] |= (limb_t)f.limbs[len - 1] << SIGNED_LIMB_SIZE;
439              g.limbs[len - 2] |= (limb_t)g.limbs[len - 1] << SIGNED_LIMB_SIZE;
440              --len;
441          }
442      }
443      // At some point, [f,g] will have been rewritten into [f',0], such that gcd(f,g) = gcd(f',0).
444      // This is proven in the paper. As f started out being modulus, a prime number, we know that
445      // gcd is 1, and thus f' is 1 or -1.
446      Assume((f.limbs[0] & MAX_SIGNED_LIMB) == 1 || (f.limbs[0] & MAX_SIGNED_LIMB) == MAX_SIGNED_LIMB);
447      // As we've maintained the invariant that f = d * x mod modulus, we get d/f mod modulus is the
448      // modular inverse of x we're looking for. As f is 1 or -1, it is also true that d/f = d*f.
449      // Normalize d to prepare it for output, while negating it if f is negative.
450      d.Normalize(f.limbs[len - 1] >> (LIMB_SIZE  - 1));
451      Num3072 ret;
452      d.ToNum3072(ret);
453      return ret;
454  }
455  
456  void Num3072::Multiply(const Num3072& a)
457  {
458      limb_t c0 = 0, c1 = 0, c2 = 0;
459      Num3072 tmp;
460  
461      /* Compute limbs 0..N-2 of this*a into tmp, including one reduction. */
462      for (int j = 0; j < LIMBS - 1; ++j) {
463          limb_t d0 = 0, d1 = 0, d2 = 0;
464          mul(d0, d1, this->limbs[1 + j], a.limbs[LIMBS + j - (1 + j)]);
465          for (int i = 2 + j; i < LIMBS; ++i) muladd3(d0, d1, d2, this->limbs[i], a.limbs[LIMBS + j - i]);
466          mulnadd3(c0, c1, c2, d0, d1, d2, MAX_PRIME_DIFF);
467          for (int i = 0; i < j + 1; ++i) muladd3(c0, c1, c2, this->limbs[i], a.limbs[j - i]);
468          extract3(c0, c1, c2, tmp.limbs[j]);
469      }
470  
471      /* Compute limb N-1 of a*b into tmp. */
472      assert(c2 == 0);
473      for (int i = 0; i < LIMBS; ++i) muladd3(c0, c1, c2, this->limbs[i], a.limbs[LIMBS - 1 - i]);
474      extract3(c0, c1, c2, tmp.limbs[LIMBS - 1]);
475  
476      /* Perform a second reduction. */
477      muln2(c0, c1, MAX_PRIME_DIFF);
478      for (int j = 0; j < LIMBS; ++j) {
479          addnextract2(c0, c1, tmp.limbs[j], this->limbs[j]);
480      }
481  
482      assert(c1 == 0);
483      assert(c0 == 0 || c0 == 1);
484  
485      /* Perform up to two more reductions if the internal state has already
486       * overflown the MAX of Num3072 or if it is larger than the modulus or
487       * if both are the case.
488       * */
489      if (this->IsOverflow()) this->FullReduce();
490      if (c0) this->FullReduce();
491  }
492  
493  void Num3072::SetToOne()
494  {
495      this->limbs[0] = 1;
496      for (int i = 1; i < LIMBS; ++i) this->limbs[i] = 0;
497  }
498  
499  void Num3072::Divide(const Num3072& a)
500  {
501      if (this->IsOverflow()) this->FullReduce();
502  
503      Num3072 inv{};
504      if (a.IsOverflow()) {
505          Num3072 b = a;
506          b.FullReduce();
507          inv = b.GetInverse();
508      } else {
509          inv = a.GetInverse();
510      }
511  
512      this->Multiply(inv);
513      if (this->IsOverflow()) this->FullReduce();
514  }
515  
516  Num3072::Num3072(const unsigned char (&data)[BYTE_SIZE]) {
517      for (int i = 0; i < LIMBS; ++i) {
518          if (sizeof(limb_t) == 4) {
519              this->limbs[i] = ReadLE32(data + 4 * i);
520          } else if (sizeof(limb_t) == 8) {
521              this->limbs[i] = ReadLE64(data + 8 * i);
522          }
523      }
524  }
525  
526  void Num3072::ToBytes(unsigned char (&out)[BYTE_SIZE]) {
527      for (int i = 0; i < LIMBS; ++i) {
528          if (sizeof(limb_t) == 4) {
529              WriteLE32(out + i * 4, this->limbs[i]);
530          } else if (sizeof(limb_t) == 8) {
531              WriteLE64(out + i * 8, this->limbs[i]);
532          }
533      }
534  }
535  
536  Num3072 MuHash3072::ToNum3072(std::span<const unsigned char> in) {
537      unsigned char tmp[Num3072::BYTE_SIZE];
538  
539      uint256 hashed_in{(HashWriter{} << in).GetSHA256()};
540      static_assert(sizeof(tmp) % ChaCha20Aligned::BLOCKLEN == 0);
541      ChaCha20Aligned{MakeByteSpan(hashed_in)}.Keystream(MakeWritableByteSpan(tmp));
542      Num3072 out{tmp};
543  
544      return out;
545  }
546  
547  MuHash3072::MuHash3072(std::span<const unsigned char> in) noexcept
548  {
549      m_numerator = ToNum3072(in);
550  }
551  
552  void MuHash3072::Finalize(uint256& out) noexcept
553  {
554      m_numerator.Divide(m_denominator);
555      m_denominator.SetToOne();  // Needed to keep the MuHash object valid
556  
557      unsigned char data[Num3072::BYTE_SIZE];
558      m_numerator.ToBytes(data);
559  
560      out = (HashWriter{} << data).GetSHA256();
561  }
562  
563  MuHash3072& MuHash3072::operator*=(const MuHash3072& mul) noexcept
564  {
565      m_numerator.Multiply(mul.m_numerator);
566      m_denominator.Multiply(mul.m_denominator);
567      return *this;
568  }
569  
570  MuHash3072& MuHash3072::operator/=(const MuHash3072& div) noexcept
571  {
572      m_numerator.Multiply(div.m_denominator);
573      m_denominator.Multiply(div.m_numerator);
574      return *this;
575  }
576  
577  MuHash3072& MuHash3072::Insert(std::span<const unsigned char> in) noexcept {
578      m_numerator.Multiply(ToNum3072(in));
579      return *this;
580  }
581  
582  MuHash3072& MuHash3072::Remove(std::span<const unsigned char> in) noexcept {
583      m_denominator.Multiply(ToNum3072(in));
584      return *this;
585  }