/ examples / perplexity / perplexity.cpp
perplexity.cpp
   1  #include "common.h"
   2  #include "llama.h"
   3  
   4  #include <cmath>
   5  #include <cstdio>
   6  #include <cstring>
   7  #include <ctime>
   8  #include <sstream>
   9  #include <thread>
  10  #include <mutex>
  11  #include <atomic>
  12  #include <vector>
  13  #include <array>
  14  #include <fstream>
  15  #include <sstream>
  16  
  17  #if defined(_MSC_VER)
  18  #pragma warning(disable: 4244 4267) // possible loss of data
  19  #endif
  20  
  21  struct results_perplexity {
  22      std::vector<llama_token> tokens;
  23      double                   ppl_value;
  24      std::vector<float>       logits;
  25      std::vector<float>       probs;
  26  };
  27  
  28  struct results_log_softmax {
  29      double log_softmax;
  30      float  logit;
  31      float  prob;
  32  };
  33  
  34  static void write_logfile(
  35      const llama_context * ctx, const gpt_params & params, const llama_model * model,
  36      const struct results_perplexity & results
  37  ) {
  38      if (params.logdir.empty()) {
  39          return;
  40      }
  41  
  42      if (params.hellaswag) {
  43          fprintf(stderr, "%s: warning: logging results is not implemented for HellaSwag. No files will be written.\n", __func__);
  44          return;
  45      }
  46  
  47      const std::string timestamp = string_get_sortable_timestamp();
  48  
  49      const bool success = fs_create_directory_with_parents(params.logdir);
  50      if (!success) {
  51          fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
  52                  __func__, params.logdir.c_str());
  53          return;
  54      }
  55  
  56      const std::string logfile_path = params.logdir + timestamp + ".yml";
  57      FILE * logfile = fopen(logfile_path.c_str(), "w");
  58  
  59      if (logfile == NULL) {
  60          fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str());
  61          return;
  62      }
  63  
  64      fprintf(logfile, "binary: main\n");
  65      char model_desc[128];
  66      llama_model_desc(model, model_desc, sizeof(model_desc));
  67      yaml_dump_non_result_info(logfile, params, ctx, timestamp, results.tokens, model_desc);
  68  
  69      fprintf(logfile, "\n");
  70      fprintf(logfile, "######################\n");
  71      fprintf(logfile, "# Perplexity Results #\n");
  72      fprintf(logfile, "######################\n");
  73      fprintf(logfile, "\n");
  74  
  75      yaml_dump_vector_float(logfile, "logits", results.logits);
  76      fprintf(logfile, "ppl_value: %f\n", results.ppl_value);
  77      yaml_dump_vector_float(logfile, "probs", results.probs);
  78  
  79      llama_dump_timing_info_yaml(logfile, ctx);
  80      fclose(logfile);
  81  }
  82  
  83  static std::vector<float> softmax(const std::vector<float>& logits) {
  84      std::vector<float> probs(logits.size());
  85      float max_logit = logits[0];
  86      for (float v : logits) {
  87          max_logit = std::max(max_logit, v);
  88      }
  89      double sum_exp = 0.0;
  90      for (size_t i = 0; i < logits.size(); i++) {
  91          // Subtract the maximum logit value from the current logit value for numerical stability
  92          const float logit = logits[i] - max_logit;
  93          const float exp_logit = expf(logit);
  94          sum_exp += exp_logit;
  95          probs[i] = exp_logit;
  96      }
  97      for (size_t i = 0; i < probs.size(); i++) {
  98          probs[i] /= sum_exp;
  99      }
 100      return probs;
 101  }
 102  
 103  static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) {
 104      float max_logit = logits[0];
 105      for (int i = 1; i < n_vocab; ++i) {
 106          max_logit = std::max(max_logit, logits[i]);
 107      }
 108      double sum_exp = 0.0;
 109      for (int i = 0; i < n_vocab; ++i) {
 110          sum_exp += expf(logits[i] - max_logit);
 111      }
 112      return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
 113  }
 114  
 115  static inline int nearest_int(float fval) {
 116      //assert(fval <= 4194303.f);
 117      float val = fval + 12582912.f;
 118      int i; memcpy(&i, &val, sizeof(int));
 119      return (i & 0x007fffff) - 0x00400000;
 120  }
 121  
 122  static double log_softmax(int n_vocab, const float * logits, uint16_t * log_prob, int tok) {
 123      float max_logit = logits[0];
 124      float min_logit = logits[0];
 125      for (int i = 1; i < n_vocab; ++i) {
 126          max_logit = std::max(max_logit, logits[i]);
 127          min_logit = std::min(min_logit, logits[i]);
 128      }
 129      min_logit = std::max(min_logit, max_logit - 16);
 130      double sum_exp = 0.0;
 131      for (int i = 0; i < n_vocab; ++i) {
 132          sum_exp += expf(logits[i] - max_logit);
 133      }
 134      const float log_sum_exp = log(sum_exp);
 135      const float min_log_prob = min_logit - max_logit - log_sum_exp;
 136      const float scale = (max_logit - min_logit)/65535.f;
 137      float * d = (float *)log_prob;
 138      d[0] = scale;
 139      d[1] = min_log_prob;
 140      log_prob += 4;
 141      if (scale) {
 142          const float inv_scale = 1/scale;
 143          for (int i = 0; i < n_vocab; ++i) {
 144              log_prob[i] = logits[i] > min_logit ? nearest_int(inv_scale*(logits[i] - min_logit)) : 0;
 145          }
 146      } else {
 147          std::memset(log_prob, 0, n_vocab*sizeof(uint16_t));
 148      }
 149      return max_logit + log_sum_exp - logits[tok];
 150  }
 151  
 152  static void process_logits(
 153      int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
 154      double & nll, double & nll2, float * logit_history, float * prob_history
 155  ) {
 156      std::mutex mutex;
 157      int counter = 0;
 158      auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
 159          double local_nll  = 0;
 160          double local_nll2 = 0;
 161          while (true) {
 162              std::unique_lock<std::mutex> lock(mutex);
 163              int i = counter++;
 164              if (i >= n_token) {
 165                  nll += local_nll; nll2 += local_nll2;
 166                  break;
 167              }
 168              lock.unlock();
 169              const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
 170              const double v = -results.log_softmax;
 171              local_nll += v;
 172              local_nll2 += v*v;
 173  
 174              logit_history[i] = results.logit;
 175              prob_history[i]  = results.prob;
 176          }
 177      };
 178      for (auto & w : workers) {
 179          w = std::thread(compute);
 180      }
 181      compute();
 182      for (auto & w : workers) {
 183          w.join();
 184      }
 185  }
 186  
 187  static void process_logits(std::ostream& out, int n_vocab, const float * logits, const int * tokens, int n_token,
 188          std::vector<std::thread> & workers, std::vector<uint16_t> & log_probs, double & nll, double & nll2) {
 189      std::mutex mutex;
 190      const int nv = 2*((n_vocab + 1)/2) + 4;
 191      int counter = 0;
 192      auto compute = [&mutex, &counter, &log_probs, &nll, &nll2, n_vocab, logits, tokens, n_token, nv] () {
 193          double local_nll  = 0;
 194          double local_nll2 = 0;
 195          while (true) {
 196              std::unique_lock<std::mutex> lock(mutex);
 197              int i = counter++;
 198              if (i >= n_token) {
 199                  nll += local_nll; nll2 += local_nll2;
 200                  break;
 201              }
 202              lock.unlock();
 203              const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
 204              local_nll += v;
 205              local_nll2 += v*v;
 206          }
 207      };
 208      for (auto & w : workers) {
 209          w = std::thread(compute);
 210      }
 211      compute();
 212      for (auto & w : workers) {
 213          w.join();
 214      }
 215      out.write((const char *)log_probs.data(), n_token*nv*sizeof(uint16_t));
 216  }
 217  
 218  struct kl_divergence_result {
 219      double sum_nll          = 0.0;
 220      double sum_nll2         = 0.0;
 221      double sum_nll_base     = 0.0;
 222      double sum_nll_base2    = 0.0;
 223      double sum_nll_nll_base = 0.0;
 224      double sum_kld          = 0.0;
 225      double sum_kld2         = 0.0;
 226      double sum_p_diff       = 0.0;
 227      double sum_p_diff2      = 0.0;
 228      double sum_p_diff4      = 0.0;
 229      float  max_p_diff       = 0.0f;
 230      size_t n_same_top       = 0.0;
 231      size_t count            = 0.0;
 232  };
 233  
 234  static std::pair<double, float> log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
 235      float max_logit = logits[0];
 236      int imax = 0;
 237      for (int i = 1; i < n_vocab; ++i) {
 238          if (logits[i] > max_logit) {
 239              max_logit = logits[i];
 240              imax = i;
 241          }
 242      }
 243      double sum_exp = 0.0;
 244      for (int i = 0; i < n_vocab; ++i) {
 245          sum_exp += expf(logits[i] - max_logit);
 246      }
 247      const float log_sum_exp = log(sum_exp);
 248      const float * d = (const float *)base_log_prob;
 249      const float scale = d[0];
 250      const float min_log_prob = d[1];
 251      base_log_prob += 4;
 252  
 253      const float nll = max_logit + log_sum_exp - logits[tok];
 254      kld.sum_nll  += nll;
 255      kld.sum_nll2 += nll*nll;
 256  
 257      const float nll_base = -(scale*base_log_prob[tok] + min_log_prob);
 258      kld.sum_nll_base  += nll_base;
 259      kld.sum_nll_base2 += nll_base*nll_base;
 260  
 261      kld.sum_nll_nll_base += nll*nll_base;
 262  
 263      max_logit += log_sum_exp;
 264      double sum = 0;
 265      int imax_base = -1;
 266      float p_log_base_max = 0;
 267      for (int i = 0; i < n_vocab; ++i) {
 268          const float p_log_base = scale*base_log_prob[i] + min_log_prob;
 269          if (i == 0 || p_log_base > p_log_base_max) {
 270              p_log_base_max = p_log_base;
 271              imax_base = i;
 272          }
 273          if (p_log_base > -16.f) {
 274              const float p_base = expf(p_log_base);
 275              sum += p_base * (p_log_base - logits[i] + max_logit);
 276          }
 277      }
 278      kld.sum_kld  += sum;
 279      kld.sum_kld2 += sum*sum;
 280      ++kld.count;
 281      if (imax == imax_base) ++kld.n_same_top;
 282  
 283      const float p_base = expf(-nll_base);
 284      const float p = expf(-nll);
 285      const float p_diff = p - p_base;
 286      kld.sum_p_diff  += p_diff;
 287      const double p_diff2 = p_diff*p_diff;
 288      kld.sum_p_diff2 += p_diff2;
 289      kld.sum_p_diff4 += p_diff2*p_diff2;
 290      kld.max_p_diff = std::max(kld.max_p_diff, std::fabs(p_diff));
 291  
 292      return std::make_pair(sum, p_diff);
 293  }
 294  
 295  static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token,
 296          std::vector<std::thread> & workers, const std::vector<uint16_t> & base_log_probs, kl_divergence_result & kld,
 297          float * kld_values, float * p_diff_values) {
 298      std::mutex mutex;
 299      const int nv = 2*((n_vocab + 1)/2) + 4;
 300      int counter = 0;
 301      auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv, kld_values, p_diff_values] () {
 302          kl_divergence_result local_kld;
 303          while (true) {
 304              std::unique_lock<std::mutex> lock(mutex);
 305              int i = counter++;
 306              if (i >= n_token) {
 307                  kld.sum_nll          += local_kld.sum_nll;
 308                  kld.sum_nll2         += local_kld.sum_nll2;
 309                  kld.sum_nll_base     += local_kld.sum_nll_base;
 310                  kld.sum_nll_base2    += local_kld.sum_nll_base2;
 311                  kld.sum_nll_nll_base += local_kld.sum_nll_nll_base;
 312                  kld.sum_kld          += local_kld.sum_kld;
 313                  kld.sum_kld2         += local_kld.sum_kld2;
 314                  kld.sum_p_diff       += local_kld.sum_p_diff;
 315                  kld.sum_p_diff2      += local_kld.sum_p_diff2;
 316                  kld.sum_p_diff4      += local_kld.sum_p_diff4;
 317                  kld.n_same_top       += local_kld.n_same_top;
 318                  kld.max_p_diff        = std::max(kld.max_p_diff, local_kld.max_p_diff);
 319                  kld.count            += local_kld.count;
 320                  break;
 321              }
 322              lock.unlock();
 323              std::pair<double, float> v = log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
 324              kld_values[i]    = (float)v.first;
 325              p_diff_values[i] = v.second;
 326          }
 327      };
 328      for (auto & w : workers) {
 329          w = std::thread(compute);
 330      }
 331      compute();
 332      for (auto & w : workers) {
 333          w.join();
 334      }
 335  }
 336  
 337  static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
 338      // Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
 339      // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
 340      // Output: `perplexity: 13.5106 [114/114]`
 341      // BOS tokens will be added for each chunk before eval
 342  
 343      const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
 344      GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1);
 345  
 346      fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
 347  
 348      std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true);
 349  
 350      const int n_ctx = llama_n_ctx(ctx);
 351  
 352      if (int(tokens.size()) < 2*n_ctx) {
 353          fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
 354                  n_ctx);
 355          fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
 356          return {std::move(tokens), 0., {}, {}};
 357      }
 358  
 359      std::vector<float> logit_history;
 360      std::vector<float> prob_history;
 361  
 362      logit_history.resize(tokens.size());
 363      prob_history.resize(tokens.size());
 364  
 365      if (params.ppl_stride <= 0) {
 366          fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
 367          return {tokens, -1, logit_history, prob_history};
 368      }
 369  
 370      const int calc_chunk = n_ctx;
 371  
 372      fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
 373  
 374      if (int(tokens.size()) <= calc_chunk) {
 375          fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
 376                  tokens.size(), n_ctx, params.ppl_stride);
 377          return {tokens, -1, logit_history, prob_history};
 378      }
 379  
 380      const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1)  / params.ppl_stride;
 381  
 382      const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
 383      const int n_vocab = llama_n_vocab(llama_get_model(ctx));
 384      const int n_batch = params.n_batch;
 385  
 386      int count = 0;
 387      double nll = 0.0;
 388  
 389      fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
 390  
 391      for (int i = 0; i < n_chunk; ++i) {
 392          const int start =     i * params.ppl_stride;
 393          const int end   = start + calc_chunk;
 394  
 395          const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
 396          //fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);
 397  
 398          std::vector<float> logits;
 399  
 400          const auto t_start = std::chrono::high_resolution_clock::now();
 401  
 402          // clear the KV cache
 403          llama_kv_cache_clear(ctx);
 404  
 405          for (int j = 0; j < num_batches; ++j) {
 406              const int batch_start = start + j * n_batch;
 407              const int batch_size  = std::min(end - batch_start, n_batch);
 408  
 409              //fprintf(stderr, "    Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
 410              // TODO: use llama_batch.logits instead of relying on logits_all == true
 411              if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
 412                  //fprintf(stderr, "%s : failed to eval\n", __func__);
 413                  return {tokens, -1, logit_history, prob_history};
 414              }
 415  
 416              // save original token and restore it after eval
 417              const auto token_org = tokens[batch_start];
 418  
 419              // add BOS token for the first batch of each chunk
 420              if (add_bos && j == 0) {
 421                  tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
 422              }
 423  
 424              const auto batch_logits = llama_get_logits(ctx);
 425              logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
 426  
 427              if (j == 0) {
 428                  tokens[batch_start] = token_org;
 429              }
 430          }
 431  
 432          const auto t_end = std::chrono::high_resolution_clock::now();
 433  
 434          if (i == 0) {
 435              const float t_total = std::chrono::duration<float>(t_end - t_start).count();
 436              fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
 437              int total_seconds = (int)(t_total * n_chunk);
 438              if (total_seconds >= 60*60) {
 439                  fprintf(stderr, "%d hours ", total_seconds / (60*60));
 440                  total_seconds = total_seconds % (60*60);
 441              }
 442              fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
 443          }
 444  
 445          //fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
 446          for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) {
 447  
 448              // Calculate probability of next token, given the previous ones.
 449              const std::vector<float> tok_logits(
 450                  logits.begin() + (j + 0) * n_vocab,
 451                  logits.begin() + (j + 1) * n_vocab);
 452  
 453              const float prob = softmax(tok_logits)[tokens[start + j + 1]];
 454              logit_history[start + j + 1] = tok_logits[tokens[start + j + 1]];
 455              prob_history[start + j + 1]  = prob;
 456  
 457              nll += -std::log(prob);
 458              ++count;
 459          }
 460          // perplexity is e^(average negative log-likelihood)
 461          if (params.ppl_output_type == 0) {
 462              printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
 463          } else {
 464              printf("%8d  %.4lf\n", i*params.ppl_stride, std::exp(nll / count));
 465          }
 466          fflush(stdout);
 467      }
 468      printf("\n");
 469  
 470      return {tokens, std::exp(nll / count), logit_history, prob_history};
 471  }
 472  
 473  static results_perplexity perplexity(llama_context * ctx, const gpt_params & params, const int32_t n_ctx) {
 474      if (params.ppl_stride > 0) {
 475          return perplexity_v2(ctx, params);
 476      }
 477  
 478      // Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
 479      // Run `./llama-perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
 480      // Output: `perplexity: 13.5106 [114/114]`
 481      // BOS tokens will be added for each chunk before eval
 482  
 483      const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
 484      GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1);
 485  
 486      std::ofstream logits_stream;
 487      if (!params.logits_file.empty()) {
 488          logits_stream.open(params.logits_file.c_str(), std::ios::binary);
 489          if (!logits_stream.is_open()) {
 490              fprintf(stderr, "%s: failed to open %s for writing\n", __func__, params.logits_file.c_str());
 491              return {};
 492          }
 493          fprintf(stderr, "%s: saving all logits to %s\n", __func__, params.logits_file.c_str());
 494          logits_stream.write("_logits_", 8);
 495          logits_stream.write(reinterpret_cast<const char *>(&n_ctx), sizeof(n_ctx));
 496      }
 497  
 498      auto tim1 = std::chrono::high_resolution_clock::now();
 499      fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
 500  
 501      std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true);
 502  
 503      auto tim2 = std::chrono::high_resolution_clock::now();
 504      fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
 505  
 506      if (int(tokens.size()) < 2*n_ctx) {
 507          fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
 508                  n_ctx);
 509          fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
 510          return {std::move(tokens), 0., {}, {}};
 511      }
 512  
 513      std::vector<float> logit_history;
 514      logit_history.resize(tokens.size());
 515  
 516      std::vector<float> prob_history;
 517      prob_history.resize(tokens.size());
 518  
 519      const int n_chunk_max = tokens.size() / n_ctx;
 520  
 521      const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
 522      const int n_vocab = llama_n_vocab(llama_get_model(ctx));
 523      const int n_batch = params.n_batch;
 524  
 525      int count = 0;
 526      double nll = 0.0;
 527      double nll2 = 0.0;
 528  
 529      const int num_batches = (n_ctx + n_batch - 1) / n_batch;
 530      const int n_seq = std::max(1, n_batch / n_ctx);
 531  
 532      GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
 533      GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
 534  
 535      llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
 536  
 537      std::vector<float> logits;
 538      if (num_batches > 1) {
 539          logits.reserve((size_t)n_ctx * n_vocab);
 540      }
 541  
 542      fprintf(stderr, "%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
 543  
 544      std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
 545  
 546      std::vector<uint16_t> log_probs;
 547      if (!params.logits_file.empty()) {
 548          logits_stream.write((const char *)&n_vocab, sizeof(n_vocab));
 549          logits_stream.write((const char *)&n_chunk, sizeof(n_chunk));
 550          logits_stream.write((const char *)tokens.data(), n_chunk*n_ctx*sizeof(tokens[0]));
 551          const int nv = 2*((n_vocab + 1)/2) + 4;
 552          log_probs.resize(n_ctx * nv);
 553      }
 554  
 555      // We get the logits for all the tokens in the context window (params.n_ctx)
 556      // from llama_eval above.  Now, based on https://huggingface.co/docs/transformers/perplexity,
 557      // calculate the perplexity over the last half of the window (so the model always has
 558      // some context to predict the token).
 559      //
 560      // We rely on the fact that attention in the forward pass only looks at previous
 561      // tokens here, so the logits returned for each token are an accurate representation
 562      // of what the model would have predicted at that point.
 563      //
 564      // Example, we have a context window of 512, we will compute perplexity for each of the
 565      // last 256 tokens.  Then, we split the input up into context window size chunks to
 566      // process the entire prompt.
 567      const int first = n_ctx/2;
 568  
 569      for (int i = 0; i < n_chunk; i += n_seq) {
 570          const int start =     i * n_ctx;
 571          const int end   = start + n_ctx;
 572  
 573          const int n_seq_batch = std::min(n_seq, n_chunk - i);
 574  
 575          const auto t_start = std::chrono::high_resolution_clock::now();
 576  
 577          // clear the KV cache
 578          llama_kv_cache_clear(ctx);
 579  
 580          for (int j = 0; j < num_batches; ++j) {
 581              const int batch_start = start + j * n_batch;
 582              const int batch_size  = std::min(end - batch_start, n_batch);
 583  
 584              int n_outputs = 0;
 585  
 586              batch.n_tokens = 0;
 587              for (int seq = 0; seq < n_seq_batch; seq++) {
 588                  int seq_start = batch_start + seq*n_ctx;
 589  
 590                  // save original token and restore it after eval
 591                  const auto token_org = tokens[seq_start];
 592  
 593                  // add BOS token for the first batch of each chunk
 594                  if (add_bos && j == 0) {
 595                      tokens[seq_start] = llama_token_bos(llama_get_model(ctx));
 596                  }
 597  
 598                  for (int k = 0; k < batch_size; ++k) {
 599                      const int idx = seq*n_ctx + k;
 600                      batch.token   [idx]    = tokens[seq_start + k];
 601                      batch.pos     [idx]    = j*n_batch + k;
 602                      batch.n_seq_id[idx]    = 1;
 603                      batch.seq_id  [idx][0] = seq;
 604                      batch.logits  [idx]    = batch.pos[idx] >= first ? 1 : 0;
 605  
 606                      n_outputs += batch.logits[idx] != 0;
 607                  }
 608                  batch.n_tokens += batch_size;
 609  
 610                  // restore the original token in case it was set to BOS
 611                  tokens[seq_start] = token_org;
 612              }
 613  
 614              if (llama_decode(ctx, batch)) {
 615                  fprintf(stderr, "%s : failed to eval\n", __func__);
 616                  return {tokens, -1, logit_history, prob_history};
 617              }
 618  
 619              if (num_batches > 1 && n_outputs > 0) {
 620                  const auto * batch_logits = llama_get_logits(ctx);
 621                  logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab);
 622              }
 623          }
 624  
 625  
 626          if (i == 0) {
 627              llama_synchronize(ctx);
 628              const auto t_end = std::chrono::high_resolution_clock::now();
 629              const float t_total = std::chrono::duration<float>(t_end - t_start).count();
 630              fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
 631              int total_seconds = (int)(t_total*n_chunk/n_seq);
 632              if (total_seconds >= 60*60) {
 633                  fprintf(stderr, "%d hours ", total_seconds / (60*60));
 634                  total_seconds = total_seconds % (60*60);
 635              }
 636              fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
 637          }
 638  
 639          for (int seq = 0; seq < n_seq_batch; seq++) {
 640              const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
 641  
 642              llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
 643              if (!params.logits_file.empty()) {
 644                  process_logits(logits_stream, n_vocab, all_logits,
 645                          tokens_data, n_ctx - 1 - first,
 646                          workers, log_probs, nll, nll2);
 647              } else {
 648                  process_logits(n_vocab, all_logits,
 649                          tokens_data, n_ctx - 1 - first,
 650                          workers, nll, nll2,
 651                          logit_history.data() + start + seq*n_ctx + first,
 652                          prob_history.data()  + start + seq*n_ctx + first);
 653              }
 654              count += n_ctx - first - 1;
 655  
 656              // perplexity is e^(average negative log-likelihood)
 657              if (params.ppl_output_type == 0) {
 658                  printf("[%d]%.4lf,", i + seq + 1, std::exp(nll / count));
 659              } else {
 660                  double av = nll/count;
 661                  double av2 = nll2/count - av*av;
 662                  if (av2 > 0) av2 = sqrt(av2/(count-1));
 663                  printf("%8d  %.4lf  %4lf  %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
 664              }
 665          }
 666          fflush(stdout);
 667  
 668          logits.clear();
 669      }
 670      printf("\n");
 671  
 672      nll2 /= count;
 673      nll /= count;
 674      const double ppl = exp(nll);
 675      nll2 -= nll * nll;
 676      if (nll2 > 0) {
 677          nll2 = sqrt(nll2/(count-1));
 678          printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
 679      } else {
 680          printf("Unexpected negative standard deviation of log(prob)\n");
 681      }
 682  
 683      llama_batch_free(batch);
 684  
 685      return {tokens, ppl, logit_history, prob_history};
 686  }
 687  
 688  static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
 689      int prev_outputs = 0;
 690      for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
 691          const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
 692  
 693          llama_batch batch_view = {
 694              n_tokens,
 695              batch.token    + i,
 696              nullptr,
 697              batch.pos      + i,
 698              batch.n_seq_id + i,
 699              batch.seq_id   + i,
 700              batch.logits   + i,
 701              0, 0, 0, // unused
 702          };
 703  
 704          const int ret = llama_decode(ctx, batch_view);
 705          if (ret != 0) {
 706              LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
 707              return false;
 708          }
 709  
 710          int n_outputs = 0;
 711          for (int i = 0; i < n_tokens; ++i) {
 712              n_outputs += batch_view.logits[i] != 0;
 713          }
 714  
 715          memcpy(batch_logits.data() + prev_outputs*n_vocab, llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float));
 716  
 717          prev_outputs += n_outputs;
 718      }
 719  
 720      return true;
 721  }
 722  
 723  #define K_TOKEN_CHUNK 4
 724  
 725  static void compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
 726          const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) {
 727      if (eval_results.size() != eval_pairs.size()) {
 728          eval_results.resize(eval_pairs.size());
 729      }
 730      if (eval_pairs.empty()) return;
 731  
 732      size_t max_threads = std::min((eval_pairs.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK, workers.size());
 733  
 734      std::atomic<int> counter(0);
 735      auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () {
 736          float local_logprobs[K_TOKEN_CHUNK];
 737          while (true) {
 738              size_t first = counter.fetch_add(K_TOKEN_CHUNK, std::memory_order_relaxed);
 739              if (first >= eval_results.size()) break;
 740              size_t last = std::min(first + K_TOKEN_CHUNK, eval_results.size());
 741              for (size_t i = first; i < last; ++i) {
 742                  auto logits = batch_logits + eval_pairs[i].first * n_vocab;
 743                  float max_logit = logits[0];
 744                  for (int j = 1; j < n_vocab; ++j) {
 745                      max_logit = std::max(max_logit, logits[j]);
 746                  }
 747                  float sum_p = 0.f;
 748                  for (int j = 0; j < n_vocab; ++j) {
 749                      sum_p += expf(logits[j] - max_logit);
 750                  }
 751                  local_logprobs[i - first] = logits[eval_pairs[i].second] - max_logit - std::log(sum_p);
 752              }
 753              std::memcpy(eval_results.data() + first, local_logprobs, (last - first)*sizeof(float));
 754          }
 755      };
 756  
 757      for (size_t it = 0; it < max_threads; ++it) {
 758          workers[it] = std::thread(compute);
 759      }
 760      for (size_t it = 0; it < max_threads; ++it) {
 761          workers[it].join();
 762      }
 763  }
 764  
 765  static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
 766      // Calculates hellaswag score (acc_norm) from prompt
 767      //
 768      // Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
 769      // All used data fields are preprocessed as in https://github.com/EleutherAI/lm-evaluation-harness/blob/df3da98c5405deafd519c2ddca52bb7c3fe36bef/lm_eval/tasks/hellaswag.py#L62-L68
 770      //
 771      // All 10042 tasks should be extracted to keep the results standardized like other implementations.
 772      //
 773      // Datafile layout:
 774      // ['??'] denotes json fields
 775      // 6 lines per task:
 776      // ['activity_label'] + ": " +['ctx']  - The first part of the query, the context
 777      // ['label'] - The index the best common sense ending aka gold ending
 778      // ['endings'][0] - Endings added to the first part of the query
 779      // ['endings'][1]
 780      // ['endings'][2]
 781      // ['endings'][3]
 782  
 783      std::vector<std::string> prompt_lines;
 784      std::istringstream strstream(params.prompt);
 785      std::string line;
 786  
 787      while (std::getline(strstream,line,'\n')) {
 788          prompt_lines.push_back(line);
 789      }
 790  
 791      if (prompt_lines.size() % 6 != 0) {
 792          fprintf(stderr, "%s : number of lines in prompt not a multiple of 6.\n", __func__);
 793          return;
 794      }
 795  
 796      size_t hs_task_count = prompt_lines.size()/6;
 797      fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
 798  
 799      const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
 800      fprintf(stderr, "================================= is_spm = %d\n", is_spm);
 801  
 802      // The tasks should be randomized so the score stabilizes quickly.
 803      bool randomize_tasks = true;
 804  
 805      // Number of tasks to use when computing the score
 806      if (params.hellaswag_tasks < hs_task_count) {
 807          hs_task_count = params.hellaswag_tasks;
 808      }
 809  
 810      // The random seed should not impact the final result if the computation is done over enough tasks, so kept hardcoded for now
 811      std::mt19937 rng(1);
 812  
 813      // Dataholder for hellaswag tasks
 814      struct hs_data_t {
 815          std::string context;
 816          size_t gold_ending_idx;
 817          std::string ending[4];
 818          size_t ending_logprob_count[4];
 819          double ending_logprob[4];
 820  
 821          size_t i_logits;        // starting index of logits in the llama_batch
 822          size_t common_prefix;   // max number of initial tokens that are the same in all sentences
 823          size_t required_tokens; // needed number of tokens to evaluate all 4 endings
 824          std::vector<llama_token> seq_tokens[4];
 825      };
 826  
 827      fprintf(stderr, "%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first")  );
 828  
 829      // Select and read data from prompt lines
 830      std::vector<hs_data_t> hs_data(hs_task_count);
 831      for (size_t i = 0; i < hs_task_count; i++) {
 832          size_t idx = i;
 833  
 834          auto & hs_cur = hs_data[i];
 835  
 836          // Select a random example of those left in the prompt
 837          if (randomize_tasks) {
 838              std::uniform_int_distribution<size_t> dist(0, prompt_lines.size()/6-1 ) ;
 839              idx = dist(rng);
 840          }
 841  
 842          hs_cur.context = prompt_lines[idx*6];
 843          hs_cur.gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
 844          for (size_t j = 0; j < 4; j++) {
 845              hs_cur.ending[j] = prompt_lines[idx*6+2+j];
 846              hs_cur.seq_tokens[j] = ::llama_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], true);
 847          }
 848  
 849          // determine the common prefix of the endings
 850          hs_cur.common_prefix = 0;
 851          for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) {
 852              if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] ||
 853                  hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] ||
 854                  hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[3][k]) {
 855                  break;
 856              }
 857              hs_cur.common_prefix++;
 858          }
 859          hs_cur.required_tokens = hs_cur.common_prefix +
 860              hs_cur.seq_tokens[0].size() - hs_cur.common_prefix +
 861              hs_cur.seq_tokens[1].size() - hs_cur.common_prefix +
 862              hs_cur.seq_tokens[2].size() - hs_cur.common_prefix +
 863              hs_cur.seq_tokens[3].size() - hs_cur.common_prefix;
 864  
 865          //GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, true).size());
 866  
 867          // Delete the selected random example from the prompt
 868          if (randomize_tasks) {
 869              prompt_lines.erase( std::next(prompt_lines.begin(),idx*6)  , std::next(prompt_lines.begin(),idx*6+6) );
 870          }
 871      }
 872  
 873      fprintf(stderr, "%s : calculating hellaswag score over selected tasks.\n", __func__);
 874  
 875      printf("\ntask\tacc_norm\n");
 876  
 877      double acc = 0.0f;
 878  
 879      const int n_vocab = llama_n_vocab(llama_get_model(ctx));
 880      const int n_ctx   = llama_n_ctx(ctx);
 881      const int n_batch = params.n_batch;
 882  
 883      const int max_tasks_per_batch = 32;
 884      const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
 885  
 886      llama_batch batch = llama_batch_init(n_ctx, 0, 4);
 887  
 888      std::vector<float> tok_logits(n_vocab);
 889      // TODO: this could be made smaller; it's currently the worst-case size
 890      std::vector<float> batch_logits(n_vocab*n_ctx);
 891  
 892      std::vector<std::pair<size_t, llama_token>> eval_pairs;
 893      std::vector<float> eval_results;
 894      std::vector<std::thread> workers(std::thread::hardware_concurrency());
 895  
 896      for (size_t i0 = 0; i0 < hs_task_count; i0++) {
 897          int n_cur = 0;
 898  
 899          size_t i1 = i0;
 900          size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
 901  
 902          llama_batch_clear(batch);
 903  
 904          // batch as much tasks as possible into the available context
 905          // each task has 4 unique sequence ids - one for each ending
 906          // the common prefix is shared among the 4 sequences to save tokens
 907          // we extract logits only from the last common token and from all ending tokens of each sequence
 908          while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
 909              auto & hs_cur = hs_data[i1];
 910              int n_logits = 0;
 911  
 912              const int s0 = 4*(i1 - i0);
 913              if (s0 + 4 > max_seq) {
 914                  break;
 915              }
 916  
 917              for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
 918                  llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
 919              }
 920              batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
 921              n_logits += 1;
 922  
 923              for (int s = 0; s < 4; ++s) {
 924                  const size_t seq_tokens_size = hs_cur.seq_tokens[s].size();
 925                  // TODO: don't evaluate the last token of each sequence
 926                  for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
 927                      const bool needs_logits = i < seq_tokens_size - 1;
 928                      llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
 929                      n_logits += needs_logits;
 930                  }
 931              }
 932  
 933              hs_cur.i_logits = i_logits;
 934              i_logits += n_logits;
 935  
 936              n_cur += hs_data[i1].required_tokens;
 937              if (++i1 == hs_task_count) {
 938                  break;
 939              }
 940          }
 941  
 942          if (i0 == i1) {
 943              fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
 944              return;
 945          }
 946  
 947          llama_kv_cache_clear(ctx);
 948  
 949          // decode all tasks [i0, i1)
 950          if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
 951              fprintf(stderr, "%s: llama_decode() failed\n", __func__);
 952              return;
 953          }
 954  
 955          // Compute log-probs in parallel
 956          // First we collect all tasks
 957          eval_pairs.clear();
 958          for (size_t i = i0; i < i1; ++i) {
 959              auto & hs_cur = hs_data[i];
 960              size_t li = 1; // skip the last logit of the common prefix (computed separately below)
 961              for (int s = 0; s < 4; ++s) {
 962                  for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
 963                      eval_pairs.emplace_back(hs_cur.i_logits + li++, hs_cur.seq_tokens[s][j + 1]);
 964                  }
 965              }
 966          }
 967          // Then we do the actual calculation
 968          compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
 969  
 970          size_t ir = 0;
 971  
 972          // compute the logprobs for each ending of the decoded tasks
 973          for (size_t i = i0; i < i1; ++i) {
 974              auto & hs_cur = hs_data[i];
 975  
 976              // get the logits of the last token of the common prefix
 977              std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*hs_cur.i_logits, n_vocab*sizeof(float));
 978  
 979              const auto first_probs = softmax(tok_logits);
 980  
 981              for (int s = 0; s < 4; ++s) {
 982                  hs_cur.ending_logprob_count[s] = 1;
 983                  hs_cur.ending_logprob[s] = std::log(first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]);
 984                  for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
 985                      hs_cur.ending_logprob[s] += eval_results[ir++];
 986                      hs_cur.ending_logprob_count[s]++;
 987                  }
 988                  hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s];
 989              }
 990  
 991              // Find the ending with maximum logprob
 992              size_t ending_logprob_max_idx = 0;
 993              double ending_logprob_max_val = hs_cur.ending_logprob[0];
 994              for (size_t s = 1; s < 4; s++) {
 995                  if (hs_cur.ending_logprob[s] > ending_logprob_max_val) {
 996                      ending_logprob_max_idx = s;
 997                      ending_logprob_max_val =  hs_cur.ending_logprob[s];
 998                  }
 999              }
1000  
1001              //printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_cur.gold_ending_idx);
1002  
1003              // If the gold ending got the maximum logprobe add one accuracy point
1004              if (ending_logprob_max_idx == hs_cur.gold_ending_idx) {
1005                  acc += 1.0;
1006              }
1007  
1008              // Print the accumulated accuracy mean x 100
1009              printf("%zu\t%.8lf\n", i + 1, acc/double(i + 1)*100.0);
1010              fflush(stdout);
1011          }
1012  
1013          i0 = i1 - 1;
1014      }
1015  
1016      llama_batch_free(batch);
1017  
1018      printf("\n");
1019  }
1020  
1021  struct winogrande_entry {
1022      std::string first;
1023      std::string second;
1024      std::array<std::string, 2> choices;
1025      int answer;
1026  
1027      size_t i_logits;
1028      size_t common_prefix;
1029      size_t required_tokens;
1030      size_t n_base1; // number of tokens for context + choice 1
1031      size_t n_base2; // number of tokens for context + choice 2
1032      std::vector<llama_token> seq_tokens[2];
1033  };
1034  
1035  static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string & prompt) {
1036      std::vector<winogrande_entry> result;
1037      std::istringstream in(prompt);
1038      std::string line;
1039      std::array<int, 4> comma_pos;
1040      while (true) {
1041          std::getline(in, line);
1042          if (in.fail() || in.eof()) break;
1043          int ipos = 0;
1044          bool quote_open = false;
1045          for (int i = 0; i < int(line.size()); ++i) {
1046              if (!quote_open) {
1047                  if (line[i] == ',') {
1048                      comma_pos[ipos++] = i;
1049                      if (ipos == 4) break;
1050                  }
1051                  else if (line[i] == '"') {
1052                      quote_open = true;
1053                  }
1054              }
1055              else {
1056                  if (line[i] == '"') {
1057                      quote_open = false;
1058                  }
1059              }
1060          }
1061          if (ipos != 4) {
1062              printf("%s: failed to find comma separators in <%s>\n", __func__, line.c_str());
1063              continue;
1064          }
1065          auto sentence = line[comma_pos[0]+1] == '"' ? line.substr(comma_pos[0]+2, comma_pos[1] - comma_pos[0] - 3)
1066                                                      : line.substr(comma_pos[0]+1, comma_pos[1] - comma_pos[0] - 1);
1067          auto choice1 = line.substr(comma_pos[1]+1, comma_pos[2] - comma_pos[1] - 1);
1068          auto choice2 = line.substr(comma_pos[2]+1, comma_pos[3] - comma_pos[2] - 1);
1069          auto answer  = line.substr(comma_pos[3]+1, line.size() - comma_pos[3] - 1);
1070          auto index = line.substr(0, comma_pos[0]);
1071          int where = 0;
1072          for ( ; where < int(sentence.size()); ++where) {
1073              if (sentence[where] == '_') break;
1074          }
1075          if (where == int(sentence.size())) {
1076              printf("%s: no _ in <%s>\n", __func__, sentence.c_str());
1077              continue;
1078          }
1079          std::istringstream stream(answer.c_str());
1080          int i_answer; stream >> i_answer;
1081          if (stream.fail() || i_answer < 1 || i_answer > 2) {
1082              printf("%s: failed to parse answer <%s>\n", __func__, answer.c_str());
1083              continue;
1084          }
1085          result.emplace_back();
1086          auto& wg = result.back();
1087          wg.first = sentence.substr(0, where);
1088          wg.second = sentence.substr(where + 1, sentence.size() - where - 1);
1089          wg.choices[0] = std::move(choice1);
1090          wg.choices[1] = std::move(choice2);
1091          wg.answer = i_answer;
1092      }
1093      return result;
1094  }
1095  
1096  /*
1097   * Evaluates the Winogrande score.
1098   * Uses a CSV containing task index, dentence, choice 1, choice 2, answer (1 or 2)
1099   * You can get one such dataset from e.g. https://huggingface.co/datasets/ikawrakow/winogrande-eval-for-llama.cpp
1100   * As an example, the 1st row in the above dataset is
1101   *
1102   *    0,Sarah was a much better surgeon than Maria so _ always got the easier cases.,Sarah,Maria,2
1103   *
1104   */
1105  static void winogrande_score(llama_context * ctx, const gpt_params & params) {
1106  
1107      constexpr int k_min_trailing_ctx = 3;
1108  
1109      auto data = load_winogrande_from_csv(params.prompt);
1110      if (data.empty()) {
1111          fprintf(stderr, "%s: no tasks\n", __func__);
1112          return;
1113      }
1114  
1115      fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, data.size());
1116  
1117      if (params.winogrande_tasks > 0 && params.winogrande_tasks < data.size()) {
1118          fprintf(stderr, "%s : selecting %zu random tasks\n", __func__, params.winogrande_tasks);
1119          std::mt19937 rng(1);
1120          std::vector<int> aux(data.size());
1121          for (int i = 0; i < int(data.size()); ++i) {
1122              aux[i] = i;
1123          }
1124          float scale = 1/(1.f + (float)rng.max());
1125          std::vector<winogrande_entry> selected;
1126          selected.resize(params.winogrande_tasks);
1127          for (int i = 0; i < int(params.winogrande_tasks); ++i) {
1128              int j = int(scale*rng()*aux.size());
1129              selected[i] = std::move(data[aux[j]]);
1130              aux[j] = aux.back();
1131              aux.pop_back();
1132          }
1133          data = std::move(selected);
1134      }
1135  
1136      fprintf(stderr, "%s : tokenizing selected tasks\n", __func__);
1137  
1138      for (auto & task : data) {
1139          task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, true);
1140          task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, true);
1141  
1142          task.common_prefix = 0;
1143          for (size_t k = 0; k < task.seq_tokens[0].size(); k++) {
1144              if (task.seq_tokens[0][k] != task.seq_tokens[1][k]) {
1145                  break;
1146              }
1147              task.common_prefix++;
1148          }
1149  
1150          // TODO: the last token of each of the sequences don't need to be evaluated
1151          task.required_tokens = task.common_prefix +
1152              task.seq_tokens[0].size() - task.common_prefix +
1153              task.seq_tokens[1].size() - task.common_prefix;
1154  
1155          task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], true).size();
1156          task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], true).size();
1157      }
1158  
1159      fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__);
1160  
1161      const int n_vocab = llama_n_vocab(llama_get_model(ctx));
1162      const int n_ctx   = llama_n_ctx(ctx);
1163      const int n_batch = params.n_batch;
1164  
1165      const int max_tasks_per_batch = 128;
1166      const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
1167  
1168      llama_batch batch = llama_batch_init(n_ctx, 0, 2);
1169  
1170      std::vector<float> tok_logits(n_vocab);
1171      // TODO: this could be made smaller; it's currently the worst-case size
1172      std::vector<float> batch_logits(n_vocab*n_ctx);
1173  
1174      std::vector<std::pair<size_t, llama_token>> eval_pairs;
1175      std::vector<float> eval_results;
1176      std::vector<std::thread> workers(std::thread::hardware_concurrency());
1177  
1178      int n_correct = 0;
1179      int n_done    = 0;
1180  
1181      for (size_t i0 = 0; i0 < data.size(); i0++) {
1182          int n_cur = 0;
1183  
1184          size_t i1 = i0;
1185          size_t i_logits = 0;
1186  
1187          llama_batch_clear(batch);
1188  
1189          while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
1190              int n_logits = 0;
1191              const int s0 = 2*(i1 - i0);
1192              if (s0 + 2 > max_seq) {
1193                  break;
1194              }
1195  
1196              for (size_t i = 0; i < data[i1].common_prefix; ++i) {
1197                  llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
1198              }
1199              batch.logits[batch.n_tokens - 1] = true;
1200              n_logits += 1;
1201  
1202              for (int s = 0; s < 2; ++s) {
1203                  // TODO: end before the last token, no need to predict past the end of the sequences
1204                  for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
1205                      llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
1206                      n_logits += 1;
1207                  }
1208              }
1209  
1210              data[i1].i_logits = i_logits;
1211              i_logits += n_logits;
1212  
1213              n_cur += data[i1].required_tokens;
1214              if (++i1 == data.size()) {
1215                  break;
1216              }
1217          }
1218  
1219          if (i0 == i1) {
1220              fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
1221              return;
1222          }
1223  
1224          llama_kv_cache_clear(ctx);
1225  
1226          // decode all tasks [i0, i1)
1227          if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
1228              fprintf(stderr, "%s: llama_decode() failed\n", __func__);
1229              return;
1230          }
1231  
1232          eval_pairs.clear();
1233          for (size_t i = i0; i < i1; ++i) {
1234              auto & task = data[i];
1235  
1236              const bool skip_choice =
1237                  task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
1238                  task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
1239  
1240              const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
1241              const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
1242              size_t li = n_base1 - task.common_prefix;
1243              for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
1244                  eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[0][j+1]);
1245              }
1246              const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
1247              const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
1248              // FIXME: this uses the wrong first logits when not skipping the choice word
1249              li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - task.common_prefix;
1250              for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
1251                  eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[1][j+1]);
1252              }
1253          }
1254          compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
1255  
1256          size_t ir = 0;
1257          for (size_t i = i0; i < i1; ++i) {
1258              auto & task = data[i];
1259  
1260              const bool skip_choice =
1261                  task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
1262                  task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
1263  
1264              float score_1st = 0;
1265              const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
1266              const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
1267              for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
1268                  score_1st += eval_results[ir++];
1269              }
1270              score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st);
1271  
1272              float score_2nd = 0;
1273              const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
1274              const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
1275              for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
1276                  score_2nd += eval_results[ir++];
1277              }
1278              score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd);
1279  
1280              int result = score_1st > score_2nd ? 1 : 2;
1281  
1282              if (result == task.answer) {
1283                  ++n_correct;
1284              }
1285              ++n_done;
1286  
1287              // print the accumulated accuracy mean x 100
1288              printf("%zu\t%.4lf\t%10.6f  %10.6f  %d  %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer);
1289              fflush(stdout);
1290          }
1291  
1292          i0 = i1 - 1;
1293      }
1294  
1295      printf("\n");
1296  
1297      if (n_done < 100) return;
1298  
1299      const float p = 1.f*n_correct/n_done;
1300      const float sigma = 100.f*sqrt(p*(1-p)/(n_done-1));
1301      printf("Final Winogrande score(%d tasks): %.4lf +/- %.4lf\n", n_done, 100*p, sigma);
1302  }
1303  
1304  static bool deserialize_string(std::istream & in, std::string & str) {
1305      uint32_t size;
1306      if (!in.read((char *)&size, sizeof(size)).fail()) {
1307          str.resize(size);
1308          if (!in.read((char *)&str[0], size).fail()) return true;
1309      }
1310      return false;
1311  }
1312  
1313  struct multiple_choice_answers {
1314      std::vector<std::string> answers;
1315      std::vector<int>         labels;
1316      bool deserialize(std::istream& in) {
1317          uint32_t n;
1318          in.read((char *)&n, sizeof(n));
1319          if (in.fail() || n > 100) return false; // 100 as max. number of answers should be good enough for any practical purpose
1320          answers.resize(n);
1321          labels.resize(n);
1322          for (auto& a : answers) {
1323              if (!deserialize_string(in, a)) return false;
1324          }
1325          in.read((char *)labels.data(), n*sizeof(int));
1326          return !in.fail();
1327      }
1328  };
1329  
1330  struct multiple_choice_task {
1331      std::string question;         // the question (or context that needs to be continued)
1332      multiple_choice_answers mc1;  // possible answers (continuations) with a single correct answer
1333      multiple_choice_answers mc2;  // possible answers (continuations) with multiple correct answers - not handled yet
1334      bool deserialize(std::istream& in) {
1335          if (!deserialize_string(in, question)) return false;
1336          return mc1.deserialize(in) && mc2.deserialize(in);
1337      }
1338  
1339      // For evaluation
1340      size_t i_logits;        // starting index of logits in the llama_batch
1341      size_t common_prefix;   // max number of initial tokens that are the same in all sentences
1342      size_t required_tokens; // needed number of tokens to evaluate all answers
1343      std::vector<std::vector<llama_token>> seq_tokens;
1344      std::vector<float> log_probs;
1345  };
1346  
1347  static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choice_task& task, bool log_error) {
1348      if (task.question.empty() || task.mc1.answers.empty()) {
1349          if (log_error) {
1350              printf("%s: found bad task with empty question and/or answers\n", __func__);
1351          }
1352          return false;
1353      }
1354      task.seq_tokens.reserve(task.mc1.answers.size());
1355      for (auto& answer : task.mc1.answers) {
1356          if (answer.empty()) {
1357              if (log_error) {
1358                  printf("%s: found empty answer\n", __func__);
1359              }
1360              return false;
1361          }
1362          task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, true));
1363      }
1364      auto min_len = task.seq_tokens.front().size();
1365      for (auto& seq : task.seq_tokens) {
1366          min_len = std::min(min_len, seq.size());
1367      }
1368      task.common_prefix = 0;
1369      for (size_t k = 0; k < min_len; ++k) {
1370          auto token = task.seq_tokens[0][k];
1371          bool all_same = true;
1372          for (size_t i = 1; i < task.seq_tokens.size(); ++i) {
1373              if (task.seq_tokens[i][k] != token) {
1374                  all_same = false;
1375                  break;
1376              }
1377          }
1378          if (!all_same) {
1379              break;
1380          }
1381          ++task.common_prefix;
1382      }
1383      task.required_tokens = task.common_prefix;
1384      for (auto& seq : task.seq_tokens) {
1385          task.required_tokens += seq.size() - task.common_prefix;
1386      }
1387      return true;
1388  }
1389  
1390  //
1391  // Calculates score for multiple choice tasks with single correct answer from prompt.
1392  // Commonly used LLM evaluation metrics of this type are
1393  //   * ARC
1394  //   * HellaSwag
1395  //   * MMLU
1396  //   * TruthfulQA
1397  //
1398  // Validation datasets for these 4 tests can be found at
1399  //     https://huggingface.co/datasets/ikawrakow/validation-datasets-for-llama.cpp
1400  // The data for these datasets was extracted from
1401  //     git@hf.co:datasets/allenai/ai2_arc
1402  //     https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
1403  //     git@hf.co:datasets/Stevross/mmlu
1404  //     https://huggingface.co/datasets/truthful_qa
1405  //
1406  static void multiple_choice_score(llama_context * ctx, const gpt_params & params) {
1407  
1408      std::istringstream strstream(params.prompt);
1409      uint32_t n_task;
1410      strstream.read((char *)&n_task, sizeof(n_task));
1411      if (strstream.fail() || n_task == 0) {
1412          printf("%s: no tasks\n", __func__);
1413          return;
1414      }
1415      printf("%s: there are %u tasks in prompt\n", __func__, n_task);
1416      std::vector<uint32_t> task_pos(n_task);
1417      strstream.read((char *)task_pos.data(), task_pos.size()*sizeof(uint32_t));
1418      if (strstream.fail()) {
1419          printf("%s: failed to read task positions from prompt\n", __func__);
1420          return;
1421      }
1422  
1423      std::vector<multiple_choice_task> tasks;
1424      if (params.multiple_choice_tasks == 0 || params.multiple_choice_tasks >= (size_t)n_task) {
1425          // Use all tasks
1426          tasks.resize(n_task);
1427          printf("%s: reading tasks", __func__);
1428          int n_dot = std::max((int) n_task/100, 1);
1429          int i = 0;
1430          for (auto& task : tasks) {
1431              ++i;
1432              if (!task.deserialize(strstream)) {
1433                  printf("%s: failed to read task %d of %u\n", __func__, i, n_task);
1434                  return;
1435              }
1436              if (i%n_dot == 0) printf(".");
1437          }
1438          printf("done\n");
1439      }
1440      else {
1441          printf("%s: selecting %zu random tasks from %u tasks available\n", __func__, params.multiple_choice_tasks, n_task);
1442          std::mt19937 rng(1);
1443          std::vector<int> aux(n_task);
1444          for (uint32_t i = 0; i < n_task; ++i) aux[i] = i;
1445          float scale = 1.f/(1.f + (float)std::mt19937::max());
1446          tasks.resize(params.multiple_choice_tasks);
1447          for (auto& task : tasks) {
1448              int j = (int)(scale * rng() * aux.size());
1449              int idx = aux[j];
1450              aux[j] = aux.back();
1451              aux.pop_back();
1452              strstream.seekg(task_pos[idx], std::ios::beg);
1453              if (!task.deserialize(strstream)) {
1454                  printf("%s: failed to read task %d at position %u\n", __func__, idx, task_pos[idx]);
1455                  return;
1456              }
1457          }
1458          n_task = params.multiple_choice_tasks;
1459      }
1460  
1461      printf("%s: preparing task data", __func__);
1462      fflush(stdout);
1463      if (n_task > 500) {
1464          printf("...");
1465          fflush(stdout);
1466          std::atomic<int> counter(0);
1467          std::atomic<int> n_bad(0);
1468          auto prepare = [&counter, &n_bad, &tasks, ctx] () {
1469              int num_tasks = tasks.size();
1470              int n_bad_local = 0;
1471              while (true) {
1472                  int first = counter.fetch_add(K_TOKEN_CHUNK);
1473                  if (first >= num_tasks) {
1474                      if (n_bad_local > 0) n_bad += n_bad_local;
1475                      break;
1476                  }
1477                  int last = std::min(first + K_TOKEN_CHUNK, num_tasks);
1478                  for (int i = first; i < last; ++i) {
1479                      if (!multiple_choice_prepare_one_task(ctx, tasks[i], false)) ++n_bad_local;
1480                  }
1481              }
1482          };
1483          size_t max_thread = std::thread::hardware_concurrency();
1484          max_thread = std::min(max_thread, (tasks.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK);
1485          std::vector<std::thread> workers(max_thread-1);
1486          for (auto& w : workers) w = std::thread(prepare);
1487          prepare();
1488          for (auto& w : workers) w.join();
1489          printf("done\n");
1490          fflush(stdout);
1491          int nbad = n_bad;
1492          if (nbad > 0) {
1493              printf("%s: found %d malformed tasks\n", __func__, nbad);
1494              return;
1495          }
1496      } else {
1497          int n_dot = std::max((int) n_task/100, 1);
1498          int i_task = 0;
1499          for (auto& task : tasks) {
1500              ++i_task;
1501              if (!multiple_choice_prepare_one_task(ctx, task, true)) {
1502                  return;
1503              }
1504              if (i_task%n_dot == 0) {
1505                  printf(".");
1506                  fflush(stdout);
1507              }
1508          }
1509          printf("done\n");
1510      }
1511  
1512      printf("%s : calculating TruthfulQA score over %zu tasks.\n", __func__, tasks.size());
1513  
1514      printf("\ntask\tacc_norm\n");
1515  
1516      const int n_vocab = llama_n_vocab(llama_get_model(ctx));
1517      const int n_ctx   = llama_n_ctx(ctx);
1518      const int n_batch = params.n_batch;
1519  
1520      const int max_tasks_per_batch = 32;
1521      const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
1522  
1523      llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
1524  
1525      std::vector<float> tok_logits(n_vocab);
1526      std::vector<float> batch_logits(n_vocab*n_ctx);
1527  
1528      std::vector<std::pair<size_t, llama_token>> eval_pairs;
1529      std::vector<float> eval_results;
1530      std::vector<std::thread> workers(std::thread::hardware_concurrency());
1531      std::vector<int> batch_indeces;
1532  
1533      int n_done = 0;
1534      int n_correct = 0;
1535      int n_tot_answers = 0;
1536  
1537      for (size_t i0 = 0; i0 < tasks.size(); i0++) {
1538          int n_cur = 0;
1539  
1540          size_t i1 = i0;
1541          size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
1542  
1543          llama_batch_clear(batch);
1544  
1545          // batch as much tasks as possible into the available context
1546          // each task has 4 unique sequence ids - one for each ending
1547          // the common prefix is shared among the 4 sequences to save tokens
1548          // we extract logits only from the last common token and from all ending tokens of each sequence
1549          int s0 = 0;
1550          while (n_cur + (int) tasks[i1].required_tokens <= n_ctx) {
1551              auto& cur_task = tasks[i1];
1552              int n_logits = 0;
1553  
1554              int num_answers = cur_task.seq_tokens.size();
1555              if (s0 + num_answers > max_seq) {
1556                  break;
1557              }
1558  
1559              if (int(batch_indeces.size()) != num_answers) {
1560                  batch_indeces.resize(num_answers);
1561              }
1562              for (int s = 0; s < num_answers; ++s) batch_indeces[s] = s0 + s;
1563  
1564              for (size_t i = 0; i < cur_task.common_prefix; ++i) {
1565                  //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
1566                  llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
1567              }
1568              batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
1569              n_logits += 1;
1570  
1571              for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
1572                  const size_t seq_tokens_size = cur_task.seq_tokens[s].size();
1573                  // TODO: don't evaluate the last token of each sequence
1574                  for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
1575                      const bool needs_logits = i < seq_tokens_size - 1;
1576                      llama_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
1577                      n_logits += needs_logits;
1578                  }
1579              }
1580  
1581              s0 += num_answers;
1582  
1583              cur_task.i_logits = i_logits;
1584              i_logits += n_logits;
1585  
1586              n_cur += cur_task.required_tokens;
1587              if (++i1 == tasks.size()) {
1588                  break;
1589              }
1590          }
1591  
1592          if (i0 == i1) {
1593              fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
1594              return;
1595          }
1596  
1597          llama_kv_cache_clear(ctx);
1598  
1599          // decode all tasks [i0, i1)
1600          if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
1601              fprintf(stderr, "%s: llama_decode() failed\n", __func__);
1602              return;
1603          }
1604  
1605          // Compute log-probs in parallel
1606          // First we collect all tasks
1607          eval_pairs.clear();
1608          for (size_t i = i0; i < i1; ++i) {
1609              auto& cur_task = tasks[i];
1610              size_t li = 1; // skip the last logit of the common prefix (computed separately below)
1611              for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
1612                  for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
1613                      eval_pairs.emplace_back(cur_task.i_logits + li++, cur_task.seq_tokens[s][j + 1]);
1614                  }
1615              }
1616          }
1617          // Then we do the actual calculation
1618          compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
1619  
1620          size_t ir = 0;
1621  
1622          // compute the logprobs for each ending of the decoded tasks
1623          for (size_t i = i0; i < i1; ++i) {
1624              auto & cur_task = tasks[i];
1625              //printf("==== Evaluating <%s> with correct answer ", cur_task.question.c_str());
1626              //for (int j = 0; j < int(cur_task.mc1.labels.size()); ++j) {
1627              //    if (cur_task.mc1.labels[j] == 1) {
1628              //        printf("%d", j+1);
1629              //    }
1630              //}
1631              //printf("\n    common_prefix: %zu\n", cur_task.common_prefix);
1632  
1633              // get the logits of the last token of the common prefix
1634              std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*cur_task.i_logits, n_vocab*sizeof(float));
1635  
1636              const auto first_probs = softmax(tok_logits);
1637  
1638              cur_task.log_probs.resize(cur_task.seq_tokens.size());
1639              for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
1640                  size_t count = 1;
1641                  float  log_prob  = std::log(first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]);
1642                  for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
1643                      //printf("        %zu  %g\n", ir, eval_results[ir]);
1644                      ++count;
1645                      log_prob += eval_results[ir++];
1646                  }
1647                  cur_task.log_probs[s] = log_prob / count;
1648                  //printf("        Final: %g\n", log_prob / count);
1649                  //printf("    <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob/count);
1650              }
1651  
1652              // Find the ending with maximum logprob
1653              size_t logprob_max_idx = 0;
1654              float  logprob_max_val = cur_task.log_probs[0];
1655              for (size_t s = 1; s < cur_task.log_probs.size(); s++) {
1656                  if (cur_task.log_probs[s] > logprob_max_val) {
1657                      logprob_max_val = cur_task.log_probs[s];
1658                      logprob_max_idx = s;
1659                  }
1660              }
1661  
1662              n_tot_answers += cur_task.log_probs.size();
1663              if (cur_task.mc1.labels[logprob_max_idx] == 1) {
1664                  ++n_correct;
1665              }
1666              ++n_done;
1667  
1668              // Print the accumulated accuracy mean x 100
1669              printf("%d\t%.8lf\n", n_done, 100.*n_correct/n_done);
1670              fflush(stdout);
1671          }
1672  
1673          i0 = i1 - 1;
1674      }
1675  
1676      llama_batch_free(batch);
1677  
1678      if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return;
1679  
1680      float p = 1.f*n_correct/n_done;
1681      float sigma = sqrt(p*(1-p)/(n_done-1));
1682      printf("\n Final result: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
1683      p = 1.f*n_done/n_tot_answers;
1684      sigma = sqrt(p*(1-p)/(n_done-1));
1685      printf("Random chance: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
1686  
1687      printf("\n");
1688  }
1689  
1690  static void kl_divergence(llama_context * ctx, const gpt_params & params) {
1691      if (params.logits_file.empty()) {
1692          fprintf(stderr, "%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__);
1693          return;
1694      }
1695      std::ifstream in(params.logits_file.c_str(), std::ios::binary);
1696      if (!in) {
1697          fprintf(stderr, "%s: failed to open %s\n", __func__, params.logits_file.c_str());
1698          return;
1699      }
1700      {
1701          char check[9]; check[8] = 0;
1702          in.read(check, 8);
1703          if (in.fail() || strncmp("_logits_", check, 8) != 0) {
1704              fprintf(stderr, "%s: %s does not look like a file containing log-probabilities\n", __func__, params.logits_file.c_str());
1705              return;
1706          }
1707      }
1708  
1709      uint32_t n_ctx;
1710      in.read((char *)&n_ctx, sizeof(n_ctx));
1711      if (n_ctx > llama_n_ctx(ctx)) {
1712          fprintf(stderr, "%s: %s has been computed with %u, while the current context is %d. Increase it with -c and retry\n",
1713                  __func__, params.logits_file.c_str(), n_ctx, params.n_ctx);
1714      }
1715  
1716      int n_vocab, n_chunk;
1717      in.read((char *)&n_vocab, sizeof(n_vocab));
1718      in.read((char *)&n_chunk, sizeof(n_chunk));
1719      if (in.fail()) {
1720          fprintf(stderr, "%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
1721          return;
1722      }
1723      if (n_vocab != llama_n_vocab(llama_get_model(ctx))) {
1724          fprintf(stderr, "%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx)));
1725      }
1726  
1727      std::vector<llama_token> tokens(n_ctx * n_chunk);
1728      if (in.read((char *)tokens.data(), tokens.size()*sizeof(tokens[0])).fail()) {
1729          fprintf(stderr, "%s: failed reading evaluation tokens from %s\n", __func__, params.logits_file.c_str());
1730          return;
1731      }
1732  
1733      const int n_batch = params.n_batch;
1734      const int num_batches = (n_ctx + n_batch - 1)/n_batch;
1735      const int nv = 2*((n_vocab + 1)/2) + 4;
1736      const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
1737      GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1);
1738  
1739      std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
1740      std::vector<float>    kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
1741      std::vector<float> p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
1742      std::vector<float> logits;
1743      if (num_batches > 1) {
1744          logits.reserve(n_ctx * n_vocab);
1745      }
1746  
1747      std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
1748  
1749      auto mean_and_uncertainty = [] (double sum, double sum2, size_t count) {
1750          if (count < 1) {
1751              return std::make_pair(0., 0.);
1752          }
1753          double f = sum/count;
1754          double df = sum2/count - f*f;
1755          df = df > 0 && count > 10 ? sqrt(df/(count-1)) : 0.;
1756          return std::make_pair(f, df);
1757      };
1758      auto covariance = [] (double suma, double sumb, double sumab, size_t count) {
1759          if (count < 10) {
1760              return 0.0;
1761          }
1762          double var = sumab/count - (suma/count)*(sumb/count);
1763          var /= count - 1;
1764          return var;
1765      };
1766  
1767      kl_divergence_result kld;
1768      auto    kld_ptr =    kld_values.data();
1769      auto p_diff_ptr = p_diff_values.data();
1770  
1771      for (int i = 0; i < n_chunk; ++i) {
1772          const int start =     i * n_ctx;
1773          const int end   = start + n_ctx;
1774  
1775          const auto t_start = std::chrono::high_resolution_clock::now();
1776  
1777          if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
1778              fprintf(stderr, "%s: failed reading log-probs for chunk %d\n", __func__, i);
1779              return;
1780          }
1781  
1782          // clear the KV cache
1783          llama_kv_cache_clear(ctx);
1784  
1785          for (int j = 0; j < num_batches; ++j) {
1786              const int batch_start = start + j * n_batch;
1787              const int batch_size  = std::min(end - batch_start, n_batch);
1788  
1789              // save original token and restore it after eval
1790              const auto token_org = tokens[batch_start];
1791  
1792              // add BOS token for the first batch of each chunk
1793              if (add_bos && j == 0) {
1794                  tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
1795              }
1796  
1797              // TODO: use llama_batch.logits instead of relying on logits_all == true
1798              if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
1799                  fprintf(stderr, "%s : failed to eval\n", __func__);
1800                  return;
1801              }
1802  
1803              // restore the original token in case it was set to BOS
1804              tokens[batch_start] = token_org;
1805  
1806              if (num_batches > 1) {
1807                  const auto * batch_logits = llama_get_logits(ctx);
1808                  logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
1809              }
1810          }
1811  
1812          const auto t_end = std::chrono::high_resolution_clock::now();
1813  
1814          if (i == 0) {
1815              const float t_total = std::chrono::duration<float>(t_end - t_start).count();
1816              fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
1817              int total_seconds = (int)(t_total * n_chunk);
1818              if (total_seconds >= 60*60) {
1819                  fprintf(stderr, "%d hours ", total_seconds / (60*60));
1820                  total_seconds = total_seconds % (60*60);
1821              }
1822              fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
1823  
1824              printf("\nchunk             PPL               ln(PPL(Q)/PPL(base))          KL Divergence              Δp RMS            Same top p\n");
1825          }
1826  
1827          const int first = n_ctx/2;
1828          const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
1829          process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
1830                  workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr);
1831          p_diff_ptr += n_ctx - 1 - first;
1832          kld_ptr    += n_ctx - 1 - first;
1833  
1834          printf("%4d", i+1);
1835  
1836          auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
1837          const double ppl_val = exp(log_ppl.first);
1838          const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 )
1839          printf("    %9.4lf ± %9.4lf", ppl_val, ppl_unc);
1840  
1841          auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
1842          const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
1843          const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
1844          const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
1845          printf("    %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc);
1846  
1847          auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
1848          printf("    %10.5lf ± %10.5lf", kl_div.first, kl_div.second);
1849  
1850          auto p_diff_mse   = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
1851          const double p_diff_rms_val = sqrt(p_diff_mse.first);
1852          const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
1853          printf("    %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
1854  
1855          double p_top_val = 1.*kld.n_same_top/kld.count;
1856          double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1));
1857          printf("    %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc);
1858  
1859          printf("\n");
1860  
1861          fflush(stdout);
1862  
1863          logits.clear();
1864      }
1865      printf("\n");
1866  
1867      if (kld.count < 100) return; // we do not wish to do statistics on so few values
1868  
1869      std::sort(kld_values.begin(), kld_values.end());
1870      std::sort(p_diff_values.begin(), p_diff_values.end());
1871  
1872      printf("====== Perplexity statistics ======\n");
1873  
1874      auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
1875      const double ppl_val = exp(log_ppl.first);
1876      const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 )
1877      printf("Mean PPL(Q)                   : %10.6lf ± %10.6lf\n", ppl_val, ppl_unc);
1878  
1879      auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
1880      const double ppl_base_val = exp(log_ppl_base.first);
1881      const double ppl_base_unc = ppl_base_val * log_ppl_base.second; // ppl_base_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl_base.second ** 2 )
1882      printf("Mean PPL(base)                : %10.6lf ± %10.6lf\n", ppl_base_val, ppl_base_unc);
1883  
1884      const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
1885      // printf("Cov(ln(PPL(Q)), ln(PPL(base))): %10.6lf\n", log_ppl_cov);
1886      const double log_ppl_cor = log_ppl_cov / (log_ppl.second*log_ppl_base.second);
1887      printf("Cor(ln(PPL(Q)), ln(PPL(base))): %6.2lf%%\n", 100.0*log_ppl_cor);
1888  
1889      const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
1890      const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
1891      printf("Mean ln(PPL(Q)/PPL(base))     : %10.6lf ± %10.6lf\n", log_ppl_ratio_val, log_ppl_ratio_unc);
1892  
1893      const double ppl_ratio_val = exp(log_ppl_ratio_val);
1894      const double ppl_ratio_unc = ppl_ratio_val * log_ppl_ratio_unc; // ppl_ratio_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl_ratio.second ** 2 )
1895      printf("Mean PPL(Q)/PPL(base)         : %10.6lf ± %10.6lf\n", ppl_ratio_val, ppl_ratio_unc);
1896  
1897      const double ppl_cov = ppl_val * ppl_base_val * log_ppl_cov;
1898      const double ppl_diff_val = ppl_val - ppl_base_val;
1899      const double ppl_diff_unc = sqrt(ppl_unc*ppl_unc + ppl_base_unc*ppl_base_unc - 2.0*ppl_cov);
1900      printf("Mean PPL(Q)-PPL(base)         : %10.6lf ± %10.6lf\n", ppl_diff_val, ppl_diff_unc);
1901  
1902      printf("\n");
1903  
1904      printf("====== KL divergence statistics ======\n");
1905      auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
1906      printf("Mean    KLD: %10.6lf ± %10.6lf\n", kl_div.first, kl_div.second);
1907      auto kld_median = kld_values.size()%2 == 0 ? 0.5f*(kld_values[kld_values.size()/2] + kld_values[kld_values.size()/2-1])
1908                                                 : kld_values[kld_values.size()/2];
1909  
1910      auto percentile = [] (std::vector<float> values, float fraction) {
1911          if (fraction <= 0) return values.front();
1912          if (fraction >= 1) return values.back();
1913          float p = fraction*(values.size() - 1);
1914          size_t ip = size_t(p); p -= ip;
1915          return (1 - p)*values[ip] + p*values[std::min(ip+1, values.size()-1)];
1916      };
1917  
1918      printf("Maximum KLD: %10.6f\n", kld_values.back());
1919      printf("99.9%%   KLD: %10.6f\n", percentile(kld_values, 0.999f));
1920      printf("99.0%%   KLD: %10.6f\n", percentile(kld_values, 0.990f));
1921      printf("99.0%%   KLD: %10.6f\n", percentile(kld_values, 0.990f));
1922      printf("Median  KLD: %10.6f\n", kld_median);
1923      printf("10.0%%   KLD: %10.6f\n", percentile(kld_values, 0.100f));
1924      printf(" 5.0%%   KLD: %10.6f\n", percentile(kld_values, 0.050f));
1925      printf(" 1.0%%   KLD: %10.6f\n", percentile(kld_values, 0.010f));
1926      printf("Minimum KLD: %10.6f\n", kld_values.front());
1927  
1928      printf("\n");
1929  
1930      printf("====== Token probability statistics ======\n");
1931  
1932      auto p_diff = mean_and_uncertainty(kld.sum_p_diff, kld.sum_p_diff2, kld.count);
1933      printf("Mean    Δp: %6.3lf ± %5.3lf %%\n",  100.0*p_diff.first, 100.0*p_diff.second);
1934  
1935      auto p_diff_median = p_diff_values.size()%2 == 0 ? 0.5f*(p_diff_values[p_diff_values.size()/2] + p_diff_values[p_diff_values.size()/2-1])
1936                                                 : p_diff_values[p_diff_values.size()/2];
1937  
1938      printf("Maximum Δp: %6.3lf%%\n",  100.0*p_diff_values.back());
1939      printf("99.9%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.999f));
1940      printf("99.0%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.990f));
1941      printf("95.0%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.950f));
1942      printf("90.0%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.900f));
1943      printf("75.0%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.750f));
1944      printf("Median  Δp: %6.3lf%%\n",  100.0*p_diff_median);
1945      printf("25.0%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.250f));
1946      printf("10.0%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.100f));
1947      printf(" 5.0%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.050f));
1948      printf(" 1.0%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.010f));
1949      printf(" 0.1%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.001f));
1950      printf("Minimum Δp: %6.3lf%%\n",  100.0*p_diff_values.front());
1951  
1952      auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
1953      // printf("MSE Δp    : %10.6lf ± %10.6lf\n", p_diff_mse.first, p_diff_mse.second);
1954  
1955      const double p_diff_rms_val = sqrt(p_diff_mse.first);
1956      const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
1957      printf("RMS Δp    : %6.3lf ± %5.3lf %%\n", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
1958  
1959      const double same_top_p = 1.0*kld.n_same_top/kld.count;
1960      printf("Same top p: %6.3lf ± %5.3lf %%\n", 100.0*same_top_p, 100.0*sqrt(same_top_p*(1.0 - same_top_p)/(kld.count - 1)));
1961  
1962  }
1963  
1964  int main(int argc, char ** argv) {
1965      gpt_params params;
1966  
1967      params.n_ctx = 512;
1968      params.logits_all = true;
1969  
1970      if (!gpt_params_parse(argc, argv, params)) {
1971          gpt_params_print_usage(argc, argv, params);
1972          return 1;
1973      }
1974  
1975      const int32_t n_ctx = params.n_ctx;
1976  
1977      if (n_ctx <= 0) {
1978          fprintf(stderr, "%s: perplexity tool requires '--ctx-size' > 0\n", __func__);
1979          return 1;
1980      }
1981  
1982      const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
1983  
1984      if (ppl) {
1985          const int32_t n_seq = std::max(1, params.n_batch / n_ctx);
1986          const int32_t n_kv = n_seq * n_ctx;
1987  
1988          params.n_parallel = n_seq;
1989          params.n_ctx      = n_kv;
1990  
1991          params.n_batch = std::min(params.n_batch, n_kv);
1992      } else {
1993          params.n_batch = std::min(params.n_batch, params.n_ctx);
1994      }
1995  
1996      if (params.ppl_stride > 0) {
1997          fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
1998                  params.n_ctx, params.n_ctx + params.ppl_stride/2);
1999          params.n_ctx += params.ppl_stride/2;
2000      }
2001  
2002      print_build_info();
2003  
2004      if (params.seed == LLAMA_DEFAULT_SEED) {
2005          params.seed = time(NULL);
2006      }
2007  
2008      fprintf(stderr, "%s: seed  = %u\n", __func__, params.seed);
2009  
2010      std::mt19937 rng(params.seed);
2011  
2012      llama_backend_init();
2013      llama_numa_init(params.numa);
2014  
2015      llama_model * model;
2016      llama_context * ctx;
2017  
2018      // ensure there's at least enough seq_ids for HellaSwag
2019      params.n_parallel = std::max(4, params.n_parallel);
2020  
2021      // load the model and apply lora adapter, if any
2022      std::tie(model, ctx) = llama_init_from_gpt_params(params);
2023      if (model == NULL) {
2024          fprintf(stderr, "%s: error: unable to load model\n", __func__);
2025          return 1;
2026      }
2027  
2028      const int n_ctx_train = llama_n_ctx_train(model);
2029  
2030      if (params.n_ctx > n_ctx_train) {
2031          fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
2032                  __func__, n_ctx_train, params.n_ctx);
2033      }
2034  
2035      // print system information
2036      {
2037          fprintf(stderr, "\n");
2038          fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str());
2039      }
2040  
2041      struct results_perplexity results;
2042      if (params.hellaswag) {
2043          hellaswag_score(ctx, params);
2044      } else if (params.winogrande) {
2045          winogrande_score(ctx, params);
2046      } else if (params.multiple_choice) {
2047          multiple_choice_score(ctx, params);
2048      } else if (params.kl_divergence) {
2049          kl_divergence(ctx, params);
2050      } else {
2051          results = perplexity(ctx, params, n_ctx);
2052      }
2053  
2054      llama_print_timings(ctx);
2055      write_logfile(ctx, params, model, results);
2056  
2057      llama_free(ctx);
2058      llama_free_model(model);
2059  
2060      llama_backend_free();
2061  
2062      return 0;
2063  }