/ common / train.cpp
train.cpp
   1  #include "train.h"
   2  #include "common.h"
   3  
   4  #include <random>
   5  #include <sstream>
   6  #include <functional>
   7  
   8  struct random_normal_distribution {
   9      std::mt19937 gen;
  10      std::normal_distribution<float> rd;
  11      float min;
  12      float max;
  13  };
  14  
  15  struct random_uniform_distribution {
  16      std::mt19937 gen;
  17      std::uniform_real_distribution<float> rd;
  18  };
  19  
  20  struct train_state  * init_train_state() {
  21      struct train_state * state = new struct train_state;
  22      state->train_its     = 0;
  23      state->train_samples = 0;
  24      state->train_tokens  = 0;
  25      state->train_epochs  = 0;
  26      state->shuffle_samples_hash  = 0;
  27      state->shuffle_sample_count  = 0;
  28      state->shuffle_next_sample   = 0;
  29      state->shuffle_rng_state_current = "";
  30      state->shuffle_rng_state_next    = "";
  31  
  32      state->opt = new struct ggml_opt_context;
  33      state->opt->ctx = NULL;
  34      state->opt->params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
  35      state->opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
  36      state->opt->loss_after = 0.0f;
  37  
  38      return state;
  39  }
  40  
  41  void free_train_state(struct train_state  * state) {
  42      delete state->opt;
  43      delete state;
  44  }
  45  
  46  struct random_normal_distribution * init_random_normal_distribution(
  47      int seed, float mean, float std, float min, float max
  48  ) {
  49      struct random_normal_distribution * rnd = (struct random_normal_distribution *) malloc(sizeof(struct random_normal_distribution));
  50      rnd->gen = std::mt19937(seed);
  51      rnd->rd = std::normal_distribution<float>{mean, std};
  52      rnd->min = min;
  53      rnd->max = max;
  54      return rnd;
  55  }
  56  
  57  struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max) {
  58      struct random_uniform_distribution * rnd = (struct random_uniform_distribution *) malloc(sizeof(struct random_uniform_distribution));
  59      rnd->gen = std::mt19937(seed);
  60      rnd->rd = std::uniform_real_distribution<float>{min, max};
  61      return rnd;
  62  }
  63  
  64  void free_random_normal_distribution (struct random_normal_distribution  * rnd) {
  65      free(rnd);
  66  }
  67  
  68  void free_random_uniform_distribution(struct random_uniform_distribution * rnd) {
  69      free(rnd);
  70  }
  71  
  72  struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
  73      float scale = 1.0f; // xavier
  74      switch (ggml_n_dims(tensor)) {
  75          case 1:
  76              scale /= sqrtf((float) tensor->ne[0]);
  77              for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  78                  float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
  79                  *dst = scale * frand_normal(rnd);
  80              }
  81              break;
  82          case 2:
  83              scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
  84              for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  85                  for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  86                      float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
  87                      *dst = scale * frand_normal(rnd);
  88                  }
  89              }
  90              break;
  91          case 3:
  92              scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
  93              for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
  94                  for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  95                      for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  96                          float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
  97                          *dst = scale * frand_normal(rnd);
  98                      }
  99                  }
 100              }
 101              break;
 102          case 4:
 103              scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
 104              for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
 105                  for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
 106                      for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
 107                          for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
 108                              float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
 109                              *dst = scale * frand_normal(rnd);
 110                          }
 111                      }
 112                  }
 113              }
 114              break;
 115          default:
 116              die("Unsupported tensor->n_dims");
 117      };
 118      return tensor;
 119  }
 120  
 121  struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd) {
 122      switch (ggml_n_dims(tensor)) {
 123          case 1:
 124              for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
 125                  float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
 126                  *dst = frand_uniform(rnd);
 127              }
 128              break;
 129          case 2:
 130              for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
 131                  for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
 132                      float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
 133                      *dst = frand_uniform(rnd);
 134                  }
 135              }
 136              break;
 137          case 3:
 138              for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
 139                  for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
 140                      for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
 141                          float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
 142                          *dst = frand_uniform(rnd);
 143                      }
 144                  }
 145              }
 146              break;
 147          case 4:
 148              for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
 149                  for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
 150                      for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
 151                          for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
 152                              float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
 153                              *dst = frand_uniform(rnd);
 154                          }
 155                      }
 156                  }
 157              }
 158              break;
 159          default:
 160              die("Unsupported tensor->n_dims");
 161      };
 162      return tensor;
 163  }
 164  
 165  float frand() {
 166      return (float)rand()/((float)(RAND_MAX) + 1.0f);
 167  }
 168  
 169  float frand_normal(struct random_normal_distribution * rnd) {
 170      return fclamp(rnd->rd(rnd->gen), rnd->min, rnd->max);
 171  }
 172  
 173  float frand_uniform(struct random_uniform_distribution * rnd) {
 174      return rnd->rd(rnd->gen);
 175  }
 176  
 177  int clamp(const int v, const int min, const int max) {
 178      return ((v < min) ? (min) : (v > max) ? (max) : v);
 179  }
 180  
 181  float fclamp(const float v, const float min, const float max) {
 182      return ((v < min) ? (min) : (v > max) ? (max) : v);
 183  }
 184  
 185  void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
 186      GGML_ASSERT(tensor->ne[0] == ne0);
 187      GGML_ASSERT(tensor->ne[1] == 1);
 188      GGML_ASSERT(tensor->ne[2] == 1);
 189      GGML_ASSERT(tensor->ne[3] == 1);
 190  }
 191  
 192  void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) {
 193      GGML_ASSERT(tensor->ne[0] == ne0);
 194      GGML_ASSERT(tensor->ne[1] == ne1);
 195      GGML_ASSERT(tensor->ne[2] == 1);
 196      GGML_ASSERT(tensor->ne[3] == 1);
 197  }
 198  
 199  void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) {
 200      GGML_ASSERT(tensor->ne[0] == ne0);
 201      GGML_ASSERT(tensor->ne[1] == ne1);
 202      GGML_ASSERT(tensor->ne[2] == ne2);
 203      GGML_ASSERT(tensor->ne[3] == 1);
 204  }
 205  
 206  void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
 207      GGML_ASSERT(tensor->ne[0] == ne0);
 208      GGML_ASSERT(tensor->ne[1] == ne1);
 209      GGML_ASSERT(tensor->ne[2] == ne2);
 210      GGML_ASSERT(tensor->ne[3] == ne3);
 211  }
 212  
 213  int64_t get_example_targets_batch(
 214      struct llama_context * lctx,
 215      struct ggml_tensor   * tokens_input,
 216      struct ggml_tensor   * target_probs,
 217      int64_t                example_id,
 218      const size_t         * samples_offs,
 219      const size_t         * samples_begin,
 220      const size_t         * samples_size,
 221            size_t           samples_count,
 222      const llama_token    * train_data,
 223      size_t                 n_train_data,
 224      bool                   separate_with_eos,
 225      bool                   separate_with_bos,
 226      bool                   fill_with_next_samples,
 227      bool                   sample_random_offsets
 228  ) {
 229      GGML_ASSERT(samples_count > 0);
 230      GGML_ASSERT(ggml_is_matrix(tokens_input));
 231      GGML_ASSERT(ggml_is_3d(target_probs));
 232      int64_t n_vocab  = target_probs->ne[0];
 233      int64_t n_tokens = tokens_input->ne[0];
 234      int64_t n_batch  = tokens_input->ne[1];
 235      GGML_ASSERT(n_vocab  == target_probs->ne[0]);
 236      GGML_ASSERT(n_tokens == target_probs->ne[1]);
 237      GGML_ASSERT(n_batch  == target_probs->ne[2]);
 238  
 239      int64_t used_samples = 0;
 240  
 241      ggml_set_f32(target_probs, 0.0f);
 242      llama_token bos = llama_token_bos(llama_get_model(lctx));
 243      llama_token eos = llama_token_eos(llama_get_model(lctx));
 244      // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
 245      for (int k=0; k<n_batch; ++k) {
 246          // printf("%s: batch %d\n", __func__, k);
 247          size_t sample_idx   = (example_id + used_samples) % samples_count;
 248          size_t sample_offs  = sample_random_offsets ? samples_offs[sample_idx] : 0;
 249          size_t sample_begin = samples_begin[sample_idx];
 250          size_t sample_size  = samples_size[sample_idx];
 251          ++used_samples;
 252  
 253          // printf("%s: sample_idx=%zu sample=%zu\n", __func__, sample_idx, sample);
 254          GGML_ASSERT(sample_begin+sample_size-1 < n_train_data);
 255  
 256          ggml_set_i32_nd(tokens_input, 0, k, 0, 0, bos);
 257          bool sample_separation_eos = !separate_with_eos;
 258          bool sample_separation_bos = !separate_with_bos;
 259          for (int64_t i=0; i<n_tokens; ++i) {
 260              llama_token token = eos;
 261              if (sample_offs >= sample_size && fill_with_next_samples) {
 262                  if (!sample_separation_eos) {
 263                      // insert eos token to separate samples
 264                      sample_separation_eos = true;
 265                  } else if (!sample_separation_bos) {
 266                      // insert bos token to separate samples
 267                      sample_separation_bos = true;
 268                      token = bos;
 269                  } else {
 270                      // sample separation is done, continue with next sample
 271                      sample_separation_eos = !separate_with_eos;
 272                      sample_separation_bos = !separate_with_bos;
 273                      sample_offs  = 0;
 274                      sample_idx   = (example_id + used_samples) % samples_count;
 275                      sample_begin = samples_begin[sample_idx];
 276                      sample_size  = samples_size[sample_idx];
 277                      ++used_samples;
 278                  }
 279              }
 280              // note: no else-if here
 281              if (sample_offs < sample_size) {
 282                  token = clamp(train_data[sample_begin+sample_offs], 0, (llama_token) (n_vocab - 1));
 283                  ++sample_offs;
 284              }
 285              ggml_set_f32_nd(target_probs,  token, (int) i, (int) k, 0, +1.0f);
 286              if (i+1<n_tokens) {
 287                  ggml_set_i32_nd(tokens_input, (int) (i + 1), (int) k, 0, 0, token);
 288              }
 289          }
 290      }
 291  
 292      return used_samples;
 293  }
 294  
 295  void mt19937_set_state(std::mt19937& rng, const std::string& rng_state) {
 296      std::stringstream s_rng_state;
 297      s_rng_state.imbue(std::locale::classic());
 298      s_rng_state.exceptions(std::stringstream::failbit);
 299      s_rng_state.str(rng_state);
 300      s_rng_state >> rng;
 301  }
 302  
 303  std::string mt19937_get_state(const std::mt19937& rng) {
 304      std::stringstream s_rng_state;
 305      s_rng_state.imbue(std::locale::classic());
 306      s_rng_state << rng;
 307      return s_rng_state.str();
 308  }
 309  
 310  std::string mt19937_seed_to_state(unsigned seed) {
 311      std::mt19937 rng(seed);
 312      return mt19937_get_state(rng);
 313  }
 314  
 315  std::string shuffle_samples(
 316          const std::string & rng_state,
 317          size_t            * shuffled_offs,
 318          size_t            * shuffled_begins,
 319          size_t            * shuffled_sizes,
 320          const size_t      * begins,
 321          const size_t      * sizes,
 322          size_t              count) {
 323      if (count == 0) return rng_state;
 324  
 325      std::mt19937 rng;
 326      mt19937_set_state(rng, rng_state);
 327  
 328      // sort indices by random value for each index
 329      std::vector<size_t> idcs;
 330      {
 331          std::vector<unsigned> rnd;
 332          idcs.resize(count);
 333          rnd.resize(count);
 334          for (unsigned i=0; i<count; ++i) {
 335              idcs[i] = i;
 336              rnd[i]  = rng();
 337          }
 338  
 339          std::sort(idcs.begin(), idcs.end(), [&rnd](size_t a, size_t b){
 340              // stable sort for reproducibility
 341              return (rnd[a] == rnd[b]) ? (a < b) : (rnd[a] < rnd[b]);
 342          });
 343      }
 344  
 345      // create random offsets
 346      for (unsigned i=0; i<count; ++i) {
 347          shuffled_offs[i] = (size_t) ((sizes[idcs[i]] - 1) * ((double) rng() / (double) (rng.max()-1)));
 348      }
 349  
 350      // reorder begins and sizes by sorted indices
 351      for (unsigned i=0; i<count; ++i) {
 352          shuffled_begins[i] = begins[idcs[i]];
 353      }
 354  
 355      for (unsigned i=0; i<count; ++i) {
 356          shuffled_sizes[i] = sizes[idcs[i]];
 357      }
 358  
 359      return mt19937_get_state(rng);
 360  }
 361  
 362  size_t hash_combine(size_t h1, size_t h2) {
 363      return h1 ^ (h2 << 1);
 364  }
 365  
 366  size_t compute_samples_hash(const char* fn, const size_t* samples_begin, const size_t* samples_size, size_t sample_count) {
 367      std::hash<std::string> h_string;
 368      std::hash<unsigned long long> h_ull;
 369      size_t h = h_string(std::string(fn));
 370      h = hash_combine(h, h_ull((unsigned long long) sample_count));
 371      for (size_t i=0; i< sample_count; ++i) {
 372          h = hash_combine(h, h_ull((unsigned long long) samples_begin[i]));
 373          h = hash_combine(h, h_ull((unsigned long long) samples_size[i]));
 374      }
 375      return h;
 376  }
 377  
 378  std::string replace_str(const char * s, const char * needle, const char * replacement) {
 379      std::string str = s;
 380      size_t pos = str.find(needle);
 381      if (pos != std::string::npos) {
 382          str.replace(pos, strlen(needle), replacement);
 383      }
 384      return str;
 385  }
 386  
 387  void print_duration(double fmillis) {
 388      if (fmillis < 1000.0f) {
 389          printf("%.1fms", (float) fmillis);
 390          return;
 391      }
 392      const int64_t one_sec  = 1000;
 393      const int64_t one_min  = one_sec  * 60;
 394      const int64_t one_hour = one_min  * 60;
 395      const int64_t one_day  = one_hour * 24;
 396  
 397      int64_t millis  = (int64_t) fmillis;
 398      int64_t days    = millis/one_day;
 399      int64_t hours   = (millis - days*one_day)/one_hour;
 400      int64_t minutes = (millis - days*one_day - hours*one_hour)/one_min;
 401      int64_t seconds = (millis - days*one_day - hours*one_hour - minutes*one_min)/one_sec;
 402  
 403      // to print int64_t either cast to (long long int) or use macro PRId64 from <inttypes.h>
 404      if (days > 0) {
 405          printf("%lldd ", (long long int) days);
 406      }
 407      printf("%02lld:%02lld:%02lld", (long long int) hours, (long long int) minutes, (long long int) seconds);
 408  }
 409  
 410  float cosine_decay(int64_t step, int64_t decay_steps, float minimum) {
 411      if (step > decay_steps) {
 412          step = decay_steps;
 413      }
 414      const float cosine_decay = 0.50f*(1.0f + cosf(3.14159265359f*step/decay_steps));
 415      const float decay = (1 - minimum)*cosine_decay + minimum;
 416      return decay;
 417  }
 418  
 419  float cosine_decay_restart(int64_t step, int64_t decay_steps, float minimum, float restart_step_mult) {
 420      while (step > decay_steps) {
 421          step -= decay_steps;
 422          decay_steps = (int64_t) (restart_step_mult * decay_steps);
 423      }
 424      return cosine_decay(step, decay_steps, minimum);
 425  }
 426  
 427  float learning_schedule(
 428      int64_t step,
 429      int64_t warmup_steps,
 430      int64_t cos_decay_steps,
 431      float   learning_rate,
 432      float   overall_minimum,
 433      float   cos_decay_minimum,
 434      float   cos_decay_restart_step_mult,
 435      bool    enable_restart) {
 436  
 437      float result =
 438          (step < warmup_steps)
 439              ? (float) step / (float) warmup_steps
 440              : enable_restart
 441                  ? cosine_decay_restart(
 442                      step - warmup_steps,
 443                      cos_decay_steps,
 444                      cos_decay_minimum,
 445                      cos_decay_restart_step_mult)
 446                  : cosine_decay(
 447                      step,
 448                      cos_decay_steps,
 449                      cos_decay_minimum);
 450  
 451      float min = overall_minimum / learning_rate;
 452      result = min + result * (1.0f - min);
 453      return result;
 454  }
 455  
 456  static bool are_same_layout(struct ggml_tensor * a, struct ggml_tensor * b) {
 457      GGML_ASSERT(a != NULL);
 458      GGML_ASSERT(b != NULL);
 459      GGML_ASSERT(a->type == b->type);
 460      GGML_ASSERT(ggml_are_same_shape(a, b));
 461      GGML_ASSERT(ggml_is_contiguous(a) && ggml_is_contiguous(b));
 462  
 463      return true;
 464  }
 465  
 466  void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name) {
 467      if (dst == NULL) {
 468          return;
 469      }
 470      struct ggml_tensor * t  = ggml_get_tensor(ctx, name);
 471      GGML_ASSERT(are_same_layout(dst, t));
 472      memcpy(dst->data, t->data, ggml_nbytes(t));
 473  
 474      if (strlen(ggml_get_name(dst)) == 0) {
 475          ggml_set_name(dst, name);
 476      }
 477  }
 478  
 479  // gguf constants
 480  static const char * LLM_KV_OPTIMIZER_TYPE = "optimizer.type";
 481  static const char * LLM_KV_OPTIMIZER_TYPE_ADAM  = "adam";
 482  static const char * LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs";
 483  static const char * LLM_KV_OPTIMIZER_FILE_VERSION               = "optimizer.file_version";
 484  static const char * LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT     = "optimizer.convergence_past_count";
 485  static const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT            = "optimizer.parameter_count";
 486  static const char * LLM_KV_OPTIMIZER_ITERATION_COUNT            = "optimizer.iteration_count";
 487  static const char * LLM_KV_OPTIMIZER_JUST_INITIALIZED           = "optimizer.just_initialized";
 488  static const char * LLM_KV_OPTIMIZER_ADAM_BEST_LOSS             = "optimizer.adam.best_loss";
 489  static const char * LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS         = "optimizer.adam.previous_loss";
 490  static const char * LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT  = "optimizer.adam.no_improvement_count";
 491  static const char * LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count";
 492  static const char * LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS            = "optimizer.lbfgs.best_loss";
 493  static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP     = "optimizer.lbfgs.line_search_step";
 494  static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J        = "optimizer.lbfgs.line_search_j";
 495  static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K        = "optimizer.lbfgs.line_search_k";
 496  static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END      = "optimizer.lbfgs.line_search_end";
 497  static const char * LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count";
 498  
 499  static const char * LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS    = "optimizer.adam.first_moments";
 500  static const char * LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS   = "optimizer.adam.second_moments";
 501  static const char * LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values";
 502  
 503  static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS  = "optimizer.lbfgs.current_parameters";
 504  static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters";
 505  static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS   = "optimizer.lbfgs.current_gradients";
 506  static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS  = "optimizer.lbfgs.previous_gradients";
 507  static const char * LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION    = "optimizer.lbfgs.search_direction";
 508  static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES    = "optimizer.lbfgs.past_loss_values";
 509  static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA        = "optimizer.lbfgs.memory_alpha";
 510  static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS           = "optimizer.lbfgs.memory_ys";
 511  static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S            = "optimizer.lbfgs.memory_s";
 512  static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y            = "optimizer.lbfgs.memory_y";
 513  
 514  static const char * LLM_KV_TRAINING_FILE_VERSION         = "training.file_version";
 515  static const char * LLM_KV_TRAINING_ITERATION_COUNT      = "training.iteration_count";
 516  static const char * LLM_KV_TRAINING_SAMPLE_COUNT         = "training.sample_count";
 517  static const char * LLM_KV_TRAINING_TOKEN_COUNT          = "training.token_count";
 518  static const char * LLM_KV_TRAINING_EPOCH_COUNT          = "training.epoch_count";
 519  static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash";
 520  static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE    = "training.shuffle.rng_state";
 521  static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count";
 522  static const char * LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE  = "training.shuffle.next_sample";
 523  
 524  #define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
 525  { \
 526      const std::string skey(key); \
 527      const int kid = gguf_find_key(ctx, skey.c_str()); \
 528      if (kid >= 0) { \
 529          enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
 530          if (ktype != (type)) { \
 531              die_fmt("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \
 532          } \
 533          (dst) = func(ctx, kid); \
 534      } else if (req) { \
 535          die_fmt("key not found in model: %s", skey.c_str()); \
 536      } \
 537  }
 538  
 539  void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) {
 540      // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
 541  
 542      uint32_t file_version;
 543      GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_FILE_VERSION);
 544      GGML_ASSERT(file_version == 0);
 545  
 546      GGUF_GET_KEY(fctx, opt->params.past, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT);
 547      GGUF_GET_KEY(fctx, opt->iter, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ITERATION_COUNT);
 548      GGUF_GET_KEY(fctx, opt->just_initialized, gguf_get_val_bool, GGUF_TYPE_BOOL, true, LLM_KV_OPTIMIZER_JUST_INITIALIZED);
 549  
 550      uint64_t nx;
 551      GGUF_GET_KEY(fctx, nx, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_OPTIMIZER_PARAMETER_COUNT);
 552      opt->nx = (size_t) nx;
 553  
 554      // don't call ggml_opt_init until optimizer type and optimizer specific parameters are know
 555  
 556      std::string opt_type;
 557      GGUF_GET_KEY(fctx, opt_type, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_OPTIMIZER_TYPE);
 558      if (opt_type == LLM_KV_OPTIMIZER_TYPE_ADAM) {
 559          opt->params.type = GGML_OPT_TYPE_ADAM;
 560  
 561          GGUF_GET_KEY(fctx, opt->adam.fx_best,          gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS);
 562          GGUF_GET_KEY(fctx, opt->adam.fx_prev,          gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS);
 563          GGUF_GET_KEY(fctx, opt->adam.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32,  true, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT);
 564  
 565          ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
 566  
 567          copy_tensor_by_name(opt->adam.m,  f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
 568          copy_tensor_by_name(opt->adam.v,  f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
 569          copy_tensor_by_name(opt->adam.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
 570      } else if (opt_type == LLM_KV_OPTIMIZER_TYPE_LBFGS) {
 571          opt->params.type = GGML_OPT_TYPE_LBFGS;
 572  
 573          GGUF_GET_KEY(fctx, opt->params.lbfgs.m,         gguf_get_val_u32, GGUF_TYPE_UINT32,  true, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT);
 574          GGUF_GET_KEY(fctx, opt->lbfgs.fx_best,          gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS);
 575          GGUF_GET_KEY(fctx, opt->lbfgs.step,             gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP);
 576          GGUF_GET_KEY(fctx, opt->lbfgs.j,                gguf_get_val_i32, GGUF_TYPE_INT32,   true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J);
 577          GGUF_GET_KEY(fctx, opt->lbfgs.k,                gguf_get_val_i32, GGUF_TYPE_INT32,   true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K);
 578          GGUF_GET_KEY(fctx, opt->lbfgs.end,              gguf_get_val_i32, GGUF_TYPE_INT32,   true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END);
 579          GGUF_GET_KEY(fctx, opt->lbfgs.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32,  true, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT);
 580  
 581          ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
 582  
 583          copy_tensor_by_name(opt->lbfgs.x,    f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
 584          copy_tensor_by_name(opt->lbfgs.xp,   f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
 585          copy_tensor_by_name(opt->lbfgs.g,    f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
 586          copy_tensor_by_name(opt->lbfgs.gp,   f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
 587          copy_tensor_by_name(opt->lbfgs.d,    f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
 588          copy_tensor_by_name(opt->lbfgs.pf,   f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
 589          copy_tensor_by_name(opt->lbfgs.lmal, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
 590          copy_tensor_by_name(opt->lbfgs.lmys, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
 591          copy_tensor_by_name(opt->lbfgs.lms,  f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
 592          copy_tensor_by_name(opt->lbfgs.lmy,  f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
 593      } else {
 594          die("unknown optimizer type\n");
 595      }
 596  }
 597  
 598  void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt) {
 599      gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_FILE_VERSION, 0);
 600      gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT, opt->params.past);
 601      gguf_set_val_u64(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT, (uint64_t) opt->nx);
 602      gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ITERATION_COUNT, opt->iter);
 603      gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized);
 604  
 605      switch (opt->params.type) {
 606          case GGML_OPT_TYPE_ADAM:
 607              {
 608                  gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM);
 609                  gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS,            opt->adam.fx_best);
 610                  gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS,        opt->adam.fx_prev);
 611                  gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT, opt->adam.n_no_improvement);
 612  
 613                  ggml_set_name(opt->adam.m, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
 614                  ggml_set_name(opt->adam.v, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
 615                  if (opt->adam.pf) {
 616                      ggml_set_name(opt->adam.pf, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
 617                  }
 618  
 619                  gguf_add_tensor(fctx, opt->adam.m);
 620                  gguf_add_tensor(fctx, opt->adam.v);
 621                  if (opt->adam.pf) {
 622                      gguf_add_tensor(fctx, opt->adam.pf);
 623                  }
 624              } break;
 625          case GGML_OPT_TYPE_LBFGS:
 626              {
 627                  gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS);
 628                  gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, opt->params.lbfgs.m);
 629                  gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS,            opt->lbfgs.fx_best);
 630                  gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP,     opt->lbfgs.step);
 631                  gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J,        opt->lbfgs.j);
 632                  gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K,        opt->lbfgs.k);
 633                  gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END,      opt->lbfgs.end);
 634                  gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, opt->lbfgs.n_no_improvement);
 635  
 636                  ggml_set_name(opt->lbfgs.x,    LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
 637                  ggml_set_name(opt->lbfgs.xp,   LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
 638                  ggml_set_name(opt->lbfgs.g,    LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
 639                  ggml_set_name(opt->lbfgs.gp,   LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
 640                  ggml_set_name(opt->lbfgs.d,    LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
 641                  if (opt->lbfgs.pf) {
 642                      ggml_set_name(opt->lbfgs.pf, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
 643                  }
 644                  ggml_set_name(opt->lbfgs.lmal, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
 645                  ggml_set_name(opt->lbfgs.lmys, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
 646                  ggml_set_name(opt->lbfgs.lms,  LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
 647                  ggml_set_name(opt->lbfgs.lmy,  LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
 648  
 649                  gguf_add_tensor(fctx, opt->lbfgs.x);
 650                  gguf_add_tensor(fctx, opt->lbfgs.xp);
 651                  gguf_add_tensor(fctx, opt->lbfgs.g);
 652                  gguf_add_tensor(fctx, opt->lbfgs.gp);
 653                  gguf_add_tensor(fctx, opt->lbfgs.d);
 654                  if (opt->lbfgs.pf) {
 655                      gguf_add_tensor(fctx, opt->lbfgs.pf);
 656                  }
 657                  gguf_add_tensor(fctx, opt->lbfgs.lmal);
 658                  gguf_add_tensor(fctx, opt->lbfgs.lmys);
 659                  gguf_add_tensor(fctx, opt->lbfgs.lms);
 660                  gguf_add_tensor(fctx, opt->lbfgs.lmy);
 661              } break;
 662      }
 663  }
 664  
 665  bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train) {
 666      if (gguf_find_key(fctx, LLM_KV_TRAINING_FILE_VERSION) < 0) {
 667          return false;
 668      }
 669  
 670      uint32_t file_version;
 671      GGUF_GET_KEY(fctx, file_version,         gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
 672      GGML_ASSERT(file_version <= 1);
 673  
 674      if (file_version == 0) {
 675  
 676          GGUF_GET_KEY(fctx, train->train_its,     gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
 677          GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT);
 678          GGUF_GET_KEY(fctx, train->train_tokens,  gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT);
 679  
 680      } else if (file_version == 1) {
 681  
 682          GGUF_GET_KEY(fctx, train->train_its,     gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_ITERATION_COUNT);
 683          GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_SAMPLE_COUNT);
 684          GGUF_GET_KEY(fctx, train->train_tokens,  gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_TOKEN_COUNT);
 685          GGUF_GET_KEY(fctx, train->train_epochs,  gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_EPOCH_COUNT);
 686  
 687          GGUF_GET_KEY(fctx, train->shuffle_samples_hash,      gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH);
 688          GGUF_GET_KEY(fctx, train->shuffle_rng_state_current, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_SHUFFLE_RNG_STATE);
 689          GGUF_GET_KEY(fctx, train->shuffle_sample_count,      gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT);
 690          GGUF_GET_KEY(fctx, train->shuffle_next_sample,       gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE);
 691      }
 692  
 693      load_opt_context_gguf(fctx, f_ggml_ctx, train->opt);
 694      return true;
 695  }
 696  
 697  void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train) {
 698      gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION,    1);
 699      gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, train->train_its);
 700      gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT,    train->train_samples);
 701      gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT,     train->train_tokens);
 702      gguf_set_val_u64(fctx, LLM_KV_TRAINING_EPOCH_COUNT,     train->train_epochs);
 703  
 704      gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH, (uint64_t) train->shuffle_samples_hash);
 705      gguf_set_val_str(fctx, LLM_KV_TRAINING_SHUFFLE_RNG_STATE,    train->shuffle_rng_state_current.c_str());
 706      gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT, (uint64_t) train->shuffle_sample_count);
 707      gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE,  (uint64_t) train->shuffle_next_sample);
 708  
 709      save_opt_context_gguf(fctx, train->opt);
 710  }
 711  
 712  
 713  struct llama_file {
 714      // use FILE * so we don't have to re-open the file to mmap
 715      FILE * fp;
 716      size_t size;
 717  
 718      llama_file(const char * fname, const char * mode) {
 719          fp = std::fopen(fname, mode);
 720          if (fp == NULL) {
 721              size = 0;
 722          } else {
 723              seek(0, SEEK_END);
 724              size = tell();
 725              seek(0, SEEK_SET);
 726          }
 727      }
 728  
 729      size_t tell() const {
 730  #ifdef _WIN32
 731          __int64 ret = _ftelli64(fp);
 732  #else
 733          long ret = std::ftell(fp);
 734  #endif
 735          GGML_ASSERT(ret != -1); // this really shouldn't fail
 736          return (size_t) ret;
 737      }
 738  
 739      void seek(size_t offset, int whence) {
 740  #ifdef _WIN32
 741          int ret = _fseeki64(fp, (__int64) offset, whence);
 742  #else
 743          int ret = std::fseek(fp, (long) offset, whence);
 744  #endif
 745          GGML_ASSERT(ret == 0); // same
 746      }
 747  
 748      void read_raw(void * ptr, size_t size) {
 749          if (size == 0) {
 750              return;
 751          }
 752          errno = 0;
 753          std::size_t ret = std::fread(ptr, size, 1, fp);
 754          if (ferror(fp)) {
 755              die_fmt("read error: %s", strerror(errno));
 756          }
 757          if (ret != 1) {
 758              die("unexpectedly reached end of file");
 759          }
 760      }
 761  
 762      std::uint32_t read_u32() {
 763          std::uint32_t ret;
 764          read_raw(&ret, sizeof(ret));
 765          return ret;
 766      }
 767  
 768      std::string read_string(std::uint32_t len) {
 769          std::vector<char> chars(len);
 770          read_raw(chars.data(), len);
 771          return std::string(chars.data(), len);
 772      }
 773  
 774      void write_raw(const void * ptr, size_t size) {
 775          if (size == 0) {
 776              return;
 777          }
 778          errno = 0;
 779          size_t ret = std::fwrite(ptr, size, 1, fp);
 780          if (ret != 1) {
 781              die_fmt("write error: %s", strerror(errno));
 782          }
 783      }
 784  
 785      void write_u32(std::uint32_t val) {
 786          write_raw(&val, sizeof(val));
 787      }
 788  
 789      ~llama_file() {
 790          if (fp) {
 791              std::fclose(fp);
 792          }
 793      }
 794  };
 795  
 796  static size_t utf8_len(char src) {
 797      const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
 798      uint8_t highbits = static_cast<uint8_t>(src) >> 4;
 799      return lookup[highbits];
 800  }
 801  
 802  // mark each byte with its utf8 unit number.
 803  // returns the number of utf8 characters.
 804  // e.g. when bytes == '\x61\xD0\xB0\x62',
 805  // then utf8_units will become [0,0,1,0]
 806  // utf8_nunits will become [1,2,2,1] and 3 is returned.
 807  // bytes where utf8_units is zero, are the begin of an utf8 character.
 808  static size_t mark_utf8_units(const char* bytes, int * utf8_units, int * utf8_nunits, size_t count) {
 809      size_t offs = 0;
 810      size_t count_utf8 = 0;
 811      while(offs < count) {
 812          int len = (int) utf8_len(bytes[offs]);
 813          for (int i=0; i<len; ++i) {
 814              utf8_units[offs+i]  = i;
 815              utf8_nunits[offs+i] = len;
 816          }
 817          offs += len;
 818          ++count_utf8;
 819      }
 820      return count_utf8;
 821  }
 822  
 823  size_t tokenize_file(
 824          struct llama_context     * lctx,
 825          const char               * filename,
 826          const std::string        & sample_start,
 827          bool                       include_sample_start,
 828          bool                       overlapping_samples,
 829          unsigned                   context_length,
 830          std::vector<llama_token> & out_tokens,
 831          std::vector<size_t>      & out_samples_begin,
 832          std::vector<size_t>      & out_samples_size) {
 833      struct llama_file f(filename, "rb");
 834  
 835      if (f.size == 0) {
 836          out_tokens.clear();
 837          out_samples_begin.clear();
 838          out_samples_size.clear();
 839          printf("%s: warning: empty or not existing training data file '%s'\n",
 840              __func__, filename);
 841          return out_tokens.size();
 842      }
 843  
 844      // account for possible leading whitespace that will be added by tokenizer
 845      // e.g. '\t' will be tokenized by llama spm tokenizer to [29871, 12]
 846      const int n_max_tokens_overhead = 1;
 847  
 848      std::vector<char> buf;
 849      buf.resize(f.size);
 850  
 851      f.read_raw(buf.data(), f.size);
 852  
 853      std::vector<int> utf8_units;
 854      std::vector<int> utf8_nunits;
 855      utf8_units.resize(buf.size());
 856      utf8_nunits.resize(buf.size());
 857      mark_utf8_units(buf.data(), utf8_units.data(), utf8_nunits.data(), buf.size());
 858  
 859      if (sample_start.size() == 0) {
 860          // tokenize all data at once
 861          out_tokens.resize(buf.size() + n_max_tokens_overhead);
 862  
 863          int n_tokens = llama_tokenize(
 864              llama_get_model(lctx),
 865              buf.data(),
 866              (int) buf.size(),
 867              out_tokens.data(),
 868              (int) out_tokens.size(),
 869              false, false);
 870          if (n_tokens < 0) {
 871              out_tokens.resize(-n_tokens);
 872              n_tokens = llama_tokenize(
 873                  llama_get_model(lctx),
 874                  buf.data(),
 875                  (int) buf.size(),
 876                  out_tokens.data(),
 877                  (int) out_tokens.size(),
 878                  false, false);
 879          }
 880          if (n_tokens >= 0) {
 881              out_tokens.resize(n_tokens);
 882          }
 883  
 884          // generate sample starts at all token positions
 885          out_samples_begin.clear();
 886          out_samples_begin.push_back(0);
 887          out_samples_size.push_back(std::min((size_t) context_length, out_tokens.size()));
 888          size_t end = (out_tokens.size() >= context_length) ? (out_tokens.size() - context_length) : 0;
 889          for (size_t sample_begin = 1; sample_begin < end; ++sample_begin) {
 890              out_samples_begin.push_back(sample_begin);
 891              out_samples_size.push_back(context_length);
 892          }
 893      } else {
 894          // split data into samples and tokenize each sample
 895          std::string data_str(buf.data(), buf.size());
 896          out_samples_begin.clear();
 897          out_samples_size.clear();
 898          out_tokens.clear();
 899  
 900          // find all positions of pattern sample_start
 901          size_t sample_begin = data_str.find(sample_start, 0);
 902          while (sample_begin != std::string::npos) {
 903              out_samples_begin.push_back(sample_begin);
 904              const size_t search_start = sample_begin + sample_start.size();
 905              sample_begin = data_str.find(sample_start, search_start);
 906          }
 907          if (out_samples_begin.size() == 0) {
 908              printf("%s: warning: sample start pattern '%s' not found. inserting single sample at data begin\n",
 909                  __func__, sample_start.c_str());
 910              out_samples_begin.push_back(0);
 911          }
 912  
 913          out_samples_size.resize(out_samples_begin.size(), 0);
 914  
 915          std::vector<char>        buf_sample;
 916          std::vector<llama_token> tok_sample;
 917  
 918          const size_t sample_begin_offset = (include_sample_start ? 0 : sample_start.size());
 919          size_t found_too_big_sample   = 0;
 920          size_t found_too_small_sample = 0;
 921          size_t found_empty_sample     = 0;
 922          size_t found_min_sample_size  = SIZE_MAX;
 923          size_t found_max_sample_size  = 0;
 924  
 925          size_t max_token_text_size = 0;
 926          int n_vocab = llama_n_vocab(llama_get_model(lctx));
 927          for (llama_token token=0; token < n_vocab; ++token) {
 928              max_token_text_size = std::max(
 929                  max_token_text_size,
 930                  strlen(llama_token_get_text(llama_get_model(lctx), token)));
 931          }
 932  
 933          // upper bound of context byte length.
 934          // strings with this byte length should always tokenize to at least context_length tokens.
 935          size_t context_byte_len = max_token_text_size*context_length;
 936  
 937          for (unsigned i=0; i<out_samples_begin.size(); ++i) {
 938              // determine sample begin and end from pattern positions
 939              size_t sample_begin = out_samples_begin[i] + sample_begin_offset;
 940              size_t sample_end   = overlapping_samples
 941                                      ? std::min(
 942                                          data_str.size(),
 943                                          sample_begin + context_byte_len)
 944                                      : (i+1 < out_samples_begin.size()
 945                                          ? out_samples_begin[i+1]
 946                                          : data_str.size());
 947              if (sample_end < utf8_units.size() && utf8_units[sample_end] > 0) {
 948                  // sample end is in the middle of an utf8 character.
 949                  // advance sample_end to the begin of the next utf8 character.
 950                  sample_end += utf8_nunits[sample_end] - utf8_units[sample_end];
 951              }
 952              size_t sample_size = sample_end - sample_begin;
 953              if (sample_size == 0) {
 954                  ++found_empty_sample;
 955              }
 956  
 957              if (sample_size > 0) {
 958                  // llama_tokenize expects zero terminated string,
 959                  // copy sample into buffer and zero terminate it.
 960                  buf_sample.resize(sample_size);
 961                  memcpy(buf_sample.data(), data_str.data() + sample_begin, sample_size);
 962  
 963                  // printf("sample: '%s'\n", buf_sample.data());
 964  
 965                  // tokenize the sample
 966                  tok_sample.resize(buf_sample.size() + n_max_tokens_overhead);
 967                  int n_tokens = llama_tokenize(llama_get_model(lctx),
 968                      buf_sample.data(),
 969                      (int) buf_sample.size(),
 970                      tok_sample.data(),
 971                      (int) tok_sample.size(),
 972                      false, false);
 973                  if (n_tokens < 0) {
 974                      tok_sample.resize(-n_tokens);
 975                      n_tokens = llama_tokenize(llama_get_model(lctx),
 976                          buf_sample.data(),
 977                          (int) buf_sample.size(),
 978                          tok_sample.data(),
 979                          (int) tok_sample.size(),
 980                          false, false);
 981                      GGML_ASSERT(n_tokens >= 0);
 982                  }
 983                  GGML_ASSERT(n_tokens <= (int) tok_sample.size());
 984  
 985                  if ((size_t) n_tokens > context_length) {
 986                      ++found_too_big_sample;
 987                  } else if ((size_t) n_tokens < context_length) {
 988                      ++found_too_small_sample;
 989                  }
 990                  found_max_sample_size = std::max(found_max_sample_size, (size_t) n_tokens);
 991                  found_min_sample_size = std::min(found_min_sample_size, (size_t) n_tokens);
 992  
 993                  // write out tokens, start and size of sample
 994                  // overwrite the string start position with the token start position
 995                  out_samples_begin[i] = out_tokens.size();
 996                  out_samples_size[i] = (size_t) n_tokens;
 997                  out_tokens.insert(out_tokens.end(), tok_sample.begin(), tok_sample.begin() + n_tokens);
 998              } else {
 999                  out_samples_begin[i] = out_tokens.size();
1000                  out_samples_size[i] = 0;
1001              }
1002  
1003          }
1004          if (found_too_big_sample > 0) {
1005              printf("%s: warning: found %zu samples (max length %zu) that exceed context length of %u. samples will be cut off.\n",
1006                  __func__, found_too_big_sample, found_max_sample_size, context_length);
1007          }
1008  
1009          if (found_too_small_sample > 0) {
1010              printf("%s: warning: found %zu samples (min length %zu) that are shorter than context length of %u.\n",
1011                  __func__, found_too_small_sample, found_min_sample_size, context_length);
1012          }
1013  
1014          if (found_empty_sample) {
1015              printf("%s: warning: found %zu empty samples.\n",
1016                  __func__, found_empty_sample);
1017          }
1018      }
1019      printf("%s: total number of samples: %zu\n",
1020          __func__, out_samples_begin.size());
1021  
1022      GGML_ASSERT(out_samples_begin.size() == out_samples_size.size());
1023  
1024      return out_tokens.size();
1025  }
1026  
1027  std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration) {
1028      std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
1029      return replace_str(filename, pattern_it, sit.c_str());
1030  }
1031  
1032  struct train_params_common get_default_train_params_common() {
1033      struct train_params_common params;
1034      params.fn_train_data     = "shakespeare.txt";
1035      params.fn_checkpoint_in  = "checkpoint.gguf";
1036      params.fn_checkpoint_out = "checkpoint-ITERATION.gguf";
1037      params.pattern_fn_it     = "ITERATION";
1038      params.fn_latest         = "LATEST";
1039  
1040      params.print_usage = false;
1041  
1042      params.save_every = 10;
1043  
1044      params.seed       =   -1;
1045  
1046      params.n_ctx      =  128;
1047      params.n_threads  =    6;
1048      params.n_batch    =    8;
1049      params.n_gradient_accumulation = 1;
1050      params.n_epochs   = -1;
1051      params.n_gpu_layers = 0;
1052  
1053      params.custom_n_ctx = false;
1054  
1055      params.use_flash              = false;
1056      params.use_checkpointing      = true;
1057  
1058      params.sample_start           = "";
1059      params.include_sample_start   = false;
1060      params.escape                 = false;
1061      params.overlapping_samples    = false;
1062      params.fill_with_next_samples = false;
1063      params.separate_with_eos      = false;
1064      params.separate_with_bos      = true;
1065      params.sample_random_offsets  = false;
1066      params.force_reshuffle        = false;
1067  
1068      params.opt_past               = 0;
1069      params.opt_delta              = 1e-5f;
1070      params.opt_max_no_improvement = 0;
1071  
1072      params.warmup            =  100;
1073      params.cos_decay_steps   = 1000;
1074      params.cos_decay_restart = 1.1f;
1075      params.cos_decay_min     = 0.1f;
1076      params.enable_restart    = false;
1077  
1078      params.adam_n_iter         = 256;
1079      params.adam_alpha          = 1e-3f;
1080      params.adam_min_alpha      = 0;
1081      params.adam_decay          = 1e-1f;
1082      params.adam_decay_min_ndim = 2;
1083      params.adam_beta1          = 0.9f;
1084      params.adam_beta2          = 0.999f;
1085      params.adam_gclip          = 1.0f;
1086      params.adam_eps_f          = 0.0f;
1087  
1088      return params;
1089  }
1090  
1091  void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train_params_common * params) {
1092      // fprintf(stderr, "usage: %s [options]\n", argv[0]);
1093      // fprintf(stderr, "\n");
1094      // fprintf(stderr, "options:\n");
1095      // fprintf(stderr, "  -h, --help                 show this help message and exit\n");
1096      fprintf(stderr, "  --train-data FNAME         path from which to load training data (default '%s')\n", params->fn_train_data);
1097      fprintf(stderr, "  --checkpoint-in FNAME      path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
1098      fprintf(stderr, "  --checkpoint-out FNAME     path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
1099      fprintf(stderr, "  --pattern-fn-it STR        pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
1100      fprintf(stderr, "  --fn-latest STR            string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest);
1101      fprintf(stderr, "  --save-every N             save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every);
1102      fprintf(stderr, "  -s SEED, --seed SEED       RNG seed (default: -1, use random seed for -1)\n");
1103      fprintf(stderr, "  -c N, --ctx N              Context size used during training (default %d)\n", params->n_ctx);
1104      fprintf(stderr, "  -t N, --threads N          Number of threads (default %d)\n", params->n_threads);
1105      fprintf(stderr, "  -b N, --batch N            Parallel batch size (default %d)\n", params->n_batch);
1106      fprintf(stderr, "  --grad-acc N               Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation);
1107      fprintf(stderr, "  --sample-start STR         Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str());
1108      fprintf(stderr, "  --include-sample-start     Include the sample start in the samples. (default off)\n");
1109      fprintf(stderr, "  --escape                   process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
1110      fprintf(stderr, "  --overlapping-samples      Samples may overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n");
1111      fprintf(stderr, "  --fill-with-next-samples   Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n");
1112      fprintf(stderr, "  --separate-with-eos        When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : "");
1113      fprintf(stderr, "  --separate-with-bos        When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : "");
1114      fprintf(stderr, "  --no-separate-with-eos     When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
1115      fprintf(stderr, "  --no-separate-with-bos     When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
1116      fprintf(stderr, "  --sample-random-offsets    Use samples beginning at random offsets. Together with fill-with-next-samples this may help for training endless text generation.%s\n", params->sample_random_offsets ? " (default)" : "");
1117      fprintf(stderr, "  --force-reshuffle          Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
1118      fprintf(stderr, "  --no-flash                 Don't use flash attention \n");
1119      fprintf(stderr, "  --use-flash                Use flash attention (default)\n");
1120      fprintf(stderr, "  --no-checkpointing         Don't use gradient checkpointing\n");
1121      fprintf(stderr, "  --use-checkpointing        Use gradient checkpointing (default)\n");
1122      fprintf(stderr, "  --warmup N                 Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
1123      fprintf(stderr, "  --cos-decay-steps N        Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
1124      fprintf(stderr, "  --cos-decay-restart N      Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
1125      fprintf(stderr, "  --cos-decay-min N          Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
1126      fprintf(stderr, "  --enable-restart N         Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
1127      fprintf(stderr, "  --disable-restart N        Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
1128      fprintf(stderr, "  --opt-past N               Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
1129      fprintf(stderr, "  --opt-delta N              Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
1130      fprintf(stderr, "  --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
1131      fprintf(stderr, "  --epochs N                 Maximum number epochs to process. (default %d)\n", params->n_epochs);
1132      fprintf(stderr, "  --adam-iter N              Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
1133      fprintf(stderr, "  --adam-alpha N             Adam learning rate alpha (default %f)\n", params->adam_alpha);
1134      fprintf(stderr, "  --adam-min-alpha N         Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
1135      fprintf(stderr, "  --adam-decay N             AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
1136      fprintf(stderr, "  --adam-decay-min-ndim N    Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
1137      fprintf(stderr, "  --adam-beta1 N             AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
1138      fprintf(stderr, "  --adam-beta2 N             AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
1139      fprintf(stderr, "  --adam-gclip N             AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
1140      fprintf(stderr, "  --adam-epsf N              AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
1141      fprintf(stderr, "  -ngl N, --n-gpu-layers N   Number of model layers to offload to GPU (default %d)", params->n_gpu_layers);
1142      fprintf(stderr, "\n");
1143  }
1144  
1145  bool consume_common_train_arg(
1146      int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param
1147  ) {
1148      int& i = *idx;
1149      std::string arg = argv[i];
1150      const std::string arg_prefix = "--";
1151      if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
1152          std::replace(arg.begin(), arg.end(), '_', '-');
1153      }
1154      if (arg == "--train-data") {
1155          if (++i >= argc) {
1156              *invalid_param = true;
1157              return true;
1158          }
1159          params->fn_train_data = argv[i];
1160      } else if (arg == "--checkpoint-in") {
1161          if (++i >= argc) {
1162              *invalid_param = true;
1163              return true;
1164          }
1165          params->fn_checkpoint_in = argv[i];
1166      } else if (arg == "--checkpoint-out") {
1167          if (++i >= argc) {
1168              *invalid_param = true;
1169              return true;
1170          }
1171          params->fn_checkpoint_out = argv[i];
1172      } else if (arg == "--pattern-fn-it") {
1173          if (++i >= argc) {
1174              *invalid_param = true;
1175              return true;
1176          }
1177          params->pattern_fn_it = argv[i];
1178      } else if (arg == "--fn-latest") {
1179          if (++i >= argc) {
1180              *invalid_param = true;
1181              return true;
1182          }
1183          params->fn_latest = argv[i];
1184      } else if (arg == "--save-every") {
1185          if (++i >= argc) {
1186              *invalid_param = true;
1187              return true;
1188          }
1189          params->save_every = std::stoi(argv[i]);
1190      } else if (arg == "-s" || arg == "--seed") {
1191          if (++i >= argc) {
1192              *invalid_param = true;
1193              return true;
1194          }
1195          params->seed = std::stoi(argv[i]);
1196      } else if (arg == "-c" || arg == "--ctx") {
1197          if (++i >= argc) {
1198              *invalid_param = true;
1199              return true;
1200          }
1201          params->n_ctx = std::stoi(argv[i]);
1202          params->custom_n_ctx = true;
1203      } else if (arg == "-t" || arg == "--threads") {
1204          if (++i >= argc) {
1205              *invalid_param = true;
1206              return true;
1207          }
1208          params->n_threads = std::stoi(argv[i]);
1209      } else if (arg == "-b" || arg == "--batch") {
1210          if (++i >= argc) {
1211              *invalid_param = true;
1212              return true;
1213          }
1214          params->n_batch = std::stoi(argv[i]);
1215      } else if (arg == "--grad-acc") {
1216          if (++i >= argc) {
1217              *invalid_param = true;
1218              return true;
1219          }
1220          params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
1221      } else if (arg == "--sample-start") {
1222          if (++i >= argc) {
1223              *invalid_param = true;
1224              return true;
1225          }
1226          params->sample_start = std::string(argv[i]);
1227      } else if (arg == "--escape") {
1228          params->escape = true;
1229      } else if (arg == "--include-sample-start") {
1230          params->include_sample_start = true;
1231      } else if (arg == "--overlapping-samples") {
1232          params->overlapping_samples = true;
1233      } else if (arg == "--fill-with-next-samples") {
1234          params->fill_with_next_samples = true;
1235      } else if (arg == "--separate-with-eos") {
1236          params->separate_with_eos = true;
1237      } else if (arg == "--separate-with-bos") {
1238          params->separate_with_bos = true;
1239      } else if (arg == "--no-separate-with-eos") {
1240          params->separate_with_eos = false;
1241      } else if (arg == "--no-separate-with-bos") {
1242          params->separate_with_bos = false;
1243      } else if (arg == "--sample-random-offsets") {
1244          params->sample_random_offsets = true;
1245      } else if (arg == "--force-reshuffle") {
1246          params->force_reshuffle = true;
1247      } else if (arg == "--no-flash") {
1248          params->use_flash = false;
1249      } else if (arg == "--use-flash") {
1250          params->use_flash = true;
1251      } else if (arg == "--no-checkpointing") {
1252          params->use_checkpointing = false;
1253      } else if (arg == "--use-checkpointing") {
1254          params->use_checkpointing = true;
1255      } else if (arg == "--warmup") {
1256          if (++i >= argc) {
1257              *invalid_param = true;
1258              return true;
1259          }
1260          params->warmup = std::stoi(argv[i]);
1261      } else if (arg == "--cos-decay-steps") {
1262          if (++i >= argc) {
1263              *invalid_param = true;
1264              return true;
1265          }
1266          params->cos_decay_steps = std::stoi(argv[i]);
1267      } else if (arg == "--cos-decay-restart") {
1268          if (++i >= argc) {
1269              *invalid_param = true;
1270              return true;
1271          }
1272          params->cos_decay_restart = std::stof(argv[i]);
1273      } else if (arg == "--cos-decay-min") {
1274          if (++i >= argc) {
1275              *invalid_param = true;
1276              return true;
1277          }
1278          params->cos_decay_min = std::stof(argv[i]);
1279      } else if (arg == "--enable-restart") {
1280          params->enable_restart = true;
1281      } else if (arg == "--disable-restart") {
1282          params->enable_restart = false;
1283      } else if (arg == "--opt-past") {
1284          if (++i >= argc) {
1285              *invalid_param = true;
1286              return true;
1287          }
1288          params->opt_past = std::stoi(argv[i]);
1289      } else if (arg == "--opt-delta") {
1290          if (++i >= argc) {
1291              *invalid_param = true;
1292              return true;
1293          }
1294          params->opt_delta = std::stof(argv[i]);
1295      } else if (arg == "--opt-max-no-improvement") {
1296          if (++i >= argc) {
1297              *invalid_param = true;
1298              return true;
1299          }
1300          params->opt_max_no_improvement = std::stoi(argv[i]);
1301      } else if (arg == "--adam-epsf") {
1302          if (++i >= argc) {
1303              *invalid_param = true;
1304              return true;
1305          }
1306          params->adam_eps_f = std::stof(argv[i]);
1307      } else if (arg == "--epochs") {
1308          if (++i >= argc) {
1309              *invalid_param = true;
1310              return true;
1311          }
1312          params->n_epochs = std::stoi(argv[i]);
1313      } else if (arg == "--adam-iter") {
1314          if (++i >= argc) {
1315              *invalid_param = true;
1316              return true;
1317          }
1318          params->adam_n_iter = std::stoi(argv[i]);
1319      } else if (arg == "--adam-alpha") {
1320          if (++i >= argc) {
1321              *invalid_param = true;
1322              return true;
1323          }
1324          params->adam_alpha = std::stof(argv[i]);
1325      } else if (arg == "--adam-min-alpha") {
1326          if (++i >= argc) {
1327              *invalid_param = true;
1328              return true;
1329          }
1330          params->adam_min_alpha = std::stof(argv[i]);
1331      } else if (arg == "--adam-decay") {
1332          if (++i >= argc) {
1333              *invalid_param = true;
1334              return true;
1335          }
1336          params->adam_decay = std::stof(argv[i]);
1337      } else if (arg == "--adam-decay-min-ndim") {
1338          if (++i >= argc) {
1339              *invalid_param = true;
1340              return true;
1341          }
1342          params->adam_decay_min_ndim = std::stoi(argv[i]);
1343      } else if (arg == "--adam-beta1") {
1344          if (++i >= argc) {
1345              *invalid_param = true;
1346              return true;
1347          }
1348          params->adam_beta1 = std::stof(argv[i]);
1349      } else if (arg == "--adam-beta2") {
1350          if (++i >= argc) {
1351              *invalid_param = true;
1352              return true;
1353          }
1354          params->adam_beta2 = std::stof(argv[i]);
1355      } else if (arg == "--adam-gclip") {
1356          if (++i >= argc) {
1357              *invalid_param = true;
1358              return true;
1359          }
1360          params->adam_gclip = std::stof(argv[i]);
1361      } else if (arg == "-ngl" || arg == "--n-gpu-layers") {
1362              if (++i >= argc) {
1363                  *invalid_param = true;
1364                  return true;
1365              }
1366              if (llama_supports_gpu_offload()) {
1367                  params->n_gpu_layers = std::stoi(argv[i]);
1368              } else {
1369                  fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
1370                  fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
1371              }
1372      } else if (arg == "-h" || arg == "--help") {
1373          params->print_usage = true;
1374          return true;
1375      } else {
1376          return false;
1377      }
1378      return true;
1379  }
1380  
1381  void finish_processing_train_args(struct train_params_common * params) {
1382      if (params->escape) {
1383          string_process_escapes(params->sample_start);
1384      }
1385  }
1386  
1387  void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel) {
1388      struct train_opt_callback_data * data   = (struct train_opt_callback_data *) vdata;
1389      struct train_params_common     * params = data->params;
1390      struct train_state             * train  = data->train;
1391      struct ggml_opt_context        * opt    = train->opt;
1392      int n_batch = params->n_batch;
1393      int n_ctx = params->n_ctx;
1394  
1395      if (accum_step == 0) {
1396          // time measurement
1397          int64_t now = ggml_time_ms();
1398          if (now > data->last_time && opt->iter > data->first_iter) {
1399              double dt = (double) (now - data->last_time);
1400              if (data->millis_per_iter == 0.0) {
1401                  data->millis_per_iter = dt;
1402              } else {
1403                  const double gain = 0.7;
1404                  data->millis_per_iter = data->millis_per_iter*(1.0-gain) + dt*gain;
1405              }
1406          }
1407  
1408          double remaining_millis = 0.0;
1409          if (data->millis_per_iter > 0.0) {
1410              const int n_iter = params->adam_n_iter;
1411              const int done_iter = opt->iter - data->first_iter;
1412              const int remaining_iter = n_iter - done_iter;
1413              remaining_millis = remaining_iter * data->millis_per_iter;
1414          }
1415  
1416          // file saving
1417          const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
1418          if (save_now) {
1419              int new_iters = opt->iter - data->last_save_iter;
1420              train->train_its    += new_iters;
1421              train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
1422  
1423              if (data->save_cb) {
1424                  data->save_cb(data->save_data, train);
1425              }
1426  
1427              data->last_save_iter = opt->iter;
1428          }
1429  
1430          // exclude file saving from time measurement, by measuring last_time after saving
1431          data->last_time = ggml_time_ms();
1432  
1433          *sched = learning_schedule(
1434              opt->iter,
1435              params->warmup,
1436              params->cos_decay_steps,
1437              params->adam_alpha,
1438              params->adam_min_alpha,
1439              params->cos_decay_min,
1440              params->cos_decay_restart,
1441              params->enable_restart);
1442  
1443          int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
1444          if (impr_plot > 0) impr_plot = 0;
1445          if (std::isnan(opt->loss_before) || std::isnan(opt->loss_after)) impr_plot = 0;
1446          printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f",
1447              __func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count,
1448              *sched, opt->loss_after);
1449  
1450  
1451          if (data->millis_per_iter > 0) {
1452              printf(" dt=");
1453              print_duration(data->millis_per_iter);
1454              printf(" eta=");
1455              print_duration(remaining_millis);
1456          }
1457  
1458          float improvement = opt->loss_before - opt->loss_after;
1459          const float plot_scale = 10.0f;
1460          int bar_len = (int)(1 + improvement*plot_scale + 0.5);
1461          printf(" |");
1462          for (int i=0; i<bar_len; ++i) {
1463              printf("-");
1464          }
1465          printf(">");
1466          printf("\n");
1467      }
1468  
1469      int64_t used_samples = get_example_targets_batch(
1470          data->lctx,
1471          data->tokens_input,
1472          data->target_probs,
1473          train->shuffle_next_sample,
1474          data->shuffled_samples_offs,
1475          data->shuffled_samples_begin,
1476          data->shuffled_samples_size,
1477          data->samples_count,
1478          data->tokens_data,
1479          data->tokens_size,
1480          params->separate_with_eos,
1481          params->separate_with_bos,
1482          params->fill_with_next_samples,
1483          params->sample_random_offsets);
1484  
1485      train->train_samples += used_samples;
1486      train->shuffle_next_sample += used_samples;
1487  
1488      if (train->shuffle_next_sample >= train->shuffle_sample_count) {
1489          ++train->train_epochs;
1490          printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) train->train_epochs);
1491          // note: we may have used some samples from the current shuffling more than once
1492          train->shuffle_rng_state_current = train->shuffle_rng_state_next;
1493          train->shuffle_rng_state_next = shuffle_samples(
1494              train->shuffle_rng_state_current,
1495              data->shuffled_samples_offs,
1496              data->shuffled_samples_begin,
1497              data->shuffled_samples_size,
1498              data->samples_begin,
1499              data->samples_size,
1500              data->samples_count);
1501          train->shuffle_next_sample = 0;
1502      }
1503  
1504      const bool last_epoch_reached = (params->n_epochs > 0 && (int64_t) train->train_epochs - data->first_epoch >= params->n_epochs);
1505      if (last_epoch_reached) {
1506          // allow optimization iteration at last epoch to be completed before canceling
1507          if (data->iter_at_last_epoch < 0) {
1508              data->iter_at_last_epoch = opt->iter;
1509          } else if (opt->iter > data->iter_at_last_epoch) {
1510              *cancel = true;
1511          }
1512      }
1513  }