/ examples / tokenize / tokenize.cpp
tokenize.cpp
  1  #include "common.h"
  2  #include "llama.h"
  3  
  4  #include <cmath>
  5  #include <cstdio>
  6  #include <fstream>
  7  #include <string>
  8  #include <vector>
  9  
 10  #if defined(_WIN32)
 11  #define WIN32_LEAN_AND_MEAN
 12  #include <windows.h>
 13  #include <shellapi.h>   // For CommandLineToArgvW
 14  #endif
 15  
 16  static void print_usage_information(const char * argv0, FILE * stream) {
 17      fprintf(stream, "usage: %s [options]\n\n", argv0);
 18      fprintf(stream, "The tokenize program tokenizes a prompt using a given model,\n");
 19      fprintf(stream, "and prints the resulting tokens to standard output.\n\n");
 20      fprintf(stream, "It needs a model file, a prompt, and optionally other flags\n");
 21      fprintf(stream, "to control the behavior of the tokenizer.\n\n");
 22      fprintf(stream, "    The possible options are:\n");
 23      fprintf(stream, "\n");
 24      fprintf(stream, "    -h, --help                           print this help and exit\n");
 25      fprintf(stream, "    -m MODEL_PATH, --model MODEL_PATH    path to model.\n");
 26      fprintf(stream, "    --ids                                if given, only print numerical token IDs, and not token strings.\n");
 27      fprintf(stream, "                                         The output format looks like [1, 2, 3], i.e. parseable by Python.\n");
 28      fprintf(stream, "    -f PROMPT_FNAME, --file PROMPT_FNAME read prompt from a file.\n");
 29      fprintf(stream, "    -p PROMPT, --prompt PROMPT           read prompt from the argument.\n");
 30      fprintf(stream, "    --stdin                              read prompt from standard input.\n");
 31      fprintf(stream, "    --no-bos                             do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n");
 32      fprintf(stream, "    --log-disable                        disable logs. Makes stderr quiet when loading the model.\n");
 33  }
 34  
 35  static void llama_log_callback_null(ggml_log_level level, const char * text, void * user_data) {
 36      (void) level;
 37      (void) text;
 38      (void) user_data;
 39  }
 40  
 41  static std::string read_prompt_from_file(const char * filepath, bool & success) {
 42      success = false;
 43  
 44      std::ifstream in(filepath, std::ios::binary);
 45      if (!in) {
 46          fprintf(stderr, "%s: could not open file '%s' for reading: %s\n", __func__, filepath, strerror(errno));
 47          return std::string();
 48      }
 49      // do not assume the file is seekable (e.g. /dev/stdin)
 50      std::stringstream buffer;
 51      buffer << in.rdbuf();
 52      if (in.fail()) {
 53          fprintf(stderr, "%s: could not read the entire file '%s': %s\n", __func__, filepath, strerror(errno));
 54          return std::string();
 55      }
 56  
 57      success = true;
 58      return buffer.str();
 59  }
 60  
 61  //
 62  // Function: ingest_args(...) -> vector<string>
 63  //
 64  //  Takes argc and argv arguments, and converts them to a vector of UTF-8 encoded
 65  //  strings, as an STL vector<string>.
 66  //
 67  //  In particular, it handles character encoding shenanigans on Windows.
 68  //
 69  // Note: raw_argc and raw_argv are not actually read at all on Windows.
 70  //       On Windows we call GetCommandLineW to get the arguments in wchar_t
 71  //       format, ignoring the regular argc/argv arguments to main().
 72  //
 73  // TODO: potential opportunity to roll common stuff into common/console.cpp
 74  //       in relation to Windows wchar_t shenanigans.
 75  static std::vector<std::string> ingest_args(int raw_argc, char ** raw_argv) {
 76      std::vector<std::string> argv;
 77  
 78      // Handle Windows, if given non-ASCII arguments.
 79      // We convert wchar_t arguments into UTF-8 char* on this platform.
 80      // Lets you invoke 'tokenize' on Windows cmd.exe with non-ASCII characters
 81      // without throwing tantrums.
 82  #if defined(_WIN32)
 83      int argc;
 84      const LPWSTR cmdline_wargv = GetCommandLineW();
 85      LPWSTR * wargv = CommandLineToArgvW(cmdline_wargv, &argc);
 86  
 87      // silence unused arg warnings
 88      (void) raw_argc;
 89      (void) raw_argv;
 90  
 91      for (int i = 0; i < argc; ++i) {
 92          int length_needed = WideCharToMultiByte(CP_UTF8, 0, wargv[i], wcslen(wargv[i]), 0, 0, NULL, NULL);
 93          char * output_buf = (char *) calloc(length_needed+1, sizeof(char));
 94          GGML_ASSERT(output_buf);
 95  
 96          WideCharToMultiByte(CP_UTF8, 0, wargv[i], wcslen(wargv[i]), output_buf, length_needed, NULL, NULL);
 97          output_buf[length_needed] = '\0';
 98  
 99          argv.push_back(output_buf);
100          free(output_buf);
101      }
102  
103      LocalFree((HLOCAL) wargv);
104  #else
105      int argc = raw_argc;
106      for (int i = 0; i < argc; ++i) {
107          argv.push_back(raw_argv[i]);
108      }
109  #endif
110  
111      GGML_ASSERT((unsigned int) argc == argv.size());
112  
113      return argv;
114  }
115  
116  //
117  // Function: write_utf8_cstr_to_stdout(const char *) -> <writes to stdout>
118  //
119  // writes a string to standard output; taking into account that on Windows
120  // to display correctly you have to use special handling. Works even if the
121  // user has not set a unicode code page on a Windows cmd.exe.
122  //
123  // In case of invalid UTF-8, invalid_utf8 is set to true on Windows, and something
124  // a human-readable is written instead.
125  //
126  // On non-Windows systems, simply printfs() the string.
127  static void write_utf8_cstr_to_stdout(const char * str, bool & invalid_utf8) {
128          invalid_utf8 = false;
129  
130  #if defined(_WIN32)
131          // Are we in a console?
132          HANDLE hConsole = GetStdHandle(STD_OUTPUT_HANDLE);
133          DWORD dwMode = 0;
134  
135          // According to Microsoft docs:
136          // "WriteConsole fails if it is used with a standard handle that is redirected to a file."
137          // Also according to the docs, you can use GetConsoleMode to check for that.
138          if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) {
139              printf("%s", str);
140              return;
141          }
142  
143          // MultiByteToWideChar reports an error if str is empty, don't report
144          // them as invalid_utf8.
145          if (*str == 0) {
146              return;
147          }
148          int length_needed = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, str, strlen(str), NULL, 0);
149          if (length_needed == 0) {
150              DWORD err = GetLastError();
151              if (err == ERROR_NO_UNICODE_TRANSLATION) {
152                  invalid_utf8 = true;
153                  int len = strlen(str);
154                  printf("<");
155                  for (int i = 0; i < len; ++i) {
156                      if (i > 0) {
157                          printf(" ");
158                      }
159                      printf("%02x", (uint8_t) str[i]);
160                  }
161                  printf(">");
162                  return;
163              }
164              GGML_ASSERT(false && "MultiByteToWideChar() failed in an unexpected way.");
165          }
166  
167          LPWSTR wstr = (LPWSTR) calloc(length_needed+1, sizeof(*wstr));
168          GGML_ASSERT(wstr);
169  
170          MultiByteToWideChar(CP_UTF8, 0, str, strlen(str), wstr, length_needed);
171          WriteConsoleW(hConsole, wstr, length_needed, NULL, NULL);
172  
173          free(wstr);
174  #else
175          // TODO: reporting invalid_utf8 would be useful on non-Windows too.
176          // printf will silently just write bad unicode.
177          printf("%s", str);
178  #endif
179  }
180  
181  int main(int raw_argc, char ** raw_argv) {
182      const std::vector<std::string> argv = ingest_args(raw_argc, raw_argv);
183      const int argc = argv.size();
184  
185      if (argc <= 1) {
186          print_usage_information(argv[0].c_str(), stderr);
187          return 1;
188      }
189  
190      //////
191      // Read out all the command line arguments.
192      //////
193  
194      // variables where to put any arguments we see.
195      bool printing_ids = false;
196      bool no_bos = false;
197      bool disable_logging = false;
198      const char * model_path = NULL;
199      const char * prompt_path = NULL;
200      const char * prompt_arg = NULL;
201  
202      // track which arguments were explicitly given
203      // used for sanity checking down the line
204      bool model_path_set = false;
205      bool prompt_path_set = false;
206      bool prompt_set = false;
207      bool stdin_set = false;
208  
209      int iarg = 1;
210      for (; iarg < argc; ++iarg) {
211          std::string arg{argv[iarg]};
212          if (arg == "-h" || arg == "--help") {
213              print_usage_information(argv[0].c_str(), stdout);
214              return 0;
215          }
216          else if (arg == "--ids") {
217              printing_ids = true;
218          }
219          else if (arg == "-m" || arg == "--model") {
220              if (model_path_set) {
221                  fprintf(stderr, "Error: -m or --model specified multiple times.\n");
222                  return 1;
223              }
224              model_path = argv[++iarg].c_str();
225              model_path_set = true;
226          }
227          else if (arg == "--no-bos") {
228              no_bos = true;
229          }
230          else if (arg == "-p" || arg == "--prompt") {
231              if (prompt_set) {
232                  fprintf(stderr, "Error: -p or --prompt specified multiple times.\n");
233                  return 1;
234              }
235              prompt_arg = argv[++iarg].c_str();
236              prompt_set = true;
237          }
238          else if (arg == "-f" || arg == "--file") {
239              if (prompt_path_set) {
240                  fprintf(stderr, "Error: -f or --file specified multiple times.\n");
241                  return 1;
242              }
243              prompt_path = argv[++iarg].c_str();
244              prompt_path_set = true;
245          }
246          else if (arg == "--stdin") {
247              stdin_set = true;
248          }
249          else if (arg == "--log-disable") {
250              disable_logging = true;
251          }
252          else {
253              fprintf(stderr, "Error: unknown option '%s'\n", argv[iarg].c_str());
254              return 1;
255          }
256      }
257  
258      //////
259      // Sanity check the command line arguments.
260      //////
261  
262      // Check that we have the required stuff set.
263      if (model_path_set && model_path == NULL) {
264          fprintf(stderr, "Error: --model requires an argument.\n");
265          return 1;
266      }
267      if (!model_path_set) {
268          fprintf(stderr, "Error: must specify --model.\n");
269          return 1;
270      }
271      if (prompt_path_set && prompt_path == NULL) {
272          fprintf(stderr, "Error: --file requires an argument.\n");
273          return 1;
274      }
275      if (prompt_set && prompt_arg == NULL) {
276          fprintf(stderr, "Error: --prompt requires an argument.\n");
277          return 1;
278      }
279      const int prompts_set = !!(prompt_path_set) + !!(prompt_set) + !!(stdin_set);
280      if (prompts_set > 1) {
281          fprintf(stderr, "Error: --stdin, --file and --prompt are mutually exclusive.\n");
282          return 1;
283      }
284      // Must have some prompt.
285      if (prompts_set == 0) {
286          fprintf(stderr, "Error: must specify one of: --stdin, --file or --prompt.\n");
287          return 1;
288      }
289  
290      GGML_ASSERT(model_path);
291      GGML_ASSERT(prompt_path || prompt_arg || stdin_set);
292  
293      //////
294      // Figure out where will the prompt come from.
295      //////
296  
297      std::string prompt;
298      if (prompt_path_set) {
299          bool success = false;
300          prompt = read_prompt_from_file(prompt_path, success);
301          if (!success) {
302              return 1;
303          }
304      } else if (prompt_set) {
305          prompt = prompt_arg;
306      } else {
307          GGML_ASSERT(stdin_set);
308          // we read stdin *after* loading model (early exit if model cannot
309          // be loaded, which can be a nicer user experience)
310      }
311  
312      //////
313      // Start actually doing the tokenizing stuff.
314      //////
315  
316  #ifdef LOG_DISABLE_LOGS
317      disable_logging = true;
318  #endif
319  
320      if (disable_logging) {
321          llama_log_set(llama_log_callback_null, NULL);
322      }
323  
324      llama_backend_init();
325  
326      llama_model_params model_params = llama_model_default_params();
327      model_params.vocab_only = true;
328      llama_model * model = llama_load_model_from_file(model_path, model_params);
329      if (!model) {
330          fprintf(stderr, "Error: could not load model from file '%s'.\n", model_path);
331          return 1;
332      }
333  
334      llama_context_params ctx_params = llama_context_default_params();
335      llama_context * ctx = llama_new_context_with_model(model, ctx_params);
336      if (!ctx) {
337          fprintf(stderr, "Error: could not create context.\n");
338          return 1;
339      }
340  
341      // read entire prompt from stdin?
342      if (stdin_set) {
343          GGML_ASSERT(!prompt_path_set && !prompt_set);
344  
345          std::stringstream stdin_buffer;
346          stdin_buffer << std::cin.rdbuf();
347          if (std::cin.fail()) {
348              fprintf(stderr, "Error: could not read the entire standard input.\n");
349              return 1;
350          }
351  
352          prompt = stdin_buffer.str();
353      }
354  
355      const bool model_wants_add_bos = llama_should_add_bos_token(model);
356      const bool add_bos = model_wants_add_bos && !no_bos;
357  
358      std::vector<llama_token> tokens;
359      tokens = ::llama_tokenize(model, prompt, add_bos, true);
360  
361      if (printing_ids) {
362          printf("[");
363      }
364  
365      for (int i = 0; i < (int) tokens.size(); i++) {
366          if (printing_ids) {
367              if (i > 0) {
368                  printf(", ");
369              }
370              printf("%d", tokens[i]);
371          } else {
372              bool invalid_utf8 = false;
373              printf("%6d -> '", tokens[i]);
374              write_utf8_cstr_to_stdout(llama_token_to_piece(ctx, tokens[i]).c_str(), invalid_utf8);
375              if (invalid_utf8) {
376                  printf("' (utf-8 decode failure)\n");
377              } else {
378                  printf("'\n");
379              }
380          }
381      }
382  
383      if (printing_ids) {
384          printf("]\n");
385      }
386  
387      // silence valgrind
388      llama_free(ctx);
389      llama_free_model(model);
390  
391      return 0;
392  }