simple.cpp
1 #include "common.h" 2 #include "llama.h" 3 4 #include <cmath> 5 #include <cstdio> 6 #include <string> 7 #include <vector> 8 9 static void print_usage(int argc, char ** argv, const gpt_params & params) { 10 gpt_params_print_usage(argc, argv, params); 11 12 LOG_TEE("\nexample usage:\n"); 13 LOG_TEE("\n %s -m model.gguf -p \"Hello my name is\" -n 32\n", argv[0]); 14 LOG_TEE("\n"); 15 } 16 17 int main(int argc, char ** argv) { 18 gpt_params params; 19 20 params.prompt = "Hello my name is"; 21 params.n_predict = 32; 22 23 if (!gpt_params_parse(argc, argv, params)) { 24 print_usage(argc, argv, params); 25 return 1; 26 } 27 28 // total length of the sequence including the prompt 29 const int n_predict = params.n_predict; 30 31 // init LLM 32 33 llama_backend_init(); 34 llama_numa_init(params.numa); 35 36 // initialize the model 37 38 llama_model_params model_params = llama_model_params_from_gpt_params(params); 39 40 llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); 41 42 if (model == NULL) { 43 fprintf(stderr , "%s: error: unable to load model\n" , __func__); 44 return 1; 45 } 46 47 // initialize the context 48 49 llama_context_params ctx_params = llama_context_params_from_gpt_params(params); 50 51 llama_context * ctx = llama_new_context_with_model(model, ctx_params); 52 53 if (ctx == NULL) { 54 fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); 55 return 1; 56 } 57 58 // tokenize the prompt 59 60 std::vector<llama_token> tokens_list; 61 tokens_list = ::llama_tokenize(ctx, params.prompt, true); 62 63 const int n_ctx = llama_n_ctx(ctx); 64 const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size()); 65 66 LOG_TEE("\n%s: n_predict = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_predict, n_ctx, n_kv_req); 67 68 // make sure the KV cache is big enough to hold all the prompt and generated tokens 69 if (n_kv_req > n_ctx) { 70 LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__); 71 LOG_TEE("%s: either reduce n_predict or increase n_ctx\n", __func__); 72 return 1; 73 } 74 75 // print the prompt token-by-token 76 77 fprintf(stderr, "\n"); 78 79 for (auto id : tokens_list) { 80 fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); 81 } 82 83 fflush(stderr); 84 85 // create a llama_batch with size 512 86 // we use this object to submit token data for decoding 87 88 llama_batch batch = llama_batch_init(512, 0, 1); 89 90 // evaluate the initial prompt 91 for (size_t i = 0; i < tokens_list.size(); i++) { 92 llama_batch_add(batch, tokens_list[i], i, { 0 }, false); 93 } 94 95 // llama_decode will output logits only for the last token of the prompt 96 batch.logits[batch.n_tokens - 1] = true; 97 98 if (llama_decode(ctx, batch) != 0) { 99 LOG_TEE("%s: llama_decode() failed\n", __func__); 100 return 1; 101 } 102 103 // main loop 104 105 int n_cur = batch.n_tokens; 106 int n_decode = 0; 107 108 const auto t_main_start = ggml_time_us(); 109 110 while (n_cur <= n_predict) { 111 // sample the next token 112 { 113 auto n_vocab = llama_n_vocab(model); 114 auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); 115 116 std::vector<llama_token_data> candidates; 117 candidates.reserve(n_vocab); 118 119 for (llama_token token_id = 0; token_id < n_vocab; token_id++) { 120 candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); 121 } 122 123 llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; 124 125 // sample the most likely token 126 const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); 127 128 // is it an end of generation? 129 if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { 130 LOG_TEE("\n"); 131 132 break; 133 } 134 135 LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str()); 136 fflush(stdout); 137 138 // prepare the next batch 139 llama_batch_clear(batch); 140 141 // push this new token for next evaluation 142 llama_batch_add(batch, new_token_id, n_cur, { 0 }, true); 143 144 n_decode += 1; 145 } 146 147 n_cur += 1; 148 149 // evaluate the current batch with the transformer model 150 if (llama_decode(ctx, batch)) { 151 fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); 152 return 1; 153 } 154 } 155 156 LOG_TEE("\n"); 157 158 const auto t_main_end = ggml_time_us(); 159 160 LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", 161 __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); 162 163 llama_print_timings(ctx); 164 165 fprintf(stderr, "\n"); 166 167 llama_batch_free(batch); 168 169 llama_free(ctx); 170 llama_free_model(model); 171 172 llama_backend_free(); 173 174 return 0; 175 }