/ tests / test-tokenizer-1-bpe.cpp
test-tokenizer-1-bpe.cpp
  1  #include "llama.h"
  2  #include "common.h"
  3  #include "unicode.h"
  4  #include "console.h"
  5  
  6  #include <cassert>
  7  #include <codecvt>
  8  #include <cstdio>
  9  #include <cstring>
 10  #include <locale>
 11  #include <string>
 12  #include <thread>
 13  #include <vector>
 14  
 15  int main(int argc, char **argv) {
 16      if (argc < 2 || argc > 3) {
 17          fprintf(stderr, "Usage: %s <vocab-file> [--ignore-merges]\n", argv[0]);
 18          return 1;
 19      }
 20  
 21      const std::string fname = argv[1];
 22      bool ignore_merges = false;
 23      if (argc == 3) {
 24          if (std::strcmp(argv[2], "--ignore-merges") != 0) {
 25              fprintf(stderr, "Usage: %s <vocab-file> [--ignore-merges]\n", argv[0]);
 26              return 1;
 27          }
 28          ignore_merges = true;
 29      }
 30  
 31      fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
 32  
 33      if (ignore_merges) {
 34          fprintf(stderr, "%s : ignoring merges for tokens inside vocab\n", __func__);
 35      }
 36  
 37      llama_model * model;
 38      llama_context * ctx;
 39  
 40      llama_backend_init();
 41  
 42      // load the vocab
 43      {
 44          auto mparams = llama_model_default_params();
 45  
 46          mparams.vocab_only = true;
 47  
 48          model = llama_load_model_from_file(fname.c_str(), mparams);
 49  
 50          if (model == NULL) {
 51              fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
 52              return 1;
 53          }
 54  
 55          auto cparams = llama_context_default_params();
 56  
 57          ctx = llama_new_context_with_model(model, cparams);
 58  
 59          if (ctx == NULL) {
 60              fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
 61              llama_free_model(model);
 62              return 1;
 63          }
 64      }
 65  
 66      GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_BPE);
 67  
 68  #ifdef _WIN32
 69      // We need this for unicode console support
 70      console::init(false, false);
 71      atexit([]() { console::cleanup(); });
 72  #endif
 73  
 74      const int n_vocab = llama_n_vocab(model);
 75  
 76      for (int i = 0; i < n_vocab; ++i) {
 77          std::string str = llama_detokenize_bpe(ctx, std::vector<int>(1, i));
 78          try {
 79              auto cps = unicode_cpts_from_utf8(str);
 80              std::vector<llama_token> tokens = llama_tokenize(ctx, str, false, true);
 81              if (ignore_merges && tokens.size() > 1) {
 82                  fprintf(stderr,
 83                          "%s : error: token %d detokenizes to '%s'(%zu) but "
 84                          "tokenization of this to multiple tokens: [",
 85                          __func__, i, str.c_str(), str.length());
 86                  fprintf(stderr, "%d", tokens[0]);
 87                  for (size_t i = 1; i < tokens.size(); i++) {
 88                      fprintf(stderr, ", %d", tokens[i]);
 89                  }
 90                  fprintf(stderr, "]\n");
 91                  return 2;
 92              }
 93              std::string check = llama_detokenize_bpe(ctx, tokens);
 94              if (check != str) {
 95                  fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",
 96                      __func__, i, str.c_str(), str.length(), check.c_str(), check.length());
 97                  return 2;
 98              }
 99          }
100          catch (const std::invalid_argument &) {
101              //fprintf(stderr, "%s : info: utf8 conversion %d '%s'\n", __func__, i, str.c_str());
102          }
103      }
104  
105      // unicode
106      {
107          const int nthread = std::thread::hardware_concurrency();
108  
109          std::vector<std::thread> threads(nthread);
110  
111          for (int i = 0; i < nthread; ++i) {
112              threads[i] = std::thread([i, nthread, ctx]() {
113                  for (uint32_t cp = i; cp < 0x0010ffff; cp += nthread) {
114                      if (!( // NOLINT
115                                  (cp < 0x03       || cp >  0x05)   && cp != 0x0b && cp != 0x11 &&
116                                  (cp < 0x13       || cp >  0x17)   && cp != 0x19 &&
117                                  (cp < 0x1c       || cp >  0x1e)   &&
118                                  (cp < 0xd800     || cp >  0xdfff) &&
119                                  (cp < 0x00040000 || cp >= 0x000e0000)
120                          )) {
121                          continue;
122                      }
123  
124                      std::string str = unicode_cpt_to_utf8(cp);
125                      std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
126                      std::string check = llama_detokenize_bpe(ctx, tokens);
127                      if (cp != 9601 && str != check) {
128                          fprintf(stderr, "error: codepoint %x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
129                                  cp, check.c_str(), check.length(), str.c_str(), str.length());
130                          std::exit(3);
131                      }
132                  }
133              });
134          }
135  
136          for (auto & t : threads) {
137              t.join();
138          }
139      }
140  
141      llama_free_model(model);
142      llama_free(ctx);
143  
144      llama_backend_free();
145  
146      return 0;
147  }