imatrix.cpp
1 #include "common.h" 2 #include "llama.h" 3 4 #include <cmath> 5 #include <cstdio> 6 #include <cstring> 7 #include <ctime> 8 #include <sstream> 9 #include <thread> 10 #include <mutex> 11 #include <vector> 12 #include <fstream> 13 #include <unordered_map> 14 #include <algorithm> 15 16 #if defined(_MSC_VER) 17 #pragma warning(disable: 4244 4267) // possible loss of data 18 #endif 19 20 static void print_usage(int argc, char ** argv, const gpt_params & params) { 21 gpt_params_print_usage(argc, argv, params); 22 23 LOG_TEE("\nexample usage:\n"); 24 LOG_TEE("\n %s \\\n" 25 " -m model.gguf -f some-text.txt [-o imatrix.dat] [--process-output] [--verbosity 1] \\\n" 26 " [--no-ppl] [--chunk 123] [--output-frequency 10] [--save-frequency 0] \\\n" 27 " [--in-file imatrix-prev-0.dat --in-file imatrix-prev-1.dat ...]\n" , argv[0]); 28 LOG_TEE("\n"); 29 } 30 31 struct Stats { 32 std::vector<float> values; 33 std::vector<int> counts; 34 int ncall = 0; 35 }; 36 37 class IMatrixCollector { 38 public: 39 IMatrixCollector() = default; 40 void set_params(gpt_params params) { m_params = std::move(params); } 41 bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data); 42 void save_imatrix(int ncall = -1) const; 43 bool load_imatrix(const char * file_name); 44 private: 45 std::unordered_map<std::string, Stats> m_stats; 46 gpt_params m_params; 47 std::mutex m_mutex; 48 int m_last_call = 0; 49 std::vector<float> m_src1_data; 50 std::vector<char> m_ids; // the expert ids from ggml_mul_mat_id 51 }; 52 53 // remove any prefix and suffixes from the name 54 // CUDA0#blk.0.attn_k.weight#0 => blk.0.attn_k.weight 55 static std::string filter_tensor_name(const char * name) { 56 std::string wname; 57 const char * p = strchr(name, '#'); 58 if (p != NULL) { 59 p = p + 1; 60 const char * q = strchr(p, '#'); 61 if (q != NULL) { 62 wname = std::string(p, q - p); 63 } else { 64 wname = p; 65 } 66 } else { 67 wname = name; 68 } 69 return wname; 70 } 71 72 bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) { 73 GGML_UNUSED(user_data); 74 75 const struct ggml_tensor * src0 = t->src[0]; 76 const struct ggml_tensor * src1 = t->src[1]; 77 std::string wname = filter_tensor_name(src0->name); 78 79 // when ask is true, the scheduler wants to know if we are interested in data from this tensor 80 // if we return true, a follow-up call will be made with ask=false in which we can do the actual collection 81 if (ask) { 82 if (t->op == GGML_OP_MUL_MAT_ID) return true; // collect all indirect matrix multiplications 83 if (t->op != GGML_OP_MUL_MAT) return false; 84 // why are small batches ignored (<16 tokens)? 85 if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return false; 86 if (!(wname.substr(0, 4) == "blk." || (m_params.process_output && wname == "output.weight"))) return false; 87 return true; 88 } 89 90 std::lock_guard<std::mutex> lock(m_mutex); 91 92 // copy the data from the GPU memory if needed 93 const bool is_host = ggml_backend_buffer_is_host(src1->buffer); 94 95 if (!is_host) { 96 m_src1_data.resize(ggml_nelements(src1)); 97 ggml_backend_tensor_get(src1, m_src1_data.data(), 0, ggml_nbytes(src1)); 98 } 99 100 const float * data = is_host ? (const float *) src1->data : m_src1_data.data(); 101 102 // this has been adapted to the new format of storing merged experts in a single 3d tensor 103 // ref: https://github.com/ggerganov/llama.cpp/pull/6387 104 if (t->op == GGML_OP_MUL_MAT_ID) { 105 // ids -> [n_experts_used, n_tokens] 106 // src1 -> [cols, n_expert_used, n_tokens] 107 const ggml_tensor * ids = t->src[2]; 108 const int n_as = src0->ne[2]; 109 const int n_ids = ids->ne[0]; 110 111 // the top-k selected expert ids are stored in the ids tensor 112 // for simplicity, always copy ids to host, because it is small 113 // take into account that ids is not contiguous! 114 115 GGML_ASSERT(ids->ne[1] == src1->ne[2]); 116 117 m_ids.resize(ggml_nbytes(ids)); 118 ggml_backend_tensor_get(ids, m_ids.data(), 0, ggml_nbytes(ids)); 119 120 auto & e = m_stats[wname]; 121 122 ++e.ncall; 123 124 if (e.values.empty()) { 125 e.values.resize(src1->ne[0]*n_as, 0); 126 e.counts.resize(src1->ne[0]*n_as, 0); 127 } 128 else if (e.values.size() != (size_t)src1->ne[0]*n_as) { 129 fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as); 130 exit(1); //GGML_ASSERT(false); 131 } 132 if (m_params.verbosity > 1) { 133 printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type); 134 } 135 // loop over all possible experts, regardless if they are used or not in the batch 136 for (int ex = 0; ex < n_as; ++ex) { 137 size_t e_start = ex*src1->ne[0]; 138 139 for (int idx = 0; idx < n_ids; ++idx) { 140 for (int row = 0; row < (int)src1->ne[2]; ++row) { 141 const int excur = *(const int32_t *) (m_ids.data() + row*ids->nb[1] + idx*ids->nb[0]); 142 143 GGML_ASSERT(excur >= 0 && excur < n_as); // sanity check 144 145 if (excur != ex) continue; 146 147 const int64_t i11 = idx % src1->ne[1]; 148 const int64_t i12 = row; 149 const float * x = (const float *)((const char *)data + i11*src1->nb[1] + i12*src1->nb[2]); 150 151 for (int j = 0; j < (int)src1->ne[0]; ++j) { 152 e.values[e_start + j] += x[j]*x[j]; 153 e.counts[e_start + j]++; 154 if (!std::isfinite(e.values[e_start + j])) { 155 fprintf(stderr, "%f detected in %s\n", e.values[e_start + j], wname.c_str()); 156 exit(1); 157 } 158 } 159 } 160 } 161 if (e.ncall > m_last_call) { 162 m_last_call = e.ncall; 163 if (m_last_call % m_params.n_out_freq == 0) { 164 save_imatrix(); 165 } 166 if (m_params.n_save_freq > 0 && m_last_call%m_params.n_save_freq == 0) { 167 save_imatrix(m_last_call); 168 } 169 } 170 } 171 } else { 172 auto & e = m_stats[wname]; 173 if (e.values.empty()) { 174 e.values.resize(src1->ne[0], 0); 175 e.counts.resize(src1->ne[0], 0); 176 } 177 else if (e.values.size() != (size_t)src1->ne[0]) { 178 fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]); 179 exit(1); //GGML_ASSERT(false); 180 } 181 ++e.ncall; 182 if (m_params.verbosity > 1) { 183 printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type); 184 } 185 for (int row = 0; row < (int)src1->ne[1]; ++row) { 186 const float * x = data + row * src1->ne[0]; 187 for (int j = 0; j < (int)src1->ne[0]; ++j) { 188 e.values[j] += x[j]*x[j]; 189 e.counts[j]++; 190 if (!std::isfinite(e.values[j])) { 191 fprintf(stderr, "%f detected in %s\n", e.values[j], wname.c_str()); 192 exit(1); 193 } 194 } 195 } 196 if (e.ncall > m_last_call) { 197 m_last_call = e.ncall; 198 if (m_last_call % m_params.n_out_freq == 0) { 199 save_imatrix(); 200 } 201 if (m_params.n_save_freq > 0 && m_last_call%m_params.n_save_freq == 0) { 202 save_imatrix(m_last_call); 203 } 204 } 205 } 206 207 return true; 208 } 209 210 void IMatrixCollector::save_imatrix(int ncall) const { 211 auto fname = m_params.out_file; 212 if (fname.empty()) { 213 fname = "imatrix.dat"; 214 } 215 216 if (ncall > 0) { 217 fname += ".at_"; 218 fname += std::to_string(ncall); 219 } 220 221 // avoid writing imatrix entries that do not have full data 222 // this can happen with MoE models where some of the experts end up not being exercised by the provided training data 223 224 int n_entries = 0; 225 std::vector<std::string> to_store; 226 227 bool is_first = true; // for printing 228 for (const auto & kv : m_stats) { 229 const int n_all = kv.second.counts.size(); 230 231 if (n_all == 0) { 232 continue; 233 } 234 235 int n_zeros = 0; 236 for (const int c : kv.second.counts) { 237 if (c == 0) { 238 n_zeros++; 239 } 240 } 241 242 if (n_zeros != 0 && is_first) { 243 fprintf(stderr, "\n"); 244 is_first = false; 245 } 246 247 if (n_zeros == n_all) { 248 fprintf(stderr, "%s: entry '%40s' has no data - skipping\n", __func__, kv.first.c_str()); 249 continue; 250 } 251 252 if (n_zeros > 0) { 253 fprintf(stderr, "%s: entry '%40s' has partial data (%.2f%%) - skipping\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all); 254 continue; 255 } 256 257 n_entries++; 258 to_store.push_back(kv.first); 259 } 260 261 if (to_store.size() < m_stats.size()) { 262 fprintf(stderr, "%s: warning: storing only %zu out of %zu entries\n", __func__, to_store.size(), m_stats.size()); 263 } 264 265 std::ofstream out(fname, std::ios::binary); 266 out.write((const char *) &n_entries, sizeof(n_entries)); 267 for (const auto & name : to_store) { 268 const auto & stat = m_stats.at(name); 269 int len = name.size(); 270 out.write((const char *) &len, sizeof(len)); 271 out.write(name.c_str(), len); 272 out.write((const char *) &stat.ncall, sizeof(stat.ncall)); 273 int nval = stat.values.size(); 274 out.write((const char *) &nval, sizeof(nval)); 275 if (nval > 0) { 276 std::vector<float> tmp(nval); 277 for (int i = 0; i < nval; i++) { 278 tmp[i] = (stat.values[i] / static_cast<float>(stat.counts[i])) * static_cast<float>(stat.ncall); 279 } 280 out.write((const char*)tmp.data(), nval*sizeof(float)); 281 } 282 } 283 284 // Write the number of call the matrix was computed with 285 out.write((const char *) &m_last_call, sizeof(m_last_call)); 286 287 // Write the input filename at the end of the file to later on specify it in quantize 288 { 289 int len = m_params.prompt_file.size(); 290 out.write((const char *) &len, sizeof(len)); 291 out.write(m_params.prompt_file.c_str(), len); 292 } 293 294 if (m_params.verbosity > 0) { 295 fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n", __func__, m_last_call, fname.c_str()); 296 } 297 } 298 299 bool IMatrixCollector::load_imatrix(const char * fname) { 300 std::ifstream in(fname, std::ios::binary); 301 if (!in) { 302 printf("%s: failed to open %s\n",__func__, fname); 303 return false; 304 } 305 int n_entries; 306 in.read((char*)&n_entries, sizeof(n_entries)); 307 if (in.fail() || n_entries < 1) { 308 printf("%s: no data in file %s\n", __func__, fname); 309 return false; 310 } 311 for (int i = 0; i < n_entries; ++i) { 312 int len; in.read((char *)&len, sizeof(len)); 313 std::vector<char> name_as_vec(len+1); 314 in.read((char *)name_as_vec.data(), len); 315 if (in.fail()) { 316 printf("%s: failed reading name for entry %d from %s\n",__func__,i+1, fname); 317 return false; 318 } 319 name_as_vec[len] = 0; 320 std::string name{name_as_vec.data()}; 321 auto & e = m_stats[std::move(name)]; 322 int ncall; 323 in.read((char*)&ncall, sizeof(ncall)); 324 int nval; 325 in.read((char *)&nval, sizeof(nval)); 326 if (in.fail() || nval < 1) { 327 printf("%s: failed reading number of values for entry %d\n",__func__,i); 328 m_stats = {}; 329 return false; 330 } 331 332 if (e.values.empty()) { 333 e.values.resize(nval, 0); 334 e.counts.resize(nval, 0); 335 } 336 337 std::vector<float> tmp(nval); 338 in.read((char*)tmp.data(), nval*sizeof(float)); 339 if (in.fail()) { 340 printf("%s: failed reading data for entry %d\n",__func__,i); 341 m_stats = {}; 342 return false; 343 } 344 345 // Recreate the state as expected by save_imatrix(), and corerct for weighted sum. 346 for (int i = 0; i < nval; i++) { 347 e.values[i] += tmp[i]; 348 e.counts[i] += ncall; 349 } 350 e.ncall += ncall; 351 352 } 353 return true; 354 } 355 356 static IMatrixCollector g_collector; 357 358 static bool ik_collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) { 359 return g_collector.collect_imatrix(t, ask, user_data); 360 } 361 362 363 struct results_log_softmax { 364 double log_softmax; 365 float logit; 366 float prob; 367 }; 368 369 static std::vector<float> softmax(const std::vector<float> & logits) { 370 std::vector<float> probs(logits.size()); 371 float max_logit = logits[0]; 372 for (float v : logits) { 373 max_logit = std::max(max_logit, v); 374 } 375 double sum_exp = 0.0; 376 for (size_t i = 0; i < logits.size(); i++) { 377 // Subtract the maximum logit value from the current logit value for numerical stability 378 const float logit = logits[i] - max_logit; 379 const float exp_logit = expf(logit); 380 sum_exp += exp_logit; 381 probs[i] = exp_logit; 382 } 383 for (size_t i = 0; i < probs.size(); i++) { 384 probs[i] /= sum_exp; 385 } 386 return probs; 387 } 388 389 static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) { 390 float max_logit = logits[0]; 391 for (int i = 1; i < n_vocab; ++i) { 392 max_logit = std::max(max_logit, logits[i]); 393 } 394 double sum_exp = 0.0; 395 for (int i = 0; i < n_vocab; ++i) { 396 sum_exp += expf(logits[i] - max_logit); 397 } 398 return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp}; 399 } 400 401 static void process_logits( 402 int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers, 403 double & nll, double & nll2, float * logit_history, float * prob_history) { 404 std::mutex mutex; 405 int counter = 0; 406 auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () { 407 double local_nll = 0; 408 double local_nll2 = 0; 409 while (true) { 410 std::unique_lock<std::mutex> lock(mutex); 411 int i = counter++; 412 if (i >= n_token) { 413 nll += local_nll; nll2 += local_nll2; 414 break; 415 } 416 lock.unlock(); 417 const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]); 418 const double v = -results.log_softmax; 419 local_nll += v; 420 local_nll2 += v*v; 421 422 logit_history[i] = results.logit; 423 prob_history[i] = results.prob; 424 } 425 }; 426 for (auto & w : workers) { 427 w = std::thread(compute); 428 } 429 compute(); 430 for (auto & w : workers) { 431 w.join(); 432 } 433 } 434 435 static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { 436 const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); 437 GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); 438 const int n_ctx = llama_n_ctx(ctx); 439 440 auto tim1 = std::chrono::high_resolution_clock::now(); 441 fprintf(stderr, "%s: tokenizing the input ..\n", __func__); 442 443 std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true); 444 445 auto tim2 = std::chrono::high_resolution_clock::now(); 446 fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count()); 447 448 if (params.i_chunk > 0) { 449 if (size_t((params.i_chunk + 2)*n_ctx) >= tokens.size()) { 450 fprintf(stderr, "%s: there will be not enough tokens left after removing %d chunks\n", __func__, params.i_chunk); 451 return false; 452 } 453 fprintf(stderr, "%s: removing initial %d chunks (%d tokens)\n", __func__, params.i_chunk, params.i_chunk*n_ctx); 454 tokens.erase(tokens.begin(), tokens.begin() + params.i_chunk*n_ctx); 455 } 456 457 if (int(tokens.size()) < 2*n_ctx) { 458 fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n",__func__,2*n_ctx, 459 n_ctx); 460 fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size()); 461 return false; 462 } 463 464 std::vector<float> logit_history; 465 std::vector<float> prob_history; 466 467 if (params.compute_ppl) { 468 logit_history.resize(tokens.size()); 469 prob_history.resize(tokens.size()); 470 } 471 472 const int n_chunk_max = tokens.size() / n_ctx; 473 474 const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); 475 const int n_vocab = llama_n_vocab(llama_get_model(ctx)); 476 const int n_batch = params.n_batch; 477 478 int count = 0; 479 double nll = 0.0; 480 double nll2 = 0.0; 481 482 fprintf(stderr, "%s: computing over %d chunks with batch_size %d\n", __func__, n_chunk, n_batch); 483 484 std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1); 485 486 const int num_batches = (n_ctx + n_batch - 1) / n_batch; 487 488 std::vector<float> logits; 489 if (params.compute_ppl && num_batches > 1) { 490 logits.reserve((size_t)n_ctx * n_vocab); 491 } 492 493 for (int i = 0; i < n_chunk; ++i) { 494 const int start = i * n_ctx; 495 const int end = start + n_ctx; 496 497 std::vector<float> logits; 498 499 const auto t_start = std::chrono::high_resolution_clock::now(); 500 501 // clear the KV cache 502 llama_kv_cache_clear(ctx); 503 504 for (int j = 0; j < num_batches; ++j) { 505 const int batch_start = start + j * n_batch; 506 const int batch_size = std::min(end - batch_start, n_batch); 507 508 // save original token and restore it after eval 509 const auto token_org = tokens[batch_start]; 510 511 // add BOS token for the first batch of each chunk 512 if (add_bos && j == 0) { 513 tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); 514 } 515 516 // TODO: use batch.logits to save computations instead of relying on logits_all == true 517 if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { 518 fprintf(stderr, "%s : failed to eval\n", __func__); 519 return false; 520 } 521 522 // restore the original token in case it was set to BOS 523 tokens[batch_start] = token_org; 524 525 if (params.compute_ppl && num_batches > 1) { 526 const auto * batch_logits = llama_get_logits(ctx); 527 logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); 528 } 529 } 530 531 const auto t_end = std::chrono::high_resolution_clock::now(); 532 533 if (i == 0) { 534 const float t_total = std::chrono::duration<float>(t_end - t_start).count(); 535 fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); 536 int total_seconds = (int)(t_total * n_chunk); 537 if (total_seconds >= 60*60) { 538 fprintf(stderr, "%d hours ", total_seconds / (60*60)); 539 total_seconds = total_seconds % (60*60); 540 } 541 fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0); 542 } 543 544 if (params.compute_ppl) { 545 const int first = n_ctx/2; 546 const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); 547 process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, 548 workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); 549 count += n_ctx - first - 1; 550 551 printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); 552 fflush(stdout); 553 554 logits.clear(); 555 } 556 } 557 printf("\n"); 558 559 if (params.compute_ppl) { 560 nll2 /= count; 561 nll /= count; 562 const double ppl = exp(nll); 563 nll2 -= nll * nll; 564 if (nll2 > 0) { 565 nll2 = sqrt(nll2/(count-1)); 566 printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl); 567 } else { 568 printf("Unexpected negative standard deviation of log(prob)\n"); 569 } 570 } 571 572 return true; 573 } 574 575 int main(int argc, char ** argv) { 576 gpt_params params; 577 578 params.n_ctx = 512; 579 params.logits_all = true; 580 params.verbosity = 1; 581 582 if (!gpt_params_parse(argc, argv, params)) { 583 print_usage(argc, argv, params); 584 return 1; 585 } 586 587 params.n_batch = std::min(params.n_batch, params.n_ctx); 588 589 g_collector.set_params(params); 590 591 for (const auto & in_file : params.in_files) { 592 printf("%s : loading imatrix from '%s'\n", __func__, in_file.c_str()); 593 if (!g_collector.load_imatrix(in_file.c_str())) { 594 fprintf(stderr, "%s : failed to load %s\n", __func__, in_file.c_str()); 595 return 1; 596 } 597 } 598 599 if (params.in_files.size() > 1) { 600 printf("%s : saving combined imatrix to '%s'\n", __func__, params.out_file.c_str()); 601 g_collector.save_imatrix(); 602 } 603 604 llama_backend_init(); 605 llama_numa_init(params.numa); 606 607 // pass the callback to the backend scheduler 608 // it will be executed for each node during the graph computation 609 params.cb_eval = ik_collect_imatrix; 610 params.cb_eval_user_data = NULL; 611 params.warmup = false; 612 613 // init 614 llama_model * model; 615 llama_context * ctx; 616 617 std::tie(model, ctx) = llama_init_from_gpt_params(params); 618 if (model == nullptr || ctx == nullptr) { 619 fprintf(stderr, "%s : failed to init\n", __func__); 620 return 1; 621 } 622 623 const int n_ctx_train = llama_n_ctx_train(model); 624 if (params.n_ctx > n_ctx_train) { 625 fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", 626 __func__, n_ctx_train, params.n_ctx); 627 } 628 629 // print system information 630 { 631 fprintf(stderr, "\n"); 632 fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str()); 633 } 634 635 if (!compute_imatrix(ctx, params)) { 636 return 1; 637 } 638 639 g_collector.save_imatrix(); 640 641 llama_print_timings(ctx); 642 643 llama_free(ctx); 644 llama_free_model(model); 645 646 llama_backend_free(); 647 648 return 0; 649 }