tokenize.cpp
1 #include "common.h" 2 #include "llama.h" 3 4 #include <cmath> 5 #include <cstdio> 6 #include <fstream> 7 #include <string> 8 #include <vector> 9 10 #if defined(_WIN32) 11 #define WIN32_LEAN_AND_MEAN 12 #include <windows.h> 13 #include <shellapi.h> // For CommandLineToArgvW 14 #endif 15 16 static void print_usage_information(const char * argv0, FILE * stream) { 17 fprintf(stream, "usage: %s [options]\n\n", argv0); 18 fprintf(stream, "The tokenize program tokenizes a prompt using a given model,\n"); 19 fprintf(stream, "and prints the resulting tokens to standard output.\n\n"); 20 fprintf(stream, "It needs a model file, a prompt, and optionally other flags\n"); 21 fprintf(stream, "to control the behavior of the tokenizer.\n\n"); 22 fprintf(stream, " The possible options are:\n"); 23 fprintf(stream, "\n"); 24 fprintf(stream, " -h, --help print this help and exit\n"); 25 fprintf(stream, " -m MODEL_PATH, --model MODEL_PATH path to model.\n"); 26 fprintf(stream, " --ids if given, only print numerical token IDs, and not token strings.\n"); 27 fprintf(stream, " The output format looks like [1, 2, 3], i.e. parseable by Python.\n"); 28 fprintf(stream, " -f PROMPT_FNAME, --file PROMPT_FNAME read prompt from a file.\n"); 29 fprintf(stream, " -p PROMPT, --prompt PROMPT read prompt from the argument.\n"); 30 fprintf(stream, " --stdin read prompt from standard input.\n"); 31 fprintf(stream, " --no-bos do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n"); 32 fprintf(stream, " --log-disable disable logs. Makes stderr quiet when loading the model.\n"); 33 } 34 35 static void llama_log_callback_null(ggml_log_level level, const char * text, void * user_data) { 36 (void) level; 37 (void) text; 38 (void) user_data; 39 } 40 41 static std::string read_prompt_from_file(const char * filepath, bool & success) { 42 success = false; 43 44 std::ifstream in(filepath, std::ios::binary); 45 if (!in) { 46 fprintf(stderr, "%s: could not open file '%s' for reading: %s\n", __func__, filepath, strerror(errno)); 47 return std::string(); 48 } 49 // do not assume the file is seekable (e.g. /dev/stdin) 50 std::stringstream buffer; 51 buffer << in.rdbuf(); 52 if (in.fail()) { 53 fprintf(stderr, "%s: could not read the entire file '%s': %s\n", __func__, filepath, strerror(errno)); 54 return std::string(); 55 } 56 57 success = true; 58 return buffer.str(); 59 } 60 61 // 62 // Function: ingest_args(...) -> vector<string> 63 // 64 // Takes argc and argv arguments, and converts them to a vector of UTF-8 encoded 65 // strings, as an STL vector<string>. 66 // 67 // In particular, it handles character encoding shenanigans on Windows. 68 // 69 // Note: raw_argc and raw_argv are not actually read at all on Windows. 70 // On Windows we call GetCommandLineW to get the arguments in wchar_t 71 // format, ignoring the regular argc/argv arguments to main(). 72 // 73 // TODO: potential opportunity to roll common stuff into common/console.cpp 74 // in relation to Windows wchar_t shenanigans. 75 static std::vector<std::string> ingest_args(int raw_argc, char ** raw_argv) { 76 std::vector<std::string> argv; 77 78 // Handle Windows, if given non-ASCII arguments. 79 // We convert wchar_t arguments into UTF-8 char* on this platform. 80 // Lets you invoke 'tokenize' on Windows cmd.exe with non-ASCII characters 81 // without throwing tantrums. 82 #if defined(_WIN32) 83 int argc; 84 const LPWSTR cmdline_wargv = GetCommandLineW(); 85 LPWSTR * wargv = CommandLineToArgvW(cmdline_wargv, &argc); 86 87 // silence unused arg warnings 88 (void) raw_argc; 89 (void) raw_argv; 90 91 for (int i = 0; i < argc; ++i) { 92 int length_needed = WideCharToMultiByte(CP_UTF8, 0, wargv[i], wcslen(wargv[i]), 0, 0, NULL, NULL); 93 char * output_buf = (char *) calloc(length_needed+1, sizeof(char)); 94 GGML_ASSERT(output_buf); 95 96 WideCharToMultiByte(CP_UTF8, 0, wargv[i], wcslen(wargv[i]), output_buf, length_needed, NULL, NULL); 97 output_buf[length_needed] = '\0'; 98 99 argv.push_back(output_buf); 100 free(output_buf); 101 } 102 103 LocalFree((HLOCAL) wargv); 104 #else 105 int argc = raw_argc; 106 for (int i = 0; i < argc; ++i) { 107 argv.push_back(raw_argv[i]); 108 } 109 #endif 110 111 GGML_ASSERT((unsigned int) argc == argv.size()); 112 113 return argv; 114 } 115 116 // 117 // Function: write_utf8_cstr_to_stdout(const char *) -> <writes to stdout> 118 // 119 // writes a string to standard output; taking into account that on Windows 120 // to display correctly you have to use special handling. Works even if the 121 // user has not set a unicode code page on a Windows cmd.exe. 122 // 123 // In case of invalid UTF-8, invalid_utf8 is set to true on Windows, and something 124 // a human-readable is written instead. 125 // 126 // On non-Windows systems, simply printfs() the string. 127 static void write_utf8_cstr_to_stdout(const char * str, bool & invalid_utf8) { 128 invalid_utf8 = false; 129 130 #if defined(_WIN32) 131 // Are we in a console? 132 HANDLE hConsole = GetStdHandle(STD_OUTPUT_HANDLE); 133 DWORD dwMode = 0; 134 135 // According to Microsoft docs: 136 // "WriteConsole fails if it is used with a standard handle that is redirected to a file." 137 // Also according to the docs, you can use GetConsoleMode to check for that. 138 if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) { 139 printf("%s", str); 140 return; 141 } 142 143 // MultiByteToWideChar reports an error if str is empty, don't report 144 // them as invalid_utf8. 145 if (*str == 0) { 146 return; 147 } 148 int length_needed = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, str, strlen(str), NULL, 0); 149 if (length_needed == 0) { 150 DWORD err = GetLastError(); 151 if (err == ERROR_NO_UNICODE_TRANSLATION) { 152 invalid_utf8 = true; 153 int len = strlen(str); 154 printf("<"); 155 for (int i = 0; i < len; ++i) { 156 if (i > 0) { 157 printf(" "); 158 } 159 printf("%02x", (uint8_t) str[i]); 160 } 161 printf(">"); 162 return; 163 } 164 GGML_ASSERT(false && "MultiByteToWideChar() failed in an unexpected way."); 165 } 166 167 LPWSTR wstr = (LPWSTR) calloc(length_needed+1, sizeof(*wstr)); 168 GGML_ASSERT(wstr); 169 170 MultiByteToWideChar(CP_UTF8, 0, str, strlen(str), wstr, length_needed); 171 WriteConsoleW(hConsole, wstr, length_needed, NULL, NULL); 172 173 free(wstr); 174 #else 175 // TODO: reporting invalid_utf8 would be useful on non-Windows too. 176 // printf will silently just write bad unicode. 177 printf("%s", str); 178 #endif 179 } 180 181 int main(int raw_argc, char ** raw_argv) { 182 const std::vector<std::string> argv = ingest_args(raw_argc, raw_argv); 183 const int argc = argv.size(); 184 185 if (argc <= 1) { 186 print_usage_information(argv[0].c_str(), stderr); 187 return 1; 188 } 189 190 ////// 191 // Read out all the command line arguments. 192 ////// 193 194 // variables where to put any arguments we see. 195 bool printing_ids = false; 196 bool no_bos = false; 197 bool disable_logging = false; 198 const char * model_path = NULL; 199 const char * prompt_path = NULL; 200 const char * prompt_arg = NULL; 201 202 // track which arguments were explicitly given 203 // used for sanity checking down the line 204 bool model_path_set = false; 205 bool prompt_path_set = false; 206 bool prompt_set = false; 207 bool stdin_set = false; 208 209 int iarg = 1; 210 for (; iarg < argc; ++iarg) { 211 std::string arg{argv[iarg]}; 212 if (arg == "-h" || arg == "--help") { 213 print_usage_information(argv[0].c_str(), stdout); 214 return 0; 215 } 216 else if (arg == "--ids") { 217 printing_ids = true; 218 } 219 else if (arg == "-m" || arg == "--model") { 220 if (model_path_set) { 221 fprintf(stderr, "Error: -m or --model specified multiple times.\n"); 222 return 1; 223 } 224 model_path = argv[++iarg].c_str(); 225 model_path_set = true; 226 } 227 else if (arg == "--no-bos") { 228 no_bos = true; 229 } 230 else if (arg == "-p" || arg == "--prompt") { 231 if (prompt_set) { 232 fprintf(stderr, "Error: -p or --prompt specified multiple times.\n"); 233 return 1; 234 } 235 prompt_arg = argv[++iarg].c_str(); 236 prompt_set = true; 237 } 238 else if (arg == "-f" || arg == "--file") { 239 if (prompt_path_set) { 240 fprintf(stderr, "Error: -f or --file specified multiple times.\n"); 241 return 1; 242 } 243 prompt_path = argv[++iarg].c_str(); 244 prompt_path_set = true; 245 } 246 else if (arg == "--stdin") { 247 stdin_set = true; 248 } 249 else if (arg == "--log-disable") { 250 disable_logging = true; 251 } 252 else { 253 fprintf(stderr, "Error: unknown option '%s'\n", argv[iarg].c_str()); 254 return 1; 255 } 256 } 257 258 ////// 259 // Sanity check the command line arguments. 260 ////// 261 262 // Check that we have the required stuff set. 263 if (model_path_set && model_path == NULL) { 264 fprintf(stderr, "Error: --model requires an argument.\n"); 265 return 1; 266 } 267 if (!model_path_set) { 268 fprintf(stderr, "Error: must specify --model.\n"); 269 return 1; 270 } 271 if (prompt_path_set && prompt_path == NULL) { 272 fprintf(stderr, "Error: --file requires an argument.\n"); 273 return 1; 274 } 275 if (prompt_set && prompt_arg == NULL) { 276 fprintf(stderr, "Error: --prompt requires an argument.\n"); 277 return 1; 278 } 279 const int prompts_set = !!(prompt_path_set) + !!(prompt_set) + !!(stdin_set); 280 if (prompts_set > 1) { 281 fprintf(stderr, "Error: --stdin, --file and --prompt are mutually exclusive.\n"); 282 return 1; 283 } 284 // Must have some prompt. 285 if (prompts_set == 0) { 286 fprintf(stderr, "Error: must specify one of: --stdin, --file or --prompt.\n"); 287 return 1; 288 } 289 290 GGML_ASSERT(model_path); 291 GGML_ASSERT(prompt_path || prompt_arg || stdin_set); 292 293 ////// 294 // Figure out where will the prompt come from. 295 ////// 296 297 std::string prompt; 298 if (prompt_path_set) { 299 bool success = false; 300 prompt = read_prompt_from_file(prompt_path, success); 301 if (!success) { 302 return 1; 303 } 304 } else if (prompt_set) { 305 prompt = prompt_arg; 306 } else { 307 GGML_ASSERT(stdin_set); 308 // we read stdin *after* loading model (early exit if model cannot 309 // be loaded, which can be a nicer user experience) 310 } 311 312 ////// 313 // Start actually doing the tokenizing stuff. 314 ////// 315 316 #ifdef LOG_DISABLE_LOGS 317 disable_logging = true; 318 #endif 319 320 if (disable_logging) { 321 llama_log_set(llama_log_callback_null, NULL); 322 } 323 324 llama_backend_init(); 325 326 llama_model_params model_params = llama_model_default_params(); 327 model_params.vocab_only = true; 328 llama_model * model = llama_load_model_from_file(model_path, model_params); 329 if (!model) { 330 fprintf(stderr, "Error: could not load model from file '%s'.\n", model_path); 331 return 1; 332 } 333 334 llama_context_params ctx_params = llama_context_default_params(); 335 llama_context * ctx = llama_new_context_with_model(model, ctx_params); 336 if (!ctx) { 337 fprintf(stderr, "Error: could not create context.\n"); 338 return 1; 339 } 340 341 // read entire prompt from stdin? 342 if (stdin_set) { 343 GGML_ASSERT(!prompt_path_set && !prompt_set); 344 345 std::stringstream stdin_buffer; 346 stdin_buffer << std::cin.rdbuf(); 347 if (std::cin.fail()) { 348 fprintf(stderr, "Error: could not read the entire standard input.\n"); 349 return 1; 350 } 351 352 prompt = stdin_buffer.str(); 353 } 354 355 const bool model_wants_add_bos = llama_should_add_bos_token(model); 356 const bool add_bos = model_wants_add_bos && !no_bos; 357 358 std::vector<llama_token> tokens; 359 tokens = ::llama_tokenize(model, prompt, add_bos, true); 360 361 if (printing_ids) { 362 printf("["); 363 } 364 365 for (int i = 0; i < (int) tokens.size(); i++) { 366 if (printing_ids) { 367 if (i > 0) { 368 printf(", "); 369 } 370 printf("%d", tokens[i]); 371 } else { 372 bool invalid_utf8 = false; 373 printf("%6d -> '", tokens[i]); 374 write_utf8_cstr_to_stdout(llama_token_to_piece(ctx, tokens[i]).c_str(), invalid_utf8); 375 if (invalid_utf8) { 376 printf("' (utf-8 decode failure)\n"); 377 } else { 378 printf("'\n"); 379 } 380 } 381 } 382 383 if (printing_ids) { 384 printf("]\n"); 385 } 386 387 // silence valgrind 388 llama_free(ctx); 389 llama_free_model(model); 390 391 return 0; 392 }