/ common / sampling.cpp
sampling.cpp
  1  #define LLAMA_API_INTERNAL
  2  #include "sampling.h"
  3  #include <random>
  4  
  5  struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
  6      struct llama_sampling_context * result = new llama_sampling_context();
  7  
  8      result->params  = params;
  9      result->grammar = nullptr;
 10  
 11      // if there is a grammar, parse it
 12      if (!params.grammar.empty()) {
 13          result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
 14  
 15          // will be empty (default) if there are parse errors
 16          if (result->parsed_grammar.rules.empty()) {
 17              fprintf(stderr, "%s: failed to parse grammar\n", __func__);
 18              delete result;
 19              return nullptr;
 20          }
 21  
 22          // Ensure that there is a "root" node.
 23          if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) {
 24              fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
 25              delete result;
 26              return nullptr;
 27          }
 28  
 29          std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
 30  
 31          result->grammar = llama_grammar_init(
 32                  grammar_rules.data(),
 33                  grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
 34      }
 35  
 36      result->prev.resize(params.n_prev);
 37  
 38      result->n_valid = 0;
 39  
 40      llama_sampling_set_rng_seed(result, params.seed);
 41  
 42      return result;
 43  }
 44  
 45  void llama_sampling_free(struct llama_sampling_context * ctx) {
 46      if (ctx->grammar != NULL) {
 47          llama_grammar_free(ctx->grammar);
 48      }
 49  
 50      delete ctx;
 51  }
 52  
 53  void llama_sampling_reset(llama_sampling_context * ctx) {
 54      if (ctx->grammar != NULL) {
 55          llama_grammar_free(ctx->grammar);
 56          ctx->grammar = NULL;
 57      }
 58  
 59      if (!ctx->parsed_grammar.rules.empty()) {
 60          std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
 61  
 62          ctx->grammar = llama_grammar_init(
 63                  grammar_rules.data(),
 64                  grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
 65      }
 66  
 67      std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
 68      ctx->cur.clear();
 69      ctx->n_valid = 0;
 70  }
 71  
 72  void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
 73      if (seed == LLAMA_DEFAULT_SEED) {
 74          seed = std::random_device{}();
 75      }
 76      ctx->rng.seed(seed);
 77  }
 78  
 79  void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
 80      if (dst->grammar) {
 81          llama_grammar_free(dst->grammar);
 82          dst->grammar = nullptr;
 83      }
 84  
 85      if (src->grammar) {
 86          dst->grammar = llama_grammar_copy(src->grammar);
 87      }
 88  
 89      dst->prev = src->prev;
 90  }
 91  
 92  llama_token llama_sampling_last(llama_sampling_context * ctx) {
 93      return ctx->prev.back();
 94  }
 95  
 96  std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) {
 97      const int size = ctx_sampling->prev.size();
 98  
 99      n = std::min(n, size);
100  
101      std::string result;
102  
103      for (int i = size - n; i < size; i++) {
104          result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]);
105      }
106  
107      return result;
108  }
109  
110  std::string llama_sampling_print(const llama_sampling_params & params) {
111      char result[1024];
112  
113      snprintf(result, sizeof(result),
114              "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
115              "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
116              "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
117              params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
118              params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
119              params.mirostat, params.mirostat_eta, params.mirostat_tau);
120  
121      return std::string(result);
122  }
123  
124  std::string llama_sampling_order_print(const llama_sampling_params & params) {
125      std::string result = "CFG -> Penalties ";
126      if (params.mirostat == 0) {
127          for (auto sampler_type : params.samplers_sequence) {
128              const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
129              if (!sampler_type_name.empty()) {
130                  result += "-> " + sampler_type_name + " ";
131              }
132          }
133      } else {
134          result += "-> mirostat ";
135      }
136  
137      return result;
138  }
139  
140  std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
141      switch (sampler_type) {
142          case llama_sampler_type::TOP_K:       return "top_k";
143          case llama_sampler_type::TFS_Z:       return "tfs_z";
144          case llama_sampler_type::TYPICAL_P:   return "typical_p";
145          case llama_sampler_type::TOP_P:       return "top_p";
146          case llama_sampler_type::MIN_P:       return "min_p";
147          case llama_sampler_type::TEMPERATURE: return "temperature";
148          default : return "";
149      }
150  }
151  
152  std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
153      std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
154          {"top_k",       llama_sampler_type::TOP_K},
155          {"top_p",       llama_sampler_type::TOP_P},
156          {"typical_p",   llama_sampler_type::TYPICAL_P},
157          {"min_p",       llama_sampler_type::MIN_P},
158          {"tfs_z",       llama_sampler_type::TFS_Z},
159          {"temperature", llama_sampler_type::TEMPERATURE}
160      };
161  
162      // since samplers names are written multiple ways
163      // make it ready for both system names and input names
164      std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
165          {"top-k",       llama_sampler_type::TOP_K},
166          {"top-p",       llama_sampler_type::TOP_P},
167          {"nucleus",     llama_sampler_type::TOP_P},
168          {"typical-p",   llama_sampler_type::TYPICAL_P},
169          {"typical",     llama_sampler_type::TYPICAL_P},
170          {"min-p",       llama_sampler_type::MIN_P},
171          {"tfs-z",       llama_sampler_type::TFS_Z},
172          {"tfs",         llama_sampler_type::TFS_Z},
173          {"temp",        llama_sampler_type::TEMPERATURE}
174      };
175  
176      std::vector<llama_sampler_type> sampler_types;
177      sampler_types.reserve(names.size());
178      for (const auto & name : names)
179      {
180          auto sampler_item = sampler_canonical_name_map.find(name);
181          if (sampler_item != sampler_canonical_name_map.end())
182          {
183              sampler_types.push_back(sampler_item->second);
184          }
185          else
186          {
187              if (allow_alt_names)
188              {
189                  sampler_item = sampler_alt_name_map.find(name);
190                  if (sampler_item != sampler_alt_name_map.end())
191                  {
192                      sampler_types.push_back(sampler_item->second);
193                  }
194              }
195          }
196      }
197      return sampler_types;
198  }
199  
200  std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string) {
201      std::unordered_map<char, llama_sampler_type> sampler_name_map {
202          {'k', llama_sampler_type::TOP_K},
203          {'p', llama_sampler_type::TOP_P},
204          {'y', llama_sampler_type::TYPICAL_P},
205          {'m', llama_sampler_type::MIN_P},
206          {'f', llama_sampler_type::TFS_Z},
207          {'t', llama_sampler_type::TEMPERATURE}
208      };
209  
210      std::vector<llama_sampler_type> sampler_types;
211      sampler_types.reserve(names_string.size());
212      for (const auto & c : names_string) {
213          const auto sampler_item = sampler_name_map.find(c);
214          if (sampler_item != sampler_name_map.end()) {
215              sampler_types.push_back(sampler_item->second);
216          }
217      }
218      return sampler_types;
219  }
220  
221  // no reasons to expose this function in header
222  static void sampler_queue(
223                     struct llama_context * ctx_main,
224              const llama_sampling_params & params,
225                   llama_token_data_array & cur_p,
226                                   size_t   min_keep) {
227      const float         temp              = params.temp;
228      const float         dynatemp_range    = params.dynatemp_range;
229      const float         dynatemp_exponent = params.dynatemp_exponent;
230      const int32_t       top_k             = params.top_k;
231      const float         top_p             = params.top_p;
232      const float         min_p             = params.min_p;
233      const float         tfs_z             = params.tfs_z;
234      const float         typical_p         = params.typical_p;
235      const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
236  
237      for (auto sampler_type : samplers_sequence) {
238          switch (sampler_type) {
239              case llama_sampler_type::TOP_K    : llama_sample_top_k    (ctx_main, &cur_p, top_k,     min_keep); break;
240              case llama_sampler_type::TFS_Z    : llama_sample_tail_free(ctx_main, &cur_p, tfs_z,     min_keep); break;
241              case llama_sampler_type::TYPICAL_P: llama_sample_typical  (ctx_main, &cur_p, typical_p, min_keep); break;
242              case llama_sampler_type::TOP_P    : llama_sample_top_p    (ctx_main, &cur_p, top_p,     min_keep); break;
243              case llama_sampler_type::MIN_P    : llama_sample_min_p    (ctx_main, &cur_p, min_p,     min_keep); break;
244              case llama_sampler_type::TEMPERATURE:
245                  if (dynatemp_range > 0) {
246                      float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
247                      float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
248                      llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
249                  } else {
250                      llama_sample_temp(ctx_main, &cur_p, temp);
251                  }
252                  break;
253              default : break;
254          }
255      }
256  }
257  
258  static llama_token llama_sampling_sample_impl(
259                    struct llama_sampling_context * ctx_sampling,
260                    struct llama_context * ctx_main,
261                    struct llama_context * ctx_cfg,
262                    const int idx,
263                    bool is_resampling) {
264      const llama_sampling_params & params = ctx_sampling->params;
265  
266      const float   temp            = params.temp;
267      const int     mirostat        = params.mirostat;
268      const float   mirostat_tau    = params.mirostat_tau;
269      const float   mirostat_eta    = params.mirostat_eta;
270  
271      std::vector<float> original_logits;
272      auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
273      if (ctx_sampling->grammar != NULL && !is_resampling) {
274          GGML_ASSERT(!original_logits.empty());
275      }
276      llama_token id = 0;
277      // Get a pointer to the logits
278      float * logits = llama_get_logits_ith(ctx_main, idx);
279  
280      if (temp < 0.0) {
281          // greedy sampling, with probs
282          llama_sample_softmax(ctx_main, &cur_p);
283          id = cur_p.data[0].id;
284      } else if (temp == 0.0) {
285          // greedy sampling, no probs
286          id = llama_sample_token_greedy(ctx_main, &cur_p);
287      } else {
288          if (mirostat == 1) {
289              const int mirostat_m = 100;
290              llama_sample_temp(ctx_main, &cur_p, temp);
291              id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
292          } else if (mirostat == 2) {
293              llama_sample_temp(ctx_main, &cur_p, temp);
294              id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
295          } else {
296              // temperature sampling
297              size_t min_keep = std::max(1, params.min_keep);
298  
299              sampler_queue(ctx_main, params, cur_p, min_keep);
300  
301              id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
302  
303              //{
304              //    const int n_top = 10;
305              //    LOG("top %d candidates:\n", n_top);
306  
307              //    for (int i = 0; i < n_top; i++) {
308              //        const llama_token id = cur_p.data[i].id;
309              //        (void)id; // To avoid a warning that id is unused when logging is disabled.
310              //        LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
311              //    }
312              //}
313  
314              //LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str());
315          }
316      }
317  
318      if (ctx_sampling->grammar != NULL && !is_resampling) {
319          // Create an array with a single token data element for the sampled id
320          llama_token_data single_token_data = {id, logits[id], 0.0f};
321          llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
322  
323          // Apply grammar constraints to the single token
324          llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);
325  
326          // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
327          bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
328  
329          // If the token is not valid according to the grammar, perform resampling
330          if (!is_valid) {
331              LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
332  
333              // Restore logits from the copy
334              std::copy(original_logits.begin(), original_logits.end(), logits);
335  
336              return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true);
337          }
338      }
339  
340      ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size;
341  
342      return id;
343  }
344  
345  static llama_token_data_array llama_sampling_prepare_impl(
346                    struct llama_sampling_context * ctx_sampling,
347                    struct llama_context * ctx_main,
348                    struct llama_context * ctx_cfg,
349                    const int idx,
350                    bool apply_grammar,
351                    std::vector<float> * original_logits) {
352      const llama_sampling_params & params = ctx_sampling->params;
353  
354      const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
355  
356      const int32_t penalty_last_n  = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
357      const float   penalty_repeat  = params.penalty_repeat;
358      const float   penalty_freq    = params.penalty_freq;
359      const float   penalty_present = params.penalty_present;
360  
361      const bool    penalize_nl     = params.penalize_nl;
362  
363      auto & prev = ctx_sampling->prev;
364      auto & cur  = ctx_sampling->cur;
365  
366      // Get a pointer to the logits
367      float * logits = llama_get_logits_ith(ctx_main, idx);
368  
369      if (ctx_sampling->grammar != NULL && !apply_grammar) {
370          GGML_ASSERT(original_logits != NULL);
371          // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
372          *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
373      }
374  
375      // apply params.logit_bias map
376      for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
377          logits[it->first] += it->second;
378      }
379  
380      if (ctx_cfg) {
381          float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
382          llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
383      }
384  
385      cur.clear();
386  
387      for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
388          cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
389      }
390  
391      llama_token_data_array cur_p = { cur.data(), cur.size(), false };
392  
393      // apply penalties
394      const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
395      const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
396      if (penalty_tokens_used_size) {
397          const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
398  
399          llama_sample_repetition_penalties(ctx_main, &cur_p,
400                  penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
401                  penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
402  
403          if (!penalize_nl) {
404              for (size_t idx = 0; idx < cur_p.size; idx++) {
405                  if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
406                      cur_p.data[idx].logit = nl_logit;
407                      break;
408                  }
409              }
410          }
411      }
412  
413      // apply grammar checks before sampling logic
414      if (apply_grammar && ctx_sampling->grammar != NULL) {
415          llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
416      }
417  
418      return cur_p;
419  }
420  
421  llama_token llama_sampling_sample(
422                    struct llama_sampling_context * ctx_sampling,
423                    struct llama_context * ctx_main,
424                    struct llama_context * ctx_cfg,
425                    const int idx) {
426      // Call the implementation function with is_resampling set to false by default
427      return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false);
428  }
429  
430  llama_token_data_array llama_sampling_prepare(
431                    struct llama_sampling_context * ctx_sampling,
432                    struct llama_context * ctx_main,
433                    struct llama_context * ctx_cfg,
434                    const int idx,
435                    bool apply_grammar,
436                    std::vector<float> * original_logits) {
437      return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
438  }
439  
440  void llama_sampling_accept(
441          struct llama_sampling_context * ctx_sampling,
442          struct llama_context * ctx_main,
443          llama_token id,
444          bool apply_grammar) {
445      ctx_sampling->prev.erase(ctx_sampling->prev.begin());
446      ctx_sampling->prev.push_back(id);
447  
448      if (ctx_sampling->grammar != NULL && apply_grammar) {
449          llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
450      }
451  }