/ src / minisketch / src / sketch_impl.h
sketch_impl.h
  1  /**********************************************************************
  2   * Copyright (c) 2018 Pieter Wuille, Greg Maxwell, Gleb Naumenko      *
  3   * Distributed under the MIT software license, see the accompanying   *
  4   * file LICENSE or http://www.opensource.org/licenses/mit-license.php.*
  5   **********************************************************************/
  6  
  7  #ifndef _MINISKETCH_SKETCH_IMPL_H_
  8  #define _MINISKETCH_SKETCH_IMPL_H_
  9  
 10  #include <random>
 11  
 12  #include "util.h"
 13  #include "sketch.h"
 14  #include "int_utils.h"
 15  
 16  /** Compute the remainder of a polynomial division of val by mod, putting the result in mod. */
 17  template<typename F>
 18  void PolyMod(const std::vector<typename F::Elem>& mod, std::vector<typename F::Elem>& val, const F& field) {
 19      size_t modsize = mod.size();
 20      CHECK_SAFE(modsize > 0 && mod.back() == 1);
 21      if (val.size() < modsize) return;
 22      CHECK_SAFE(val.back() != 0);
 23      while (val.size() >= modsize) {
 24          auto term = val.back();
 25          val.pop_back();
 26          if (term != 0) {
 27              typename F::Multiplier mul(field, term);
 28              for (size_t x = 0; x < mod.size() - 1; ++x) {
 29                  val[val.size() - modsize + 1 + x] ^= mul(mod[x]);
 30              }
 31          }
 32      }
 33      while (val.size() > 0 && val.back() == 0) val.pop_back();
 34  }
 35  
 36  /** Compute the quotient of a polynomial division of val by mod, putting the quotient in div and the remainder in val. */
 37  template<typename F>
 38  void DivMod(const std::vector<typename F::Elem>& mod, std::vector<typename F::Elem>& val, std::vector<typename F::Elem>& div, const F& field) {
 39      size_t modsize = mod.size();
 40      CHECK_SAFE(mod.size() > 0 && mod.back() == 1);
 41      if (val.size() < mod.size()) {
 42          div.clear();
 43          return;
 44      }
 45      CHECK_SAFE(val.back() != 0);
 46      div.resize(val.size() - mod.size() + 1);
 47      while (val.size() >= modsize) {
 48          auto term = val.back();
 49          div[val.size() - modsize] = term;
 50          val.pop_back();
 51          if (term != 0) {
 52              typename F::Multiplier mul(field, term);
 53              for (size_t x = 0; x < mod.size() - 1; ++x) {
 54                  val[val.size() - modsize + 1 + x] ^= mul(mod[x]);
 55              }
 56          }
 57      }
 58  }
 59  
 60  /** Make a polynomial monic. */
 61  template<typename F>
 62  typename F::Elem MakeMonic(std::vector<typename F::Elem>& a, const F& field) {
 63      CHECK_SAFE(a.back() != 0);
 64      if (a.back() == 1) return 0;
 65      auto inv = field.Inv(a.back());
 66      typename F::Multiplier mul(field, inv);
 67      a.back() = 1;
 68      for (size_t i = 0; i < a.size() - 1; ++i) {
 69          a[i] = mul(a[i]);
 70      }
 71      return inv;
 72  }
 73  
 74  /** Compute the GCD of two polynomials, putting the result in a. b will be cleared. */
 75  template<typename F>
 76  void GCD(std::vector<typename F::Elem>& a, std::vector<typename F::Elem>& b, const F& field) {
 77      if (a.size() < b.size()) std::swap(a, b);
 78      while (b.size() > 0) {
 79          if (b.size() == 1) {
 80              a.resize(1);
 81              a[0] = 1;
 82              return;
 83          }
 84          MakeMonic(b, field);
 85          PolyMod(b, a, field);
 86          std::swap(a, b);
 87      }
 88  }
 89  
 90  /** Square a polynomial. */
 91  template<typename F>
 92  void Sqr(std::vector<typename F::Elem>& poly, const F& field) {
 93      if (poly.size() == 0) return;
 94      poly.resize(poly.size() * 2 - 1);
 95      for (size_t i = 0; i < poly.size(); ++i) {
 96          auto x = poly.size() - i - 1;
 97          poly[x] = (x & 1) ? 0 : field.Sqr(poly[x / 2]);
 98      }
 99  }
100  
101  /** Compute the trace map of (param*x) modulo mod, putting the result in out. */
102  template<typename F>
103  void TraceMod(const std::vector<typename F::Elem>& mod, std::vector<typename F::Elem>& out, const typename F::Elem& param, const F& field) {
104      out.reserve(mod.size() * 2);
105      out.resize(2);
106      out[0] = 0;
107      out[1] = param;
108  
109      for (int i = 0; i < field.Bits() - 1; ++i) {
110          Sqr(out, field);
111          if (out.size() < 2) out.resize(2);
112          out[1] = param;
113          PolyMod(mod, out, field);
114      }
115  }
116  
117  /** One step of the root finding algorithm; finds roots of stack[pos] and adds them to roots. Stack elements >= pos are destroyed.
118   *
119   * It operates on a stack of polynomials. The polynomial operated on is `stack[pos]`, where elements of `stack` with index higher
120   * than `pos` are used as scratch space.
121   *
122   * `stack[pos]` is assumed to be square-free polynomial. If `fully_factorizable` is true, it is also assumed to have no irreducible
123   * factors of degree higher than 1.
124  
125   * This implements the Berlekamp trace algorithm, plus an efficient test to fail fast in
126   * case the polynomial cannot be fully factored.
127   */
128  template<typename F>
129  bool RecFindRoots(std::vector<std::vector<typename F::Elem>>& stack, size_t pos, std::vector<typename F::Elem>& roots, bool fully_factorizable, int depth, typename F::Elem randv, const F& field) {
130      auto& ppoly = stack[pos];
131      // We assert ppoly.size() > 1 (instead of just ppoly.size() > 0) to additionally exclude
132      // constants polynomials because
133      //  - ppoly is not constant initially (this is ensured by FindRoots()), and
134      //  - we never recurse on a constant polynomial.
135      CHECK_SAFE(ppoly.size() > 1 && ppoly.back() == 1);
136      /* 1st degree input: constant term is the root. */
137      if (ppoly.size() == 2) {
138          roots.push_back(ppoly[0]);
139          return true;
140      }
141      /* 2nd degree input: use direct quadratic solver. */
142      if (ppoly.size() == 3) {
143          CHECK_RETURN(ppoly[1] != 0, false); // Equations of the form (x^2 + a) have two identical solutions; contradicts square-free assumption. */
144          auto input = field.Mul(ppoly[0], field.Sqr(field.Inv(ppoly[1])));
145          auto root = field.Qrt(input);
146          if ((field.Sqr(root) ^ root) != input) {
147              CHECK_SAFE(!fully_factorizable);
148              return false; // No root found.
149          }
150          auto sol = field.Mul(root, ppoly[1]);
151          roots.push_back(sol);
152          roots.push_back(sol ^ ppoly[1]);
153          return true;
154      }
155      /* 3rd degree input and more: recurse further. */
156      if (pos + 3 > stack.size()) {
157          // Allocate memory if necessary.
158          stack.resize((pos + 3) * 2);
159      }
160      auto& poly = stack[pos];
161      auto& tmp = stack[pos + 1];
162      auto& trace = stack[pos + 2];
163      trace.clear();
164      tmp.clear();
165      for (int iter = 0;; ++iter) {
166          // Compute the polynomial (trace(x*randv) mod poly(x)) symbolically,
167          // and put the result in `trace`.
168          TraceMod(poly, trace, randv, field);
169  
170          if (iter >= 1 && !fully_factorizable) {
171              // If the polynomial cannot be factorized completely (it has an
172              // irreducible factor of degree higher than 1), we want to avoid
173              // the case where this is only detected after trying all BITS
174              // independent split attempts fail (see the assert below).
175              //
176              // Observe that if we call y = randv*x, it is true that:
177              //
178              //   trace = y + y^2 + y^4 + y^8 + ... y^(FIELDSIZE/2) mod poly
179              //
180              // Due to the Frobenius endomorphism, this means:
181              //
182              //   trace^2 = y^2 + y^4 + y^8 + ... + y^FIELDSIZE mod poly
183              //
184              // Or, adding them up:
185              //
186              //   trace + trace^2 = y + y^FIELDSIZE mod poly.
187              //                   = randv*x + randv^FIELDSIZE*x^FIELDSIZE
188              //                   = randv*x + randv*x^FIELDSIZE
189              //                   = randv*(x + x^FIELDSIZE).
190              //     (all mod poly)
191              //
192              // x + x^FIELDSIZE is the polynomial which has every field element
193              // as root once. Whenever x + x^FIELDSIZE is multiple of poly,
194              // this means it only has unique first degree factors. The same
195              // holds for its constant multiple randv*(x + x^FIELDSIZE) =
196              // trace + trace^2.
197              //
198              // We use this test to quickly verify whether the polynomial is
199              // fully factorizable after already having computed a trace.
200              // We don't invoke it immediately; only when splitting has failed
201              // at least once, which avoids it for most polynomials that are
202              // fully factorizable (or at least pushes the test down the
203              // recursion to factors which are smaller and thus faster).
204              tmp = trace;
205              Sqr(tmp, field);
206              for (size_t i = 0; i < trace.size(); ++i) {
207                  tmp[i] ^= trace[i];
208              }
209              while (tmp.size() && tmp.back() == 0) tmp.pop_back();
210              PolyMod(poly, tmp, field);
211  
212              // Whenever the test fails, we can immediately abort the root
213              // finding. Whenever it succeeds, we can remember and pass down
214              // the information that it is in fact fully factorizable, avoiding
215              // the need to run the test again.
216              if (tmp.size() != 0) return false;
217              fully_factorizable = true;
218          }
219  
220          if (fully_factorizable) {
221              // Every successful iteration of this algorithm splits the input
222              // polynomial further into buckets, each corresponding to a subset
223              // of 2^(BITS-depth) roots. If after depth splits the degree of
224              // the polynomial is >= 2^(BITS-depth), something is wrong.
225              CHECK_RETURN(field.Bits() - depth >= std::numeric_limits<decltype(poly.size())>::digits ||
226                  (poly.size() - 2) >> (field.Bits() - depth) == 0, false);
227          }
228  
229          depth++;
230          // In every iteration we multiply randv by 2. As a result, the set
231          // of randv values forms a GF(2)-linearly independent basis of splits.
232          randv = field.Mul2(randv);
233          tmp = poly;
234          GCD(trace, tmp, field);
235          if (trace.size() != poly.size() && trace.size() > 1) break;
236      }
237      MakeMonic(trace, field);
238      DivMod(trace, poly, tmp, field);
239      // At this point, the stack looks like [... (poly) tmp trace], and we want to recursively
240      // find roots of trace and tmp (= poly/trace). As we don't care about poly anymore, move
241      // trace into its position first.
242      std::swap(poly, trace);
243      // Now the stack is [... (trace) tmp ...]. First we factor tmp (at pos = pos+1), and then
244      // we factor trace (at pos = pos).
245      if (!RecFindRoots(stack, pos + 1, roots, fully_factorizable, depth, randv, field)) return false;
246      // The stack position pos contains trace, the polynomial with all of poly's roots which (after
247      // multiplication with randv) have trace 0. This is never the case for irreducible factors
248      // (which always end up in tmp), so we can set fully_factorizable to true when recursing.
249      bool ret = RecFindRoots(stack, pos, roots, true, depth, randv, field);
250      // Because of the above, recursion can never fail here.
251      CHECK_SAFE(ret);
252      return ret;
253  }
254  
255  /** Returns the roots of a fully factorizable polynomial
256   *
257   * This function assumes that the input polynomial is square-free
258   * and not the zero polynomial (represented by an empty vector).
259   *
260   * In case the square-free polynomial is not fully factorizable, i.e., it
261   * has fewer roots than its degree, the empty vector is returned.
262   */
263  template<typename F>
264  std::vector<typename F::Elem> FindRoots(const std::vector<typename F::Elem>& poly, typename F::Elem basis, const F& field) {
265      std::vector<typename F::Elem> roots;
266      CHECK_RETURN(poly.size() != 0, {});
267      CHECK_RETURN(basis != 0, {});
268      if (poly.size() == 1) return roots; // No roots when the polynomial is a constant.
269      roots.reserve(poly.size() - 1);
270      std::vector<std::vector<typename F::Elem>> stack = {poly};
271  
272      // Invoke the recursive factorization algorithm.
273      if (!RecFindRoots(stack, 0, roots, false, 0, basis, field)) {
274          // Not fully factorizable.
275          return {};
276      }
277      CHECK_RETURN(poly.size() - 1 == roots.size(), {});
278      return roots;
279  }
280  
281  template<typename F>
282  std::vector<typename F::Elem> BerlekampMassey(const std::vector<typename F::Elem>& syndromes, size_t max_degree, const F& field) {
283      std::vector<typename F::Multiplier> table;
284      std::vector<typename F::Elem> current, prev, tmp;
285      current.reserve(syndromes.size() / 2 + 1);
286      prev.reserve(syndromes.size() / 2 + 1);
287      tmp.reserve(syndromes.size() / 2 + 1);
288      current.resize(1);
289      current[0] = 1;
290      prev.resize(1);
291      prev[0] = 1;
292      typename F::Elem b = 1, b_inv = 1;
293      bool b_have_inv = true;
294      table.reserve(syndromes.size());
295  
296      for (size_t n = 0; n != syndromes.size(); ++n) {
297          table.emplace_back(field, syndromes[n]);
298          auto discrepancy = syndromes[n];
299          for (size_t i = 1; i < current.size(); ++i) discrepancy ^= table[n - i](current[i]);
300          if (discrepancy != 0) {
301              int x = static_cast<int>(n + 1 - (current.size() - 1) - (prev.size() - 1));
302              if (!b_have_inv) {
303                  b_inv = field.Inv(b);
304                  b_have_inv = true;
305              }
306              bool swap = 2 * (current.size() - 1) <= n;
307              if (swap) {
308                  if (prev.size() + x - 1 > max_degree) return {}; // We'd exceed maximum degree
309                  tmp = current;
310                  current.resize(prev.size() + x);
311              }
312              typename F::Multiplier mul(field, field.Mul(discrepancy, b_inv));
313              for (size_t i = 0; i < prev.size(); ++i) current[i + x] ^= mul(prev[i]);
314              if (swap) {
315                  std::swap(prev, tmp);
316                  b = discrepancy;
317                  b_have_inv = false;
318              }
319          }
320      }
321      CHECK_RETURN(current.size() && current.back() != 0, {});
322      return current;
323  }
324  
325  template<typename F>
326  std::vector<typename F::Elem> ReconstructAllSyndromes(const std::vector<typename F::Elem>& odd_syndromes, const F& field) {
327      std::vector<typename F::Elem> all_syndromes;
328      all_syndromes.resize(odd_syndromes.size() * 2);
329      for (size_t i = 0; i < odd_syndromes.size(); ++i) {
330          all_syndromes[i * 2] = odd_syndromes[i];
331          all_syndromes[i * 2 + 1] = field.Sqr(all_syndromes[i]);
332      }
333      return all_syndromes;
334  }
335  
336  template<typename F>
337  void AddToOddSyndromes(std::vector<typename F::Elem>& osyndromes, typename F::Elem data, const F& field) {
338      auto sqr = field.Sqr(data);
339      typename F::Multiplier mul(field, sqr);
340      for (auto& osyndrome : osyndromes) {
341          osyndrome ^= data;
342          data = mul(data);
343      }
344  }
345  
346  template<typename F>
347  std::vector<typename F::Elem> FullDecode(const std::vector<typename F::Elem>& osyndromes, const F& field) {
348      auto asyndromes = ReconstructAllSyndromes<typename F::Elem>(osyndromes, field);
349      auto poly = BerlekampMassey(asyndromes, field);
350      std::reverse(poly.begin(), poly.end());
351      return FindRoots(poly, field);
352  }
353  
354  template<typename F>
355  class SketchImpl final : public Sketch
356  {
357      const F m_field;
358      std::vector<typename F::Elem> m_syndromes;
359      typename F::Elem m_basis;
360  
361  public:
362      template<typename... Args>
363      SketchImpl(int implementation, int bits, const Args&... args) : Sketch(implementation, bits), m_field(args...) {
364          std::random_device rng;
365          std::uniform_int_distribution<uint64_t> dist;
366          m_basis = m_field.FromSeed(dist(rng));
367      }
368  
369      size_t Syndromes() const override { return m_syndromes.size(); }
370      void Init(size_t count) override { m_syndromes.assign(count, 0); }
371  
372      void Add(uint64_t val) override
373      {
374          auto elem = m_field.FromUint64(val);
375          AddToOddSyndromes(m_syndromes, elem, m_field);
376      }
377  
378      void Serialize(unsigned char* ptr) const override
379      {
380          BitWriter writer(ptr);
381          for (const auto& val : m_syndromes) {
382              m_field.Serialize(writer, val);
383          }
384          writer.Flush();
385      }
386  
387      void Deserialize(const unsigned char* ptr) override
388      {
389          BitReader reader(ptr);
390          for (auto& val : m_syndromes) {
391              val = m_field.Deserialize(reader);
392          }
393      }
394  
395      int Decode(int max_count, uint64_t* out) const override
396      {
397          auto all_syndromes = ReconstructAllSyndromes(m_syndromes, m_field);
398          auto poly = BerlekampMassey(all_syndromes, max_count, m_field);
399          if (poly.size() == 0) return -1;
400          if (poly.size() == 1) return 0;
401          if ((int)poly.size() > 1 + max_count) return -1;
402          std::reverse(poly.begin(), poly.end());
403          auto roots = FindRoots(poly, m_basis, m_field);
404          if (roots.size() == 0) return -1;
405  
406          for (const auto& root : roots) {
407              *(out++) = m_field.ToUint64(root);
408          }
409          return static_cast<int>(roots.size());
410      }
411  
412      size_t Merge(const Sketch* other_sketch) override
413      {
414          // Sad cast. This is safe only because the caller code in minisketch.cpp checks
415          // that implementation and field size match.
416          const SketchImpl* other = static_cast<const SketchImpl*>(other_sketch);
417          m_syndromes.resize(std::min(m_syndromes.size(), other->m_syndromes.size()));
418          for (size_t i = 0; i < m_syndromes.size(); ++i) {
419              m_syndromes[i] ^= other->m_syndromes[i];
420          }
421          return m_syndromes.size();
422      }
423  
424      void SetSeed(uint64_t seed) override
425      {
426          if (seed == (uint64_t)-1) {
427              m_basis = 1;
428          } else {
429              m_basis = m_field.FromSeed(seed);
430          }
431      }
432  };
433  
434  #endif