/ tests / test-quantize-fns.cpp
test-quantize-fns.cpp
  1  // Unit tests for quantization specific functions - quantize, dequantize and dot product
  2  
  3  #include "ggml.h"
  4  
  5  #undef NDEBUG
  6  #include <assert.h>
  7  #include <math.h>
  8  #include <stdio.h>
  9  #include <string>
 10  #include <vector>
 11  
 12  #if defined(_MSC_VER)
 13  #pragma warning(disable: 4244 4267) // possible loss of data
 14  #endif
 15  
 16  constexpr float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f;
 17  constexpr float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f;
 18  constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f;
 19  constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f;
 20  constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f;
 21  constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f;
 22  constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f;
 23  
 24  static const char* RESULT_STR[] = {"ok", "FAILED"};
 25  
 26  
 27  // Generate synthetic data
 28  static void generate_data(float offset, size_t n, float * dst) {
 29      for (size_t i = 0; i < n; i++) {
 30          dst[i] = 0.1 + 2*cosf(i + offset);
 31      }
 32  }
 33  
 34  // Calculate RMSE between two float arrays
 35  static float array_rmse(const float * a1, const float * a2, size_t n) {
 36      double sum = 0;
 37      for (size_t i = 0; i < n; i++) {
 38          double diff = a1[i] - a2[i];
 39          sum += diff * diff;
 40      }
 41      return sqrtf(sum) / n;
 42  }
 43  
 44  // Total quantization error on test data
 45  static float total_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) {
 46      std::vector<uint8_t> tmp_q(2*test_size);
 47      std::vector<float> tmp_out(test_size);
 48  
 49      qfns.from_float(test_data, tmp_q.data(), test_size);
 50      qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
 51      return array_rmse(test_data, tmp_out.data(), test_size);
 52  }
 53  
 54  // Total quantization error on test data
 55  static float reference_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) {
 56      std::vector<uint8_t> tmp_q(2*test_size);
 57      std::vector<float> tmp_out(test_size);
 58      std::vector<float> tmp_out_ref(test_size);
 59  
 60      qfns.from_float(test_data, tmp_q.data(), test_size);
 61      qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
 62  
 63      qfns.from_float_reference(test_data, tmp_q.data(), test_size);
 64      qfns.to_float(tmp_q.data(), tmp_out_ref.data(), test_size);
 65  
 66      return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size);
 67  }
 68  
 69  static float dot_product(const float * a1, const float * a2, size_t test_size) {
 70      double sum = 0;
 71      for (size_t i = 0; i < test_size; i++) {
 72          sum += a1[i] * a2[i];
 73      }
 74      return sum;
 75  }
 76  
 77  // Total dot product error
 78  static float dot_product_error(
 79      ggml_type_traits_t & qfns, size_t test_size, const float * test_data1, const float *test_data2
 80  ) {
 81      std::vector<uint8_t> tmp_q1(2*test_size);
 82      std::vector<uint8_t> tmp_q2(2*test_size);
 83  
 84      auto vdot = ggml_internal_get_type_traits(qfns.vec_dot_type);
 85  
 86      qfns.from_float(test_data1, tmp_q1.data(), test_size);
 87      vdot.from_float(test_data2, tmp_q2.data(), test_size);
 88  
 89      float result = INFINITY;
 90      qfns.vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, 1);
 91  
 92      const float dot_ref = dot_product(test_data1, test_data2, test_size);
 93  
 94      return fabsf(result - dot_ref) / test_size;
 95  }
 96  
 97  int main(int argc, char * argv[]) {
 98      bool verbose = false;
 99      const size_t test_size = 32 * 128;
100  
101      std::string arg;
102      for (int i = 1; i < argc; i++) {
103          arg = argv[i];
104  
105          if (arg == "-v") {
106              verbose = true;
107          } else {
108              fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
109              return 1;
110          }
111      }
112  
113      std::vector<float> test_data(test_size);
114      std::vector<float> test_data2(test_size);
115  
116      generate_data(0.0, test_data.size(), test_data.data());
117      generate_data(1.0, test_data2.size(), test_data2.data());
118  
119      // Initialize GGML, ensures float conversion tables are initialized
120      struct ggml_init_params ggml_params = {
121          /* .mem_size   = */ 1*1024,
122          /* .mem_buffer = */ NULL,
123          /* .no_alloc   = */ true,
124      };
125      struct ggml_context * ctx = ggml_init(ggml_params);
126  
127      int num_failed = 0;
128      bool failed = false;
129  
130      for (int i = 0; i < GGML_TYPE_COUNT; i++) {
131          ggml_type type = (ggml_type) i;
132          ggml_type_traits_t qfns = ggml_internal_get_type_traits(type);
133  
134          // deprecated - skip
135          if (qfns.blck_size == 0) {
136              continue;
137          }
138  
139          const ggml_type ei = (ggml_type)i;
140  
141          printf("Testing %s\n", ggml_type_name((ggml_type) i));
142          ggml_quantize_init(ei);
143  
144          if (qfns.from_float && qfns.to_float) {
145              const float total_error = total_quantization_error(qfns, test_size, test_data.data());
146              const float max_quantization_error =
147                  type == GGML_TYPE_Q2_K    ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
148                  type == GGML_TYPE_IQ2_S   ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
149                  type == GGML_TYPE_Q3_K    ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
150                  type == GGML_TYPE_IQ3_S   ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
151                  type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : MAX_QUANTIZATION_TOTAL_ERROR;
152              failed = !(total_error < max_quantization_error);
153              num_failed += failed;
154              if (failed || verbose) {
155                  printf("%5s absolute quantization error:    %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error);
156              }
157  
158              const float reference_error = reference_quantization_error(qfns, test_size, test_data.data());
159              failed = !(reference_error < MAX_QUANTIZATION_REFERENCE_ERROR);
160              num_failed += failed;
161              if (failed || verbose) {
162                  printf("%5s reference implementation error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], reference_error);
163              }
164  
165              const float vec_dot_error = dot_product_error(qfns, test_size, test_data.data(), test_data2.data());
166              const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
167                                              type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
168                                            ? MAX_DOT_PRODUCT_ERROR_LOWBIT
169                                            : MAX_DOT_PRODUCT_ERROR;
170              failed = !(vec_dot_error < max_allowed_error);
171              num_failed += failed;
172              if (failed || verbose) {
173                  printf("%5s dot product error:              %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error);
174              }
175          }
176      }
177  
178      if (num_failed || verbose) {
179          printf("%d tests failed\n", num_failed);
180      }
181  
182      ggml_free(ctx);
183  
184      return num_failed > 0;
185  }