/ tests / test-tokenizer-1-spm.cpp
test-tokenizer-1-spm.cpp
  1  #include "llama.h"
  2  #include "common.h"
  3  #include "unicode.h"
  4  #include "console.h"
  5  
  6  #include <cassert>
  7  #include <codecvt>
  8  #include <cstdio>
  9  #include <cstring>
 10  #include <locale>
 11  #include <string>
 12  #include <thread>
 13  #include <vector>
 14  
 15  int main(int argc, char ** argv) {
 16      if (argc < 2) {
 17          fprintf(stderr, "Usage: %s <vocab-file>\n", argv[0]);
 18          return 1;
 19      }
 20  
 21      const std::string fname = argv[1];
 22  
 23      fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
 24  
 25      llama_model * model;
 26      llama_context * ctx;
 27  
 28      llama_backend_init();
 29  
 30      // load the vocab
 31      {
 32          auto mparams = llama_model_default_params();
 33  
 34          mparams.vocab_only = true;
 35  
 36          model = llama_load_model_from_file(fname.c_str(), mparams);
 37  
 38          if (model == NULL) {
 39              fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
 40              return 1;
 41          }
 42  
 43          auto cparams = llama_context_default_params();
 44  
 45          ctx = llama_new_context_with_model(model, cparams);
 46  
 47          if (ctx == NULL) {
 48              fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
 49              llama_free_model(model);
 50              return 1;
 51          }
 52      }
 53  
 54      GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
 55  
 56  #ifdef _WIN32
 57      // We need this for unicode console support
 58      console::init(false, false);
 59      atexit([]() { console::cleanup(); });
 60  #endif
 61  
 62      const int n_vocab = llama_n_vocab(model);
 63  
 64      for (int i = 0; i < n_vocab; ++i) {
 65          std::string str = llama_detokenize_spm(ctx, std::vector<int>(1, i));
 66          std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
 67          std::string check = llama_detokenize_spm(ctx, tokens);
 68          if (check != str) {
 69              fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",
 70                  __func__, i, str.c_str(), str.length(), check.c_str(), check.length());
 71              return 2;
 72          }
 73      }
 74  
 75      // unicode
 76      {
 77          const int nthread = std::thread::hardware_concurrency();
 78  
 79          std::vector<std::thread> threads(nthread);
 80  
 81          for (int i = 0; i < nthread; ++i) {
 82              threads[i] = std::thread([i, nthread, ctx]() {
 83                  for (uint32_t cp = i; cp < 0x0010ffff; cp += nthread) {
 84                      if (cp >= 0xd800 && cp <= 0xdfff) {
 85                          continue;
 86                      }
 87  
 88                      std::string str = unicode_cpt_to_utf8(cp);
 89                      std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
 90                      std::string check = llama_detokenize_spm(ctx, tokens);
 91                      if (cp != 9601 && str != check) {
 92                          fprintf(stderr, "error: codepoint %x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
 93                                  cp, check.c_str(), check.length(), str.c_str(), str.length());
 94                          std::exit(3);
 95                      }
 96                  }
 97              });
 98          }
 99  
100          for (auto & t : threads) {
101              t.join();
102          }
103      }
104  
105      llama_free_model(model);
106      llama_free(ctx);
107  
108      llama_backend_free();
109  
110      return 0;
111  }