test-tokenizer-1-bpe.cpp
1 #include "llama.h" 2 #include "common.h" 3 #include "unicode.h" 4 #include "console.h" 5 6 #include <cassert> 7 #include <codecvt> 8 #include <cstdio> 9 #include <cstring> 10 #include <locale> 11 #include <string> 12 #include <thread> 13 #include <vector> 14 15 int main(int argc, char **argv) { 16 if (argc < 2 || argc > 3) { 17 fprintf(stderr, "Usage: %s <vocab-file> [--ignore-merges]\n", argv[0]); 18 return 1; 19 } 20 21 const std::string fname = argv[1]; 22 bool ignore_merges = false; 23 if (argc == 3) { 24 if (std::strcmp(argv[2], "--ignore-merges") != 0) { 25 fprintf(stderr, "Usage: %s <vocab-file> [--ignore-merges]\n", argv[0]); 26 return 1; 27 } 28 ignore_merges = true; 29 } 30 31 fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str()); 32 33 if (ignore_merges) { 34 fprintf(stderr, "%s : ignoring merges for tokens inside vocab\n", __func__); 35 } 36 37 llama_model * model; 38 llama_context * ctx; 39 40 llama_backend_init(); 41 42 // load the vocab 43 { 44 auto mparams = llama_model_default_params(); 45 46 mparams.vocab_only = true; 47 48 model = llama_load_model_from_file(fname.c_str(), mparams); 49 50 if (model == NULL) { 51 fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); 52 return 1; 53 } 54 55 auto cparams = llama_context_default_params(); 56 57 ctx = llama_new_context_with_model(model, cparams); 58 59 if (ctx == NULL) { 60 fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); 61 llama_free_model(model); 62 return 1; 63 } 64 } 65 66 GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_BPE); 67 68 #ifdef _WIN32 69 // We need this for unicode console support 70 console::init(false, false); 71 atexit([]() { console::cleanup(); }); 72 #endif 73 74 const int n_vocab = llama_n_vocab(model); 75 76 for (int i = 0; i < n_vocab; ++i) { 77 std::string str = llama_detokenize_bpe(ctx, std::vector<int>(1, i)); 78 try { 79 auto cps = unicode_cpts_from_utf8(str); 80 std::vector<llama_token> tokens = llama_tokenize(ctx, str, false, true); 81 if (ignore_merges && tokens.size() > 1) { 82 fprintf(stderr, 83 "%s : error: token %d detokenizes to '%s'(%zu) but " 84 "tokenization of this to multiple tokens: [", 85 __func__, i, str.c_str(), str.length()); 86 fprintf(stderr, "%d", tokens[0]); 87 for (size_t i = 1; i < tokens.size(); i++) { 88 fprintf(stderr, ", %d", tokens[i]); 89 } 90 fprintf(stderr, "]\n"); 91 return 2; 92 } 93 std::string check = llama_detokenize_bpe(ctx, tokens); 94 if (check != str) { 95 fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n", 96 __func__, i, str.c_str(), str.length(), check.c_str(), check.length()); 97 return 2; 98 } 99 } 100 catch (const std::invalid_argument &) { 101 //fprintf(stderr, "%s : info: utf8 conversion %d '%s'\n", __func__, i, str.c_str()); 102 } 103 } 104 105 // unicode 106 { 107 const int nthread = std::thread::hardware_concurrency(); 108 109 std::vector<std::thread> threads(nthread); 110 111 for (int i = 0; i < nthread; ++i) { 112 threads[i] = std::thread([i, nthread, ctx]() { 113 for (uint32_t cp = i; cp < 0x0010ffff; cp += nthread) { 114 if (!( // NOLINT 115 (cp < 0x03 || cp > 0x05) && cp != 0x0b && cp != 0x11 && 116 (cp < 0x13 || cp > 0x17) && cp != 0x19 && 117 (cp < 0x1c || cp > 0x1e) && 118 (cp < 0xd800 || cp > 0xdfff) && 119 (cp < 0x00040000 || cp >= 0x000e0000) 120 )) { 121 continue; 122 } 123 124 std::string str = unicode_cpt_to_utf8(cp); 125 std::vector<llama_token> tokens = llama_tokenize(ctx, str, false); 126 std::string check = llama_detokenize_bpe(ctx, tokens); 127 if (cp != 9601 && str != check) { 128 fprintf(stderr, "error: codepoint %x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n", 129 cp, check.c_str(), check.length(), str.c_str(), str.length()); 130 std::exit(3); 131 } 132 } 133 }); 134 } 135 136 for (auto & t : threads) { 137 t.join(); 138 } 139 } 140 141 llama_free_model(model); 142 llama_free(ctx); 143 144 llama_backend_free(); 145 146 return 0; 147 }