save-load-state.cpp
1 #include "common.h" 2 #include "llama.h" 3 4 #include <vector> 5 #include <cstdio> 6 #include <chrono> 7 8 int main(int argc, char ** argv) { 9 gpt_params params; 10 11 params.prompt = "The quick brown fox"; 12 13 if (!gpt_params_parse(argc, argv, params)) { 14 gpt_params_print_usage(argc, argv, params); 15 return 1; 16 } 17 18 print_build_info(); 19 20 if (params.n_predict < 0) { 21 params.n_predict = 16; 22 } 23 24 auto n_past = 0; 25 26 std::string result0; 27 std::string result1; 28 std::string result2; 29 30 // init 31 llama_model * model; 32 llama_context * ctx; 33 34 std::tie(model, ctx) = llama_init_from_gpt_params(params); 35 if (model == nullptr || ctx == nullptr) { 36 fprintf(stderr, "%s : failed to init\n", __func__); 37 return 1; 38 } 39 40 // tokenize prompt 41 auto tokens = llama_tokenize(ctx, params.prompt, true); 42 43 // evaluate prompt 44 llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), n_past, 0)); 45 n_past += tokens.size(); 46 47 // save state (rng, logits, embedding and kv_cache) to file 48 { 49 std::vector<uint8_t> state_mem(llama_state_get_size(ctx)); 50 const size_t written = llama_state_get_data(ctx, state_mem.data()); 51 52 FILE *fp_write = fopen("dump_state.bin", "wb"); 53 fwrite(state_mem.data(), 1, written, fp_write); 54 fclose(fp_write); 55 56 fprintf(stderr, "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size()); 57 } 58 59 // save state (last tokens) 60 const auto n_past_saved = n_past; 61 62 // first run 63 printf("\nfirst run: %s", params.prompt.c_str()); 64 65 for (auto i = 0; i < params.n_predict; i++) { 66 auto * logits = llama_get_logits(ctx); 67 auto n_vocab = llama_n_vocab(model); 68 69 std::vector<llama_token_data> candidates; 70 candidates.reserve(n_vocab); 71 for (llama_token token_id = 0; token_id < n_vocab; token_id++) { 72 candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); 73 } 74 llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; 75 auto next_token = llama_sample_token(ctx, &candidates_p); 76 auto next_token_str = llama_token_to_piece(ctx, next_token); 77 78 printf("%s", next_token_str.c_str()); 79 result0 += next_token_str; 80 81 if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) { 82 fprintf(stderr, "\n%s : failed to evaluate\n", __func__); 83 llama_free(ctx); 84 llama_free_model(model); 85 return 1; 86 } 87 n_past += 1; 88 } 89 90 printf("\n\n"); 91 92 // free old context 93 llama_free(ctx); 94 95 // make new context 96 auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); 97 98 printf("\nsecond run: %s", params.prompt.c_str()); 99 100 // load state (rng, logits, embedding and kv_cache) from file 101 { 102 std::vector<uint8_t> state_mem(llama_state_get_size(ctx2)); 103 104 FILE * fp_read = fopen("dump_state.bin", "rb"); 105 const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); 106 fclose(fp_read); 107 108 if (read != llama_state_set_data(ctx2, state_mem.data())) { 109 fprintf(stderr, "\n%s : failed to read state\n", __func__); 110 llama_free(ctx2); 111 llama_free_model(model); 112 return 1; 113 } 114 115 fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size()); 116 } 117 118 // restore state (last tokens) 119 n_past = n_past_saved; 120 121 // second run 122 for (auto i = 0; i < params.n_predict; i++) { 123 auto * logits = llama_get_logits(ctx2); 124 auto n_vocab = llama_n_vocab(model); 125 std::vector<llama_token_data> candidates; 126 candidates.reserve(n_vocab); 127 for (llama_token token_id = 0; token_id < n_vocab; token_id++) { 128 candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); 129 } 130 llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; 131 auto next_token = llama_sample_token(ctx2, &candidates_p); 132 auto next_token_str = llama_token_to_piece(ctx2, next_token); 133 134 printf("%s", next_token_str.c_str()); 135 result1 += next_token_str; 136 137 if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1, n_past, 0))) { 138 fprintf(stderr, "\n%s : failed to evaluate\n", __func__); 139 llama_free(ctx2); 140 llama_free_model(model); 141 return 1; 142 } 143 n_past += 1; 144 } 145 146 printf("\n\n"); 147 148 llama_free(ctx2); 149 150 if (result0 != result1) { 151 fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__); 152 return 1; 153 } 154 155 // make new context 156 auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); 157 158 printf("\nsingle seq run: %s", params.prompt.c_str()); 159 160 // load state (rng, logits, embedding and kv_cache) from file 161 { 162 std::vector<uint8_t> state_mem(llama_state_get_size(ctx3)); 163 164 FILE * fp_read = fopen("dump_state.bin", "rb"); 165 const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); 166 fclose(fp_read); 167 168 if (read != llama_state_set_data(ctx3, state_mem.data())) { 169 fprintf(stderr, "\n%s : failed to read state\n", __func__); 170 llama_free(ctx3); 171 llama_free_model(model); 172 return 1; 173 } 174 175 fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size()); 176 } 177 178 // restore state (last tokens) 179 n_past = n_past_saved; 180 181 // save seq 0 and load into seq 1 182 { 183 // save kv of seq 0 184 std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0)); 185 const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0); 186 if (ncopy != seq_store.size()) { 187 fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size()); 188 llama_free(ctx3); 189 llama_free_model(model); 190 return 1; 191 } 192 fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); 193 194 // erase whole kv 195 llama_kv_cache_clear(ctx3); 196 fprintf(stderr, "%s : kv cache cleared\n", __func__); 197 198 // restore kv into seq 1 199 const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1); 200 if (nset != seq_store.size()) { 201 fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size()); 202 llama_free(ctx3); 203 llama_free_model(model); 204 return 1; 205 } 206 fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset); 207 } 208 209 // third run with seq 1 instead of 0 210 for (auto i = 0; i < params.n_predict; i++) { 211 auto * logits = llama_get_logits(ctx3); 212 auto n_vocab = llama_n_vocab(model); 213 std::vector<llama_token_data> candidates; 214 candidates.reserve(n_vocab); 215 for (llama_token token_id = 0; token_id < n_vocab; token_id++) { 216 candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); 217 } 218 llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; 219 auto next_token = llama_sample_token(ctx3, &candidates_p); 220 auto next_token_str = llama_token_to_piece(ctx3, next_token); 221 222 printf("%s", next_token_str.c_str()); 223 result2 += next_token_str; 224 225 if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) { 226 fprintf(stderr, "\n%s : failed to evaluate\n", __func__); 227 llama_free(ctx3); 228 llama_free_model(model); 229 return 1; 230 } 231 n_past += 1; 232 } 233 234 printf("\n"); 235 236 llama_free(ctx3); 237 llama_free_model(model); 238 239 if (result0 != result2) { 240 fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__); 241 return 1; 242 } 243 244 fprintf(stderr, "\n%s : success\n", __func__); 245 246 return 0; 247 }