test-tokenizer-1-spm.cpp
1 #include "llama.h" 2 #include "common.h" 3 #include "unicode.h" 4 #include "console.h" 5 6 #include <cassert> 7 #include <codecvt> 8 #include <cstdio> 9 #include <cstring> 10 #include <locale> 11 #include <string> 12 #include <thread> 13 #include <vector> 14 15 int main(int argc, char ** argv) { 16 if (argc < 2) { 17 fprintf(stderr, "Usage: %s <vocab-file>\n", argv[0]); 18 return 1; 19 } 20 21 const std::string fname = argv[1]; 22 23 fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str()); 24 25 llama_model * model; 26 llama_context * ctx; 27 28 llama_backend_init(); 29 30 // load the vocab 31 { 32 auto mparams = llama_model_default_params(); 33 34 mparams.vocab_only = true; 35 36 model = llama_load_model_from_file(fname.c_str(), mparams); 37 38 if (model == NULL) { 39 fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); 40 return 1; 41 } 42 43 auto cparams = llama_context_default_params(); 44 45 ctx = llama_new_context_with_model(model, cparams); 46 47 if (ctx == NULL) { 48 fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); 49 llama_free_model(model); 50 return 1; 51 } 52 } 53 54 GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); 55 56 #ifdef _WIN32 57 // We need this for unicode console support 58 console::init(false, false); 59 atexit([]() { console::cleanup(); }); 60 #endif 61 62 const int n_vocab = llama_n_vocab(model); 63 64 for (int i = 0; i < n_vocab; ++i) { 65 std::string str = llama_detokenize_spm(ctx, std::vector<int>(1, i)); 66 std::vector<llama_token> tokens = llama_tokenize(ctx, str, false); 67 std::string check = llama_detokenize_spm(ctx, tokens); 68 if (check != str) { 69 fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n", 70 __func__, i, str.c_str(), str.length(), check.c_str(), check.length()); 71 return 2; 72 } 73 } 74 75 // unicode 76 { 77 const int nthread = std::thread::hardware_concurrency(); 78 79 std::vector<std::thread> threads(nthread); 80 81 for (int i = 0; i < nthread; ++i) { 82 threads[i] = std::thread([i, nthread, ctx]() { 83 for (uint32_t cp = i; cp < 0x0010ffff; cp += nthread) { 84 if (cp >= 0xd800 && cp <= 0xdfff) { 85 continue; 86 } 87 88 std::string str = unicode_cpt_to_utf8(cp); 89 std::vector<llama_token> tokens = llama_tokenize(ctx, str, false); 90 std::string check = llama_detokenize_spm(ctx, tokens); 91 if (cp != 9601 && str != check) { 92 fprintf(stderr, "error: codepoint %x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n", 93 cp, check.c_str(), check.length(), str.c_str(), str.length()); 94 std::exit(3); 95 } 96 } 97 }); 98 } 99 100 for (auto & t : threads) { 101 t.join(); 102 } 103 } 104 105 llama_free_model(model); 106 llama_free(ctx); 107 108 llama_backend_free(); 109 110 return 0; 111 }