sketch_impl.h
1 /********************************************************************** 2 * Copyright (c) 2018 Pieter Wuille, Greg Maxwell, Gleb Naumenko * 3 * Distributed under the MIT software license, see the accompanying * 4 * file LICENSE or http://www.opensource.org/licenses/mit-license.php.* 5 **********************************************************************/ 6 7 #ifndef _MINISKETCH_SKETCH_IMPL_H_ 8 #define _MINISKETCH_SKETCH_IMPL_H_ 9 10 #include <random> 11 12 #include "util.h" 13 #include "sketch.h" 14 #include "int_utils.h" 15 16 /** Compute the remainder of a polynomial division of val by mod, putting the result in mod. */ 17 template<typename F> 18 void PolyMod(const std::vector<typename F::Elem>& mod, std::vector<typename F::Elem>& val, const F& field) { 19 size_t modsize = mod.size(); 20 CHECK_SAFE(modsize > 0 && mod.back() == 1); 21 if (val.size() < modsize) return; 22 CHECK_SAFE(val.back() != 0); 23 while (val.size() >= modsize) { 24 auto term = val.back(); 25 val.pop_back(); 26 if (term != 0) { 27 typename F::Multiplier mul(field, term); 28 for (size_t x = 0; x < mod.size() - 1; ++x) { 29 val[val.size() - modsize + 1 + x] ^= mul(mod[x]); 30 } 31 } 32 } 33 while (val.size() > 0 && val.back() == 0) val.pop_back(); 34 } 35 36 /** Compute the quotient of a polynomial division of val by mod, putting the quotient in div and the remainder in val. */ 37 template<typename F> 38 void DivMod(const std::vector<typename F::Elem>& mod, std::vector<typename F::Elem>& val, std::vector<typename F::Elem>& div, const F& field) { 39 size_t modsize = mod.size(); 40 CHECK_SAFE(mod.size() > 0 && mod.back() == 1); 41 if (val.size() < mod.size()) { 42 div.clear(); 43 return; 44 } 45 CHECK_SAFE(val.back() != 0); 46 div.resize(val.size() - mod.size() + 1); 47 while (val.size() >= modsize) { 48 auto term = val.back(); 49 div[val.size() - modsize] = term; 50 val.pop_back(); 51 if (term != 0) { 52 typename F::Multiplier mul(field, term); 53 for (size_t x = 0; x < mod.size() - 1; ++x) { 54 val[val.size() - modsize + 1 + x] ^= mul(mod[x]); 55 } 56 } 57 } 58 } 59 60 /** Make a polynomial monic. */ 61 template<typename F> 62 typename F::Elem MakeMonic(std::vector<typename F::Elem>& a, const F& field) { 63 CHECK_SAFE(a.back() != 0); 64 if (a.back() == 1) return 0; 65 auto inv = field.Inv(a.back()); 66 typename F::Multiplier mul(field, inv); 67 a.back() = 1; 68 for (size_t i = 0; i < a.size() - 1; ++i) { 69 a[i] = mul(a[i]); 70 } 71 return inv; 72 } 73 74 /** Compute the GCD of two polynomials, putting the result in a. b will be cleared. */ 75 template<typename F> 76 void GCD(std::vector<typename F::Elem>& a, std::vector<typename F::Elem>& b, const F& field) { 77 if (a.size() < b.size()) std::swap(a, b); 78 while (b.size() > 0) { 79 if (b.size() == 1) { 80 a.resize(1); 81 a[0] = 1; 82 return; 83 } 84 MakeMonic(b, field); 85 PolyMod(b, a, field); 86 std::swap(a, b); 87 } 88 } 89 90 /** Square a polynomial. */ 91 template<typename F> 92 void Sqr(std::vector<typename F::Elem>& poly, const F& field) { 93 if (poly.size() == 0) return; 94 poly.resize(poly.size() * 2 - 1); 95 for (size_t i = 0; i < poly.size(); ++i) { 96 auto x = poly.size() - i - 1; 97 poly[x] = (x & 1) ? 0 : field.Sqr(poly[x / 2]); 98 } 99 } 100 101 /** Compute the trace map of (param*x) modulo mod, putting the result in out. */ 102 template<typename F> 103 void TraceMod(const std::vector<typename F::Elem>& mod, std::vector<typename F::Elem>& out, const typename F::Elem& param, const F& field) { 104 out.reserve(mod.size() * 2); 105 out.resize(2); 106 out[0] = 0; 107 out[1] = param; 108 109 for (int i = 0; i < field.Bits() - 1; ++i) { 110 Sqr(out, field); 111 if (out.size() < 2) out.resize(2); 112 out[1] = param; 113 PolyMod(mod, out, field); 114 } 115 } 116 117 /** One step of the root finding algorithm; finds roots of stack[pos] and adds them to roots. Stack elements >= pos are destroyed. 118 * 119 * It operates on a stack of polynomials. The polynomial operated on is `stack[pos]`, where elements of `stack` with index higher 120 * than `pos` are used as scratch space. 121 * 122 * `stack[pos]` is assumed to be square-free polynomial. If `fully_factorizable` is true, it is also assumed to have no irreducible 123 * factors of degree higher than 1. 124 125 * This implements the Berlekamp trace algorithm, plus an efficient test to fail fast in 126 * case the polynomial cannot be fully factored. 127 */ 128 template<typename F> 129 bool RecFindRoots(std::vector<std::vector<typename F::Elem>>& stack, size_t pos, std::vector<typename F::Elem>& roots, bool fully_factorizable, int depth, typename F::Elem randv, const F& field) { 130 auto& ppoly = stack[pos]; 131 // We assert ppoly.size() > 1 (instead of just ppoly.size() > 0) to additionally exclude 132 // constants polynomials because 133 // - ppoly is not constant initially (this is ensured by FindRoots()), and 134 // - we never recurse on a constant polynomial. 135 CHECK_SAFE(ppoly.size() > 1 && ppoly.back() == 1); 136 /* 1st degree input: constant term is the root. */ 137 if (ppoly.size() == 2) { 138 roots.push_back(ppoly[0]); 139 return true; 140 } 141 /* 2nd degree input: use direct quadratic solver. */ 142 if (ppoly.size() == 3) { 143 CHECK_RETURN(ppoly[1] != 0, false); // Equations of the form (x^2 + a) have two identical solutions; contradicts square-free assumption. */ 144 auto input = field.Mul(ppoly[0], field.Sqr(field.Inv(ppoly[1]))); 145 auto root = field.Qrt(input); 146 if ((field.Sqr(root) ^ root) != input) { 147 CHECK_SAFE(!fully_factorizable); 148 return false; // No root found. 149 } 150 auto sol = field.Mul(root, ppoly[1]); 151 roots.push_back(sol); 152 roots.push_back(sol ^ ppoly[1]); 153 return true; 154 } 155 /* 3rd degree input and more: recurse further. */ 156 if (pos + 3 > stack.size()) { 157 // Allocate memory if necessary. 158 stack.resize((pos + 3) * 2); 159 } 160 auto& poly = stack[pos]; 161 auto& tmp = stack[pos + 1]; 162 auto& trace = stack[pos + 2]; 163 trace.clear(); 164 tmp.clear(); 165 for (int iter = 0;; ++iter) { 166 // Compute the polynomial (trace(x*randv) mod poly(x)) symbolically, 167 // and put the result in `trace`. 168 TraceMod(poly, trace, randv, field); 169 170 if (iter >= 1 && !fully_factorizable) { 171 // If the polynomial cannot be factorized completely (it has an 172 // irreducible factor of degree higher than 1), we want to avoid 173 // the case where this is only detected after trying all BITS 174 // independent split attempts fail (see the assert below). 175 // 176 // Observe that if we call y = randv*x, it is true that: 177 // 178 // trace = y + y^2 + y^4 + y^8 + ... y^(FIELDSIZE/2) mod poly 179 // 180 // Due to the Frobenius endomorphism, this means: 181 // 182 // trace^2 = y^2 + y^4 + y^8 + ... + y^FIELDSIZE mod poly 183 // 184 // Or, adding them up: 185 // 186 // trace + trace^2 = y + y^FIELDSIZE mod poly. 187 // = randv*x + randv^FIELDSIZE*x^FIELDSIZE 188 // = randv*x + randv*x^FIELDSIZE 189 // = randv*(x + x^FIELDSIZE). 190 // (all mod poly) 191 // 192 // x + x^FIELDSIZE is the polynomial which has every field element 193 // as root once. Whenever x + x^FIELDSIZE is multiple of poly, 194 // this means it only has unique first degree factors. The same 195 // holds for its constant multiple randv*(x + x^FIELDSIZE) = 196 // trace + trace^2. 197 // 198 // We use this test to quickly verify whether the polynomial is 199 // fully factorizable after already having computed a trace. 200 // We don't invoke it immediately; only when splitting has failed 201 // at least once, which avoids it for most polynomials that are 202 // fully factorizable (or at least pushes the test down the 203 // recursion to factors which are smaller and thus faster). 204 tmp = trace; 205 Sqr(tmp, field); 206 for (size_t i = 0; i < trace.size(); ++i) { 207 tmp[i] ^= trace[i]; 208 } 209 while (tmp.size() && tmp.back() == 0) tmp.pop_back(); 210 PolyMod(poly, tmp, field); 211 212 // Whenever the test fails, we can immediately abort the root 213 // finding. Whenever it succeeds, we can remember and pass down 214 // the information that it is in fact fully factorizable, avoiding 215 // the need to run the test again. 216 if (tmp.size() != 0) return false; 217 fully_factorizable = true; 218 } 219 220 if (fully_factorizable) { 221 // Every successful iteration of this algorithm splits the input 222 // polynomial further into buckets, each corresponding to a subset 223 // of 2^(BITS-depth) roots. If after depth splits the degree of 224 // the polynomial is >= 2^(BITS-depth), something is wrong. 225 CHECK_RETURN(field.Bits() - depth >= std::numeric_limits<decltype(poly.size())>::digits || 226 (poly.size() - 2) >> (field.Bits() - depth) == 0, false); 227 } 228 229 depth++; 230 // In every iteration we multiply randv by 2. As a result, the set 231 // of randv values forms a GF(2)-linearly independent basis of splits. 232 randv = field.Mul2(randv); 233 tmp = poly; 234 GCD(trace, tmp, field); 235 if (trace.size() != poly.size() && trace.size() > 1) break; 236 } 237 MakeMonic(trace, field); 238 DivMod(trace, poly, tmp, field); 239 // At this point, the stack looks like [... (poly) tmp trace], and we want to recursively 240 // find roots of trace and tmp (= poly/trace). As we don't care about poly anymore, move 241 // trace into its position first. 242 std::swap(poly, trace); 243 // Now the stack is [... (trace) tmp ...]. First we factor tmp (at pos = pos+1), and then 244 // we factor trace (at pos = pos). 245 if (!RecFindRoots(stack, pos + 1, roots, fully_factorizable, depth, randv, field)) return false; 246 // The stack position pos contains trace, the polynomial with all of poly's roots which (after 247 // multiplication with randv) have trace 0. This is never the case for irreducible factors 248 // (which always end up in tmp), so we can set fully_factorizable to true when recursing. 249 bool ret = RecFindRoots(stack, pos, roots, true, depth, randv, field); 250 // Because of the above, recursion can never fail here. 251 CHECK_SAFE(ret); 252 return ret; 253 } 254 255 /** Returns the roots of a fully factorizable polynomial 256 * 257 * This function assumes that the input polynomial is square-free 258 * and not the zero polynomial (represented by an empty vector). 259 * 260 * In case the square-free polynomial is not fully factorizable, i.e., it 261 * has fewer roots than its degree, the empty vector is returned. 262 */ 263 template<typename F> 264 std::vector<typename F::Elem> FindRoots(const std::vector<typename F::Elem>& poly, typename F::Elem basis, const F& field) { 265 std::vector<typename F::Elem> roots; 266 CHECK_RETURN(poly.size() != 0, {}); 267 CHECK_RETURN(basis != 0, {}); 268 if (poly.size() == 1) return roots; // No roots when the polynomial is a constant. 269 roots.reserve(poly.size() - 1); 270 std::vector<std::vector<typename F::Elem>> stack = {poly}; 271 272 // Invoke the recursive factorization algorithm. 273 if (!RecFindRoots(stack, 0, roots, false, 0, basis, field)) { 274 // Not fully factorizable. 275 return {}; 276 } 277 CHECK_RETURN(poly.size() - 1 == roots.size(), {}); 278 return roots; 279 } 280 281 template<typename F> 282 std::vector<typename F::Elem> BerlekampMassey(const std::vector<typename F::Elem>& syndromes, size_t max_degree, const F& field) { 283 std::vector<typename F::Multiplier> table; 284 std::vector<typename F::Elem> current, prev, tmp; 285 current.reserve(syndromes.size() / 2 + 1); 286 prev.reserve(syndromes.size() / 2 + 1); 287 tmp.reserve(syndromes.size() / 2 + 1); 288 current.resize(1); 289 current[0] = 1; 290 prev.resize(1); 291 prev[0] = 1; 292 typename F::Elem b = 1, b_inv = 1; 293 bool b_have_inv = true; 294 table.reserve(syndromes.size()); 295 296 for (size_t n = 0; n != syndromes.size(); ++n) { 297 table.emplace_back(field, syndromes[n]); 298 auto discrepancy = syndromes[n]; 299 for (size_t i = 1; i < current.size(); ++i) discrepancy ^= table[n - i](current[i]); 300 if (discrepancy != 0) { 301 int x = static_cast<int>(n + 1 - (current.size() - 1) - (prev.size() - 1)); 302 if (!b_have_inv) { 303 b_inv = field.Inv(b); 304 b_have_inv = true; 305 } 306 bool swap = 2 * (current.size() - 1) <= n; 307 if (swap) { 308 if (prev.size() + x - 1 > max_degree) return {}; // We'd exceed maximum degree 309 tmp = current; 310 current.resize(prev.size() + x); 311 } 312 typename F::Multiplier mul(field, field.Mul(discrepancy, b_inv)); 313 for (size_t i = 0; i < prev.size(); ++i) current[i + x] ^= mul(prev[i]); 314 if (swap) { 315 std::swap(prev, tmp); 316 b = discrepancy; 317 b_have_inv = false; 318 } 319 } 320 } 321 CHECK_RETURN(current.size() && current.back() != 0, {}); 322 return current; 323 } 324 325 template<typename F> 326 std::vector<typename F::Elem> ReconstructAllSyndromes(const std::vector<typename F::Elem>& odd_syndromes, const F& field) { 327 std::vector<typename F::Elem> all_syndromes; 328 all_syndromes.resize(odd_syndromes.size() * 2); 329 for (size_t i = 0; i < odd_syndromes.size(); ++i) { 330 all_syndromes[i * 2] = odd_syndromes[i]; 331 all_syndromes[i * 2 + 1] = field.Sqr(all_syndromes[i]); 332 } 333 return all_syndromes; 334 } 335 336 template<typename F> 337 void AddToOddSyndromes(std::vector<typename F::Elem>& osyndromes, typename F::Elem data, const F& field) { 338 auto sqr = field.Sqr(data); 339 typename F::Multiplier mul(field, sqr); 340 for (auto& osyndrome : osyndromes) { 341 osyndrome ^= data; 342 data = mul(data); 343 } 344 } 345 346 template<typename F> 347 std::vector<typename F::Elem> FullDecode(const std::vector<typename F::Elem>& osyndromes, const F& field) { 348 auto asyndromes = ReconstructAllSyndromes<typename F::Elem>(osyndromes, field); 349 auto poly = BerlekampMassey(asyndromes, field); 350 std::reverse(poly.begin(), poly.end()); 351 return FindRoots(poly, field); 352 } 353 354 template<typename F> 355 class SketchImpl final : public Sketch 356 { 357 const F m_field; 358 std::vector<typename F::Elem> m_syndromes; 359 typename F::Elem m_basis; 360 361 public: 362 template<typename... Args> 363 SketchImpl(int implementation, int bits, const Args&... args) : Sketch(implementation, bits), m_field(args...) { 364 std::random_device rng; 365 std::uniform_int_distribution<uint64_t> dist; 366 m_basis = m_field.FromSeed(dist(rng)); 367 } 368 369 size_t Syndromes() const override { return m_syndromes.size(); } 370 void Init(size_t count) override { m_syndromes.assign(count, 0); } 371 372 void Add(uint64_t val) override 373 { 374 auto elem = m_field.FromUint64(val); 375 AddToOddSyndromes(m_syndromes, elem, m_field); 376 } 377 378 void Serialize(unsigned char* ptr) const override 379 { 380 BitWriter writer(ptr); 381 for (const auto& val : m_syndromes) { 382 m_field.Serialize(writer, val); 383 } 384 writer.Flush(); 385 } 386 387 void Deserialize(const unsigned char* ptr) override 388 { 389 BitReader reader(ptr); 390 for (auto& val : m_syndromes) { 391 val = m_field.Deserialize(reader); 392 } 393 } 394 395 int Decode(int max_count, uint64_t* out) const override 396 { 397 auto all_syndromes = ReconstructAllSyndromes(m_syndromes, m_field); 398 auto poly = BerlekampMassey(all_syndromes, max_count, m_field); 399 if (poly.size() == 0) return -1; 400 if (poly.size() == 1) return 0; 401 if ((int)poly.size() > 1 + max_count) return -1; 402 std::reverse(poly.begin(), poly.end()); 403 auto roots = FindRoots(poly, m_basis, m_field); 404 if (roots.size() == 0) return -1; 405 406 for (const auto& root : roots) { 407 *(out++) = m_field.ToUint64(root); 408 } 409 return static_cast<int>(roots.size()); 410 } 411 412 size_t Merge(const Sketch* other_sketch) override 413 { 414 // Sad cast. This is safe only because the caller code in minisketch.cpp checks 415 // that implementation and field size match. 416 const SketchImpl* other = static_cast<const SketchImpl*>(other_sketch); 417 m_syndromes.resize(std::min(m_syndromes.size(), other->m_syndromes.size())); 418 for (size_t i = 0; i < m_syndromes.size(); ++i) { 419 m_syndromes[i] ^= other->m_syndromes[i]; 420 } 421 return m_syndromes.size(); 422 } 423 424 void SetSeed(uint64_t seed) override 425 { 426 if (seed == (uint64_t)-1) { 427 m_basis = 1; 428 } else { 429 m_basis = m_field.FromSeed(seed); 430 } 431 } 432 }; 433 434 #endif