/ tests / test-tokenizer-0.cpp
test-tokenizer-0.cpp
  1  #include "llama.h"
  2  #include "common.h"
  3  #include "console.h"
  4  
  5  #include <cstdio>
  6  #include <string>
  7  #include <map>
  8  #include <vector>
  9  #include <fstream>
 10  
 11  //static const std::map<std::string, std::vector<llama_token>> & k_tests() {
 12  //    static std::map<std::string, std::vector<llama_token>> _k_tests = {
 13  //        { ""                      , {  }, },
 14  //        { " "                     , {     220, }, },
 15  //        { "  "                    , {     256, }, },
 16  //        { "   "                   , {     262, }, },
 17  //        { "\t"                    , {     197, }, },
 18  //        { "\n"                    , {     198, }, },
 19  //        { "\n\n"                  , {     271, }, },
 20  //        { "\n\n\n"                , {    1432, }, },
 21  //        { "\t\n"                  , {    1602, }, },
 22  //        { "Hello world"           , {    9906,   1917, }, },
 23  //        { " Hello world"          , {   22691,   1917, }, },
 24  //        { "Hello World"           , {    9906,   4435, }, },
 25  //        { " Hello World"          , {   22691,   4435, }, },
 26  //        { " Hello World!"         , {   22691,   4435,      0, }, },
 27  //        { "Hello, world!"         , {    9906,     11,   1917,      0, }, },
 28  //        { " Hello, world!"        , {   22691,     11,   1917,      0, }, },
 29  //        { " this is πŸ¦™.cpp"        , {     420,    374,  11410,     99,    247,     13,  11055, }, },
 30  //        { "w048 7tuijk dsdfhu"    , {      86,  23904,    220,     22,     83,   2005,  42908,  11729,   3013,  17156, }, },
 31  //        { "Π½Π΅Ρ‰ΠΎ Π½Π° Π‘ΡŠΠ»Π³Π°Ρ€ΡΠΊΠΈ"     , {   79862, 102118,  13373,  64571,  34694,   3114, 112203,  80112, }, },
 32  //        { "αž€αžΆαž“αŸ‹αžαŸ‚αž–αž·αžŸαŸαžŸαž’αžΆαž…αžαž›αž…αŸαž‰"   , {   21549,    222,  98629,    241,  45358,    233,  21549,    237,  45358,    224,  21549,    244,  21549,    115,  21549,    253,  45358,    223,  21549,    253,  21549,     95,  98629,    227,  21549,    223,  21549,    249,  21549,    227,  45358,    223,  21549,    231, }, },
 33  //        { "πŸš€ (normal) πŸ˜Άβ€πŸŒ«οΈ (multiple emojis concatenated) βœ… (only emoji that has its own token)", {    9468,    248,    222,    320,   8416,      8,  27623,    114, 102470,   9468,    234,    104,  31643,    320,  36773, 100166,  98634,      8,  26602,    227,    320,   3323,  43465,    430,    706,   1202,   1866,   4037,      8, }, },
 34  //        { "Hello"                 , {    9906, }, },
 35  //        { " Hello"                , {   22691, }, },
 36  //        { "  Hello"               , {     220,  22691, }, },
 37  //        { "   Hello"              , {     256,  22691, }, },
 38  //        { "    Hello"             , {     262,  22691, }, },
 39  //        { "    Hello\n    Hello"  , {     262,  22691,    198,    262,  22691, }, },
 40  //        { " ("                    , {     320, }, },
 41  //        { "\n ="                  , {     198,    284, }, },
 42  //        { "' era"                 , {       6,  11639, }, },
 43  //        { "Hello, y'all! How are you 😁 ?ζˆ‘ζƒ³εœ¨appleε·₯作1314151倩~", {    9906,     11,    379,  65948,      0,   2650,    527,    499,  27623,    223,    949,  37046, 101067,  19000,  23182, 102301,   9263,  18136,     16,  36827,  21909, }, },
 44  //        { "3"                     , {      18, }, },
 45  //        { "33"                    , {    1644, }, },
 46  //        { "333"                   , {    8765, }, },
 47  //        { "3333"                  , {    8765,     18, }, },
 48  //        { "33333"                 , {    8765,   1644, }, },
 49  //        { "333333"                , {    8765,   8765, }, },
 50  //        { "3333333"               , {    8765,   8765,     18, }, },
 51  //        { "33333333"              , {    8765,   8765,   1644, }, },
 52  //        { "333333333"             , {    8765,   8765,   8765, }, },
 53  //    };
 54  //
 55  //    return _k_tests;
 56  //}
 57  
 58  using llama_tests = std::map<std::string, std::vector<llama_token>>;
 59  
 60  static llama_tests read_tests(const std::string & fname_inp, const std::string & fname_out) {
 61      llama_tests tests;
 62  
 63      std::ifstream ifs_inp(fname_inp);
 64      if (!ifs_inp) {
 65          fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_inp.c_str());
 66          return tests;
 67      }
 68  
 69      std::string sraw((std::istreambuf_iterator<char>(ifs_inp)), std::istreambuf_iterator<char>());
 70  
 71      std::ifstream ifs_out(fname_out);
 72      if (!ifs_out) {
 73          fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_out.c_str());
 74          return tests;
 75      }
 76  
 77      std::vector<std::string> sout;
 78      for (std::string line; std::getline(ifs_out, line);) {
 79          sout.push_back(line);
 80      }
 81  
 82      const std::string sep = "\n__ggml_vocab_test__\n";
 83  
 84      std::vector<std::string> sinp;
 85  
 86      size_t pos = 0;
 87      while (pos < sraw.size()) {
 88          const size_t next = sraw.find(sep, pos);
 89          if (next == std::string::npos) {
 90              sinp.push_back(sraw.substr(pos));
 91              break;
 92          }
 93          sinp.push_back(sraw.substr(pos, next - pos));
 94          pos = next + sep.size();
 95      }
 96  
 97      if (sinp.size() != sout.size()) {
 98          fprintf(stderr, "%s : error: input and output files have different number of tests\n", __func__);
 99          return tests;
100      }
101  
102      for (size_t i = 0; i < sinp.size(); ++i) {
103          const std::string & s = sinp[i];
104          const std::string & o = string_strip(sout[i]);
105  
106          std::vector<llama_token> toks;
107  
108          size_t pos = 0;
109          while (pos < o.size()) {
110              size_t next = o.find(' ', pos);
111              if (next == std::string::npos) {
112                  next = o.size();
113              }
114              const std::string stok = o.substr(pos, next - pos);
115              toks.push_back(std::stoi(stok));
116              pos = next + 1;
117          }
118  
119          tests[s] = toks;
120      }
121  
122      return tests;
123  }
124  
125  int main(int argc, char **argv) {
126      if (argc < 2) {
127          fprintf(stderr, "Usage: %s vocab-file [text-file]\n", argv[0]);
128          return 1;
129      }
130  
131      const std::string fname = argv[1];
132  
133      const std::string fname_inp = fname + ".inp";
134      const std::string fname_out = fname + ".out";
135  
136      std::string fname_text;
137      if (argc > 2) {
138          fname_text = argv[2];
139      }
140  
141      fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
142  
143      llama_model * model;
144      llama_context * ctx;
145  
146      llama_backend_init();
147  
148      // load the vocab
149      {
150          auto mparams = llama_model_default_params();
151  
152          mparams.vocab_only = true;
153  
154          model = llama_load_model_from_file(fname.c_str(), mparams);
155  
156          if (model == NULL) {
157              fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
158              return 1;
159          }
160  
161          auto cparams = llama_context_default_params();
162  
163          ctx = llama_new_context_with_model(model, cparams);
164  
165          if (ctx == NULL) {
166              fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
167              llama_free_model(model);
168              return 1;
169          }
170      }
171  
172  #ifdef _WIN32
173      // We need this for unicode console support
174      console::init(false, false);
175      atexit([]() { console::cleanup(); });
176  #endif
177  
178      bool success = true;
179  
180      const auto k_tests = [&]() -> llama_tests {
181          if (!fname_text.empty()) {
182              return {};
183          }
184  
185          const auto res = read_tests(fname_inp, fname_out);
186  
187          if (res.empty()) {
188              fprintf(stderr, "%s : error: no tests found\n", __func__);
189              exit(1);
190          }
191  
192          return res;
193      }();
194  
195      const bool add_special = false;
196  
197      for (const auto & test_kv : k_tests) {
198          const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special);
199  
200          printf("\n");
201          printf("src: '%s'\n", test_kv.first.c_str());
202          printf("res: '%s'\n", llama_detokenize_bpe(ctx, res).c_str());
203          printf("tok: ");
204          for (const auto & tok : res) {
205              printf("%d ", tok);
206          }
207          printf("\n");
208  
209          bool correct = res.size() == test_kv.second.size();
210          for (int i = 0; i < (int) res.size() && correct; ++i) {
211              if (test_kv.second[i] != res[i]) {
212                  correct = false;
213              }
214          }
215  
216          if (!correct) {
217              fprintf(stderr, "%s : failed test:    '%s'\n", __func__, test_kv.first.c_str());
218              fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
219                  llama_detokenize_bpe(ctx, res).c_str(),
220                  llama_detokenize_bpe(ctx, test_kv.second).c_str());
221              fprintf(stderr, "%s : expected tokens: ", __func__);
222              for (const auto & t : test_kv.second) {
223                  fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
224              }
225              fprintf(stderr, "\n");
226              fprintf(stderr, "%s : got tokens:      ", __func__);
227              for (const auto & t : res) {
228                  fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
229              }
230              fprintf(stderr, "\n");
231  
232              success = false;
233          }
234      }
235  
236      if (!fname_text.empty()) {
237          fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str());
238  
239          std::string text;
240          {
241              std::ifstream ifs(fname_text);
242              if (!ifs) {
243                  fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_text.c_str());
244                  return 1;
245              }
246              text = std::string(std::istreambuf_iterator<char>(ifs), std::istreambuf_iterator<char>());
247          }
248  
249          fprintf(stderr, "%s : text size: %zu\n", __func__, text.size());
250  
251          std::vector<llama_token> res;
252  
253          {
254              const auto t_start = ggml_time_us();
255  
256              res = llama_tokenize(ctx, text, add_special);
257  
258              const auto t_end = ggml_time_us();
259  
260              fprintf(stderr, "%s : tokenized in %.3f ms (cpp)\n", __func__, (t_end - t_start) / 1000.0);
261          }
262  
263          fprintf(stderr, "%s : tokens: %zu\n", __func__, res.size());
264  
265          {
266              const std::string fname_out = fname_text + ".tokcpp";
267  
268              std::ofstream ofs(fname_out);
269              if (!ofs) {
270                  fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_out.c_str());
271                  return 1;
272              }
273  
274              for (const auto & tok : res) {
275                  //ofs << tok << " '" << string_strip(llama_detokenize_bpe(ctx, std::vector<int>{tok})) << "'" << std::endl;
276                  ofs << tok << "\n";
277              }
278          }
279  
280          fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str());
281      }
282  
283      llama_free_model(model);
284      llama_free(ctx);
285  
286      llama_backend_free();
287  
288      printf("\n");
289      printf("Tests %s\n", success ? "passed" : "failed");
290  
291      return success ? 0 : 3;
292  }