/ examples / speculative / speculative.cpp
speculative.cpp
  1  #include "common.h"
  2  #include "llama.h"
  3  
  4  #include <cmath>
  5  #include <cstdio>
  6  #include <string>
  7  #include <vector>
  8  #include <set>
  9  
 10  #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE  100
 11  #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
 12  
 13  struct seq_draft {
 14      bool active   = false;
 15      bool drafting = false;
 16      bool skip     = false;
 17  
 18      int i_batch_dft = 0;
 19      std::vector<int> i_batch_tgt;
 20  
 21      std::vector<llama_token> tokens;
 22      std::vector<std::vector<llama_token_data>> dists;
 23  
 24      struct llama_sampling_context * ctx_sampling;
 25  };
 26  
 27  int main(int argc, char ** argv) {
 28      gpt_params params;
 29  
 30      if (!gpt_params_parse(argc, argv, params)) {
 31          gpt_params_print_usage(argc, argv, params);
 32          return 1;
 33      }
 34  
 35      if (params.model_draft.empty()) {
 36          fprintf(stderr, "%s: error: --model-draft is required\n", __func__);
 37          return 1;
 38      }
 39  
 40      // max number of parallel drafting sequences (i.e. tree branches)
 41      const int n_seq_dft = params.n_parallel;
 42  
 43      // probability threshold for splitting a draft branch (only for n_seq_dft > 1)
 44      const float p_split  = params.p_split;
 45  
 46      if (params.seed == LLAMA_DEFAULT_SEED) {
 47          params.seed = time(NULL);
 48      }
 49      std::default_random_engine rng(params.seed);
 50      std::uniform_real_distribution<> u_dist;
 51  
 52  #ifndef LOG_DISABLE_LOGS
 53      log_set_target(log_filename_generator("speculative", "log"));
 54      LOG_TEE("Log start\n");
 55      log_dump_cmdline(argc, argv);
 56  #endif // LOG_DISABLE_LOGS
 57  
 58      // init llama.cpp
 59      llama_backend_init();
 60      llama_numa_init(params.numa);
 61  
 62      llama_model * model_tgt = NULL;
 63      llama_model * model_dft = NULL;
 64  
 65      llama_context * ctx_tgt = NULL;
 66      llama_context * ctx_dft = NULL;
 67  
 68      // load the target model
 69      std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
 70  
 71      // load the draft model
 72      params.model = params.model_draft;
 73      params.n_gpu_layers = params.n_gpu_layers_draft;
 74      if (params.n_threads_draft > 0) {
 75          params.n_threads = params.n_threads_draft;
 76      }
 77      params.n_threads_batch = params.n_threads_batch_draft;
 78      std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
 79  
 80      const bool vocab_type_tgt = llama_vocab_type(model_tgt);
 81      LOG("vocab_type tgt: %d\n", vocab_type_tgt);
 82  
 83      const bool vocab_type_dft = llama_vocab_type(model_dft);
 84      LOG("vocab_type dft: %d\n", vocab_type_dft);
 85  
 86      if (vocab_type_tgt != vocab_type_dft) {
 87          fprintf(stderr, "%s: error: draft model vocab type must match target model to use speculation but ", __func__);
 88          fprintf(stderr, "vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
 89          return 1;
 90      }
 91  
 92      if (
 93          llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
 94          llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
 95          llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
 96          llama_token_eos(model_tgt) != llama_token_eos(model_dft)
 97      ) {
 98          fprintf(stderr, "%s: error: draft model special tokens must match target model to use speculation\n", __func__);
 99          return 1;
100      }
101  
102      {
103          const int n_vocab_tgt = llama_n_vocab(model_tgt);
104          const int n_vocab_dft = llama_n_vocab(model_dft);
105          const int vocab_diff  = n_vocab_tgt > n_vocab_dft
106              ? n_vocab_tgt - n_vocab_dft
107              : n_vocab_dft - n_vocab_tgt;
108  
109          if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
110              fprintf(stderr, "%s: error: draft model vocab must closely match target model to use speculation but ", __func__);
111              fprintf(stderr, "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
112                      n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
113              return 1;
114          }
115  
116          for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
117              const char * token_text_tgt = llama_token_get_text(model_tgt, i);
118              const char * token_text_dft = llama_token_get_text(model_dft, i);
119              if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
120                  fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__);
121                  fprintf(stderr, "token %d content differs - target '%s', draft '%s'\n", i,
122                          llama_token_to_piece(ctx_tgt, i).c_str(),
123                          llama_token_to_piece(ctx_dft, i).c_str());
124                  return 1;
125              }
126          }
127      }
128  
129  
130      // Tokenize the prompt
131      std::vector<llama_token> inp;
132      inp = ::llama_tokenize(ctx_tgt, params.prompt, true, true);
133  
134      const int max_context_size     = llama_n_ctx(ctx_tgt);
135      const int max_tokens_list_size = max_context_size - 4;
136  
137      if ((int) inp.size() > max_tokens_list_size) {
138          fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
139          return 1;
140      }
141  
142      fprintf(stderr, "\n\n");
143  
144      for (auto id : inp) {
145          fprintf(stderr, "%s", llama_token_to_piece(ctx_tgt, id).c_str());
146      }
147  
148      fflush(stderr);
149  
150      const int n_input = inp.size();
151  
152      const auto t_enc_start = ggml_time_us();
153  
154      // eval the prompt with both models
155      llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0,           0));
156      llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(),           1, n_input - 1, 0));
157      llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input,     0,           0));
158  
159      const auto t_enc_end = ggml_time_us();
160  
161      // the 2 models should have the same vocab
162      //GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
163  
164      // how many tokens to draft each time
165      int n_draft = params.n_draft;
166  
167      int n_predict = 0;
168      int n_drafted = 0;
169      int n_accept  = 0;
170  
171      int n_past_tgt = inp.size();
172      int n_past_dft = inp.size();
173  
174      // used to determine end of generation
175      bool has_eos = false;
176  
177      // target model sampling context
178      struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
179  
180      // draft sequence data
181      std::vector<seq_draft> drafts(n_seq_dft);
182  
183      params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
184      if (params.sparams.temp == 0) {
185          params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
186      }
187  
188      for (int s = 0; s < n_seq_dft; ++s) {
189          drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
190      }
191  
192      llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
193      llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
194  
195      const auto t_dec_start = ggml_time_us();
196  
197      // sample from the last token of the prompt
198      drafts[0].i_batch_tgt.resize(1);
199      drafts[0].i_batch_tgt[0] = 0;
200  
201      while (true) {
202          std::set<int> active_seqs = {};
203  
204          // print current draft sequences
205          for (int s = 0; s < n_seq_dft; ++s) {
206              if (!drafts[s].active) {
207                  continue;
208              }
209  
210              active_seqs.insert(s);
211              const auto & tokens = drafts[s].tokens;
212  
213              LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
214          }
215  
216          int i_dft  = 0;
217          int s_keep = 0;
218  
219          llama_token token_id;
220          std::string token_str;
221  
222          // loop until we fail to accept a drafted token or we run out of drafted tokens
223          while (true) {
224  
225              // check if the target token matches any of the drafts
226              // for stochastic sampling, attempt to match the token with the drafted tokens
227              {
228                  bool accept = false;
229                  if (params.sparams.temp > 0) {
230                      // stochastic verification
231  
232                      llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
233                      llama_sample_softmax(ctx_tgt, &dist_tgt);
234                      float p_tgt = 0, p_dft = 0;
235  
236                      // GGML_ASSERT(dist_tgt.size() == dist_dft.size());
237  
238                      while (active_seqs.size() > 0) {
239                          // randomly select a sequence to verify from active sequences
240                          std::uniform_int_distribution<unsigned int> u_int_dist(0, active_seqs.size() - 1);
241                          int s = *std::next(active_seqs.begin(), u_int_dist(rng));
242                          if (i_dft >= (int) drafts[s].tokens.size()) {
243                              drafts[s].active = false;
244                              active_seqs.erase(s);
245                              continue;
246                          }
247                          if (accept) {
248                              // if we already accepted a token, we can skip the rest
249                              if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) {
250                                  drafts[s].active = false;
251                                  active_seqs.erase(s);
252                              }
253                              continue;
254                          }
255                          LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
256                          float r = u_dist(rng);
257                          llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true };
258                          // acquire the token probabilities assigned by the draft and target models
259                          for (size_t i = 0; i < dist_tgt.size; i++) {
260                              if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
261                                  p_tgt = dist_tgt.data[i].p;
262                              }
263                              if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) {
264                                  p_dft = dist_dft.data[i].p;
265                              }
266                              if (p_tgt && p_dft) {
267                                  break;
268                              }
269                          }
270                          LOG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt);
271                          if (r <= p_tgt / p_dft) {
272                              s_keep = s;
273                              accept = true;
274                              token_id = drafts[s].tokens[i_dft];
275                              token_str = llama_token_to_piece(ctx_tgt, token_id);
276                              llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
277  
278                              LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
279                              break;
280                          } else {
281                              LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
282                              drafts[s].active = false;
283  
284                              // calculate residual probability
285                              GGML_ASSERT(dist_tgt.sorted);
286                              GGML_ASSERT(dist_dft.sorted);
287                              float sum_probs = 0.0f;
288  
289                              // sort dist by id
290                              std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
291                                  return a.id < b.id;
292                              });
293                              std::sort(dist_dft.data, dist_dft.data + dist_dft.size, [](const llama_token_data &a, const llama_token_data &b) {
294                                  return a.id < b.id;
295                              });
296  
297                              for (size_t i = 0; i < dist_tgt.size; i++) {
298                                  dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
299                                  sum_probs += dist_tgt.data[i].p;
300                              }
301                              for (size_t i = 0; i < dist_tgt.size; i++) {
302                                  dist_tgt.data[i].p /= sum_probs;
303                              }
304  
305                              // sort dist_tgt by p desc
306                              std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
307                                  return a.p > b.p;
308                              });
309                          }
310  
311                          active_seqs.erase(s);
312                          for(int i = 0; i < n_seq_dft; i++) {
313                              if (i == s) {
314                                  continue;
315                              }
316                              if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
317                                  // synchronize active status for sequences with the same drafted token
318                                  drafts[i].active = drafts[i].active && accept;
319                                  if (!drafts[i].active) {
320                                      active_seqs.erase(s);
321                                  }
322                              }
323                          }
324                      }
325  
326                      if (!accept) {
327                          // all drafted tokens were rejected
328                          // sample from the target model
329                          LOG("all drafted tokens were rejected, sampling from residual distribution\n");
330                          token_id = llama_sample_token(ctx_tgt, &dist_tgt);
331                          llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
332                          token_str = llama_token_to_piece(ctx_tgt, token_id);
333                      }
334  
335                  } else {
336                      // greedy verification
337  
338                      // sample from the target model
339                      LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
340                      token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
341  
342                      llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
343  
344                      //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
345  
346                      token_str = llama_token_to_piece(ctx_tgt, token_id);
347  
348                      for (int s = 0; s < n_seq_dft; ++s) {
349                          if (!drafts[s].active) {
350                              continue;
351                          }
352  
353                          if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) {
354                              LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str());
355  
356                              s_keep = s;
357                              accept = true;
358                          } else {
359                              drafts[s].active = false;
360                          }
361                      }
362                  }
363  
364                  if (llama_token_is_eog(model_tgt, token_id)) {
365                      has_eos = true;
366                  }
367                  ++n_predict;
368  
369                  if (accept) {
370                      ++n_accept;
371                      ++n_past_tgt;
372                      ++n_past_dft;
373                      ++i_dft;
374                      if (params.use_color) {
375                          // Color token according to its origin sequence
376                          printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str());
377                      } else {
378                          printf("%s", token_str.c_str());
379                      }
380                      fflush(stdout);
381                      continue;
382                  } else {
383                      printf("%s", token_str.c_str());
384                      fflush(stdout);
385                      break;
386                  }
387              }
388          }
389  
390          {
391              LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
392  
393              // TODO: simplify
394              {
395                  LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
396  
397                  llama_kv_cache_seq_keep(ctx_dft, s_keep);
398                  llama_kv_cache_seq_cp  (ctx_dft, s_keep, 0, -1, -1);
399                  llama_kv_cache_seq_keep(ctx_dft, 0);
400  
401                  llama_kv_cache_seq_rm  (ctx_tgt, s_keep, n_past_tgt, -1);
402                  llama_kv_cache_seq_keep(ctx_tgt, s_keep);
403                  llama_kv_cache_seq_cp  (ctx_tgt, s_keep, 0, -1, -1);
404                  llama_kv_cache_seq_keep(ctx_tgt, 0);
405              }
406  
407              for (int s = 0; s < n_seq_dft; ++s) {
408                  drafts[s].active = false;
409                  drafts[s].tokens.clear();
410                  drafts[s].i_batch_tgt.clear();
411                  drafts[s].dists.clear();
412              }
413              // note: will be erased after the speculation phase
414              drafts[0].tokens.push_back(token_id);
415              drafts[0].dists.push_back(std::vector<llama_token_data>());
416              drafts[0].i_batch_tgt.push_back(0);
417  
418              llama_batch_clear(batch_dft);
419              llama_batch_add  (batch_dft, token_id, n_past_dft, { 0 }, true);
420  
421              llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
422              // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
423              llama_decode(ctx_dft, batch_dft);
424  
425              ++n_past_dft;
426          }
427  
428          if (n_predict > params.n_predict || has_eos) {
429              break;
430          }
431  
432          llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling);
433  
434          int n_seq_cur  = 1;
435          int n_past_cur = n_past_dft;
436  
437          for (int s = 0; s < n_seq_dft; ++s) {
438              drafts[s].active   = false;
439              drafts[s].drafting = false;
440          }
441          drafts[0].active      = true;
442          drafts[0].drafting    = true;
443          drafts[0].i_batch_dft = 0;
444  
445          llama_batch_clear(batch_tgt);
446          llama_batch_add  (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
447  
448          // sample n_draft tokens from the draft model using tree-based sampling
449          for (int i = 0; i < n_draft; ++i) {
450              batch_dft.n_tokens = 0;
451  
452              for (int s = 0; s < n_seq_dft; ++s) {
453                  drafts[s].skip = false;
454              }
455  
456              for (int s = 0; s < n_seq_dft; ++s) {
457                  if (!drafts[s].drafting || drafts[s].skip) {
458                      continue;
459                  }
460  
461                  llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
462  
463                  const auto & cur_p = drafts[s].ctx_sampling->cur;
464  
465                  for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
466                      LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
467                              k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
468                  }
469  
470                  std::vector<int> sa(1, s);
471  
472                  // attempt to split the branch if the probability is high enough
473                  for (int f = 1; f < 8; ++f) {
474                      if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
475                          LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
476  
477                          llama_kv_cache_seq_rm(ctx_dft,    n_seq_cur, -1, -1);
478                          llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
479  
480                          // all previous tokens from this branch are now also part of the new branch
481                          for (int t = 0; t < batch_tgt.n_tokens; ++t) {
482                              for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
483                                  if (batch_tgt.seq_id[t][p] == s) {
484                                      batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
485                                      batch_tgt.n_seq_id[t]++;
486                                      break;
487                                  }
488                              }
489                          }
490  
491                          // copy the draft state
492                          drafts[n_seq_cur].active   = true;
493                          drafts[n_seq_cur].drafting = true;
494                          drafts[n_seq_cur].skip     = true;
495  
496                          drafts[n_seq_cur].tokens      = drafts[s].tokens;
497                          drafts[n_seq_cur].dists       = drafts[s].dists;
498                          drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
499                          drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
500  
501                          llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling);
502  
503                          sa.push_back(n_seq_cur);
504  
505                          n_seq_cur++;
506                      } else {
507                          break;
508                      }
509                  }
510  
511                  // add drafted token for each sequence
512                  for (int is = 0; is < (int) sa.size(); ++is) {
513                      const llama_token id = cur_p[is].id;
514  
515                      const int s = sa[is];
516  
517                      llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
518  
519                      drafts[s].tokens.push_back(id);
520                      // save cur_p.data into drafts[s].dists
521                      drafts[s].dists.push_back(cur_p);
522  
523                      // add unique drafted tokens to the target batch
524                      drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
525  
526                      llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
527  
528                      // add the token to the batch for batched decoding with the draft model
529                      drafts[s].i_batch_dft = batch_dft.n_tokens;
530  
531                      llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
532  
533                      if (batch_tgt.n_tokens > n_draft) {
534                          drafts[s].drafting = false;
535                      }
536                  }
537              }
538  
539              // no sequence is drafting anymore
540              if (batch_dft.n_tokens == 0) {
541                  break;
542              }
543  
544              // evaluate the drafted tokens on the draft model
545              llama_decode(ctx_dft, batch_dft);
546              ++n_past_cur;
547              ++n_drafted;
548  
549              if (batch_tgt.n_tokens > n_draft) {
550                  break;
551              }
552          }
553  
554          // evaluate the target model on the drafted tokens
555          {
556              llama_kv_cache_seq_keep(ctx_tgt, 0);
557              for (int s = 1; s < n_seq_dft; ++s) {
558                  llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
559              }
560  
561              // LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
562              llama_decode(ctx_tgt, batch_tgt);
563              ++n_past_tgt;
564          }
565  
566          // the first token is always proposed by the target model before the speculation loop so we erase it here
567          for (int s = 0; s < n_seq_dft; ++s) {
568              if (!drafts[s].active) {
569                  continue;
570              }
571  
572              drafts[s].tokens.erase(drafts[s].tokens.begin());
573              drafts[s].dists.erase(drafts[s].dists.begin());
574          }
575      }
576  
577      auto t_dec_end = ggml_time_us();
578  
579      LOG_TEE("\n\n");
580  
581      LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input,   (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
582      LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict  / ((t_dec_end - t_dec_start) / 1e6f));
583  
584      LOG_TEE("\n");
585      LOG_TEE("n_draft   = %d\n", n_draft);
586      LOG_TEE("n_predict = %d\n", n_predict);
587      LOG_TEE("n_drafted = %d\n", n_drafted);
588      LOG_TEE("n_accept  = %d\n", n_accept);
589      LOG_TEE("accept    = %.3f%%\n", 100.0f * n_accept / n_drafted);
590  
591      LOG_TEE("\ndraft:\n");
592      llama_print_timings(ctx_dft);
593  
594      LOG_TEE("\ntarget:\n");
595      llama_print_timings(ctx_tgt);
596  
597      llama_sampling_free(ctx_sampling);
598      for (int s = 0; s < n_seq_dft; ++s) {
599          llama_sampling_free(drafts[s].ctx_sampling);
600      }
601  
602      llama_batch_free(batch_dft);
603  
604      llama_free(ctx_tgt);
605      llama_free_model(model_tgt);
606  
607      llama_free(ctx_dft);
608      llama_free_model(model_dft);
609  
610      llama_backend_free();
611  
612      fprintf(stderr, "\n\n");
613  
614      return 0;
615  }