/ common / sampling.h
sampling.h
  1  #pragma once
  2  
  3  #include "llama.h"
  4  
  5  #include "grammar-parser.h"
  6  
  7  #include <random>
  8  #include <string>
  9  #include <unordered_map>
 10  #include <vector>
 11  
 12  // sampler types
 13  enum class llama_sampler_type : char {
 14      TOP_K       = 'k',
 15      TOP_P       = 'p',
 16      MIN_P       = 'm',
 17      TFS_Z       = 'f',
 18      TYPICAL_P   = 'y',
 19      TEMPERATURE = 't'
 20  };
 21  
 22  // sampling parameters
 23  typedef struct llama_sampling_params {
 24      int32_t     n_prev                = 64;                 // number of previous tokens to remember
 25      int32_t     n_probs               = 0;                  // if greater than 0, output the probabilities of top n_probs tokens.
 26      int32_t     min_keep              = 0;                  // 0 = disabled, otherwise samplers should return at least min_keep tokens
 27      int32_t     top_k                 = 40;                 // <= 0 to use vocab size
 28      float       top_p                 = 0.95f;              // 1.0 = disabled
 29      float       min_p                 = 0.05f;              // 0.0 = disabled
 30      float       tfs_z                 = 1.00f;              // 1.0 = disabled
 31      float       typical_p             = 1.00f;              // 1.0 = disabled
 32      float       temp                  = 0.80f;              // <= 0.0 to sample greedily, 0.0 to not output probabilities
 33      float       dynatemp_range        = 0.00f;              // 0.0 = disabled
 34      float       dynatemp_exponent     = 1.00f;              // controls how entropy maps to temperature in dynamic temperature sampler
 35      int32_t     penalty_last_n        = 64;                 // last n tokens to penalize (0 = disable penalty, -1 = context size)
 36      float       penalty_repeat        = 1.00f;              // 1.0 = disabled
 37      float       penalty_freq          = 0.00f;              // 0.0 = disabled
 38      float       penalty_present       = 0.00f;              // 0.0 = disabled
 39      int32_t     mirostat              = 0;                  // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
 40      float       mirostat_tau          = 5.00f;              // target entropy
 41      float       mirostat_eta          = 0.10f;              // learning rate
 42      bool        penalize_nl           = false;              // consider newlines as a repeatable token
 43      uint32_t    seed                  = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
 44  
 45      std::vector<llama_sampler_type> samplers_sequence = {
 46          llama_sampler_type::TOP_K,
 47          llama_sampler_type::TFS_Z,
 48          llama_sampler_type::TYPICAL_P,
 49          llama_sampler_type::TOP_P,
 50          llama_sampler_type::MIN_P,
 51          llama_sampler_type::TEMPERATURE
 52      };
 53  
 54      std::string grammar;  // optional BNF-like grammar to constrain sampling
 55  
 56      // Classifier-Free Guidance
 57      // https://arxiv.org/abs/2306.17806
 58      std::string cfg_negative_prompt; // string to help guidance
 59      float       cfg_scale     = 1.f; // how strong is guidance
 60  
 61      std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
 62  
 63      std::vector<llama_token> penalty_prompt_tokens;
 64      bool                     use_penalty_prompt_tokens = false;
 65  } llama_sampling_params;
 66  
 67  // general sampler context
 68  // TODO: move to llama.h
 69  struct llama_sampling_context {
 70      // parameters that will be used for sampling
 71      llama_sampling_params params;
 72  
 73      // mirostat sampler state
 74      float mirostat_mu;
 75  
 76      llama_grammar * grammar;
 77  
 78      // internal
 79      grammar_parser::parse_state parsed_grammar;
 80  
 81      // TODO: replace with ring-buffer
 82      std::vector<llama_token>      prev;
 83      std::vector<llama_token_data> cur;
 84      size_t n_valid; // Number of correct top tokens with correct probabilities.
 85  
 86      std::mt19937 rng;
 87  };
 88  
 89  #include "common.h"
 90  
 91  // Create a new sampling context instance.
 92  struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
 93  
 94  void llama_sampling_free(struct llama_sampling_context * ctx);
 95  
 96  // Reset the sampler context
 97  // - clear prev tokens
 98  // - reset grammar
 99  void llama_sampling_reset(llama_sampling_context * ctx);
100  
101  // Set the sampler seed
102  void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
103  
104  // Copy the sampler context
105  void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
106  
107  // Get the last sampled token
108  llama_token llama_sampling_last(llama_sampling_context * ctx);
109  
110  // Get a string representation of the last sampled tokens
111  std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
112  
113  // Print sampling parameters into a string
114  std::string llama_sampling_print(const llama_sampling_params & params);
115  
116  // Print sampling order into a string
117  std::string llama_sampling_order_print(const llama_sampling_params & params);
118  
119  std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);
120  
121  std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
122  std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string);
123  
124  // this is a common sampling function used across the examples for convenience
125  // it can serve as a starting point for implementing your own sampling function
126  // Note: When using multiple sequences, it is the caller's responsibility to call
127  //       llama_sampling_reset when a sequence ends
128  //
129  // required:
130  //  - ctx_main:     context to use for sampling
131  //  - ctx_sampling: sampling-specific context
132  //
133  // optional:
134  //  - ctx_cfg:      context to use for classifier-free guidance
135  //  - idx:          sample from llama_get_logits_ith(ctx, idx)
136  //
137  // returns:
138  //  - token:      sampled token
139  //  - candidates: vector of candidate tokens
140  //
141  llama_token llama_sampling_sample(
142          struct llama_sampling_context * ctx_sampling,
143          struct llama_context * ctx_main,
144          struct llama_context * ctx_cfg,
145          int idx = -1);
146  
147  // Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
148  llama_token_data_array llama_sampling_prepare(
149          struct llama_sampling_context * ctx_sampling,
150          struct llama_context * ctx_main,
151          struct llama_context * ctx_cfg,
152          int idx = 0,
153          bool apply_grammar = true,
154          std::vector<float> * original_logits = nullptr);
155  
156  void llama_sampling_accept(
157          struct llama_sampling_context * ctx_sampling,
158          struct llama_context * ctx_main,
159          llama_token id,
160          bool apply_grammar);