/ tests / test-grammar-parser.cpp
test-grammar-parser.cpp
  1  #ifdef NDEBUG
  2  #undef NDEBUG
  3  #endif
  4  
  5  #include "llama.h"
  6  #include "grammar-parser.h"
  7  
  8  #include <cassert>
  9  
 10  static const char * type_str(llama_gretype type) {
 11      switch (type) {
 12          case LLAMA_GRETYPE_CHAR: return "LLAMA_GRETYPE_CHAR";
 13          case LLAMA_GRETYPE_CHAR_NOT: return "LLAMA_GRETYPE_CHAR_NOT";
 14          case LLAMA_GRETYPE_CHAR_ALT: return "LLAMA_GRETYPE_CHAR_ALT";
 15          case LLAMA_GRETYPE_CHAR_RNG_UPPER: return "LLAMA_GRETYPE_CHAR_RNG_UPPER";
 16          case LLAMA_GRETYPE_RULE_REF: return "LLAMA_GRETYPE_RULE_REF";
 17          case LLAMA_GRETYPE_ALT: return "LLAMA_GRETYPE_ALT";
 18          case LLAMA_GRETYPE_END: return "LLAMA_GRETYPE_END";
 19          default: return "?";
 20      }
 21  }
 22  
 23  static void verify_parsing(const char *grammar_bytes, const std::vector<std::pair<std::string, uint32_t>> expected, const std::vector<llama_grammar_element> &expected_rules) {
 24      uint32_t index = 0;
 25      grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_bytes);
 26  
 27      std::map<uint32_t, std::string> symbol_names;
 28      for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) {
 29          symbol_names[it->second] = it->first;
 30      }
 31  
 32      auto print_all = [&]() {
 33          fprintf(stderr, "    verify_parsing(R\"\"\"(%s)\"\"\", {\n", grammar_bytes);
 34          for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) {
 35              fprintf(stderr, "        {\"%s\", %u},\n", it->first.c_str(), it->second);
 36          }
 37          fprintf(stderr, "    }, {\n");
 38          for (size_t i_rule = 0; i_rule < parsed_grammar.rules.size(); i_rule++) {
 39              fprintf(stderr, "        // %s (index %zu)\n", symbol_names[i_rule].c_str(), i_rule);
 40              auto & rule = parsed_grammar.rules[i_rule];
 41              for (uint32_t i = 0; i < rule.size(); i++) {
 42                  std::string rule_str;
 43                  fprintf(stderr, "        {%s, ", type_str(rule[i].type));
 44                  if (rule[i].type == LLAMA_GRETYPE_CHAR || rule[i].type == LLAMA_GRETYPE_CHAR_ALT ||
 45                      rule[i].type == LLAMA_GRETYPE_CHAR_NOT || rule[i].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
 46                      char c = rule[i].value;
 47                      if (c == '\n') {
 48                          fprintf(stderr, "'\\n'");
 49                      } else if (c == '\t') {
 50                          fprintf(stderr, "'\\t'");
 51                      } else if (c == '\r') {
 52                          fprintf(stderr, "'\\r'");
 53                      } else if (c == '\0') {
 54                          fprintf(stderr, "'\\0'");
 55                      } else {
 56                          fprintf(stderr, "'%c'", c);
 57                      }
 58                  } else if (rule[i].type == LLAMA_GRETYPE_RULE_REF) {
 59                      fprintf(stderr, "/* %s */ %u", symbol_names[rule[i].value].c_str(), rule[i].value);
 60                  } else {
 61                      fprintf(stderr, "%u", rule[i].value);
 62                  }
 63                  fprintf(stderr, "},\n");
 64              }
 65          }
 66          fprintf(stderr, "    });\n");
 67      };
 68  
 69      if (getenv("TEST_GRAMMAR_PARSER_PRINT_ALL")) {
 70          print_all();
 71          fprintf(stderr, "\n");
 72          return;
 73      }
 74  
 75      fprintf(stderr, "Testing grammar:%s\n", grammar_bytes);
 76  
 77      if (parsed_grammar.symbol_ids.size() != expected.size()) {
 78          fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n");
 79          print_all();
 80          assert(parsed_grammar.symbol_ids.size() == expected.size());
 81      }
 82  
 83      for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it)
 84      {
 85          std::string key = it->first;
 86          uint32_t value = it->second;
 87          std::pair<std::string, uint32_t> expected_pair = expected[index];
 88  
 89          // pretty print error message before asserting
 90          if (expected_pair.first != key || expected_pair.second != value)
 91          {
 92              fprintf(stderr, "index: %u\n", index);
 93              fprintf(stderr, "expected_pair: %s, %u\n", expected_pair.first.c_str(), expected_pair.second);
 94              fprintf(stderr, "actual_pair: %s, %u\n", key.c_str(), value);
 95              fprintf(stderr, "expected_pair != actual_pair\n");
 96              fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n");
 97              print_all();
 98          }
 99  
100          assert(expected_pair.first == key && expected_pair.second == value);
101  
102          index++;
103      }
104  
105      index = 0;
106      for (auto rule : parsed_grammar.rules)
107      {
108          // compare rule to expected rule
109          for (uint32_t i = 0; i < rule.size(); i++)
110          {
111              llama_grammar_element element = rule[i];
112              llama_grammar_element expected_element = expected_rules[index];
113  
114              // pretty print error message before asserting
115              if (expected_element.type != element.type || expected_element.value != element.value)
116              {
117                  fprintf(stderr, "index: %u\n", index);
118                  fprintf(stderr, "expected_element: %s, %u\n", type_str(expected_element.type), expected_element.value);
119                  fprintf(stderr, "actual_element: %s, %u\n", type_str(element.type), element.value);
120                  fprintf(stderr, "expected_element != actual_element\n");
121                  fprintf(stderr, "all elements:\n");
122                  fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n");
123                  print_all();
124              }
125  
126              assert(expected_element.type == element.type && expected_element.value == element.value);
127              index++;
128          }
129      }
130  }
131  
132  static void verify_failure(const char *grammar_bytes) {
133      fprintf(stderr, "Testing expected failure:%s\n", grammar_bytes);
134      auto result = grammar_parser::parse(grammar_bytes);
135      assert(result.rules.empty() && "should have failed");
136  }
137  
138  int main()
139  {
140      verify_failure(R"""(
141          root ::= "a"{,}"
142      )""");
143  
144      verify_failure(R"""(
145          root ::= "a"{,10}"
146      )""");
147  
148      verify_parsing(R"""(
149          root  ::= "a"
150      )""", {
151          {"root", 0},
152      }, {
153          // root (index 0)
154          {LLAMA_GRETYPE_CHAR, 'a'},
155          {LLAMA_GRETYPE_END, 0},
156      });
157  
158      verify_parsing(R"""(
159          root  ::= "a" | [bdx-z] | [^1-3]
160      )""", {
161          {"root", 0},
162      }, {
163          // root (index 0)
164          {LLAMA_GRETYPE_CHAR, 'a'},
165          {LLAMA_GRETYPE_ALT, 0},
166          {LLAMA_GRETYPE_CHAR, 'b'},
167          {LLAMA_GRETYPE_CHAR_ALT, 'd'},
168          {LLAMA_GRETYPE_CHAR_ALT, 'x'},
169          {LLAMA_GRETYPE_CHAR_RNG_UPPER, 'z'},
170          {LLAMA_GRETYPE_ALT, 0},
171          {LLAMA_GRETYPE_CHAR_NOT, '1'},
172          {LLAMA_GRETYPE_CHAR_RNG_UPPER, '3'},
173          {LLAMA_GRETYPE_END, 0},
174      });
175  
176      verify_parsing(R"""(
177          root  ::= a+
178          a     ::= "a"
179      )""", {
180          {"a", 1},
181          {"root", 0},
182          {"root_2", 2},
183      }, {
184          // root (index 0)
185          {LLAMA_GRETYPE_RULE_REF, /* a */ 1},
186          {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
187          {LLAMA_GRETYPE_END, 0},
188          // a (index 1)
189          {LLAMA_GRETYPE_CHAR, 'a'},
190          {LLAMA_GRETYPE_END, 0},
191          // root_2 (index 2)
192          {LLAMA_GRETYPE_RULE_REF, /* a */ 1},
193          {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
194          {LLAMA_GRETYPE_ALT, 0},
195          {LLAMA_GRETYPE_END, 0},
196      });
197  
198      verify_parsing(R"""(
199          root  ::= "a"+
200      )""", {
201          {"root", 0},
202          {"root_1", 1},
203      }, {
204          // root (index 0)
205          {LLAMA_GRETYPE_CHAR, 'a'},
206          {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
207          {LLAMA_GRETYPE_END, 0},
208          // root_1 (index 1)
209          {LLAMA_GRETYPE_CHAR, 'a'},
210          {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
211          {LLAMA_GRETYPE_ALT, 0},
212          {LLAMA_GRETYPE_END, 0},
213      });
214  
215      verify_parsing(R"""(
216          root  ::= a?
217          a     ::= "a"
218      )""", {
219          {"a", 1},
220          {"root", 0},
221          {"root_2", 2},
222      }, {
223          // root (index 0)
224          {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
225          {LLAMA_GRETYPE_END, 0},
226          // a (index 1)
227          {LLAMA_GRETYPE_CHAR, 'a'},
228          {LLAMA_GRETYPE_END, 0},
229          // root_2 (index 2)
230          {LLAMA_GRETYPE_RULE_REF, /* a */ 1},
231          {LLAMA_GRETYPE_ALT, 0},
232          {LLAMA_GRETYPE_END, 0},
233      });
234  
235      verify_parsing(R"""(
236          root  ::= "a"?
237      )""", {
238          {"root", 0},
239          {"root_1", 1},
240      }, {
241          // root (index 0)
242          {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
243          {LLAMA_GRETYPE_END, 0},
244          // root_1 (index 1)
245          {LLAMA_GRETYPE_CHAR, 'a'},
246          {LLAMA_GRETYPE_ALT, 0},
247          {LLAMA_GRETYPE_END, 0},
248      });
249  
250      verify_parsing(R"""(
251          root  ::= a*
252          a     ::= "a"
253      )""", {
254          {"a", 1},
255          {"root", 0},
256          {"root_2", 2},
257      }, {
258          // root (index 0)
259          {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
260          {LLAMA_GRETYPE_END, 0},
261          // a (index 1)
262          {LLAMA_GRETYPE_CHAR, 'a'},
263          {LLAMA_GRETYPE_END, 0},
264          // root_2 (index 2)
265          {LLAMA_GRETYPE_RULE_REF, /* a */ 1},
266          {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
267          {LLAMA_GRETYPE_ALT, 0},
268          {LLAMA_GRETYPE_END, 0},
269      });
270  
271      verify_parsing(R"""(
272          root  ::= "a"*
273      )""", {
274          {"root", 0},
275          {"root_1", 1},
276      }, {
277          // root (index 0)
278          {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
279          {LLAMA_GRETYPE_END, 0},
280          // root_1 (index 1)
281          {LLAMA_GRETYPE_CHAR, 'a'},
282          {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
283          {LLAMA_GRETYPE_ALT, 0},
284          {LLAMA_GRETYPE_END, 0},
285      });
286  
287      verify_parsing(R"""(
288          root  ::= "a"{2}
289      )""", {
290          {"root", 0},
291      }, {
292          // root (index 0)
293          {LLAMA_GRETYPE_CHAR, 'a'},
294          {LLAMA_GRETYPE_CHAR, 'a'},
295          {LLAMA_GRETYPE_END, 0},
296      });
297  
298      verify_parsing(R"""(
299          root  ::= "a"{2,}
300      )""", {
301          {"root", 0},
302          {"root_1", 1},
303      }, {
304          // root (index 0)
305          {LLAMA_GRETYPE_CHAR, 'a'},
306          {LLAMA_GRETYPE_CHAR, 'a'},
307          {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
308          {LLAMA_GRETYPE_END, 0},
309          // root_1 (index 1)
310          {LLAMA_GRETYPE_CHAR, 'a'},
311          {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
312          {LLAMA_GRETYPE_ALT, 0},
313          {LLAMA_GRETYPE_END, 0},
314      });
315  
316      verify_parsing(R"""(
317          root  ::= "a"{ 4}
318      )""", {
319          {"root", 0},
320      }, {
321          // root (index 0)
322          {LLAMA_GRETYPE_CHAR, 'a'},
323          {LLAMA_GRETYPE_CHAR, 'a'},
324          {LLAMA_GRETYPE_CHAR, 'a'},
325          {LLAMA_GRETYPE_CHAR, 'a'},
326          {LLAMA_GRETYPE_END, 0},
327      });
328  
329      verify_parsing(R"""(
330          root  ::= "a"{2,4}
331      )""", {
332          {"root", 0},
333          {"root_1", 1},
334          {"root_2", 2},
335      }, {
336          // root (index 0)
337          {LLAMA_GRETYPE_CHAR, 'a'},
338          {LLAMA_GRETYPE_CHAR, 'a'},
339          {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
340          {LLAMA_GRETYPE_END, 0},
341          // root_1 (index 1)
342          {LLAMA_GRETYPE_CHAR, 'a'},
343          {LLAMA_GRETYPE_ALT, 0},
344          {LLAMA_GRETYPE_END, 0},
345          // root_2 (index 2)
346          {LLAMA_GRETYPE_CHAR, 'a'},
347          {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
348          {LLAMA_GRETYPE_ALT, 0},
349          {LLAMA_GRETYPE_END, 0},
350      });
351  
352      verify_parsing(R"""(
353          root  ::= (expr "=" term "\n")+
354          expr  ::= term ([-+*/] term)*
355          term  ::= [0-9]+
356      )""", {
357          {"expr", 2},
358          {"expr_5", 5},
359          {"expr_6", 6},
360          {"root", 0},
361          {"root_1", 1},
362          {"root_4", 4},
363          {"term", 3},
364          {"term_7", 7},
365      }, {
366          // root (index 0)
367          {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
368          {LLAMA_GRETYPE_RULE_REF, /* root_4 */ 4},
369          {LLAMA_GRETYPE_END, 0},
370          // root_1 (index 1)
371          {LLAMA_GRETYPE_RULE_REF, /* expr */ 2},
372          {LLAMA_GRETYPE_CHAR, '='},
373          {LLAMA_GRETYPE_RULE_REF, /* term */ 3},
374          {LLAMA_GRETYPE_CHAR, '\n'},
375          {LLAMA_GRETYPE_END, 0},
376          // expr (index 2)
377          {LLAMA_GRETYPE_RULE_REF, /* term */ 3},
378          {LLAMA_GRETYPE_RULE_REF, /* expr_6 */ 6},
379          {LLAMA_GRETYPE_END, 0},
380          // term (index 3)
381          {LLAMA_GRETYPE_CHAR, '0'},
382          {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
383          {LLAMA_GRETYPE_RULE_REF, /* term_7 */ 7},
384          {LLAMA_GRETYPE_END, 0},
385          // root_4 (index 4)
386          {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
387          {LLAMA_GRETYPE_RULE_REF, /* root_4 */ 4},
388          {LLAMA_GRETYPE_ALT, 0},
389          {LLAMA_GRETYPE_END, 0},
390          // expr_5 (index 5)
391          {LLAMA_GRETYPE_CHAR, '-'},
392          {LLAMA_GRETYPE_CHAR_ALT, '+'},
393          {LLAMA_GRETYPE_CHAR_ALT, '*'},
394          {LLAMA_GRETYPE_CHAR_ALT, '/'},
395          {LLAMA_GRETYPE_RULE_REF, /* term */ 3},
396          {LLAMA_GRETYPE_END, 0},
397          // expr_6 (index 6)
398          {LLAMA_GRETYPE_RULE_REF, /* expr_5 */ 5},
399          {LLAMA_GRETYPE_RULE_REF, /* expr_6 */ 6},
400          {LLAMA_GRETYPE_ALT, 0},
401          {LLAMA_GRETYPE_END, 0},
402          // term_7 (index 7)
403          {LLAMA_GRETYPE_CHAR, '0'},
404          {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
405          {LLAMA_GRETYPE_RULE_REF, /* term_7 */ 7},
406          {LLAMA_GRETYPE_ALT, 0},
407          {LLAMA_GRETYPE_END, 0},
408      });
409  
410      verify_parsing(R"""(
411          root  ::= (expr "=" ws term "\n")+
412          expr  ::= term ([-+*/] term)*
413          term  ::= ident | num | "(" ws expr ")" ws
414          ident ::= [a-z] [a-z0-9_]* ws
415          num   ::= [0-9]+ ws
416          ws    ::= [ \t\n]*
417      )""", {
418          {"expr", 2},
419          {"expr_6", 6},
420          {"expr_7", 7},
421          {"ident", 8},
422          {"ident_10", 10},
423          {"num", 9},
424          {"num_11", 11},
425          {"root", 0},
426          {"root_1", 1},
427          {"root_5", 5},
428          {"term", 4},
429          {"ws", 3},
430          {"ws_12", 12},
431      }, {
432          // root (index 0)
433          {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
434          {LLAMA_GRETYPE_RULE_REF, /* root_5 */ 5},
435          {LLAMA_GRETYPE_END, 0},
436          // root_1 (index 1)
437          {LLAMA_GRETYPE_RULE_REF, /* expr */ 2},
438          {LLAMA_GRETYPE_CHAR, '='},
439          {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
440          {LLAMA_GRETYPE_RULE_REF, /* term */ 4},
441          {LLAMA_GRETYPE_CHAR, '\n'},
442          {LLAMA_GRETYPE_END, 0},
443          // expr (index 2)
444          {LLAMA_GRETYPE_RULE_REF, /* term */ 4},
445          {LLAMA_GRETYPE_RULE_REF, /* expr_7 */ 7},
446          {LLAMA_GRETYPE_END, 0},
447          // ws (index 3)
448          {LLAMA_GRETYPE_RULE_REF, /* ws_12 */ 12},
449          {LLAMA_GRETYPE_END, 0},
450          // term (index 4)
451          {LLAMA_GRETYPE_RULE_REF, /* ident */ 8},
452          {LLAMA_GRETYPE_ALT, 0},
453          {LLAMA_GRETYPE_RULE_REF, /* num */ 9},
454          {LLAMA_GRETYPE_ALT, 0},
455          {LLAMA_GRETYPE_CHAR, '('},
456          {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
457          {LLAMA_GRETYPE_RULE_REF, /* expr */ 2},
458          {LLAMA_GRETYPE_CHAR, ')'},
459          {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
460          {LLAMA_GRETYPE_END, 0},
461          // root_5 (index 5)
462          {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
463          {LLAMA_GRETYPE_RULE_REF, /* root_5 */ 5},
464          {LLAMA_GRETYPE_ALT, 0},
465          {LLAMA_GRETYPE_END, 0},
466          // expr_6 (index 6)
467          {LLAMA_GRETYPE_CHAR, '-'},
468          {LLAMA_GRETYPE_CHAR_ALT, '+'},
469          {LLAMA_GRETYPE_CHAR_ALT, '*'},
470          {LLAMA_GRETYPE_CHAR_ALT, '/'},
471          {LLAMA_GRETYPE_RULE_REF, /* term */ 4},
472          {LLAMA_GRETYPE_END, 0},
473          // expr_7 (index 7)
474          {LLAMA_GRETYPE_RULE_REF, /* expr_6 */ 6},
475          {LLAMA_GRETYPE_RULE_REF, /* expr_7 */ 7},
476          {LLAMA_GRETYPE_ALT, 0},
477          {LLAMA_GRETYPE_END, 0},
478          // ident (index 8)
479          {LLAMA_GRETYPE_CHAR, 'a'},
480          {LLAMA_GRETYPE_CHAR_RNG_UPPER, 'z'},
481          {LLAMA_GRETYPE_RULE_REF, /* ident_10 */ 10},
482          {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
483          {LLAMA_GRETYPE_END, 0},
484          // num (index 9)
485          {LLAMA_GRETYPE_CHAR, '0'},
486          {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
487          {LLAMA_GRETYPE_RULE_REF, /* num_11 */ 11},
488          {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
489          {LLAMA_GRETYPE_END, 0},
490          // ident_10 (index 10)
491          {LLAMA_GRETYPE_CHAR, 'a'},
492          {LLAMA_GRETYPE_CHAR_RNG_UPPER, 'z'},
493          {LLAMA_GRETYPE_CHAR_ALT, '0'},
494          {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
495          {LLAMA_GRETYPE_CHAR_ALT, '_'},
496          {LLAMA_GRETYPE_RULE_REF, /* ident_10 */ 10},
497          {LLAMA_GRETYPE_ALT, 0},
498          {LLAMA_GRETYPE_END, 0},
499          // num_11 (index 11)
500          {LLAMA_GRETYPE_CHAR, '0'},
501          {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
502          {LLAMA_GRETYPE_RULE_REF, /* num_11 */ 11},
503          {LLAMA_GRETYPE_ALT, 0},
504          {LLAMA_GRETYPE_END, 0},
505          // ws_12 (index 12)
506          {LLAMA_GRETYPE_CHAR, ' '},
507          {LLAMA_GRETYPE_CHAR_ALT, '\t'},
508          {LLAMA_GRETYPE_CHAR_ALT, '\n'},
509          {LLAMA_GRETYPE_RULE_REF, /* ws_12 */ 12},
510          {LLAMA_GRETYPE_ALT, 0},
511          {LLAMA_GRETYPE_END, 0},
512      });
513  
514      return 0;
515  }