test-quantize-fns.cpp
1 // Unit tests for quantization specific functions - quantize, dequantize and dot product 2 3 #include "ggml.h" 4 5 #undef NDEBUG 6 #include <assert.h> 7 #include <math.h> 8 #include <stdio.h> 9 #include <string> 10 #include <vector> 11 12 #if defined(_MSC_VER) 13 #pragma warning(disable: 4244 4267) // possible loss of data 14 #endif 15 16 constexpr float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f; 17 constexpr float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f; 18 constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f; 19 constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f; 20 constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f; 21 constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f; 22 constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f; 23 24 static const char* RESULT_STR[] = {"ok", "FAILED"}; 25 26 27 // Generate synthetic data 28 static void generate_data(float offset, size_t n, float * dst) { 29 for (size_t i = 0; i < n; i++) { 30 dst[i] = 0.1 + 2*cosf(i + offset); 31 } 32 } 33 34 // Calculate RMSE between two float arrays 35 static float array_rmse(const float * a1, const float * a2, size_t n) { 36 double sum = 0; 37 for (size_t i = 0; i < n; i++) { 38 double diff = a1[i] - a2[i]; 39 sum += diff * diff; 40 } 41 return sqrtf(sum) / n; 42 } 43 44 // Total quantization error on test data 45 static float total_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) { 46 std::vector<uint8_t> tmp_q(2*test_size); 47 std::vector<float> tmp_out(test_size); 48 49 qfns.from_float(test_data, tmp_q.data(), test_size); 50 qfns.to_float(tmp_q.data(), tmp_out.data(), test_size); 51 return array_rmse(test_data, tmp_out.data(), test_size); 52 } 53 54 // Total quantization error on test data 55 static float reference_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) { 56 std::vector<uint8_t> tmp_q(2*test_size); 57 std::vector<float> tmp_out(test_size); 58 std::vector<float> tmp_out_ref(test_size); 59 60 qfns.from_float(test_data, tmp_q.data(), test_size); 61 qfns.to_float(tmp_q.data(), tmp_out.data(), test_size); 62 63 qfns.from_float_reference(test_data, tmp_q.data(), test_size); 64 qfns.to_float(tmp_q.data(), tmp_out_ref.data(), test_size); 65 66 return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size); 67 } 68 69 static float dot_product(const float * a1, const float * a2, size_t test_size) { 70 double sum = 0; 71 for (size_t i = 0; i < test_size; i++) { 72 sum += a1[i] * a2[i]; 73 } 74 return sum; 75 } 76 77 // Total dot product error 78 static float dot_product_error( 79 ggml_type_traits_t & qfns, size_t test_size, const float * test_data1, const float *test_data2 80 ) { 81 std::vector<uint8_t> tmp_q1(2*test_size); 82 std::vector<uint8_t> tmp_q2(2*test_size); 83 84 auto vdot = ggml_internal_get_type_traits(qfns.vec_dot_type); 85 86 qfns.from_float(test_data1, tmp_q1.data(), test_size); 87 vdot.from_float(test_data2, tmp_q2.data(), test_size); 88 89 float result = INFINITY; 90 qfns.vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, 1); 91 92 const float dot_ref = dot_product(test_data1, test_data2, test_size); 93 94 return fabsf(result - dot_ref) / test_size; 95 } 96 97 int main(int argc, char * argv[]) { 98 bool verbose = false; 99 const size_t test_size = 32 * 128; 100 101 std::string arg; 102 for (int i = 1; i < argc; i++) { 103 arg = argv[i]; 104 105 if (arg == "-v") { 106 verbose = true; 107 } else { 108 fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); 109 return 1; 110 } 111 } 112 113 std::vector<float> test_data(test_size); 114 std::vector<float> test_data2(test_size); 115 116 generate_data(0.0, test_data.size(), test_data.data()); 117 generate_data(1.0, test_data2.size(), test_data2.data()); 118 119 // Initialize GGML, ensures float conversion tables are initialized 120 struct ggml_init_params ggml_params = { 121 /* .mem_size = */ 1*1024, 122 /* .mem_buffer = */ NULL, 123 /* .no_alloc = */ true, 124 }; 125 struct ggml_context * ctx = ggml_init(ggml_params); 126 127 int num_failed = 0; 128 bool failed = false; 129 130 for (int i = 0; i < GGML_TYPE_COUNT; i++) { 131 ggml_type type = (ggml_type) i; 132 ggml_type_traits_t qfns = ggml_internal_get_type_traits(type); 133 134 // deprecated - skip 135 if (qfns.blck_size == 0) { 136 continue; 137 } 138 139 const ggml_type ei = (ggml_type)i; 140 141 printf("Testing %s\n", ggml_type_name((ggml_type) i)); 142 ggml_quantize_init(ei); 143 144 if (qfns.from_float && qfns.to_float) { 145 const float total_error = total_quantization_error(qfns, test_size, test_data.data()); 146 const float max_quantization_error = 147 type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS : 148 type == GGML_TYPE_IQ2_S ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS : 149 type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : 150 type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : 151 type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : MAX_QUANTIZATION_TOTAL_ERROR; 152 failed = !(total_error < max_quantization_error); 153 num_failed += failed; 154 if (failed || verbose) { 155 printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error); 156 } 157 158 const float reference_error = reference_quantization_error(qfns, test_size, test_data.data()); 159 failed = !(reference_error < MAX_QUANTIZATION_REFERENCE_ERROR); 160 num_failed += failed; 161 if (failed || verbose) { 162 printf("%5s reference implementation error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], reference_error); 163 } 164 165 const float vec_dot_error = dot_product_error(qfns, test_size, test_data.data(), test_data2.data()); 166 const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS || 167 type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S 168 ? MAX_DOT_PRODUCT_ERROR_LOWBIT 169 : MAX_DOT_PRODUCT_ERROR; 170 failed = !(vec_dot_error < max_allowed_error); 171 num_failed += failed; 172 if (failed || verbose) { 173 printf("%5s dot product error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error); 174 } 175 } 176 } 177 178 if (num_failed || verbose) { 179 printf("%d tests failed\n", num_failed); 180 } 181 182 ggml_free(ctx); 183 184 return num_failed > 0; 185 }