sampling.h
1 #pragma once 2 3 #include "llama.h" 4 5 #include "grammar-parser.h" 6 7 #include <random> 8 #include <string> 9 #include <unordered_map> 10 #include <vector> 11 12 // sampler types 13 enum class llama_sampler_type : char { 14 TOP_K = 'k', 15 TOP_P = 'p', 16 MIN_P = 'm', 17 TFS_Z = 'f', 18 TYPICAL_P = 'y', 19 TEMPERATURE = 't' 20 }; 21 22 // sampling parameters 23 typedef struct llama_sampling_params { 24 int32_t n_prev = 64; // number of previous tokens to remember 25 int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. 26 int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens 27 int32_t top_k = 40; // <= 0 to use vocab size 28 float top_p = 0.95f; // 1.0 = disabled 29 float min_p = 0.05f; // 0.0 = disabled 30 float tfs_z = 1.00f; // 1.0 = disabled 31 float typical_p = 1.00f; // 1.0 = disabled 32 float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities 33 float dynatemp_range = 0.00f; // 0.0 = disabled 34 float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler 35 int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) 36 float penalty_repeat = 1.00f; // 1.0 = disabled 37 float penalty_freq = 0.00f; // 0.0 = disabled 38 float penalty_present = 0.00f; // 0.0 = disabled 39 int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 40 float mirostat_tau = 5.00f; // target entropy 41 float mirostat_eta = 0.10f; // learning rate 42 bool penalize_nl = false; // consider newlines as a repeatable token 43 uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context 44 45 std::vector<llama_sampler_type> samplers_sequence = { 46 llama_sampler_type::TOP_K, 47 llama_sampler_type::TFS_Z, 48 llama_sampler_type::TYPICAL_P, 49 llama_sampler_type::TOP_P, 50 llama_sampler_type::MIN_P, 51 llama_sampler_type::TEMPERATURE 52 }; 53 54 std::string grammar; // optional BNF-like grammar to constrain sampling 55 56 // Classifier-Free Guidance 57 // https://arxiv.org/abs/2306.17806 58 std::string cfg_negative_prompt; // string to help guidance 59 float cfg_scale = 1.f; // how strong is guidance 60 61 std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens 62 63 std::vector<llama_token> penalty_prompt_tokens; 64 bool use_penalty_prompt_tokens = false; 65 } llama_sampling_params; 66 67 // general sampler context 68 // TODO: move to llama.h 69 struct llama_sampling_context { 70 // parameters that will be used for sampling 71 llama_sampling_params params; 72 73 // mirostat sampler state 74 float mirostat_mu; 75 76 llama_grammar * grammar; 77 78 // internal 79 grammar_parser::parse_state parsed_grammar; 80 81 // TODO: replace with ring-buffer 82 std::vector<llama_token> prev; 83 std::vector<llama_token_data> cur; 84 size_t n_valid; // Number of correct top tokens with correct probabilities. 85 86 std::mt19937 rng; 87 }; 88 89 #include "common.h" 90 91 // Create a new sampling context instance. 92 struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params); 93 94 void llama_sampling_free(struct llama_sampling_context * ctx); 95 96 // Reset the sampler context 97 // - clear prev tokens 98 // - reset grammar 99 void llama_sampling_reset(llama_sampling_context * ctx); 100 101 // Set the sampler seed 102 void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed); 103 104 // Copy the sampler context 105 void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); 106 107 // Get the last sampled token 108 llama_token llama_sampling_last(llama_sampling_context * ctx); 109 110 // Get a string representation of the last sampled tokens 111 std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n); 112 113 // Print sampling parameters into a string 114 std::string llama_sampling_print(const llama_sampling_params & params); 115 116 // Print sampling order into a string 117 std::string llama_sampling_order_print(const llama_sampling_params & params); 118 119 std::string llama_sampling_type_to_str(llama_sampler_type sampler_type); 120 121 std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names); 122 std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string); 123 124 // this is a common sampling function used across the examples for convenience 125 // it can serve as a starting point for implementing your own sampling function 126 // Note: When using multiple sequences, it is the caller's responsibility to call 127 // llama_sampling_reset when a sequence ends 128 // 129 // required: 130 // - ctx_main: context to use for sampling 131 // - ctx_sampling: sampling-specific context 132 // 133 // optional: 134 // - ctx_cfg: context to use for classifier-free guidance 135 // - idx: sample from llama_get_logits_ith(ctx, idx) 136 // 137 // returns: 138 // - token: sampled token 139 // - candidates: vector of candidate tokens 140 // 141 llama_token llama_sampling_sample( 142 struct llama_sampling_context * ctx_sampling, 143 struct llama_context * ctx_main, 144 struct llama_context * ctx_cfg, 145 int idx = -1); 146 147 // Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters. 148 llama_token_data_array llama_sampling_prepare( 149 struct llama_sampling_context * ctx_sampling, 150 struct llama_context * ctx_main, 151 struct llama_context * ctx_cfg, 152 int idx = 0, 153 bool apply_grammar = true, 154 std::vector<float> * original_logits = nullptr); 155 156 void llama_sampling_accept( 157 struct llama_sampling_context * ctx_sampling, 158 struct llama_context * ctx_main, 159 llama_token id, 160 bool apply_grammar);