passkey.cpp
1 #include "common.h" 2 #include "llama.h" 3 4 #include <cmath> 5 #include <cstdio> 6 #include <string> 7 #include <vector> 8 9 static void print_usage(int argc, char ** argv, const gpt_params & params) { 10 gpt_params_print_usage(argc, argv, params); 11 12 LOG_TEE("\nexample usage:\n"); 13 LOG_TEE("\n %s -m model.gguf --junk 250 --pos 90 --keep 32 --grp-attn-n 2 [--seed 1234]\n", argv[0]); 14 LOG_TEE("\n"); 15 } 16 17 int main(int argc, char ** argv) { 18 gpt_params params; 19 20 params.n_junk = 250; 21 params.n_keep = 32; 22 params.i_pos = -1; 23 24 if (!gpt_params_parse(argc, argv, params)) { 25 print_usage(argc, argv, params); 26 return 1; 27 } 28 29 srand(params.seed == LLAMA_DEFAULT_SEED ? time(NULL) : params.seed); 30 31 int n_junk = params.n_junk; 32 int n_keep = params.n_keep; 33 int n_grp = params.grp_attn_n; 34 int i_pos = params.i_pos; 35 36 if (i_pos == -1) { 37 i_pos = rand() % n_junk; 38 } 39 40 const std::string prompt_prefix = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."; 41 const std::string prompt_suffix = " What is the pass key? The pass key is"; 42 43 // generate junk text 44 params.prompt = prompt_prefix; 45 46 const int passkey = rand() % 50000 + 1; 47 48 for (int i = 0; i < n_junk; i++) { 49 if (i % n_junk == i_pos) { 50 params.prompt += " The pass key is " + std::to_string(passkey) + ". Remember it. " + std::to_string(passkey) + " is the pass key."; 51 } 52 53 params.prompt += " The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."; 54 } 55 56 params.prompt += prompt_suffix; 57 58 // init LLM 59 60 llama_backend_init(); 61 llama_numa_init(params.numa); 62 63 // initialize the model 64 65 llama_model_params model_params = llama_model_params_from_gpt_params(params); 66 67 llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); 68 69 if (model == NULL) { 70 fprintf(stderr , "%s: error: unable to load model\n" , __func__); 71 return 1; 72 } 73 74 // initialize the context 75 76 llama_context_params ctx_params = llama_context_params_from_gpt_params(params); 77 78 ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep; 79 80 GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp"); 81 82 llama_context * ctx = llama_new_context_with_model(model, ctx_params); 83 84 if (ctx == NULL) { 85 fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); 86 return 1; 87 } 88 89 // tokenize the prompt 90 std::vector<llama_token> tokens_list; 91 tokens_list = ::llama_tokenize(ctx, params.prompt, true); 92 93 // tokenize the prefix and use it as a sink 94 const int n_tokens_prefix = ::llama_tokenize(ctx, prompt_prefix, true).size(); 95 96 const int n_tokens_all = tokens_list.size(); 97 98 // we leave a margin of 16 tokens for the generated text - it should contain just the passkey 99 const int n_predict = 16; 100 101 // total length of the sequences including the prompt 102 const int n_len = n_tokens_all + n_predict; 103 104 const int n_ctx = llama_n_ctx(ctx) - n_keep; 105 const int n_kv_req = llama_n_ctx(ctx); 106 const int n_batch = ctx_params.n_batch; 107 const int n_batch_grp = ctx_params.n_batch/n_grp; 108 109 LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d, n_junk = %d, i_pos = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch, n_junk, i_pos); 110 111 // print the prompt token-by-token 112 113 LOG_TEE("\n"); 114 LOG_TEE("prefix tokens: %d\n", n_tokens_prefix); 115 LOG_TEE("prompt tokens: %d\n", n_tokens_all); 116 //LOG_TEE("prompt: %s\n", params.prompt.c_str()); 117 118 llama_batch batch = llama_batch_init(params.n_batch, 0, 1); 119 120 int n_past = 0; 121 122 // fill the KV cache 123 for (int i = 0; i < n_ctx; i += n_batch) { 124 if (i > 0 && n_grp > 1) { 125 // if SelfExtend is enabled, we compress the position from the last batch by a factor of n_grp 126 const int ib = i/n_batch - 1; 127 const int bd = n_batch_grp*(n_grp - 1); 128 129 llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); 130 llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); 131 llama_kv_cache_update (ctx); 132 133 n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; 134 } 135 136 llama_batch_clear(batch); 137 138 for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { 139 llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); 140 } 141 142 if (i + n_batch >= n_tokens_all) { 143 batch.logits[batch.n_tokens - 1] = true; 144 } 145 146 if (llama_decode(ctx, batch) != 0) { 147 LOG_TEE("%s: llama_decode() failed\n", __func__); 148 return 1; 149 } 150 151 LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all)); 152 153 if (i + n_batch >= n_tokens_all) { 154 break; 155 } 156 } 157 158 for (int i = n_ctx; i < n_tokens_all; i += n_batch) { 159 const int n_discard = n_batch; 160 161 LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard); 162 163 llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); 164 llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); 165 //llama_kv_cache_defrag (ctx); 166 llama_kv_cache_update (ctx); 167 168 n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; 169 170 llama_batch_clear(batch); 171 172 for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { 173 llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); 174 } 175 176 if (i + n_batch >= n_tokens_all) { 177 batch.logits[batch.n_tokens - 1] = true; 178 } 179 180 if (llama_decode(ctx, batch) != 0) { 181 LOG_TEE("%s: llama_decode() failed\n", __func__); 182 return 1; 183 } 184 185 LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all)); 186 } 187 188 { 189 const int n_discard = n_past - n_ctx + n_predict; 190 191 if (n_discard > 0) { 192 LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); 193 194 llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); 195 llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); 196 //llama_kv_cache_defrag (ctx); 197 llama_kv_cache_update (ctx); 198 199 n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; 200 } 201 } 202 203 LOG_TEE("\n"); 204 LOG_TEE("%s: passkey = %d, inserted at position %d / %d (token pos: ~%d)\n", __func__, passkey, i_pos, n_junk, (i_pos * n_tokens_all) / n_junk); 205 LOG_TEE("\n"); 206 207 // main loop 208 209 int n_cur = n_tokens_all; 210 int n_decode = 0; 211 212 LOG_TEE("%s", prompt_suffix.c_str()); 213 fflush(stdout); 214 215 const auto t_main_start = ggml_time_us(); 216 217 while (n_cur <= n_len) { 218 // sample the next token 219 { 220 auto n_vocab = llama_n_vocab(model); 221 auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); 222 223 std::vector<llama_token_data> candidates; 224 candidates.reserve(n_vocab); 225 226 for (llama_token token_id = 0; token_id < n_vocab; token_id++) { 227 candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); 228 } 229 230 llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; 231 232 // sample the most likely token 233 const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); 234 235 // is it an end of generation? 236 if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { 237 LOG_TEE("\n"); 238 239 break; 240 } 241 242 LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str()); 243 fflush(stdout); 244 245 n_decode += 1; 246 247 // prepare the next batch 248 llama_batch_clear(batch); 249 250 // push this new token for next evaluation 251 llama_batch_add(batch, new_token_id, n_past++, { 0 }, true); 252 } 253 254 n_cur += 1; 255 256 // evaluate the current batch with the transformer model 257 if (llama_decode(ctx, batch)) { 258 fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); 259 return 1; 260 } 261 } 262 263 LOG_TEE("\n"); 264 265 const auto t_main_end = ggml_time_us(); 266 267 LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", 268 __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); 269 270 llama_print_timings(ctx); 271 272 fprintf(stderr, "\n"); 273 274 llama_batch_free(batch); 275 276 llama_free(ctx); 277 llama_free_model(model); 278 279 llama_backend_free(); 280 281 return 0; 282 }