grammar-parser.cpp
1 #include "grammar-parser.h" 2 #include <cstdint> 3 #include <cwchar> 4 #include <string> 5 #include <utility> 6 #include <stdexcept> 7 #include <exception> 8 9 namespace grammar_parser { 10 // NOTE: assumes valid utf8 (but checks for overrun) 11 // copied from llama.cpp 12 static std::pair<uint32_t, const char *> decode_utf8(const char * src) { 13 static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; 14 uint8_t first_byte = static_cast<uint8_t>(*src); 15 uint8_t highbits = first_byte >> 4; 16 int len = lookup[highbits]; 17 uint8_t mask = (1 << (8 - len)) - 1; 18 uint32_t value = first_byte & mask; 19 const char * end = src + len; // may overrun! 20 const char * pos = src + 1; 21 for ( ; pos < end && *pos; pos++) { 22 value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F); 23 } 24 return std::make_pair(value, pos); 25 } 26 27 static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { 28 uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size()); 29 auto result = state.symbol_ids.emplace(std::string(src, len), next_id); 30 return result.first->second; 31 } 32 33 static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { 34 uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size()); 35 state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; 36 return next_id; 37 } 38 39 static void add_rule( 40 parse_state & state, 41 uint32_t rule_id, 42 const std::vector<llama_grammar_element> & rule) { 43 if (state.rules.size() <= rule_id) { 44 state.rules.resize(rule_id + 1); 45 } 46 state.rules[rule_id] = rule; 47 } 48 49 static bool is_digit_char(char c) { 50 return '0' <= c && c <= '9'; 51 } 52 53 static bool is_word_char(char c) { 54 return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); 55 } 56 57 static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) { 58 const char * pos = src; 59 const char * end = src + size; 60 uint32_t value = 0; 61 for ( ; pos < end && *pos; pos++) { 62 value <<= 4; 63 char c = *pos; 64 if ('a' <= c && c <= 'f') { 65 value += c - 'a' + 10; 66 } else if ('A' <= c && c <= 'F') { 67 value += c - 'A' + 10; 68 } else if ('0' <= c && c <= '9') { 69 value += c - '0'; 70 } else { 71 break; 72 } 73 } 74 if (pos != end) { 75 throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); 76 } 77 return std::make_pair(value, pos); 78 } 79 80 static const char * parse_space(const char * src, bool newline_ok) { 81 const char * pos = src; 82 while (*pos == ' ' || *pos == '\t' || *pos == '#' || 83 (newline_ok && (*pos == '\r' || *pos == '\n'))) { 84 if (*pos == '#') { 85 while (*pos && *pos != '\r' && *pos != '\n') { 86 pos++; 87 } 88 } else { 89 pos++; 90 } 91 } 92 return pos; 93 } 94 95 static const char * parse_name(const char * src) { 96 const char * pos = src; 97 while (is_word_char(*pos)) { 98 pos++; 99 } 100 if (pos == src) { 101 throw std::runtime_error(std::string("expecting name at ") + src); 102 } 103 return pos; 104 } 105 106 static const char * parse_int(const char * src) { 107 const char * pos = src; 108 while (is_digit_char(*pos)) { 109 pos++; 110 } 111 if (pos == src) { 112 throw std::runtime_error(std::string("expecting integer at ") + src); 113 } 114 return pos; 115 } 116 117 static std::pair<uint32_t, const char *> parse_char(const char * src) { 118 if (*src == '\\') { 119 switch (src[1]) { 120 case 'x': return parse_hex(src + 2, 2); 121 case 'u': return parse_hex(src + 2, 4); 122 case 'U': return parse_hex(src + 2, 8); 123 case 't': return std::make_pair('\t', src + 2); 124 case 'r': return std::make_pair('\r', src + 2); 125 case 'n': return std::make_pair('\n', src + 2); 126 case '\\': 127 case '"': 128 case '[': 129 case ']': 130 return std::make_pair(src[1], src + 2); 131 default: 132 throw std::runtime_error(std::string("unknown escape at ") + src); 133 } 134 } else if (*src) { 135 return decode_utf8(src); 136 } 137 throw std::runtime_error("unexpected end of input"); 138 } 139 140 const char * parse_alternates( 141 parse_state & state, 142 const char * src, 143 const std::string & rule_name, 144 uint32_t rule_id, 145 bool is_nested); 146 147 static const char * parse_sequence( 148 parse_state & state, 149 const char * src, 150 const std::string & rule_name, 151 std::vector<llama_grammar_element> & out_elements, 152 bool is_nested) { 153 size_t last_sym_start = out_elements.size(); 154 const char * pos = src; 155 156 auto handle_repetitions = [&](int min_times, int max_times) { 157 158 if (last_sym_start == out_elements.size()) { 159 throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); 160 } 161 162 // apply transformation to previous symbol (last_sym_start to end) according to 163 // the following rewrite rules: 164 // S{m,n} --> S S S (m times) S'(n-m) 165 // S'(x) ::= S S'(x-1) | 166 // (... n-m definitions of these S' rules ...) 167 // S'(1) ::= S | 168 // S{m,} --> S S S (m times) S' 169 // S' ::= S S' | 170 // S* --> S{0,} 171 // --> S' ::= S S' | 172 // S+ --> S{1,} 173 // --> S S' 174 // S' ::= S S' | 175 // S? --> S{0,1} 176 // --> S' 177 // S' ::= S | 178 179 std::vector<llama_grammar_element> previous_elements(out_elements.begin() + last_sym_start, out_elements.end()); 180 if (min_times == 0) { 181 out_elements.resize(last_sym_start); 182 } else { 183 // Repeat the previous elements (min_times - 1) times 184 for (int i = 1; i < min_times; i++) { 185 out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end()); 186 } 187 } 188 189 uint32_t last_rec_rule_id = 0; 190 auto n_opt = max_times < 0 ? 1 : max_times - min_times; 191 192 std::vector<llama_grammar_element> rec_rule(previous_elements); 193 for (int i = 0; i < n_opt; i++) { 194 rec_rule.resize(previous_elements.size()); 195 uint32_t rec_rule_id = generate_symbol_id(state, rule_name); 196 if (i > 0 || max_times < 0) { 197 rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); 198 } 199 rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); 200 rec_rule.push_back({LLAMA_GRETYPE_END, 0}); 201 add_rule(state, rec_rule_id, rec_rule); 202 last_rec_rule_id = rec_rule_id; 203 } 204 if (n_opt > 0) { 205 out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); 206 } 207 }; 208 209 while (*pos) { 210 if (*pos == '"') { // literal string 211 pos++; 212 last_sym_start = out_elements.size(); 213 while (*pos != '"') { 214 if (!*pos) { 215 throw std::runtime_error("unexpected end of input"); 216 } 217 auto char_pair = parse_char(pos); 218 pos = char_pair.second; 219 out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); 220 } 221 pos = parse_space(pos + 1, is_nested); 222 } else if (*pos == '[') { // char range(s) 223 pos++; 224 enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; 225 if (*pos == '^') { 226 pos++; 227 start_type = LLAMA_GRETYPE_CHAR_NOT; 228 } 229 last_sym_start = out_elements.size(); 230 while (*pos != ']') { 231 if (!*pos) { 232 throw std::runtime_error("unexpected end of input"); 233 } 234 auto char_pair = parse_char(pos); 235 pos = char_pair.second; 236 enum llama_gretype type = last_sym_start < out_elements.size() 237 ? LLAMA_GRETYPE_CHAR_ALT 238 : start_type; 239 240 out_elements.push_back({type, char_pair.first}); 241 if (pos[0] == '-' && pos[1] != ']') { 242 if (!pos[1]) { 243 throw std::runtime_error("unexpected end of input"); 244 } 245 auto endchar_pair = parse_char(pos + 1); 246 pos = endchar_pair.second; 247 out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); 248 } 249 } 250 pos = parse_space(pos + 1, is_nested); 251 } else if (is_word_char(*pos)) { // rule reference 252 const char * name_end = parse_name(pos); 253 uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); 254 pos = parse_space(name_end, is_nested); 255 last_sym_start = out_elements.size(); 256 out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); 257 } else if (*pos == '(') { // grouping 258 // parse nested alternates into synthesized rule 259 pos = parse_space(pos + 1, true); 260 uint32_t sub_rule_id = generate_symbol_id(state, rule_name); 261 pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); 262 last_sym_start = out_elements.size(); 263 // output reference to synthesized rule 264 out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); 265 if (*pos != ')') { 266 throw std::runtime_error(std::string("expecting ')' at ") + pos); 267 } 268 pos = parse_space(pos + 1, is_nested); 269 } else if (*pos == '.') { // any char 270 last_sym_start = out_elements.size(); 271 out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); 272 pos = parse_space(pos + 1, is_nested); 273 } else if (*pos == '*') { 274 pos = parse_space(pos + 1, is_nested); 275 handle_repetitions(0, -1); 276 } else if (*pos == '+') { 277 pos = parse_space(pos + 1, is_nested); 278 handle_repetitions(1, -1); 279 } else if (*pos == '?') { 280 pos = parse_space(pos + 1, is_nested); 281 handle_repetitions(0, 1); 282 } else if (*pos == '{') { 283 pos = parse_space(pos + 1, is_nested); 284 285 if (!is_digit_char(*pos)) { 286 throw std::runtime_error(std::string("expecting an int at ") + pos); 287 } 288 const char * int_end = parse_int(pos); 289 int min_times = std::stoul(std::string(pos, int_end - pos)); 290 pos = parse_space(int_end, is_nested); 291 292 int max_times = -1; 293 294 if (*pos == '}') { 295 max_times = min_times; 296 pos = parse_space(pos + 1, is_nested); 297 } else if (*pos == ',') { 298 pos = parse_space(pos + 1, is_nested); 299 300 if (is_digit_char(*pos)) { 301 const char * int_end = parse_int(pos); 302 max_times = std::stoul(std::string(pos, int_end - pos)); 303 pos = parse_space(int_end, is_nested); 304 } 305 306 if (*pos != '}') { 307 throw std::runtime_error(std::string("expecting '}' at ") + pos); 308 } 309 pos = parse_space(pos + 1, is_nested); 310 } else { 311 throw std::runtime_error(std::string("expecting ',' at ") + pos); 312 } 313 handle_repetitions(min_times, max_times); 314 } else { 315 break; 316 } 317 } 318 return pos; 319 } 320 321 const char * parse_alternates( 322 parse_state & state, 323 const char * src, 324 const std::string & rule_name, 325 uint32_t rule_id, 326 bool is_nested) { 327 std::vector<llama_grammar_element> rule; 328 const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); 329 while (*pos == '|') { 330 rule.push_back({LLAMA_GRETYPE_ALT, 0}); 331 pos = parse_space(pos + 1, true); 332 pos = parse_sequence(state, pos, rule_name, rule, is_nested); 333 } 334 rule.push_back({LLAMA_GRETYPE_END, 0}); 335 add_rule(state, rule_id, rule); 336 return pos; 337 } 338 339 static const char * parse_rule(parse_state & state, const char * src) { 340 const char * name_end = parse_name(src); 341 const char * pos = parse_space(name_end, false); 342 size_t name_len = name_end - src; 343 uint32_t rule_id = get_symbol_id(state, src, name_len); 344 const std::string name(src, name_len); 345 346 if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { 347 throw std::runtime_error(std::string("expecting ::= at ") + pos); 348 } 349 pos = parse_space(pos + 3, true); 350 351 pos = parse_alternates(state, pos, name, rule_id, false); 352 353 if (*pos == '\r') { 354 pos += pos[1] == '\n' ? 2 : 1; 355 } else if (*pos == '\n') { 356 pos++; 357 } else if (*pos) { 358 throw std::runtime_error(std::string("expecting newline or end at ") + pos); 359 } 360 return parse_space(pos, true); 361 } 362 363 parse_state parse(const char * src) { 364 try { 365 parse_state state; 366 const char * pos = parse_space(src, true); 367 while (*pos) { 368 pos = parse_rule(state, pos); 369 } 370 // Validate the state to ensure that all rules are defined 371 for (const auto & rule : state.rules) { 372 for (const auto & elem : rule) { 373 if (elem.type == LLAMA_GRETYPE_RULE_REF) { 374 // Ensure that the rule at that location exists 375 if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) { 376 // Get the name of the rule that is missing 377 for (const auto & kv : state.symbol_ids) { 378 if (kv.second == elem.value) { 379 throw std::runtime_error("Undefined rule identifier '" + kv.first + "'"); 380 } 381 } 382 } 383 } 384 } 385 } 386 return state; 387 } catch (const std::exception & err) { 388 fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); 389 return parse_state(); 390 } 391 } 392 393 static void print_grammar_char(FILE * file, uint32_t c) { 394 if (0x20 <= c && c <= 0x7f) { 395 fprintf(file, "%c", static_cast<char>(c)); 396 } else { 397 // cop out of encoding UTF-8 398 fprintf(file, "<U+%04X>", c); 399 } 400 } 401 402 static bool is_char_element(llama_grammar_element elem) { 403 switch (elem.type) { 404 case LLAMA_GRETYPE_CHAR: return true; 405 case LLAMA_GRETYPE_CHAR_NOT: return true; 406 case LLAMA_GRETYPE_CHAR_ALT: return true; 407 case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; 408 case LLAMA_GRETYPE_CHAR_ANY: return true; 409 default: return false; 410 } 411 } 412 413 static void print_rule_binary(FILE * file, const std::vector<llama_grammar_element> & rule) { 414 for (auto elem : rule) { 415 switch (elem.type) { 416 case LLAMA_GRETYPE_END: fprintf(file, "END"); break; 417 case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; 418 case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; 419 case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; 420 case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; 421 case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; 422 case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; 423 case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; 424 } 425 switch (elem.type) { 426 case LLAMA_GRETYPE_END: 427 case LLAMA_GRETYPE_ALT: 428 case LLAMA_GRETYPE_RULE_REF: 429 fprintf(file, "(%u) ", elem.value); 430 break; 431 case LLAMA_GRETYPE_CHAR: 432 case LLAMA_GRETYPE_CHAR_NOT: 433 case LLAMA_GRETYPE_CHAR_RNG_UPPER: 434 case LLAMA_GRETYPE_CHAR_ALT: 435 case LLAMA_GRETYPE_CHAR_ANY: 436 fprintf(file, "(\""); 437 print_grammar_char(file, elem.value); 438 fprintf(file, "\") "); 439 break; 440 } 441 } 442 fprintf(file, "\n"); 443 } 444 445 static void print_rule( 446 FILE * file, 447 uint32_t rule_id, 448 const std::vector<llama_grammar_element> & rule, 449 const std::map<uint32_t, std::string> & symbol_id_names) { 450 if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { 451 throw std::runtime_error( 452 "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); 453 } 454 fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); 455 for (size_t i = 0, end = rule.size() - 1; i < end; i++) { 456 llama_grammar_element elem = rule[i]; 457 switch (elem.type) { 458 case LLAMA_GRETYPE_END: 459 throw std::runtime_error( 460 "unexpected end of rule: " + std::to_string(rule_id) + "," + 461 std::to_string(i)); 462 case LLAMA_GRETYPE_ALT: 463 fprintf(file, "| "); 464 break; 465 case LLAMA_GRETYPE_RULE_REF: 466 fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); 467 break; 468 case LLAMA_GRETYPE_CHAR: 469 fprintf(file, "["); 470 print_grammar_char(file, elem.value); 471 break; 472 case LLAMA_GRETYPE_CHAR_NOT: 473 fprintf(file, "[^"); 474 print_grammar_char(file, elem.value); 475 break; 476 case LLAMA_GRETYPE_CHAR_RNG_UPPER: 477 if (i == 0 || !is_char_element(rule[i - 1])) { 478 throw std::runtime_error( 479 "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + 480 std::to_string(rule_id) + "," + std::to_string(i)); 481 } 482 fprintf(file, "-"); 483 print_grammar_char(file, elem.value); 484 break; 485 case LLAMA_GRETYPE_CHAR_ALT: 486 if (i == 0 || !is_char_element(rule[i - 1])) { 487 throw std::runtime_error( 488 "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + 489 std::to_string(rule_id) + "," + std::to_string(i)); 490 } 491 print_grammar_char(file, elem.value); 492 break; 493 case LLAMA_GRETYPE_CHAR_ANY: 494 fprintf(file, "."); 495 break; 496 } 497 if (is_char_element(elem)) { 498 switch (rule[i + 1].type) { 499 case LLAMA_GRETYPE_CHAR_ALT: 500 case LLAMA_GRETYPE_CHAR_RNG_UPPER: 501 case LLAMA_GRETYPE_CHAR_ANY: 502 break; 503 default: 504 fprintf(file, "] "); 505 } 506 } 507 } 508 fprintf(file, "\n"); 509 } 510 511 void print_grammar(FILE * file, const parse_state & state) { 512 try { 513 std::map<uint32_t, std::string> symbol_id_names; 514 for (const auto & kv : state.symbol_ids) { 515 symbol_id_names[kv.second] = kv.first; 516 } 517 for (size_t i = 0, end = state.rules.size(); i < end; i++) { 518 // fprintf(file, "%zu: ", i); 519 // print_rule_binary(file, state.rules[i]); 520 print_rule(file, uint32_t(i), state.rules[i], symbol_id_names); 521 // fprintf(file, "\n"); 522 } 523 } catch (const std::exception & err) { 524 fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); 525 } 526 } 527 528 std::vector<const llama_grammar_element *> parse_state::c_rules() { 529 std::vector<const llama_grammar_element *> ret; 530 ret.reserve(rules.size()); 531 for (const auto & rule : rules) { 532 ret.push_back(rule.data()); 533 } 534 return ret; 535 } 536 }