/ src / crypto / muhash.cpp
muhash.cpp
  1  // Copyright (c) 2017-present The Bitcoin Core developers
  2  // Distributed under the MIT software license, see the accompanying
  3  // file COPYING or http://www.opensource.org/licenses/mit-license.php.
  4  
  5  #include <crypto/muhash.h>
  6  
  7  #include <crypto/chacha20.h>
  8  #include <crypto/common.h>
  9  #include <hash.h>
 10  #include <util/check.h>
 11  
 12  #include <bit>
 13  #include <cassert>
 14  #include <cstdio>
 15  #include <limits>
 16  
 17  namespace {
 18  
 19  using limb_t = Num3072::limb_t;
 20  using signed_limb_t = Num3072::signed_limb_t;
 21  using double_limb_t = Num3072::double_limb_t;
 22  using signed_double_limb_t = Num3072::signed_double_limb_t;
 23  constexpr int LIMB_SIZE = Num3072::LIMB_SIZE;
 24  constexpr int SIGNED_LIMB_SIZE = Num3072::SIGNED_LIMB_SIZE;
 25  constexpr int LIMBS = Num3072::LIMBS;
 26  constexpr int SIGNED_LIMBS = Num3072::SIGNED_LIMBS;
 27  constexpr int FINAL_LIMB_POSITION = 3072 / SIGNED_LIMB_SIZE;
 28  constexpr int FINAL_LIMB_MODULUS_BITS = 3072 % SIGNED_LIMB_SIZE;
 29  constexpr limb_t MAX_LIMB = (limb_t)(-1);
 30  constexpr limb_t MAX_SIGNED_LIMB = (((limb_t)1) << SIGNED_LIMB_SIZE) - 1;
 31  /** 2^3072 - 1103717, the largest 3072-bit safe prime number, is used as the modulus. */
 32  constexpr limb_t MAX_PRIME_DIFF = 1103717;
 33  /** The modular inverse of (2**3072 - MAX_PRIME_DIFF) mod (MAX_SIGNED_LIMB + 1). */
 34  constexpr limb_t MODULUS_INVERSE = limb_t(0x70a1421da087d93);
 35  
 36  
 37  /** Extract the lowest limb of [c0,c1,c2] into n, and left shift the number by 1 limb. */
 38  inline void extract3(limb_t& c0, limb_t& c1, limb_t& c2, limb_t& n)
 39  {
 40      n = c0;
 41      c0 = c1;
 42      c1 = c2;
 43      c2 = 0;
 44  }
 45  
 46  /** [c0,c1] = a * b */
 47  inline void mul(limb_t& c0, limb_t& c1, const limb_t& a, const limb_t& b)
 48  {
 49      double_limb_t t = (double_limb_t)a * b;
 50      c1 = t >> LIMB_SIZE;
 51      c0 = t;
 52  }
 53  
 54  /* [c0,c1,c2] += n * [d0,d1,d2]. c2 is 0 initially */
 55  inline void mulnadd3(limb_t& c0, limb_t& c1, limb_t& c2, limb_t& d0, limb_t& d1, limb_t& d2, const limb_t& n)
 56  {
 57      double_limb_t t = (double_limb_t)d0 * n + c0;
 58      c0 = t;
 59      t >>= LIMB_SIZE;
 60      t += (double_limb_t)d1 * n + c1;
 61      c1 = t;
 62      t >>= LIMB_SIZE;
 63      c2 = t + d2 * n;
 64  }
 65  
 66  /* [c0,c1] *= n */
 67  inline void muln2(limb_t& c0, limb_t& c1, const limb_t& n)
 68  {
 69      double_limb_t t = (double_limb_t)c0 * n;
 70      c0 = t;
 71      t >>= LIMB_SIZE;
 72      t += (double_limb_t)c1 * n;
 73      c1 = t;
 74  }
 75  
 76  /** [c0,c1,c2] += a * b */
 77  inline void muladd3(limb_t& c0, limb_t& c1, limb_t& c2, const limb_t& a, const limb_t& b)
 78  {
 79      double_limb_t t = (double_limb_t)a * b;
 80      limb_t th = t >> LIMB_SIZE;
 81      limb_t tl = t;
 82  
 83      c0 += tl;
 84      th += (c0 < tl) ? 1 : 0;
 85      c1 += th;
 86      c2 += (c1 < th) ? 1 : 0;
 87  }
 88  
 89  /**
 90   * Add limb a to [c0,c1]: [c0,c1] += a. Then extract the lowest
 91   * limb of [c0,c1] into n, and left shift the number by 1 limb.
 92   * */
 93  inline void addnextract2(limb_t& c0, limb_t& c1, const limb_t& a, limb_t& n)
 94  {
 95      limb_t c2 = 0;
 96  
 97      // add
 98      c0 += a;
 99      if (c0 < a) {
100          c1 += 1;
101  
102          // Handle case when c1 has overflown
103          if (c1 == 0) c2 = 1;
104      }
105  
106      // extract
107      n = c0;
108      c0 = c1;
109      c1 = c2;
110  }
111  
112  } // namespace
113  
114  /** Indicates whether d is larger than the modulus. */
115  bool Num3072::IsOverflow() const
116  {
117      if (this->limbs[0] <= std::numeric_limits<limb_t>::max() - MAX_PRIME_DIFF) return false;
118      for (int i = 1; i < LIMBS; ++i) {
119          if (this->limbs[i] != std::numeric_limits<limb_t>::max()) return false;
120      }
121      return true;
122  }
123  
124  void Num3072::FullReduce()
125  {
126      limb_t c0 = MAX_PRIME_DIFF;
127      limb_t c1 = 0;
128      for (int i = 0; i < LIMBS; ++i) {
129          addnextract2(c0, c1, this->limbs[i], this->limbs[i]);
130      }
131  }
132  
133  namespace {
134  /** A type representing a number in signed limb representation. */
135  struct Num3072Signed
136  {
137      /** The represented value is sum(limbs[i]*2^(SIGNED_LIMB_SIZE*i), i=0..SIGNED_LIMBS-1).
138       *  Note that limbs may be negative, or exceed 2^SIGNED_LIMB_SIZE-1. */
139      signed_limb_t limbs[SIGNED_LIMBS];
140  
141      /** Construct a Num3072Signed with value 0. */
142      Num3072Signed()
143      {
144          memset(limbs, 0, sizeof(limbs));
145      }
146  
147      /** Convert a Num3072 to a Num3072Signed. Output will be normalized and in
148       *  range 0..2^3072-1. */
149      void FromNum3072(const Num3072& in)
150      {
151          double_limb_t c = 0;
152          int b = 0, outpos = 0;
153          for (int i = 0; i < LIMBS; ++i) {
154              c += double_limb_t{in.limbs[i]} << b;
155              b += LIMB_SIZE;
156              while (b >= SIGNED_LIMB_SIZE) {
157                  limbs[outpos++] = limb_t(c) & MAX_SIGNED_LIMB;
158                  c >>= SIGNED_LIMB_SIZE;
159                  b -= SIGNED_LIMB_SIZE;
160              }
161          }
162          Assume(outpos == SIGNED_LIMBS - 1);
163          limbs[SIGNED_LIMBS - 1] = c;
164          c >>= SIGNED_LIMB_SIZE;
165          Assume(c == 0);
166      }
167  
168      /** Convert a Num3072Signed to a Num3072. Input must be in range 0..modulus-1. */
169      void ToNum3072(Num3072& out) const
170      {
171          double_limb_t c = 0;
172          int b = 0, outpos = 0;
173          for (int i = 0; i < SIGNED_LIMBS; ++i) {
174              c += double_limb_t(limbs[i]) << b;
175              b += SIGNED_LIMB_SIZE;
176              if (b >= LIMB_SIZE) {
177                  out.limbs[outpos++] = c;
178                  c >>= LIMB_SIZE;
179                  b -= LIMB_SIZE;
180              }
181          }
182          Assume(outpos == LIMBS);
183          Assume(c == 0);
184      }
185  
186      /** Take a Num3072Signed in range 1-2*2^3072..2^3072-1, and:
187       *  - optionally negate it (if negate is true)
188       *  - reduce it modulo the modulus (2^3072 - MAX_PRIME_DIFF)
189       *  - produce output with all limbs in range 0..2^SIGNED_LIMB_SIZE-1
190       */
191      void Normalize(bool negate)
192      {
193          // Add modulus if this was negative. This brings the range of *this to 1-2^3072..2^3072-1.
194          signed_limb_t cond_add = limbs[SIGNED_LIMBS-1] >> (LIMB_SIZE-1); // -1 if this is negative; 0 otherwise
195          limbs[0] += signed_limb_t(-MAX_PRIME_DIFF) & cond_add;
196          limbs[FINAL_LIMB_POSITION] += (signed_limb_t(1) << FINAL_LIMB_MODULUS_BITS) & cond_add;
197          // Next negate all limbs if negate was set. This does not change the range of *this.
198          signed_limb_t cond_negate = -signed_limb_t(negate); // -1 if this negate is true; 0 otherwise
199          for (int i = 0; i < SIGNED_LIMBS; ++i) {
200              limbs[i] = (limbs[i] ^ cond_negate) - cond_negate;
201          }
202          // Perform carry (make all limbs except the top one be in range 0..2^SIGNED_LIMB_SIZE-1).
203          for (int i = 0; i < SIGNED_LIMBS - 1; ++i) {
204              limbs[i + 1] += limbs[i] >> SIGNED_LIMB_SIZE;
205              limbs[i] &= MAX_SIGNED_LIMB;
206          }
207          // Again add modulus if *this was negative. This brings the range of *this to 0..2^3072-1.
208          cond_add = limbs[SIGNED_LIMBS-1] >> (LIMB_SIZE-1); // -1 if this is negative; 0 otherwise
209          limbs[0] += signed_limb_t(-MAX_PRIME_DIFF) & cond_add;
210          limbs[FINAL_LIMB_POSITION] += (signed_limb_t(1) << FINAL_LIMB_MODULUS_BITS) & cond_add;
211          // Perform another carry. Now all limbs are in range 0..2^SIGNED_LIMB_SIZE-1.
212          for (int i = 0; i < SIGNED_LIMBS - 1; ++i) {
213              limbs[i + 1] += limbs[i] >> SIGNED_LIMB_SIZE;
214              limbs[i] &= MAX_SIGNED_LIMB;
215          }
216      }
217  };
218  
219  /** 2x2 transformation matrix with signed_limb_t elements. */
220  struct SignedMatrix
221  {
222      signed_limb_t u, v, q, r;
223  };
224  
225  /** Compute the transformation matrix for SIGNED_LIMB_SIZE divsteps.
226   *
227   * eta: initial eta value
228   * f:   bottom SIGNED_LIMB_SIZE bits of initial f value
229   * g:   bottom SIGNED_LIMB_SIZE bits of initial g value
230   * out: resulting transformation matrix, scaled by 2^SIGNED_LIMB_SIZE
231   * return: eta value after SIGNED_LIMB_SIZE divsteps
232   */
233  inline limb_t ComputeDivstepMatrix(signed_limb_t eta, limb_t f, limb_t g, SignedMatrix& out)
234  {
235      /** inv256[i] = -1/(2*i+1) (mod 256) */
236      static const uint8_t NEGINV256[128] = {
237          0xFF, 0x55, 0x33, 0x49, 0xC7, 0x5D, 0x3B, 0x11, 0x0F, 0xE5, 0xC3, 0x59,
238          0xD7, 0xED, 0xCB, 0x21, 0x1F, 0x75, 0x53, 0x69, 0xE7, 0x7D, 0x5B, 0x31,
239          0x2F, 0x05, 0xE3, 0x79, 0xF7, 0x0D, 0xEB, 0x41, 0x3F, 0x95, 0x73, 0x89,
240          0x07, 0x9D, 0x7B, 0x51, 0x4F, 0x25, 0x03, 0x99, 0x17, 0x2D, 0x0B, 0x61,
241          0x5F, 0xB5, 0x93, 0xA9, 0x27, 0xBD, 0x9B, 0x71, 0x6F, 0x45, 0x23, 0xB9,
242          0x37, 0x4D, 0x2B, 0x81, 0x7F, 0xD5, 0xB3, 0xC9, 0x47, 0xDD, 0xBB, 0x91,
243          0x8F, 0x65, 0x43, 0xD9, 0x57, 0x6D, 0x4B, 0xA1, 0x9F, 0xF5, 0xD3, 0xE9,
244          0x67, 0xFD, 0xDB, 0xB1, 0xAF, 0x85, 0x63, 0xF9, 0x77, 0x8D, 0x6B, 0xC1,
245          0xBF, 0x15, 0xF3, 0x09, 0x87, 0x1D, 0xFB, 0xD1, 0xCF, 0xA5, 0x83, 0x19,
246          0x97, 0xAD, 0x8B, 0xE1, 0xDF, 0x35, 0x13, 0x29, 0xA7, 0x3D, 0x1B, 0xF1,
247          0xEF, 0xC5, 0xA3, 0x39, 0xB7, 0xCD, 0xAB, 0x01
248      };
249      // Coefficients of returned SignedMatrix; starts off as identity matrix. */
250      limb_t u = 1, v = 0, q = 0, r = 1;
251      // The number of divsteps still left.
252      int i = SIGNED_LIMB_SIZE;
253      while (true) {
254          /* Use a sentinel bit to count zeros only up to i. */
255          int zeros = std::countr_zero(g | (MAX_LIMB << i));
256          /* Perform zeros divsteps at once; they all just divide g by two. */
257          g >>= zeros;
258          u <<= zeros;
259          v <<= zeros;
260          eta -= zeros;
261          i -= zeros;
262           /* We're done once we've performed SIGNED_LIMB_SIZE divsteps. */
263          if (i == 0) break;
264          /* If eta is negative, negate it and replace f,g with g,-f. */
265          if (eta < 0) {
266              limb_t tmp;
267              eta = -eta;
268              tmp = f; f = g; g = -tmp;
269              tmp = u; u = q; q = -tmp;
270              tmp = v; v = r; r = -tmp;
271          }
272          /* eta is now >= 0. In what follows we're going to cancel out the bottom bits of g. No more
273           * than i can be cancelled out (as we'd be done before that point), and no more than eta+1
274           * can be done as its sign will flip once that happens. */
275          int limit = ((int)eta + 1) > i ? i : ((int)eta + 1);
276          /* m is a mask for the bottom min(limit, 8) bits (our table only supports 8 bits). */
277          limb_t m = (MAX_LIMB >> (LIMB_SIZE - limit)) & 255U;
278          /* Find what multiple of f must be added to g to cancel its bottom min(limit, 8) bits. */
279          limb_t w = (g * NEGINV256[(f >> 1) & 127]) & m;
280          /* Do so. */
281          g += f * w;
282          q += u * w;
283          r += v * w;
284      }
285      out.u = (signed_limb_t)u;
286      out.v = (signed_limb_t)v;
287      out.q = (signed_limb_t)q;
288      out.r = (signed_limb_t)r;
289      return eta;
290  }
291  
292  /** Apply matrix t/2^SIGNED_LIMB_SIZE to vector [d,e], modulo modulus.
293   *
294   * On input and output, d and e are in range 1-2*modulus..modulus-1.
295   */
296  inline void UpdateDE(Num3072Signed& d, Num3072Signed& e, const SignedMatrix& t)
297  {
298      const signed_limb_t u = t.u, v=t.v, q=t.q, r=t.r;
299  
300      /* [md,me] start as zero; plus [u,q] if d is negative; plus [v,r] if e is negative. */
301      signed_limb_t sd = d.limbs[SIGNED_LIMBS - 1] >> (LIMB_SIZE - 1);
302      signed_limb_t se = e.limbs[SIGNED_LIMBS - 1] >> (LIMB_SIZE - 1);
303      signed_limb_t md = (u & sd) + (v & se);
304      signed_limb_t me = (q & sd) + (r & se);
305      /* Begin computing t*[d,e]. */
306      signed_limb_t di = d.limbs[0], ei = e.limbs[0];
307      signed_double_limb_t cd = (signed_double_limb_t)u * di + (signed_double_limb_t)v * ei;
308      signed_double_limb_t ce = (signed_double_limb_t)q * di + (signed_double_limb_t)r * ei;
309      /* Correct md,me so that t*[d,e]+modulus*[md,me] has SIGNED_LIMB_SIZE zero bottom bits. */
310      md -= (MODULUS_INVERSE * limb_t(cd) + md) & MAX_SIGNED_LIMB;
311      me -= (MODULUS_INVERSE * limb_t(ce) + me) & MAX_SIGNED_LIMB;
312      /* Update the beginning of computation for t*[d,e]+modulus*[md,me] now md,me are known. */
313      cd -= (signed_double_limb_t)1103717 * md;
314      ce -= (signed_double_limb_t)1103717 * me;
315      /* Verify that the low SIGNED_LIMB_SIZE bits of the computation are indeed zero, and then throw them away. */
316      Assume((cd & MAX_SIGNED_LIMB) == 0);
317      Assume((ce & MAX_SIGNED_LIMB) == 0);
318      cd >>= SIGNED_LIMB_SIZE;
319      ce >>= SIGNED_LIMB_SIZE;
320      /* Now iteratively compute limb i=1..SIGNED_LIMBS-2 of t*[d,e]+modulus*[md,me], and store them in output
321       * limb i-1 (shifting down by SIGNED_LIMB_SIZE bits). The corresponding limbs in modulus are all zero,
322       * so modulus/md/me are not actually involved here. */
323      for (int i = 1; i < SIGNED_LIMBS - 1; ++i) {
324          di = d.limbs[i];
325          ei = e.limbs[i];
326          cd += (signed_double_limb_t)u * di + (signed_double_limb_t)v * ei;
327          ce += (signed_double_limb_t)q * di + (signed_double_limb_t)r * ei;
328          d.limbs[i - 1] = (signed_limb_t)cd & MAX_SIGNED_LIMB; cd >>= SIGNED_LIMB_SIZE;
329          e.limbs[i - 1] = (signed_limb_t)ce & MAX_SIGNED_LIMB; ce >>= SIGNED_LIMB_SIZE;
330      }
331      /* Compute limb SIGNED_LIMBS-1 of t*[d,e]+modulus*[md,me], and store it in output limb SIGNED_LIMBS-2. */
332      di = d.limbs[SIGNED_LIMBS - 1];
333      ei = e.limbs[SIGNED_LIMBS - 1];
334      cd += (signed_double_limb_t)u * di + (signed_double_limb_t)v * ei;
335      ce += (signed_double_limb_t)q * di + (signed_double_limb_t)r * ei;
336      cd += (signed_double_limb_t)md << FINAL_LIMB_MODULUS_BITS;
337      ce += (signed_double_limb_t)me << FINAL_LIMB_MODULUS_BITS;
338      d.limbs[SIGNED_LIMBS - 2] = (signed_limb_t)cd & MAX_SIGNED_LIMB; cd >>= SIGNED_LIMB_SIZE;
339      e.limbs[SIGNED_LIMBS - 2] = (signed_limb_t)ce & MAX_SIGNED_LIMB; ce >>= SIGNED_LIMB_SIZE;
340      /* What remains goes into output limb SINGED_LIMBS-1 */
341      d.limbs[SIGNED_LIMBS - 1] = (signed_limb_t)cd;
342      e.limbs[SIGNED_LIMBS - 1] = (signed_limb_t)ce;
343  }
344  
345  /** Apply matrix t/2^SIGNED_LIMB_SIZE to vector (f,g).
346   *
347   * The matrix t must be chosen such that t*(f,g) results in multiples of 2^SIGNED_LIMB_SIZE.
348   * This is the case for matrices computed by ComputeDivstepMatrix().
349   */
350  inline void UpdateFG(Num3072Signed& f, Num3072Signed& g, const SignedMatrix& t, int len)
351  {
352      const signed_limb_t u = t.u, v=t.v, q=t.q, r=t.r;
353  
354      signed_limb_t fi, gi;
355      signed_double_limb_t cf, cg;
356      /* Start computing t*[f,g]. */
357      fi = f.limbs[0];
358      gi = g.limbs[0];
359      cf = (signed_double_limb_t)u * fi + (signed_double_limb_t)v * gi;
360      cg = (signed_double_limb_t)q * fi + (signed_double_limb_t)r * gi;
361      /* Verify that the bottom SIGNED_LIMB_BITS bits of the result are zero, and then throw them away. */
362      Assume((cf & MAX_SIGNED_LIMB) == 0);
363      Assume((cg & MAX_SIGNED_LIMB) == 0);
364      cf >>= SIGNED_LIMB_SIZE;
365      cg >>= SIGNED_LIMB_SIZE;
366      /* Now iteratively compute limb i=1..SIGNED_LIMBS-1 of t*[f,g], and store them in output limb i-1 (shifting
367       * down by SIGNED_LIMB_BITS bits). */
368      for (int i = 1; i < len; ++i) {
369          fi = f.limbs[i];
370          gi = g.limbs[i];
371          cf += (signed_double_limb_t)u * fi + (signed_double_limb_t)v * gi;
372          cg += (signed_double_limb_t)q * fi + (signed_double_limb_t)r * gi;
373          f.limbs[i - 1] = (signed_limb_t)cf & MAX_SIGNED_LIMB; cf >>= SIGNED_LIMB_SIZE;
374          g.limbs[i - 1] = (signed_limb_t)cg & MAX_SIGNED_LIMB; cg >>= SIGNED_LIMB_SIZE;
375      }
376      /* What remains is limb SIGNED_LIMBS of t*[f,g]; store it as output limb SIGNED_LIMBS-1. */
377      f.limbs[len - 1] = (signed_limb_t)cf;
378      g.limbs[len - 1] = (signed_limb_t)cg;
379  
380  }
381  } // namespace
382  
383  Num3072 Num3072::GetInverse() const
384  {
385      // Compute a modular inverse based on a variant of the safegcd algorithm:
386      // - Paper: https://gcd.cr.yp.to/papers.html
387      // - Inspired by this code in libsecp256k1:
388      //   https://github.com/bitcoin-core/secp256k1/blob/master/src/modinv32_impl.h
389      // - Explanation of the algorithm:
390      //   https://github.com/bitcoin-core/secp256k1/blob/master/doc/safegcd_implementation.md
391  
392      // Local variables d, e, f, g:
393      // - f and g are the variables whose gcd we compute (despite knowing the answer is 1):
394      //   - f is always odd, and initialized as modulus
395      //   - g is initialized as *this (called x in what follows)
396      // - d and e are the numbers for which at every step it is the case that:
397      //   - f = d * x mod modulus; d is initialized as 0
398      //   - g = e * x mod modulus; e is initialized as 1
399      Num3072Signed d, e, f, g;
400      e.limbs[0] = 1;
401      // F is initialized as modulus, which in signed limb representation can be expressed
402      // simply as 2^3072 + -MAX_PRIME_DIFF.
403      f.limbs[0] = -MAX_PRIME_DIFF;
404      f.limbs[FINAL_LIMB_POSITION] = ((limb_t)1) << FINAL_LIMB_MODULUS_BITS;
405      g.FromNum3072(*this);
406      int len = SIGNED_LIMBS; //!< The number of significant limbs in f and g
407      signed_limb_t eta = -1; //!< State to track knowledge about ratio of f and g
408      // Perform divsteps on [f,g] until g=0 is reached, keeping (d,e) synchronized with them.
409      while (true) {
410          // Compute transformation matrix t that represents the next SIGNED_LIMB_SIZE divsteps
411          // to apply. This can be computed from just the bottom limb of f and g, and eta.
412          SignedMatrix t;
413          eta = ComputeDivstepMatrix(eta, f.limbs[0], g.limbs[0], t);
414          // Apply that transformation matrix to the full [f,g] vector.
415          UpdateFG(f, g, t, len);
416          // Apply that transformation matrix to the full [d,e] vector (mod modulus).
417          UpdateDE(d, e, t);
418  
419          // Check if g is zero.
420          if (g.limbs[0] == 0) {
421              signed_limb_t cond = 0;
422              for (int j = 1; j < len; ++j) {
423                  cond |= g.limbs[j];
424              }
425              // If so, we're done.
426              if (cond == 0) break;
427          }
428  
429          // Check if the top limbs of both f and g are both 0 or -1.
430          signed_limb_t fn = f.limbs[len - 1], gn = g.limbs[len - 1];
431          signed_limb_t cond = ((signed_limb_t)len - 2) >> (LIMB_SIZE - 1);
432          cond |= fn ^ (fn >> (LIMB_SIZE - 1));
433          cond |= gn ^ (gn >> (LIMB_SIZE - 1));
434          if (cond == 0) {
435              // If so, drop the top limb, shrinking the size of f and g, by
436              // propagating the sign to the previous limb.
437              f.limbs[len - 2] |= (limb_t)f.limbs[len - 1] << SIGNED_LIMB_SIZE;
438              g.limbs[len - 2] |= (limb_t)g.limbs[len - 1] << SIGNED_LIMB_SIZE;
439              --len;
440          }
441      }
442      // At some point, [f,g] will have been rewritten into [f',0], such that gcd(f,g) = gcd(f',0).
443      // This is proven in the paper. As f started out being modulus, a prime number, we know that
444      // gcd is 1, and thus f' is 1 or -1.
445      Assume((f.limbs[0] & MAX_SIGNED_LIMB) == 1 || (f.limbs[0] & MAX_SIGNED_LIMB) == MAX_SIGNED_LIMB);
446      // As we've maintained the invariant that f = d * x mod modulus, we get d/f mod modulus is the
447      // modular inverse of x we're looking for. As f is 1 or -1, it is also true that d/f = d*f.
448      // Normalize d to prepare it for output, while negating it if f is negative.
449      d.Normalize(f.limbs[len - 1] >> (LIMB_SIZE  - 1));
450      Num3072 ret;
451      d.ToNum3072(ret);
452      return ret;
453  }
454  
455  void Num3072::Multiply(const Num3072& a)
456  {
457      limb_t c0 = 0, c1 = 0, c2 = 0;
458      Num3072 tmp;
459  
460      /* Compute limbs 0..N-2 of this*a into tmp, including one reduction. */
461      for (int j = 0; j < LIMBS - 1; ++j) {
462          limb_t d0 = 0, d1 = 0, d2 = 0;
463          mul(d0, d1, this->limbs[1 + j], a.limbs[LIMBS + j - (1 + j)]);
464          for (int i = 2 + j; i < LIMBS; ++i) muladd3(d0, d1, d2, this->limbs[i], a.limbs[LIMBS + j - i]);
465          mulnadd3(c0, c1, c2, d0, d1, d2, MAX_PRIME_DIFF);
466          for (int i = 0; i < j + 1; ++i) muladd3(c0, c1, c2, this->limbs[i], a.limbs[j - i]);
467          extract3(c0, c1, c2, tmp.limbs[j]);
468      }
469  
470      /* Compute limb N-1 of a*b into tmp. */
471      assert(c2 == 0);
472      for (int i = 0; i < LIMBS; ++i) muladd3(c0, c1, c2, this->limbs[i], a.limbs[LIMBS - 1 - i]);
473      extract3(c0, c1, c2, tmp.limbs[LIMBS - 1]);
474  
475      /* Perform a second reduction. */
476      muln2(c0, c1, MAX_PRIME_DIFF);
477      for (int j = 0; j < LIMBS; ++j) {
478          addnextract2(c0, c1, tmp.limbs[j], this->limbs[j]);
479      }
480  
481      assert(c1 == 0);
482      assert(c0 == 0 || c0 == 1);
483  
484      /* Perform up to two more reductions if the internal state has already
485       * overflown the MAX of Num3072 or if it is larger than the modulus or
486       * if both are the case.
487       * */
488      if (this->IsOverflow()) this->FullReduce();
489      if (c0) this->FullReduce();
490  }
491  
492  void Num3072::SetToOne()
493  {
494      this->limbs[0] = 1;
495      for (int i = 1; i < LIMBS; ++i) this->limbs[i] = 0;
496  }
497  
498  void Num3072::Divide(const Num3072& a)
499  {
500      if (this->IsOverflow()) this->FullReduce();
501  
502      Num3072 inv{};
503      if (a.IsOverflow()) {
504          Num3072 b = a;
505          b.FullReduce();
506          inv = b.GetInverse();
507      } else {
508          inv = a.GetInverse();
509      }
510  
511      this->Multiply(inv);
512      if (this->IsOverflow()) this->FullReduce();
513  }
514  
515  Num3072::Num3072(const unsigned char (&data)[BYTE_SIZE]) {
516      for (int i = 0; i < LIMBS; ++i) {
517          if (sizeof(limb_t) == 4) {
518              this->limbs[i] = ReadLE32(data + 4 * i);
519          } else if (sizeof(limb_t) == 8) {
520              this->limbs[i] = ReadLE64(data + 8 * i);
521          }
522      }
523  }
524  
525  void Num3072::ToBytes(unsigned char (&out)[BYTE_SIZE]) {
526      for (int i = 0; i < LIMBS; ++i) {
527          if (sizeof(limb_t) == 4) {
528              WriteLE32(out + i * 4, this->limbs[i]);
529          } else if (sizeof(limb_t) == 8) {
530              WriteLE64(out + i * 8, this->limbs[i]);
531          }
532      }
533  }
534  
535  Num3072 MuHash3072::ToNum3072(std::span<const unsigned char> in) {
536      unsigned char tmp[Num3072::BYTE_SIZE];
537  
538      uint256 hashed_in{(HashWriter{} << in).GetSHA256()};
539      static_assert(sizeof(tmp) % ChaCha20Aligned::BLOCKLEN == 0);
540      ChaCha20Aligned{MakeByteSpan(hashed_in)}.Keystream(MakeWritableByteSpan(tmp));
541      Num3072 out{tmp};
542  
543      return out;
544  }
545  
546  MuHash3072::MuHash3072(std::span<const unsigned char> in) noexcept
547  {
548      m_numerator = ToNum3072(in);
549  }
550  
551  void MuHash3072::Finalize(uint256& out) noexcept
552  {
553      m_numerator.Divide(m_denominator);
554      m_denominator.SetToOne();  // Needed to keep the MuHash object valid
555  
556      unsigned char data[Num3072::BYTE_SIZE];
557      m_numerator.ToBytes(data);
558  
559      out = (HashWriter{} << data).GetSHA256();
560  }
561  
562  MuHash3072& MuHash3072::operator*=(const MuHash3072& mul) noexcept
563  {
564      m_numerator.Multiply(mul.m_numerator);
565      m_denominator.Multiply(mul.m_denominator);
566      return *this;
567  }
568  
569  MuHash3072& MuHash3072::operator/=(const MuHash3072& div) noexcept
570  {
571      m_numerator.Multiply(div.m_denominator);
572      m_denominator.Multiply(div.m_numerator);
573      return *this;
574  }
575  
576  MuHash3072& MuHash3072::Insert(std::span<const unsigned char> in) noexcept {
577      m_numerator.Multiply(ToNum3072(in));
578      return *this;
579  }
580  
581  MuHash3072& MuHash3072::Remove(std::span<const unsigned char> in) noexcept {
582      m_denominator.Multiply(ToNum3072(in));
583      return *this;
584  }