test-grammar-integration.cpp
1 #ifdef NDEBUG 2 #undef NDEBUG 3 #endif 4 5 #define LLAMA_API_INTERNAL 6 7 #include "ggml.h" 8 #include "llama.h" 9 #include "grammar-parser.h" 10 #include "unicode.h" 11 #include <cassert> 12 #include <string> 13 #include <vector> 14 15 static llama_grammar* build_grammar(const std::string & grammar_str) { 16 auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); 17 18 // Ensure we parsed correctly 19 assert(!parsed_grammar.rules.empty()); 20 21 // Ensure we have a root node 22 assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end())); 23 24 std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules()); 25 llama_grammar* grammar = llama_grammar_init( 26 grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); 27 28 return grammar; 29 } 30 31 static bool test_build_grammar_fails(const std::string & grammar_str) { 32 fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str()); 33 bool grammar_fails = false; 34 try { 35 build_grammar(grammar_str); 36 fprintf(stderr, " ❌ Expected build failure, but succeeded\n"); 37 } catch (const std::exception & err) { 38 grammar_fails = true; 39 fprintf(stdout, " ✅︎\n"); 40 } 41 return grammar_fails; 42 } 43 44 static bool match_string(const std::string & input, llama_grammar* grammar) { 45 auto decoded = decode_utf8(input, {}); 46 47 const auto & code_points = decoded.first; 48 49 for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { 50 auto prev_stacks = grammar->stacks; 51 llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks); 52 if (grammar->stacks.empty()) { 53 // no stacks means that the grammar failed to match at this point 54 return false; 55 } 56 } 57 58 for (const auto & stack : grammar->stacks) { 59 if (stack.empty()) { 60 // An empty stack means that the grammar has been completed 61 return true; 62 } 63 } 64 65 return false; 66 } 67 68 static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) { 69 fprintf(stderr, "⚫ Testing %s. Grammar: %s\n", test_desc.c_str(), grammar_str.c_str()); 70 fflush(stderr); 71 72 auto grammar = build_grammar(grammar_str); 73 74 // Save the original grammar stacks so that we can reset after every new string we want to test 75 auto original_stacks = grammar->stacks; 76 77 fprintf(stderr, " 🔵 Valid strings:\n"); 78 79 // Passing strings 80 for (const auto & test_string : passing_strings) { 81 fprintf(stderr, " \"%s\" ", test_string.c_str()); 82 fflush(stderr); 83 84 bool matched = match_string(test_string, grammar); 85 86 if (!matched) { 87 fprintf(stderr, "❌ (failed to match)\n"); 88 } else { 89 fprintf(stdout, "✅︎\n"); 90 } 91 92 assert(matched); 93 94 // Reset the grammar stacks 95 grammar->stacks = original_stacks; 96 } 97 98 fprintf(stderr, " 🟠 Invalid strings:\n"); 99 100 // Failing strings 101 for (const auto & test_string : failing_strings) { 102 fprintf(stderr, " \"%s\" ", test_string.c_str()); 103 fflush(stderr); 104 105 bool matched = match_string(test_string, grammar); 106 107 if (matched) { 108 fprintf(stderr, "❌ (incorrectly matched)\n"); 109 } else { 110 fprintf(stdout, "✅︎\n"); 111 } 112 assert(!matched); 113 114 // Reset the grammar stacks 115 grammar->stacks = original_stacks; 116 } 117 118 // Clean up allocated memory 119 llama_grammar_free(grammar); 120 } 121 122 static void test_simple_grammar() { 123 // Test case for a simple grammar 124 test_grammar( 125 "simple grammar", 126 R"""( 127 root ::= expr 128 expr ::= term ("+" term)* 129 term ::= number 130 number ::= [0-9]+)""", 131 // Passing strings 132 { 133 "42", 134 "1+2+3+4+5", 135 "123+456", 136 }, 137 // Failing strings 138 { 139 "+", 140 "/ 3", 141 "1+2+3+4+5+", 142 "12a45", 143 } 144 ); 145 } 146 147 static void test_complex_grammar() { 148 // Test case for a more complex grammar, with both failure strings and success strings 149 test_grammar( 150 "medium complexity grammar", 151 // Grammar 152 R"""( 153 root ::= expression 154 expression ::= term ws (("+"|"-") ws term)* 155 term ::= factor ws (("*"|"/") ws factor)* 156 factor ::= number | variable | "(" expression ")" | function-call 157 number ::= [0-9]+ 158 variable ::= [a-zA-Z_][a-zA-Z0-9_]* 159 function-call ::= variable ws "(" (expression ("," ws expression)*)? ")" 160 ws ::= [ \t\n\r]?)""", 161 // Passing strings 162 { 163 "42", 164 "1*2*3*4*5", 165 "x", 166 "x+10", 167 "x1+y2", 168 "(a+b)*(c-d)", 169 "func()", 170 "func(x,y+2)", 171 "a*(b+c)-d/e", 172 "f(g(x),h(y,z))", 173 "x + 10", 174 "x1 + y2", 175 "(a + b) * (c - d)", 176 "func()", 177 "func(x, y + 2)", 178 "a * (b + c) - d / e", 179 "f(g(x), h(y, z))", 180 "123+456", 181 "123*456*789-123/456+789*123", 182 "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456" 183 }, 184 // Failing strings 185 { 186 "+", 187 "/ 3x", 188 "x + + y", 189 "a * / b", 190 "func(,)", 191 "func(x y)", 192 "(a + b", 193 "x + y)", 194 "a + b * (c - d", 195 "42 +", 196 "x +", 197 "x + 10 +", 198 "(a + b) * (c - d", 199 "func(", 200 "func(x, y + 2", 201 "a * (b + c) - d /", 202 "f(g(x), h(y, z)", 203 "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/", 204 } 205 ); 206 } 207 208 static void test_special_chars() { 209 // A collection of tests to exercise special characters such as "." 210 test_grammar( 211 "special characters", 212 // Grammar 213 R"""( 214 root ::= ... "abc" ... 215 )""", 216 // Passing strings 217 { 218 "abcabcabc", 219 "aaaabcccc", 220 // NOTE: Also ensures that multi-byte characters still count as a single character 221 "🔵🟠✅abc❌🟠🔵" 222 }, 223 // Failing strings 224 { 225 "aaabcccc", 226 "aaaaabcccc", 227 "aaaabccc", 228 "aaaabccccc", 229 "🔵🟠✅❌abc❌✅🟠🔵" 230 "🔵🟠abc🟠🔵" 231 } 232 ); 233 } 234 235 static void test_quantifiers() { 236 // A collection of tests to exercise * + and ? quantifiers 237 238 test_grammar( 239 "* quantifier", 240 // Grammar 241 R"""(root ::= "a"*)""", 242 // Passing strings 243 { 244 "", 245 "a", 246 "aaaaa", 247 "aaaaaaaaaaaaaaaaaa", 248 "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" 249 }, 250 // Failing strings 251 { 252 "b", 253 "ab", 254 "aab", 255 "ba", 256 "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" 257 } 258 ); 259 test_grammar( 260 "+ quantifier", 261 // Grammar 262 R"""(root ::= "a"+)""", 263 // Passing strings 264 { 265 "a", 266 "aaaaa", 267 "aaaaaaaaaaaaaaaaaa", 268 "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" 269 }, 270 // Failing strings 271 { 272 "", 273 "b", 274 "ab", 275 "aab", 276 "ba", 277 "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" 278 } 279 ); 280 test_grammar( 281 "? quantifier", 282 // Grammar 283 R"""(root ::= "a"?)""", 284 // Passing strings 285 { 286 "", 287 "a" 288 }, 289 // Failing strings 290 { 291 "b", 292 "ab", 293 "aa", 294 "ba", 295 } 296 ); 297 test_grammar( 298 "mixed quantifiers", 299 // Grammar 300 R"""( 301 root ::= cons+ vowel* cons? (vowel cons)* 302 vowel ::= [aeiouy] 303 cons ::= [bcdfghjklmnpqrstvwxyz] 304 )""", 305 // Passing strings 306 { 307 "yes", 308 "no", 309 "noyes", 310 "crwth", 311 "four", 312 "bryyyy", 313 }, 314 // Failing strings 315 { 316 "yess", 317 "yesno", 318 "forty", 319 "catyyy", 320 } 321 ); 322 test_grammar( 323 "simple exact repetition", 324 // Grammar 325 R"""( 326 root ::= [ab]{4} 327 )""", 328 // Passing strings 329 { 330 "aaaa", 331 "bbbb", 332 "abab", 333 }, 334 // Failing strings 335 { 336 "a", 337 "b", 338 "aaaaa", 339 } 340 ); 341 test_grammar( 342 "simple min repetition", 343 // Grammar 344 R"""( 345 root ::= [ab]{4,} 346 )""", 347 // Passing strings 348 { 349 "aaaa", 350 "aaaaab", 351 "bbbb", 352 "ababab", 353 }, 354 // Failing strings 355 { 356 "", 357 "aba", 358 } 359 ); 360 test_grammar( 361 "simple max repetition", 362 // Grammar 363 R"""( 364 root ::= [ab]{0,4} 365 )""", 366 // Passing strings 367 { 368 "", 369 "a", 370 "aa", 371 "aaa", 372 "aaab", 373 }, 374 // Failing strings 375 { 376 "aaaaa", 377 } 378 ); 379 test_grammar( 380 "min / max repetition", 381 // Grammar 382 R"""( 383 root ::= ("0x" [A-F0-9]{2} " "?){3,5} 384 )""", 385 // Passing strings 386 { 387 "0xFF 0x12 0xAB", 388 "0xFF 0x12 0xAB 0x00 0x00", 389 }, 390 // Failing strings 391 { 392 "", 393 "0xFF", 394 "0xFF 0x12", 395 "0xFF 0x12 0xAB 0x00 0x00 0x00", 396 } 397 ); 398 } 399 400 static void test_failure_missing_root() { 401 fprintf(stderr, "⚫ Testing missing root node:\n"); 402 // Test case for a grammar that is missing a root rule 403 const std::string grammar_str = R"""(rot ::= expr 404 expr ::= term ("+" term)* 405 term ::= number 406 number ::= [0-9]+)"""; 407 408 grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); 409 410 // Ensure we parsed correctly 411 assert(!parsed_grammar.rules.empty()); 412 413 // Ensure we do NOT have a root node 414 assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()); 415 fprintf(stderr, " ✅︎ Passed\n"); 416 } 417 418 static void test_failure_missing_reference() { 419 fprintf(stderr, "⚫ Testing missing reference node:\n"); 420 421 // Test case for a grammar that is missing a referenced rule 422 const std::string grammar_str = 423 R"""(root ::= expr 424 expr ::= term ("+" term)* 425 term ::= numero 426 number ::= [0-9]+)"""; 427 428 fprintf(stderr, " Expected error: "); 429 430 grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); 431 432 // Ensure we did NOT parsed correctly 433 assert(parsed_grammar.rules.empty()); 434 435 fprintf(stderr, " End of expected error.\n"); 436 fprintf(stderr, " ✅︎ Passed\n"); 437 } 438 439 static void test_failure_left_recursion() { 440 fprintf(stderr, "⚫ Testing left recursion detection:\n"); 441 442 // Test simple left recursion detection 443 const std::string simple_str = R"""(root ::= "a" | root "a")"""; 444 assert(test_build_grammar_fails(simple_str)); 445 446 // Test more complicated left recursion detection 447 const std::string medium_str = R"""( 448 root ::= asdf 449 asdf ::= "a" | asdf "a" 450 )"""; 451 assert(test_build_grammar_fails(medium_str)); 452 453 // Test even more complicated left recursion detection 454 const std::string hard_str = R"""( 455 root ::= asdf 456 asdf ::= "a" | foo "b" 457 foo ::= "c" | asdf "d" | "e")"""; 458 assert(test_build_grammar_fails(hard_str)); 459 460 // Test yet even more complicated left recursion detection 461 const std::string hardest_str = R"""( 462 root ::= asdf 463 asdf ::= "a" | foo "b" 464 foo ::= "c" | empty asdf "d" | "e" 465 empty ::= "blah" | )"""; 466 assert(test_build_grammar_fails(hardest_str)); 467 468 fprintf(stderr, " ✅︎ Passed\n"); 469 } 470 471 int main() { 472 fprintf(stdout, "Running grammar integration tests...\n"); 473 test_simple_grammar(); 474 test_complex_grammar(); 475 test_special_chars(); 476 test_quantifiers(); 477 test_failure_missing_root(); 478 test_failure_missing_reference(); 479 test_failure_left_recursion(); 480 fprintf(stdout, "All tests passed.\n"); 481 return 0; 482 }