sampling.cpp
1 #define LLAMA_API_INTERNAL 2 #include "sampling.h" 3 #include <random> 4 5 struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { 6 struct llama_sampling_context * result = new llama_sampling_context(); 7 8 result->params = params; 9 result->grammar = nullptr; 10 11 // if there is a grammar, parse it 12 if (!params.grammar.empty()) { 13 result->parsed_grammar = grammar_parser::parse(params.grammar.c_str()); 14 15 // will be empty (default) if there are parse errors 16 if (result->parsed_grammar.rules.empty()) { 17 fprintf(stderr, "%s: failed to parse grammar\n", __func__); 18 delete result; 19 return nullptr; 20 } 21 22 // Ensure that there is a "root" node. 23 if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) { 24 fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); 25 delete result; 26 return nullptr; 27 } 28 29 std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules()); 30 31 result->grammar = llama_grammar_init( 32 grammar_rules.data(), 33 grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); 34 } 35 36 result->prev.resize(params.n_prev); 37 38 result->n_valid = 0; 39 40 llama_sampling_set_rng_seed(result, params.seed); 41 42 return result; 43 } 44 45 void llama_sampling_free(struct llama_sampling_context * ctx) { 46 if (ctx->grammar != NULL) { 47 llama_grammar_free(ctx->grammar); 48 } 49 50 delete ctx; 51 } 52 53 void llama_sampling_reset(llama_sampling_context * ctx) { 54 if (ctx->grammar != NULL) { 55 llama_grammar_free(ctx->grammar); 56 ctx->grammar = NULL; 57 } 58 59 if (!ctx->parsed_grammar.rules.empty()) { 60 std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules()); 61 62 ctx->grammar = llama_grammar_init( 63 grammar_rules.data(), 64 grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); 65 } 66 67 std::fill(ctx->prev.begin(), ctx->prev.end(), 0); 68 ctx->cur.clear(); 69 ctx->n_valid = 0; 70 } 71 72 void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { 73 if (seed == LLAMA_DEFAULT_SEED) { 74 seed = std::random_device{}(); 75 } 76 ctx->rng.seed(seed); 77 } 78 79 void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { 80 if (dst->grammar) { 81 llama_grammar_free(dst->grammar); 82 dst->grammar = nullptr; 83 } 84 85 if (src->grammar) { 86 dst->grammar = llama_grammar_copy(src->grammar); 87 } 88 89 dst->prev = src->prev; 90 } 91 92 llama_token llama_sampling_last(llama_sampling_context * ctx) { 93 return ctx->prev.back(); 94 } 95 96 std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) { 97 const int size = ctx_sampling->prev.size(); 98 99 n = std::min(n, size); 100 101 std::string result; 102 103 for (int i = size - n; i < size; i++) { 104 result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]); 105 } 106 107 return result; 108 } 109 110 std::string llama_sampling_print(const llama_sampling_params & params) { 111 char result[1024]; 112 113 snprintf(result, sizeof(result), 114 "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" 115 "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" 116 "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", 117 params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, 118 params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp, 119 params.mirostat, params.mirostat_eta, params.mirostat_tau); 120 121 return std::string(result); 122 } 123 124 std::string llama_sampling_order_print(const llama_sampling_params & params) { 125 std::string result = "CFG -> Penalties "; 126 if (params.mirostat == 0) { 127 for (auto sampler_type : params.samplers_sequence) { 128 const auto sampler_type_name = llama_sampling_type_to_str(sampler_type); 129 if (!sampler_type_name.empty()) { 130 result += "-> " + sampler_type_name + " "; 131 } 132 } 133 } else { 134 result += "-> mirostat "; 135 } 136 137 return result; 138 } 139 140 std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) { 141 switch (sampler_type) { 142 case llama_sampler_type::TOP_K: return "top_k"; 143 case llama_sampler_type::TFS_Z: return "tfs_z"; 144 case llama_sampler_type::TYPICAL_P: return "typical_p"; 145 case llama_sampler_type::TOP_P: return "top_p"; 146 case llama_sampler_type::MIN_P: return "min_p"; 147 case llama_sampler_type::TEMPERATURE: return "temperature"; 148 default : return ""; 149 } 150 } 151 152 std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) { 153 std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map { 154 {"top_k", llama_sampler_type::TOP_K}, 155 {"top_p", llama_sampler_type::TOP_P}, 156 {"typical_p", llama_sampler_type::TYPICAL_P}, 157 {"min_p", llama_sampler_type::MIN_P}, 158 {"tfs_z", llama_sampler_type::TFS_Z}, 159 {"temperature", llama_sampler_type::TEMPERATURE} 160 }; 161 162 // since samplers names are written multiple ways 163 // make it ready for both system names and input names 164 std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map { 165 {"top-k", llama_sampler_type::TOP_K}, 166 {"top-p", llama_sampler_type::TOP_P}, 167 {"nucleus", llama_sampler_type::TOP_P}, 168 {"typical-p", llama_sampler_type::TYPICAL_P}, 169 {"typical", llama_sampler_type::TYPICAL_P}, 170 {"min-p", llama_sampler_type::MIN_P}, 171 {"tfs-z", llama_sampler_type::TFS_Z}, 172 {"tfs", llama_sampler_type::TFS_Z}, 173 {"temp", llama_sampler_type::TEMPERATURE} 174 }; 175 176 std::vector<llama_sampler_type> sampler_types; 177 sampler_types.reserve(names.size()); 178 for (const auto & name : names) 179 { 180 auto sampler_item = sampler_canonical_name_map.find(name); 181 if (sampler_item != sampler_canonical_name_map.end()) 182 { 183 sampler_types.push_back(sampler_item->second); 184 } 185 else 186 { 187 if (allow_alt_names) 188 { 189 sampler_item = sampler_alt_name_map.find(name); 190 if (sampler_item != sampler_alt_name_map.end()) 191 { 192 sampler_types.push_back(sampler_item->second); 193 } 194 } 195 } 196 } 197 return sampler_types; 198 } 199 200 std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string) { 201 std::unordered_map<char, llama_sampler_type> sampler_name_map { 202 {'k', llama_sampler_type::TOP_K}, 203 {'p', llama_sampler_type::TOP_P}, 204 {'y', llama_sampler_type::TYPICAL_P}, 205 {'m', llama_sampler_type::MIN_P}, 206 {'f', llama_sampler_type::TFS_Z}, 207 {'t', llama_sampler_type::TEMPERATURE} 208 }; 209 210 std::vector<llama_sampler_type> sampler_types; 211 sampler_types.reserve(names_string.size()); 212 for (const auto & c : names_string) { 213 const auto sampler_item = sampler_name_map.find(c); 214 if (sampler_item != sampler_name_map.end()) { 215 sampler_types.push_back(sampler_item->second); 216 } 217 } 218 return sampler_types; 219 } 220 221 // no reasons to expose this function in header 222 static void sampler_queue( 223 struct llama_context * ctx_main, 224 const llama_sampling_params & params, 225 llama_token_data_array & cur_p, 226 size_t min_keep) { 227 const float temp = params.temp; 228 const float dynatemp_range = params.dynatemp_range; 229 const float dynatemp_exponent = params.dynatemp_exponent; 230 const int32_t top_k = params.top_k; 231 const float top_p = params.top_p; 232 const float min_p = params.min_p; 233 const float tfs_z = params.tfs_z; 234 const float typical_p = params.typical_p; 235 const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence; 236 237 for (auto sampler_type : samplers_sequence) { 238 switch (sampler_type) { 239 case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; 240 case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break; 241 case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; 242 case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; 243 case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; 244 case llama_sampler_type::TEMPERATURE: 245 if (dynatemp_range > 0) { 246 float dynatemp_min = std::max(0.0f, temp - dynatemp_range); 247 float dynatemp_max = std::max(0.0f, temp + dynatemp_range); 248 llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent); 249 } else { 250 llama_sample_temp(ctx_main, &cur_p, temp); 251 } 252 break; 253 default : break; 254 } 255 } 256 } 257 258 static llama_token llama_sampling_sample_impl( 259 struct llama_sampling_context * ctx_sampling, 260 struct llama_context * ctx_main, 261 struct llama_context * ctx_cfg, 262 const int idx, 263 bool is_resampling) { 264 const llama_sampling_params & params = ctx_sampling->params; 265 266 const float temp = params.temp; 267 const int mirostat = params.mirostat; 268 const float mirostat_tau = params.mirostat_tau; 269 const float mirostat_eta = params.mirostat_eta; 270 271 std::vector<float> original_logits; 272 auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits); 273 if (ctx_sampling->grammar != NULL && !is_resampling) { 274 GGML_ASSERT(!original_logits.empty()); 275 } 276 llama_token id = 0; 277 // Get a pointer to the logits 278 float * logits = llama_get_logits_ith(ctx_main, idx); 279 280 if (temp < 0.0) { 281 // greedy sampling, with probs 282 llama_sample_softmax(ctx_main, &cur_p); 283 id = cur_p.data[0].id; 284 } else if (temp == 0.0) { 285 // greedy sampling, no probs 286 id = llama_sample_token_greedy(ctx_main, &cur_p); 287 } else { 288 if (mirostat == 1) { 289 const int mirostat_m = 100; 290 llama_sample_temp(ctx_main, &cur_p, temp); 291 id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu); 292 } else if (mirostat == 2) { 293 llama_sample_temp(ctx_main, &cur_p, temp); 294 id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); 295 } else { 296 // temperature sampling 297 size_t min_keep = std::max(1, params.min_keep); 298 299 sampler_queue(ctx_main, params, cur_p, min_keep); 300 301 id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng); 302 303 //{ 304 // const int n_top = 10; 305 // LOG("top %d candidates:\n", n_top); 306 307 // for (int i = 0; i < n_top; i++) { 308 // const llama_token id = cur_p.data[i].id; 309 // (void)id; // To avoid a warning that id is unused when logging is disabled. 310 // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p); 311 // } 312 //} 313 314 //LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str()); 315 } 316 } 317 318 if (ctx_sampling->grammar != NULL && !is_resampling) { 319 // Create an array with a single token data element for the sampled id 320 llama_token_data single_token_data = {id, logits[id], 0.0f}; 321 llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; 322 323 // Apply grammar constraints to the single token 324 llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar); 325 326 // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY 327 bool is_valid = single_token_data_array.data[0].logit != -INFINITY; 328 329 // If the token is not valid according to the grammar, perform resampling 330 if (!is_valid) { 331 LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str()); 332 333 // Restore logits from the copy 334 std::copy(original_logits.begin(), original_logits.end(), logits); 335 336 return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true); 337 } 338 } 339 340 ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size; 341 342 return id; 343 } 344 345 static llama_token_data_array llama_sampling_prepare_impl( 346 struct llama_sampling_context * ctx_sampling, 347 struct llama_context * ctx_main, 348 struct llama_context * ctx_cfg, 349 const int idx, 350 bool apply_grammar, 351 std::vector<float> * original_logits) { 352 const llama_sampling_params & params = ctx_sampling->params; 353 354 const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); 355 356 const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; 357 const float penalty_repeat = params.penalty_repeat; 358 const float penalty_freq = params.penalty_freq; 359 const float penalty_present = params.penalty_present; 360 361 const bool penalize_nl = params.penalize_nl; 362 363 auto & prev = ctx_sampling->prev; 364 auto & cur = ctx_sampling->cur; 365 366 // Get a pointer to the logits 367 float * logits = llama_get_logits_ith(ctx_main, idx); 368 369 if (ctx_sampling->grammar != NULL && !apply_grammar) { 370 GGML_ASSERT(original_logits != NULL); 371 // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this. 372 *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))}; 373 } 374 375 // apply params.logit_bias map 376 for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { 377 logits[it->first] += it->second; 378 } 379 380 if (ctx_cfg) { 381 float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); 382 llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); 383 } 384 385 cur.clear(); 386 387 for (llama_token token_id = 0; token_id < n_vocab; token_id++) { 388 cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); 389 } 390 391 llama_token_data_array cur_p = { cur.data(), cur.size(), false }; 392 393 // apply penalties 394 const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; 395 const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); 396 if (penalty_tokens_used_size) { 397 const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; 398 399 llama_sample_repetition_penalties(ctx_main, &cur_p, 400 penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, 401 penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); 402 403 if (!penalize_nl) { 404 for (size_t idx = 0; idx < cur_p.size; idx++) { 405 if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { 406 cur_p.data[idx].logit = nl_logit; 407 break; 408 } 409 } 410 } 411 } 412 413 // apply grammar checks before sampling logic 414 if (apply_grammar && ctx_sampling->grammar != NULL) { 415 llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); 416 } 417 418 return cur_p; 419 } 420 421 llama_token llama_sampling_sample( 422 struct llama_sampling_context * ctx_sampling, 423 struct llama_context * ctx_main, 424 struct llama_context * ctx_cfg, 425 const int idx) { 426 // Call the implementation function with is_resampling set to false by default 427 return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false); 428 } 429 430 llama_token_data_array llama_sampling_prepare( 431 struct llama_sampling_context * ctx_sampling, 432 struct llama_context * ctx_main, 433 struct llama_context * ctx_cfg, 434 const int idx, 435 bool apply_grammar, 436 std::vector<float> * original_logits) { 437 return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits); 438 } 439 440 void llama_sampling_accept( 441 struct llama_sampling_context * ctx_sampling, 442 struct llama_context * ctx_main, 443 llama_token id, 444 bool apply_grammar) { 445 ctx_sampling->prev.erase(ctx_sampling->prev.begin()); 446 ctx_sampling->prev.push_back(id); 447 448 if (ctx_sampling->grammar != NULL && apply_grammar) { 449 llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); 450 } 451 }