/ tests / test-grammar-integration.cpp
test-grammar-integration.cpp
  1  #ifdef NDEBUG
  2  #undef NDEBUG
  3  #endif
  4  
  5  #define LLAMA_API_INTERNAL
  6  
  7  #include "ggml.h"
  8  #include "llama.h"
  9  #include "grammar-parser.h"
 10  #include "unicode.h"
 11  #include <cassert>
 12  #include <string>
 13  #include <vector>
 14  
 15  static llama_grammar* build_grammar(const std::string & grammar_str) {
 16      auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
 17  
 18      // Ensure we parsed correctly
 19      assert(!parsed_grammar.rules.empty());
 20  
 21      // Ensure we have a root node
 22      assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
 23  
 24      std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
 25      llama_grammar* grammar = llama_grammar_init(
 26          grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
 27  
 28      return grammar;
 29  }
 30  
 31  static bool test_build_grammar_fails(const std::string & grammar_str) {
 32      fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str());
 33      bool grammar_fails = false;
 34      try {
 35          build_grammar(grammar_str);
 36          fprintf(stderr, "  ❌ Expected build failure, but succeeded\n");
 37      } catch (const std::exception & err) {
 38          grammar_fails = true;
 39          fprintf(stdout, "  ✅︎\n");
 40      }
 41      return grammar_fails;
 42  }
 43  
 44  static bool match_string(const std::string & input, llama_grammar* grammar) {
 45      auto decoded = decode_utf8(input, {});
 46  
 47      const auto & code_points = decoded.first;
 48  
 49      for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
 50          auto prev_stacks = grammar->stacks;
 51          llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
 52          if (grammar->stacks.empty()) {
 53              // no stacks means that the grammar failed to match at this point
 54              return false;
 55          }
 56      }
 57  
 58      for (const auto & stack : grammar->stacks) {
 59          if (stack.empty()) {
 60              // An empty stack means that the grammar has been completed
 61              return true;
 62          }
 63      }
 64  
 65      return false;
 66  }
 67  
 68  static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
 69      fprintf(stderr, "⚫ Testing %s. Grammar: %s\n", test_desc.c_str(), grammar_str.c_str());
 70      fflush(stderr);
 71  
 72      auto grammar = build_grammar(grammar_str);
 73  
 74      // Save the original grammar stacks so that we can reset after every new string we want to test
 75      auto original_stacks = grammar->stacks;
 76  
 77      fprintf(stderr, "  🔵 Valid strings:\n");
 78  
 79      // Passing strings
 80      for (const auto & test_string : passing_strings) {
 81          fprintf(stderr, "    \"%s\" ", test_string.c_str());
 82          fflush(stderr);
 83  
 84          bool matched = match_string(test_string, grammar);
 85  
 86          if (!matched) {
 87              fprintf(stderr, "❌ (failed to match)\n");
 88          } else {
 89              fprintf(stdout, "✅︎\n");
 90          }
 91  
 92          assert(matched);
 93  
 94          // Reset the grammar stacks
 95          grammar->stacks = original_stacks;
 96      }
 97  
 98      fprintf(stderr, "  🟠 Invalid strings:\n");
 99  
100      // Failing strings
101      for (const auto & test_string : failing_strings) {
102          fprintf(stderr, "    \"%s\" ", test_string.c_str());
103          fflush(stderr);
104  
105          bool matched = match_string(test_string, grammar);
106  
107          if (matched) {
108              fprintf(stderr, "❌ (incorrectly matched)\n");
109          } else {
110              fprintf(stdout, "✅︎\n");
111          }
112          assert(!matched);
113  
114          // Reset the grammar stacks
115          grammar->stacks = original_stacks;
116      }
117  
118      // Clean up allocated memory
119      llama_grammar_free(grammar);
120  }
121  
122  static void test_simple_grammar() {
123      // Test case for a simple grammar
124      test_grammar(
125          "simple grammar",
126          R"""(
127              root ::= expr
128              expr ::= term ("+" term)*
129              term ::= number
130              number ::= [0-9]+)""",
131          // Passing strings
132          {
133              "42",
134              "1+2+3+4+5",
135              "123+456",
136          },
137          // Failing strings
138          {
139              "+",
140              "/ 3",
141              "1+2+3+4+5+",
142              "12a45",
143          }
144      );
145  }
146  
147  static void test_complex_grammar() {
148      // Test case for a more complex grammar, with both failure strings and success strings
149      test_grammar(
150          "medium complexity grammar",
151          // Grammar
152          R"""(
153              root ::= expression
154              expression ::= term ws (("+"|"-") ws term)*
155              term ::= factor ws (("*"|"/") ws factor)*
156              factor ::= number | variable | "(" expression ")" | function-call
157              number ::= [0-9]+
158              variable ::= [a-zA-Z_][a-zA-Z0-9_]*
159              function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
160              ws ::= [ \t\n\r]?)""",
161          // Passing strings
162          {
163              "42",
164              "1*2*3*4*5",
165              "x",
166              "x+10",
167              "x1+y2",
168              "(a+b)*(c-d)",
169              "func()",
170              "func(x,y+2)",
171              "a*(b+c)-d/e",
172              "f(g(x),h(y,z))",
173              "x + 10",
174              "x1 + y2",
175              "(a + b) * (c - d)",
176              "func()",
177              "func(x, y + 2)",
178              "a * (b + c) - d / e",
179              "f(g(x), h(y, z))",
180              "123+456",
181              "123*456*789-123/456+789*123",
182              "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
183          },
184          // Failing strings
185          {
186              "+",
187              "/ 3x",
188              "x + + y",
189              "a * / b",
190              "func(,)",
191              "func(x y)",
192              "(a + b",
193              "x + y)",
194              "a + b * (c - d",
195              "42 +",
196              "x +",
197              "x + 10 +",
198              "(a + b) * (c - d",
199              "func(",
200              "func(x, y + 2",
201              "a * (b + c) - d /",
202              "f(g(x), h(y, z)",
203              "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
204          }
205      );
206  }
207  
208  static void test_special_chars() {
209      // A collection of tests to exercise special characters such as "."
210      test_grammar(
211          "special characters",
212          // Grammar
213          R"""(
214              root ::= ... "abc" ...
215              )""",
216          // Passing strings
217          {
218              "abcabcabc",
219              "aaaabcccc",
220              // NOTE: Also ensures that multi-byte characters still count as a single character
221              "🔵🟠✅abc❌🟠🔵"
222          },
223          // Failing strings
224          {
225              "aaabcccc",
226              "aaaaabcccc",
227              "aaaabccc",
228              "aaaabccccc",
229              "🔵🟠✅❌abc❌✅🟠🔵"
230              "🔵🟠abc🟠🔵"
231          }
232      );
233  }
234  
235  static void test_quantifiers() {
236      // A collection of tests to exercise * + and ? quantifiers
237  
238      test_grammar(
239          "* quantifier",
240          // Grammar
241          R"""(root ::= "a"*)""",
242          // Passing strings
243          {
244              "",
245              "a",
246              "aaaaa",
247              "aaaaaaaaaaaaaaaaaa",
248              "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
249          },
250          // Failing strings
251          {
252              "b",
253              "ab",
254              "aab",
255              "ba",
256              "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
257          }
258      );
259      test_grammar(
260          "+ quantifier",
261          // Grammar
262          R"""(root ::= "a"+)""",
263          // Passing strings
264          {
265              "a",
266              "aaaaa",
267              "aaaaaaaaaaaaaaaaaa",
268              "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
269          },
270          // Failing strings
271          {
272              "",
273              "b",
274              "ab",
275              "aab",
276              "ba",
277              "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
278          }
279      );
280      test_grammar(
281          "? quantifier",
282          // Grammar
283          R"""(root ::= "a"?)""",
284          // Passing strings
285          {
286              "",
287              "a"
288          },
289          // Failing strings
290          {
291              "b",
292              "ab",
293              "aa",
294              "ba",
295          }
296      );
297      test_grammar(
298          "mixed quantifiers",
299          // Grammar
300          R"""(
301              root ::= cons+ vowel* cons? (vowel cons)*
302              vowel ::= [aeiouy]
303              cons ::= [bcdfghjklmnpqrstvwxyz]
304              )""",
305          // Passing strings
306          {
307              "yes",
308              "no",
309              "noyes",
310              "crwth",
311              "four",
312              "bryyyy",
313          },
314          // Failing strings
315          {
316              "yess",
317              "yesno",
318              "forty",
319              "catyyy",
320          }
321      );
322      test_grammar(
323          "simple exact repetition",
324          // Grammar
325          R"""(
326              root ::= [ab]{4}
327          )""",
328          // Passing strings
329          {
330              "aaaa",
331              "bbbb",
332              "abab",
333          },
334          // Failing strings
335          {
336              "a",
337              "b",
338              "aaaaa",
339          }
340      );
341      test_grammar(
342          "simple min repetition",
343          // Grammar
344          R"""(
345              root ::= [ab]{4,}
346          )""",
347          // Passing strings
348          {
349              "aaaa",
350              "aaaaab",
351              "bbbb",
352              "ababab",
353          },
354          // Failing strings
355          {
356              "",
357              "aba",
358          }
359      );
360      test_grammar(
361          "simple max repetition",
362          // Grammar
363          R"""(
364              root ::= [ab]{0,4}
365          )""",
366          // Passing strings
367          {
368              "",
369              "a",
370              "aa",
371              "aaa",
372              "aaab",
373          },
374          // Failing strings
375          {
376              "aaaaa",
377          }
378      );
379      test_grammar(
380          "min / max repetition",
381          // Grammar
382          R"""(
383              root ::= ("0x" [A-F0-9]{2} " "?){3,5}
384          )""",
385          // Passing strings
386          {
387              "0xFF 0x12 0xAB",
388              "0xFF 0x12 0xAB 0x00 0x00",
389          },
390          // Failing strings
391          {
392              "",
393              "0xFF",
394              "0xFF 0x12",
395              "0xFF 0x12 0xAB 0x00 0x00 0x00",
396          }
397      );
398  }
399  
400  static void test_failure_missing_root() {
401      fprintf(stderr, "⚫ Testing missing root node:\n");
402      // Test case for a grammar that is missing a root rule
403      const std::string grammar_str = R"""(rot ::= expr
404  expr ::= term ("+" term)*
405  term ::= number
406  number ::= [0-9]+)""";
407  
408      grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
409  
410      // Ensure we parsed correctly
411      assert(!parsed_grammar.rules.empty());
412  
413      // Ensure we do NOT have a root node
414      assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
415      fprintf(stderr, "  ✅︎ Passed\n");
416  }
417  
418  static void test_failure_missing_reference() {
419      fprintf(stderr, "⚫ Testing missing reference node:\n");
420  
421      // Test case for a grammar that is missing a referenced rule
422      const std::string grammar_str =
423  R"""(root ::= expr
424  expr ::= term ("+" term)*
425  term ::= numero
426  number ::= [0-9]+)""";
427  
428      fprintf(stderr, "    Expected error:  ");
429  
430      grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
431  
432      // Ensure we did NOT parsed correctly
433      assert(parsed_grammar.rules.empty());
434  
435      fprintf(stderr, "    End of expected error.\n");
436      fprintf(stderr, "  ✅︎ Passed\n");
437  }
438  
439  static void test_failure_left_recursion() {
440      fprintf(stderr, "⚫ Testing left recursion detection:\n");
441  
442      // Test simple left recursion detection
443      const std::string simple_str = R"""(root ::= "a" | root "a")""";
444      assert(test_build_grammar_fails(simple_str));
445  
446      // Test more complicated left recursion detection
447      const std::string medium_str = R"""(
448  root ::= asdf
449  asdf ::= "a" | asdf "a"
450  )""";
451      assert(test_build_grammar_fails(medium_str));
452  
453      // Test even more complicated left recursion detection
454      const std::string hard_str = R"""(
455  root ::= asdf
456  asdf ::= "a" | foo "b"
457  foo ::= "c" | asdf "d" | "e")""";
458      assert(test_build_grammar_fails(hard_str));
459  
460      // Test yet even more complicated left recursion detection
461      const std::string hardest_str = R"""(
462  root ::= asdf
463  asdf ::= "a" | foo "b"
464  foo ::= "c" | empty asdf "d" | "e"
465  empty ::= "blah" | )""";
466      assert(test_build_grammar_fails(hardest_str));
467  
468      fprintf(stderr, "  ✅︎ Passed\n");
469  }
470  
471  int main() {
472      fprintf(stdout, "Running grammar integration tests...\n");
473      test_simple_grammar();
474      test_complex_grammar();
475      test_special_chars();
476      test_quantifiers();
477      test_failure_missing_root();
478      test_failure_missing_reference();
479      test_failure_left_recursion();
480      fprintf(stdout, "All tests passed.\n");
481      return 0;
482  }