/ src / minisketch / src / test.cpp
test.cpp
  1  /**********************************************************************
  2   * Copyright (c) 2018,2021 Pieter Wuille, Greg Maxwell, Gleb Naumenko *
  3   * Distributed under the MIT software license, see the accompanying   *
  4   * file LICENSE or http://www.opensource.org/licenses/mit-license.php.*
  5   **********************************************************************/
  6  
  7  #include <algorithm>
  8  #include <cstdio>
  9  #include <limits>
 10  #include <random>
 11  #include <stdexcept>
 12  #include <string>
 13  #include <vector>
 14  
 15  #include "../include/minisketch.h"
 16  #include "util.h"
 17  
 18  namespace {
 19  
 20  uint64_t Combination(uint64_t n, uint64_t k) {
 21      if (n - k < k) k = n - k;
 22      uint64_t ret = 1;
 23      for (uint64_t i = 1; i <= k; ++i) {
 24          ret = (ret * n) / i;
 25          --n;
 26      }
 27      return ret;
 28  }
 29  
 30  /** Create a vector with Minisketch objects, one for each implementation. */
 31  std::vector<Minisketch> CreateSketches(uint32_t bits, size_t capacity) {
 32      if (!Minisketch::BitsSupported(bits)) return {};
 33      std::vector<Minisketch> ret;
 34      for (uint32_t impl = 0; impl <= Minisketch::MaxImplementation(); ++impl) {
 35          if (Minisketch::ImplementationSupported(bits, impl)) {
 36              CHECK(Minisketch::BitsSupported(bits));
 37              ret.push_back(Minisketch(bits, impl, capacity));
 38              CHECK((bool)ret.back());
 39          } else {
 40              // implementation 0 must always work unless field size is disabled
 41              CHECK(impl != 0 || !Minisketch::BitsSupported(bits));
 42          }
 43      }
 44      return ret;
 45  }
 46  
 47  /** Test properties by exhaustively decoding all 2**(bits*capacity) sketches
 48   *  with specified capacity and bits. */
 49  void TestExhaustive(uint32_t bits, size_t capacity) {
 50      auto sketches = CreateSketches(bits, capacity);
 51      if (sketches.empty()) return;
 52      auto sketches_rebuild = CreateSketches(bits, capacity);
 53  
 54      std::vector<unsigned char> serialized;
 55      std::vector<unsigned char> serialized_empty;
 56      std::vector<uint64_t> counts; //!< counts[i] = number of results with i elements
 57      std::vector<uint64_t> elements_0; //!< Result vector for elements for impl=0
 58      std::vector<uint64_t> elements_other; //!< Result vector for elements for other impls
 59      std::vector<uint64_t> elements_too_small; //!< Result vector that's too small
 60  
 61      counts.resize(capacity + 1);
 62      serialized.resize(sketches[0].GetSerializedSize());
 63      serialized_empty.resize(sketches[0].GetSerializedSize());
 64  
 65      // Iterate over all (bits)-bit sketches with (capacity) syndromes.
 66      for (uint64_t x = 0; (x >> (bits * capacity)) == 0; ++x) {
 67          // Construct the serialization.
 68          for (size_t i = 0; i < serialized.size(); ++i) {
 69              serialized[i] = (x >> (i * 8)) & 0xFF;
 70          }
 71  
 72          // Compute all the solutions
 73          sketches[0].Deserialize(serialized);
 74          elements_0.resize(64);
 75          bool decodable_0 = sketches[0].Decode(elements_0);
 76          std::sort(elements_0.begin(), elements_0.end());
 77  
 78          // Verify that decoding with other implementations agrees.
 79          for (size_t impl = 1; impl < sketches.size(); ++impl) {
 80              sketches[impl].Deserialize(serialized);
 81              elements_other.resize(64);
 82              bool decodable_other = sketches[impl].Decode(elements_other);
 83              CHECK(decodable_other == decodable_0);
 84              std::sort(elements_other.begin(), elements_other.end());
 85              CHECK(elements_other == elements_0);
 86          }
 87  
 88          // If there are solutions:
 89          if (decodable_0) {
 90              if (!elements_0.empty()) {
 91                  // Decoding with limit one less than the number of elements should fail.
 92                  elements_too_small.resize(elements_0.size() - 1);
 93                  for (size_t impl = 0; impl < sketches.size(); ++impl) {
 94                      CHECK(!sketches[impl].Decode(elements_too_small));
 95                  }
 96              }
 97  
 98              // Reconstruct the sketch from the solutions.
 99              for (size_t impl = 0; impl < sketches.size(); ++impl) {
100                  // Clear the sketch.
101                  sketches_rebuild[impl].Deserialize(serialized_empty);
102                  // Load all decoded elements into it.
103                  for (uint64_t elem : elements_0) {
104                      CHECK(elem != 0);
105                      CHECK(elem >> bits == 0);
106                      sketches_rebuild[impl].Add(elem);
107                  }
108                  // Reserialize the result
109                  auto serialized_rebuild = sketches_rebuild[impl].Serialize();
110                  // Compare
111                  CHECK(serialized == serialized_rebuild);
112                  // Count it
113                  if (impl == 0 && elements_0.size() <= capacity) ++counts[elements_0.size()];
114              }
115          }
116      }
117  
118      // Verify that the number of decodable sketches with given elements is expected.
119      uint64_t mask = bits == 64 ? UINT64_MAX : (uint64_t{1} << bits) - 1;
120      for (uint64_t i = 0; i <= capacity && (i & mask) == i; ++i) {
121          CHECK(counts[i] == Combination(mask, i));
122      }
123  }
124  
125  /** Test properties of sketches with random elements put in. */
126  void TestRandomized(uint32_t bits, size_t max_capacity, size_t iter) {
127      std::random_device rnd;
128      std::uniform_int_distribution<uint64_t> capacity_dist(0, std::min<uint64_t>(std::numeric_limits<uint64_t>::max() >> (64 - bits), max_capacity));
129      std::uniform_int_distribution<uint64_t> element_dist(1, std::numeric_limits<uint64_t>::max() >> (64 - bits));
130      std::uniform_int_distribution<uint64_t> rand64(0, std::numeric_limits<uint64_t>::max());
131      std::uniform_int_distribution<int64_t> size_offset_dist(-3, 3);
132  
133      std::vector<uint64_t> decode_0;
134      std::vector<uint64_t> decode_other;
135      std::vector<uint64_t> decode_temp;
136      std::vector<uint64_t> elements;
137  
138      for (size_t i = 0; i < iter; ++i) {
139          // Determine capacity, and construct Minisketch objects for all implementations.
140          uint64_t capacity = capacity_dist(rnd);
141          auto sketches = CreateSketches(bits, capacity);
142          // Sanity checks
143          if (sketches.empty()) return;
144          for (size_t impl = 0; impl < sketches.size(); ++impl) {
145              CHECK(sketches[impl].GetBits() == bits);
146              CHECK(sketches[impl].GetCapacity() == capacity);
147              CHECK(sketches[impl].GetSerializedSize() == sketches[0].GetSerializedSize());
148          }
149          // Determine the number of elements, and create a vector to store them in.
150          size_t element_count = std::max<int64_t>(0, std::max<int64_t>(0, capacity + size_offset_dist(rnd)));
151          elements.resize(element_count);
152          // Add the elements to all sketches
153          for (size_t j = 0; j < element_count; ++j) {
154              uint64_t elem = element_dist(rnd);
155              CHECK(elem != 0);
156              elements[j] = elem;
157              for (auto& sketch : sketches) sketch.Add(elem);
158          }
159          // Remove pairs of duplicates in elements, as they cancel out.
160          std::sort(elements.begin(), elements.end());
161          size_t real_element_count = element_count;
162          for (size_t pos = 0; pos + 1 < elements.size(); ++pos) {
163              if (elements[pos] == elements[pos + 1]) {
164                  real_element_count -= 2;
165                  // Set both elements to 0; afterwards we will move these to the end.
166                  elements[pos] = 0;
167                  elements[pos + 1] = 0;
168                  ++pos;
169              }
170          }
171          if (real_element_count < element_count) {
172              // Move all introduced zeroes (masking duplicates) to the end.
173              std::sort(elements.begin(), elements.end(), [](uint64_t a, uint64_t b) { return a != b && (b == 0 || (a != 0 && a < b)); });
174              CHECK(elements[real_element_count] == 0);
175              elements.resize(real_element_count);
176          }
177          // Create and compare serializations
178          auto serialized_0 = sketches[0].Serialize();
179          for (size_t impl = 1; impl < sketches.size(); ++impl) {
180              auto serialized_other = sketches[impl].Serialize();
181              CHECK(serialized_other == serialized_0);
182          }
183          // Deserialize and reserialize them
184          for (size_t impl = 0; impl < sketches.size(); ++impl) {
185              sketches[impl].Deserialize(serialized_0);
186              auto reserialized = sketches[impl].Serialize();
187              CHECK(reserialized == serialized_0);
188          }
189          // Decode with limit set to the capacity, and compare results
190          decode_0.resize(capacity);
191          bool decodable_0 = sketches[0].Decode(decode_0);
192          std::sort(decode_0.begin(), decode_0.end());
193          for (size_t impl = 1; impl < sketches.size(); ++impl) {
194              decode_other.resize(capacity);
195              bool decodable_other = sketches[impl].Decode(decode_other);
196              CHECK(decodable_other == decodable_0);
197              std::sort(decode_other.begin(), decode_other.end());
198              CHECK(decode_other == decode_0);
199          }
200          // If the result is decodable, it should also be decodable with limit
201          // set to the actual number of elements, and not with one less.
202          if (decodable_0) {
203              for (auto& sketch : sketches) {
204                  decode_temp.resize(decode_0.size());
205                  bool decodable = sketch.Decode(decode_temp);
206                  CHECK(decodable);
207                  std::sort(decode_temp.begin(), decode_temp.end());
208                  CHECK(decode_temp == decode_0);
209                  if (!decode_0.empty()) {
210                      decode_temp.resize(decode_0.size() - 1);
211                      decodable = sketch.Decode(decode_temp);
212                      CHECK(!decodable);
213                  }
214              }
215          }
216          // If the actual number of elements is not higher than the capacity, the
217          // result should be decodable, and the result should match what we put in.
218          if (real_element_count <= capacity) {
219              CHECK(decodable_0);
220              CHECK(decode_0 == elements);
221          }
222      }
223  }
224  
225  void TestComputeFunctions() {
226      for (uint32_t bits = 0; bits <= 256; ++bits) {
227          for (uint32_t fpbits = 0; fpbits <= 512; ++fpbits) {
228              std::vector<size_t> table_max_elements(1025);
229              for (size_t capacity = 0; capacity <= 1024; ++capacity) {
230                  table_max_elements[capacity] = minisketch_compute_max_elements(bits, capacity, fpbits);
231                  // Exception for bits==0
232                  if (bits == 0) CHECK(table_max_elements[capacity] == 0);
233                  // A sketch with capacity N cannot guarantee decoding more than N elements.
234                  CHECK(table_max_elements[capacity] <= capacity);
235                  // When asking for N bits of false positive protection, either no solution exists, or no more than ceil(N / bits) excess capacity should be needed.
236                  if (bits > 0) CHECK(table_max_elements[capacity] == 0 || capacity - table_max_elements[capacity] <= (fpbits + bits - 1) / bits);
237                  // Increasing capacity by one, if there is a solution, should always increment the max_elements by at least one as well.
238                  if (capacity > 0) CHECK(table_max_elements[capacity] == 0 || table_max_elements[capacity] > table_max_elements[capacity - 1]);
239              }
240  
241              std::vector<size_t> table_capacity(513);
242              for (size_t max_elements = 0; max_elements <= 512; ++max_elements) {
243                  table_capacity[max_elements] = minisketch_compute_capacity(bits, max_elements, fpbits);
244                  // Exception for bits==0
245                  if (bits == 0) CHECK(table_capacity[max_elements] == 0);
246                  // To be able to decode N elements, capacity needs to be at least N.
247                  if (bits > 0) CHECK(table_capacity[max_elements] >= max_elements);
248                  // A sketch of N bits in total cannot have more than N bits of false positive protection;
249                  if (bits > 0) CHECK(bits * table_capacity[max_elements] >= fpbits);
250                  // When asking for N bits of false positive protection, no more than ceil(N / bits) excess capacity should be needed.
251                  if (bits > 0) CHECK(table_capacity[max_elements] - max_elements <= (fpbits + bits - 1) / bits);
252                  // Increasing max_elements by one can only increment the capacity by 0 or 1.
253                  if (max_elements > 0 && fpbits < 256) CHECK(table_capacity[max_elements] == table_capacity[max_elements - 1] || table_capacity[max_elements] == table_capacity[max_elements - 1] + 1);
254                  // Check round-tripping max_elements->capacity->max_elements (only a lower bound)
255                  CHECK(table_capacity[max_elements] <= 1024);
256                  CHECK(table_max_elements[table_capacity[max_elements]] == 0 || table_max_elements[table_capacity[max_elements]] >= max_elements);
257              }
258  
259              for (size_t capacity = 0; capacity <= 512; ++capacity) {
260                  // Check round-tripping capacity->max_elements->capacity (exact, if it exists)
261                  CHECK(table_max_elements[capacity] <= 512);
262                  CHECK(table_max_elements[capacity] == 0 || table_capacity[table_max_elements[capacity]] == capacity);
263              }
264          }
265      }
266  }
267  
268  } // namespace
269  
270  int main(int argc, char** argv) {
271      uint64_t test_complexity = 4;
272      if (argc > 1) {
273          size_t len = 0;
274          std::string arg{argv[1]};
275          try {
276              test_complexity = 0;
277              long long complexity = std::stoll(arg, &len);
278              if (complexity >= 1 && len == arg.size() && ((uint64_t)complexity <= std::numeric_limits<uint64_t>::max() >> 10)) {
279                  test_complexity = complexity;
280              }
281          } catch (const std::logic_error&) {}
282          if (test_complexity == 0) {
283              fprintf(stderr, "Invalid complexity specified: '%s'\n", arg.c_str());
284              return 1;
285          }
286      }
287  
288  #ifdef MINISKETCH_VERIFY
289      const char* mode = " in verify mode";
290  #else
291      const char* mode = "";
292  #endif
293      printf("Running libminisketch tests%s with complexity=%llu\n", mode, (unsigned long long)test_complexity);
294  
295      TestComputeFunctions();
296  
297      for (unsigned j = 2; j <= 64; ++j) {
298          TestRandomized(j, 8, (test_complexity << 10) / j);
299          TestRandomized(j, 128, (test_complexity << 7) / j);
300          TestRandomized(j, 4096, test_complexity / j);
301      }
302  
303      // Test capacity==0 together with all field sizes, and then
304      // all combinations of bits and capacity up to a certain bits*capacity,
305      // depending on test_complexity.
306      for (int weight = 0; weight <= 40; ++weight) {
307          for (int bits = 2; weight == 0 ? bits <= 64 : (bits <= 32 && bits <= weight); ++bits) {
308              int capacity = weight / bits;
309              if (capacity * bits != weight) continue;
310              TestExhaustive(bits, capacity);
311          }
312          if (weight >= 16 && test_complexity >> (weight - 16) == 0) break;
313      }
314  
315      printf("All tests successful.\n");
316      return 0;
317  }