test-sampling.cpp
1 #include "ggml.h" 2 #include "llama.h" 3 4 #ifdef NDEBUG 5 #undef NDEBUG 6 #endif 7 8 #include <algorithm> 9 #include <cmath> 10 #include <string> 11 #include <vector> 12 13 static void dump(const llama_token_data_array * candidates) { 14 for (size_t i = 0; i < candidates->size; i++) { 15 printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit); 16 } 17 } 18 19 #define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0) 20 21 static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) { 22 const size_t n_vocab = probs.size(); 23 std::vector<llama_token_data> candidates; 24 candidates.reserve(n_vocab); 25 for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { 26 const float logit = logf(probs[token_id]); 27 candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); 28 } 29 30 llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; 31 llama_sample_softmax(nullptr, &candidates_p); 32 DUMP(&candidates_p); 33 llama_sample_top_k(nullptr, &candidates_p, k, 1); 34 DUMP(&candidates_p); 35 36 GGML_ASSERT(candidates_p.size == expected_probs.size()); 37 for (size_t i = 0; i < candidates_p.size; i++) { 38 GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5); 39 } 40 } 41 42 static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) { 43 const size_t n_vocab = probs.size(); 44 std::vector<llama_token_data> candidates; 45 candidates.reserve(n_vocab); 46 for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { 47 const float logit = logf(probs[token_id]); 48 candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); 49 } 50 51 llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; 52 llama_sample_softmax(nullptr, &candidates_p); 53 DUMP(&candidates_p); 54 llama_sample_top_p(nullptr, &candidates_p, p, 1); 55 DUMP(&candidates_p); 56 57 GGML_ASSERT(candidates_p.size == expected_probs.size()); 58 for (size_t i = 0; i < candidates_p.size; i++) { 59 GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); 60 } 61 } 62 63 static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) { 64 const size_t n_vocab = probs.size(); 65 std::vector<llama_token_data> candidates; 66 candidates.reserve(n_vocab); 67 for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { 68 const float logit = logf(probs[token_id]); 69 candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); 70 } 71 72 llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; 73 DUMP(&candidates_p); 74 llama_sample_tail_free(nullptr, &candidates_p, z, 1); 75 DUMP(&candidates_p); 76 77 GGML_ASSERT(candidates_p.size == expected_probs.size()); 78 for (size_t i = 0; i < candidates_p.size; i++) { 79 GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); 80 } 81 } 82 83 static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) { 84 const size_t n_vocab = probs.size(); 85 std::vector<llama_token_data> candidates; 86 candidates.reserve(n_vocab); 87 for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { 88 const float logit = logf(probs[token_id]); 89 candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); 90 } 91 92 llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; 93 DUMP(&candidates_p); 94 llama_sample_min_p(nullptr, &candidates_p, p, 1); 95 DUMP(&candidates_p); 96 llama_sample_softmax(nullptr, &candidates_p); 97 98 GGML_ASSERT(candidates_p.size == expected_probs.size()); 99 for (size_t i = 0; i < candidates_p.size; i++) { 100 GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); 101 } 102 } 103 104 static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) { 105 const size_t n_vocab = probs.size(); 106 std::vector<llama_token_data> candidates; 107 candidates.reserve(n_vocab); 108 for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { 109 const float logit = logf(probs[token_id]); 110 candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); 111 } 112 113 llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; 114 DUMP(&candidates_p); 115 llama_sample_typical(nullptr, &candidates_p, p, 1); 116 DUMP(&candidates_p); 117 118 GGML_ASSERT(candidates_p.size == expected_probs.size()); 119 for (size_t i = 0; i < candidates_p.size; i++) { 120 GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); 121 } 122 } 123 124 static void test_repetition_penalties( 125 const std::vector<float> & probs, const std::vector<llama_token> & last_tokens, 126 const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence 127 ) { 128 GGML_ASSERT(probs.size() == expected_probs.size()); 129 130 const size_t n_vocab = probs.size(); 131 std::vector<llama_token_data> candidates; 132 candidates.reserve(n_vocab); 133 for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { 134 const float logit = logf(probs[token_id]); 135 candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); 136 } 137 138 llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; 139 llama_sample_softmax(nullptr, &candidates_p); 140 DUMP(&candidates_p); 141 llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence); 142 llama_sample_softmax(nullptr, &candidates_p); 143 DUMP(&candidates_p); 144 145 GGML_ASSERT(candidates_p.size == expected_probs.size()); 146 for (size_t i = 0; i < candidates_p.size; i++) { 147 GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); 148 } 149 } 150 151 static void test_sampler_queue( 152 const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p 153 ) { 154 std::vector<llama_token_data> candidates; 155 candidates.reserve(n_vocab); 156 for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { 157 const float logit = logf(token_id); 158 candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); 159 } 160 161 llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; 162 163 llama_token min_token_id = 0; 164 const llama_token max_token_id = n_vocab-1; 165 166 for (auto s : samplers_sequence) { 167 switch (s){ 168 case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break; 169 case 'f': GGML_ASSERT(false && "tail_free test not implemented"); break; 170 case 'y': GGML_ASSERT(false && "typical test not implemented"); break; 171 case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break; 172 case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break; 173 case 't': GGML_ASSERT(false && "temperature test not implemented"); break; 174 default : GGML_ASSERT(false && "Unknown sampler"); break; 175 } 176 177 llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests 178 179 const int size = candidates_p.size; 180 181 if (s == 'k') { 182 const int expected_size = std::min(size, top_k); 183 min_token_id = std::max(min_token_id, (llama_token)(n_vocab - top_k)); 184 185 GGML_ASSERT(size == expected_size); 186 GGML_ASSERT(candidates_p.data[0].id == max_token_id); 187 GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id); 188 } else if (s == 'p') { 189 const int softmax_divisor = n_vocab * (n_vocab-1) / 2 - min_token_id * (min_token_id-1) / 2; 190 const int softmax_numerator_target = ceilf(top_p * softmax_divisor); 191 192 min_token_id = n_vocab; 193 int expected_size = 0; 194 int cumsum = 0; 195 do { // do-while because always at least one token is sampled 196 min_token_id--; 197 expected_size++; 198 199 cumsum += min_token_id; 200 } while (cumsum < softmax_numerator_target); 201 202 // token 0 has p == 0, need special consideration for cumsum because top_p immediately returns 203 if (min_token_id == 1) { 204 min_token_id--; 205 expected_size += 1; 206 } 207 208 GGML_ASSERT(size == expected_size); 209 GGML_ASSERT(candidates_p.data[0].id == max_token_id); 210 GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id); 211 } else if (s == 'm') { 212 int expected_size = ceilf((1.0f-min_p) * n_vocab); 213 expected_size = std::max(expected_size, 1); 214 expected_size = std::min(expected_size, size); 215 216 min_token_id = floorf(min_p * n_vocab); 217 min_token_id = std::max(min_token_id, 1); 218 min_token_id = std::max(min_token_id, (llama_token)(n_vocab - size)); 219 min_token_id = std::min(min_token_id, (llama_token)(n_vocab - 1)); 220 221 GGML_ASSERT(size == expected_size); 222 GGML_ASSERT(candidates_p.data[0].id == max_token_id); 223 GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id); 224 } else { 225 GGML_ASSERT(false); 226 } 227 } 228 229 printf("Sampler queue %3s OK with n_vocab=%05ld top_k=%05d top_p=%f min_p=%f\n", 230 samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p); 231 } 232 233 int main(void) { 234 ggml_time_init(); 235 236 test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1); 237 test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3); 238 test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4); 239 test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0); 240 241 test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0); 242 test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f); 243 test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 0.8f); 244 test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1); 245 246 test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.00f); 247 test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.24f); 248 test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.9f, 0.3f/0.9f, 0.2f/0.9f}, 0.26f); 249 test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.9f, 0.3f/0.9f, 0.2f/0.9f}, 0.49f); 250 test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.51f); 251 test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.74f); 252 test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f); 253 test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f); 254 255 test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f); 256 test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f); 257 test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f); 258 259 test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f); 260 test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f); 261 262 test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f); 263 test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f); 264 test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f); 265 266 test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f); 267 test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f); 268 test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f); 269 270 test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f); 271 test_sampler_queue(10000, "k", 1, 1.0f, 1.0f); 272 test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f); 273 test_sampler_queue(10000, "p", 10000, 0.0f, 1.0f); 274 test_sampler_queue(10000, "m", 10000, 1.0f, 1.0f); 275 test_sampler_queue(10000, "m", 10000, 1.0f, 1e-12); 276 277 test_sampler_queue(10000, "k", 100, 1.0000f, 1.0f); 278 test_sampler_queue(10000, "p", 10000, 0.0002f, 1.0f); 279 test_sampler_queue(10000, "p", 10000, 0.8000f, 1.0f); 280 test_sampler_queue(10000, "m", 10000, 1.0000f, 9997.9f/9999.0f); 281 test_sampler_queue(10000, "m", 10000, 1.0000f, 0.1f); 282 283 test_sampler_queue(10000, "kp", 100, 0.8f, 0.1f); 284 test_sampler_queue(10000, "km", 100, 0.8f, 0.1f); 285 test_sampler_queue(10000, "pk", 100, 0.8f, 0.1f); 286 test_sampler_queue(10000, "pm", 100, 0.8f, 0.1f); 287 test_sampler_queue(10000, "mk", 100, 0.8f, 0.1f); 288 test_sampler_queue(10000, "mp", 100, 0.8f, 9997.9f/9999.0f); 289 test_sampler_queue(10000, "mp", 100, 0.8f, 0.1f); 290 291 test_sampler_queue(10000, "kpm", 100, 0.8f, 0.1f); 292 test_sampler_queue(10000, "kmp", 100, 0.8f, 0.1f); 293 test_sampler_queue(10000, "pkm", 100, 0.8f, 0.1f); 294 test_sampler_queue(10000, "pmk", 100, 0.8f, 0.1f); 295 test_sampler_queue(10000, "mkp", 100, 0.8f, 0.1f); 296 test_sampler_queue(10000, "mpk", 100, 0.8f, 0.1f); 297 298 printf("OK\n"); 299 300 return 0; 301 }