json_schema_to_grammar.py
1 #!/usr/bin/env python3 2 import argparse 3 import itertools 4 import json 5 import re 6 import sys 7 from typing import Any, Dict, List, Set, Tuple, Union 8 9 10 def _build_repetition(item_rule, min_items, max_items, separator_rule=None): 11 12 if min_items == 0 and max_items == 1: 13 return f'{item_rule}?' 14 15 if not separator_rule: 16 if min_items == 1 and max_items is None: 17 return f'{item_rule}+' 18 elif min_items == 0 and max_items is None: 19 return f'{item_rule}*' 20 else: 21 return f'{item_rule}{{{min_items},{max_items if max_items is not None else ""}}}' 22 23 result = item_rule + ' ' + _build_repetition(f'({separator_rule} {item_rule})', min_items - 1 if min_items > 0 else 0, max_items - 1 if max_items is not None else None) 24 return f'({result})?' if min_items == 0 else result 25 26 27 class BuiltinRule: 28 def __init__(self, content: str, deps: list = None): 29 self.content = content 30 self.deps = deps or [] 31 32 # Constraining spaces to prevent model "running away". 33 SPACE_RULE = '| " " | "\\n" [ \\t]{0,20}' 34 35 PRIMITIVE_RULES = { 36 'boolean' : BuiltinRule('("true" | "false") space', []), 37 'decimal-part' : BuiltinRule('[0-9]{1,16}', []), 38 'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []), 39 'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), 40 'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']), 41 'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), 42 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), 43 'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), 44 'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []), 45 'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []), 46 'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']), 47 'null' : BuiltinRule('"null" space', []), 48 } 49 50 # TODO: support "uri", "email" string formats 51 STRING_FORMAT_RULES = { 52 'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), 53 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), 54 'date-time' : BuiltinRule('date "T" time', ['date', 'time']), 55 'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']), 56 'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']), 57 'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']), 58 } 59 60 DOTALL = '[\\U00000000-\\U0010FFFF]' 61 DOT = '[^\\x0A\\x0D]' 62 63 RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]) 64 65 INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+') 66 GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') 67 GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]') 68 GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'} 69 70 NON_LITERAL_SET = set('|.()[]{}*+?') 71 ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?') 72 73 74 class SchemaConverter: 75 def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern): 76 self._prop_order = prop_order 77 self._allow_fetch = allow_fetch 78 self._dotall = dotall 79 self._raw_pattern = raw_pattern 80 self._rules = { 81 'space': SPACE_RULE, 82 } 83 self._refs = {} 84 self._refs_being_resolved = set() 85 86 def _format_literal(self, literal): 87 escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( 88 lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal 89 ) 90 return f'"{escaped}"' 91 92 def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str: 93 ''' 94 not_literal('a') -> '[^a]' 95 not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?' 96 ''' 97 assert len(literal) > 0, 'Empty literal not supported' 98 def recurse(i: int): 99 c = literal[i] 100 if maybe_escaped_underscores and c == '_': 101 yield f'[^{c}\\\\]' 102 yield ' | ' 103 yield f'"\\\\"? "{c}"' 104 else: 105 yield f'[^{c}]' 106 if i < len(literal) - 1: 107 yield ' | ' 108 yield self._format_literal(c) 109 yield ' (' 110 yield from recurse(i + 1) 111 yield ')?' 112 113 return ''.join(('(', *recurse(0), ')')) 114 115 def _add_rule(self, name, rule): 116 esc_name = INVALID_RULE_CHARS_RE.sub('-', name) 117 if esc_name not in self._rules or self._rules[esc_name] == rule: 118 key = esc_name 119 else: 120 i = 0 121 while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule: 122 i += 1 123 key = f'{esc_name}{i}' 124 self._rules[key] = rule 125 return key 126 127 def resolve_refs(self, schema: dict, url: str): 128 ''' 129 Resolves all $ref fields in the given schema, fetching any remote schemas, 130 replacing $ref with absolute reference URL and populating self._refs with the 131 respective referenced (sub)schema dictionaries. 132 ''' 133 def visit(n: dict): 134 if isinstance(n, list): 135 return [visit(x) for x in n] 136 elif isinstance(n, dict): 137 ref = n.get('$ref') 138 if ref is not None and ref not in self._refs: 139 if ref.startswith('https://'): 140 assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)' 141 import requests 142 143 frag_split = ref.split('#') 144 base_url = frag_split[0] 145 146 target = self._refs.get(base_url) 147 if target is None: 148 target = self.resolve_refs(requests.get(ref).json(), base_url) 149 self._refs[base_url] = target 150 151 if len(frag_split) == 1 or frag_split[-1] == '': 152 return target 153 elif ref.startswith('#/'): 154 target = schema 155 ref = f'{url}{ref}' 156 n['$ref'] = ref 157 else: 158 raise ValueError(f'Unsupported ref {ref}') 159 160 for sel in ref.split('#')[-1].split('/')[1:]: 161 assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' 162 target = target[sel] 163 164 self._refs[ref] = target 165 else: 166 for v in n.values(): 167 visit(v) 168 169 return n 170 return visit(schema) 171 172 def _generate_union_rule(self, name, alt_schemas): 173 return ' | '.join(( 174 self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}') 175 for i, alt_schema in enumerate(alt_schemas) 176 )) 177 178 def _visit_pattern(self, pattern, name): 179 ''' 180 Transforms a regular expression pattern into a GBNF rule. 181 182 Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions 183 Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md 184 185 Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. 186 187 Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which 188 we define sub-rules to keep the output lean. 189 ''' 190 191 assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"' 192 pattern = pattern[1:-1] 193 sub_rule_ids = {} 194 195 i = 0 196 length = len(pattern) 197 198 def to_rule(s: Tuple[str, bool]) -> str: 199 (txt, is_literal) = s 200 return "\"" + txt + "\"" if is_literal else txt 201 202 def transform() -> Tuple[str, bool]: 203 ''' 204 Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. 205 ''' 206 nonlocal i 207 nonlocal pattern 208 nonlocal sub_rule_ids 209 210 start = i 211 # For each component of this sequence, store its string representation and whether it's a literal. 212 # We only need a flat structure here to apply repetition operators to the last item, and 213 # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially 214 # (GBNF's syntax is luckily very close to regular expressions!) 215 seq: list[Tuple[str, bool]] = [] 216 217 def get_dot(): 218 if self._dotall: 219 rule = DOTALL 220 else: 221 # Accept any character... except \n and \r line break chars (\x0A and \xOD) 222 rule = DOT 223 return self._add_rule(f'dot', rule) 224 225 def join_seq(): 226 nonlocal seq 227 ret = [] 228 for is_literal, g in itertools.groupby(seq, lambda x: x[1]): 229 if is_literal: 230 ret.append((''.join(x[0] for x in g), True)) 231 else: 232 ret.extend(g) 233 if len(ret) == 1: 234 return ret[0] 235 return (' '.join(to_rule(x) for x in seq), False) 236 237 while i < length: 238 c = pattern[i] 239 if c == '.': 240 seq.append((get_dot(), False)) 241 i += 1 242 elif c == '(': 243 i += 1 244 if i < length: 245 assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/' 246 seq.append((f'({to_rule(transform())})', False)) 247 elif c == ')': 248 i += 1 249 assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}' 250 return join_seq() 251 elif c == '[': 252 square_brackets = c 253 i += 1 254 while i < length and pattern[i] != ']': 255 if pattern[i] == '\\': 256 square_brackets += pattern[i:i+2] 257 i += 2 258 else: 259 square_brackets += pattern[i] 260 i += 1 261 assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}' 262 square_brackets += ']' 263 i += 1 264 seq.append((square_brackets, False)) 265 elif c == '|': 266 seq.append(('|', False)) 267 i += 1 268 elif c in ('*', '+', '?'): 269 seq[-1] = (to_rule(seq[-1]) + c, False) 270 i += 1 271 elif c == '{': 272 curly_brackets = c 273 i += 1 274 while i < length and pattern[i] != '}': 275 curly_brackets += pattern[i] 276 i += 1 277 assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}' 278 curly_brackets += '}' 279 i += 1 280 nums = [s.strip() for s in curly_brackets[1:-1].split(',')] 281 min_times = 0 282 max_times = None 283 try: 284 if len(nums) == 1: 285 min_times = int(nums[0]) 286 max_times = min_times 287 else: 288 assert len(nums) == 2 289 min_times = int(nums[0]) if nums[0] else 0 290 max_times = int(nums[1]) if nums[1] else None 291 except ValueError: 292 raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/') 293 294 (sub, sub_is_literal) = seq[-1] 295 296 if not sub_is_literal: 297 id = sub_rule_ids.get(sub) 298 if id is None: 299 id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub) 300 sub_rule_ids[sub] = id 301 sub = id 302 303 seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times), False) 304 else: 305 literal = '' 306 while i < length: 307 if pattern[i] == '\\' and i < length - 1: 308 next = pattern[i + 1] 309 if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS: 310 i += 1 311 literal += pattern[i] 312 i += 1 313 else: 314 literal += pattern[i:i+2] 315 i += 2 316 elif pattern[i] == '"' and not self._raw_pattern: 317 literal += '\\"' 318 i += 1 319 elif pattern[i] not in NON_LITERAL_SET and \ 320 (i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET): 321 literal += pattern[i] 322 i += 1 323 else: 324 break 325 if literal: 326 seq.append((literal, True)) 327 328 return join_seq() 329 330 return self._add_rule( 331 name, 332 to_rule(transform()) if self._raw_pattern \ 333 else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space") 334 335 336 def _resolve_ref(self, ref): 337 ref_name = ref.split('/')[-1] 338 if ref_name not in self._rules and ref not in self._refs_being_resolved: 339 self._refs_being_resolved.add(ref) 340 resolved = self._refs[ref] 341 ref_name = self.visit(resolved, ref_name) 342 self._refs_being_resolved.remove(ref) 343 return ref_name 344 345 def _generate_constant_rule(self, value): 346 return self._format_literal(json.dumps(value)) 347 348 def visit(self, schema, name): 349 schema_type = schema.get('type') 350 schema_format = schema.get('format') 351 rule_name = name + '-' if name in RESERVED_NAMES else name or 'root' 352 353 if (ref := schema.get('$ref')) is not None: 354 return self._add_rule(rule_name, self._resolve_ref(ref)) 355 356 elif 'oneOf' in schema or 'anyOf' in schema: 357 return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf'])) 358 359 elif isinstance(schema_type, list): 360 return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type])) 361 362 elif 'const' in schema: 363 return self._add_rule(rule_name, self._generate_constant_rule(schema['const'])) 364 365 elif 'enum' in schema: 366 rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) 367 return self._add_rule(rule_name, rule) 368 369 elif schema_type in (None, 'object') and \ 370 ('properties' in schema or \ 371 ('additionalProperties' in schema and schema['additionalProperties'] is not True)): 372 required = set(schema.get('required', [])) 373 properties = list(schema.get('properties', {}).items()) 374 return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties'))) 375 376 elif schema_type in (None, 'object') and 'allOf' in schema: 377 required = set() 378 properties = [] 379 hybrid_name = name 380 def add_component(comp_schema, is_required): 381 if (ref := comp_schema.get('$ref')) is not None: 382 comp_schema = self._refs[ref] 383 384 if 'properties' in comp_schema: 385 for prop_name, prop_schema in comp_schema['properties'].items(): 386 properties.append((prop_name, prop_schema)) 387 if is_required: 388 required.add(prop_name) 389 390 for t in schema['allOf']: 391 if 'anyOf' in t: 392 for tt in t['anyOf']: 393 add_component(tt, is_required=False) 394 else: 395 add_component(t, is_required=True) 396 397 return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[])) 398 399 elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): 400 items = schema.get('items') or schema['prefixItems'] 401 if isinstance(items, list): 402 return self._add_rule( 403 rule_name, 404 '"[" space ' + 405 ' "," space '.join( 406 self.visit(item, f'{name}{"-" if name else ""}tuple-{i}') 407 for i, item in enumerate(items)) + 408 ' "]" space') 409 else: 410 item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') 411 min_items = schema.get("minItems", 0) 412 max_items = schema.get("maxItems") 413 return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space') 414 415 elif schema_type in (None, 'string') and 'pattern' in schema: 416 return self._visit_pattern(schema['pattern'], rule_name) 417 418 elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''): 419 return self._add_primitive( 420 'root' if rule_name == 'root' else schema_format, 421 PRIMITIVE_RULES['uuid'] 422 ) 423 424 elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES: 425 prim_name = f'{schema_format}-string' 426 return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name])) 427 428 elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema): 429 char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) 430 min_len = schema.get('minLength', 0) 431 max_len = schema.get('maxLength') 432 433 return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space') 434 435 elif (schema_type == 'object') or (len(schema) == 0): 436 return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) 437 438 else: 439 assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' 440 # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero 441 return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type]) 442 443 def _add_primitive(self, name: str, rule: BuiltinRule): 444 n = self._add_rule(name, rule.content) 445 446 for dep in rule.deps: 447 dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep) 448 assert dep_rule, f'Rule {dep} not known' 449 if dep not in self._rules: 450 self._add_primitive(dep, dep_rule) 451 return n 452 453 def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]): 454 prop_order = self._prop_order 455 # sort by position in prop_order (if specified) then by original order 456 sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))] 457 458 prop_kv_rule_names = {} 459 for prop_name, prop_schema in properties: 460 prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}') 461 prop_kv_rule_names[prop_name] = self._add_rule( 462 f'{name}{"-" if name else ""}{prop_name}-kv', 463 fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}' 464 ) 465 required_props = [k for k in sorted_props if k in required] 466 optional_props = [k for k in sorted_props if k not in required] 467 468 if additional_properties == True or isinstance(additional_properties, dict): 469 sub_name = f'{name}{"-" if name else ""}additional' 470 value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value') 471 prop_kv_rule_names["*"] = self._add_rule( 472 f'{sub_name}-kv', 473 self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}' 474 ) 475 optional_props.append("*") 476 477 rule = '"{" space ' 478 rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props) 479 480 if optional_props: 481 rule += ' (' 482 if required_props: 483 rule += ' "," space ( ' 484 485 def get_recursive_refs(ks, first_is_optional): 486 [k, *rest] = ks 487 kv_rule_name = prop_kv_rule_names[k] 488 if k == '*': 489 res = self._add_rule( 490 f'{name}{"-" if name else ""}additional-kvs', 491 f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*' 492 ) 493 elif first_is_optional: 494 res = f'( "," space {kv_rule_name} )?' 495 else: 496 res = kv_rule_name 497 if len(rest) > 0: 498 res += ' ' + self._add_rule( 499 f'{name}{"-" if name else ""}{k}-rest', 500 get_recursive_refs(rest, first_is_optional=True) 501 ) 502 return res 503 504 rule += ' | '.join( 505 get_recursive_refs(optional_props[i:], first_is_optional=False) 506 for i in range(len(optional_props)) 507 ) 508 if required_props: 509 rule += ' )' 510 rule += ' )?' 511 512 rule += ' "}" space' 513 514 return rule 515 516 def format_grammar(self): 517 return '\n'.join( 518 f'{name} ::= {rule}' 519 for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0]) 520 ) 521 522 523 def main(args_in = None): 524 parser = argparse.ArgumentParser( 525 description=''' 526 Generates a grammar (suitable for use in ./llama-cli) that produces JSON conforming to a 527 given JSON schema. Only a subset of JSON schema features are supported; more may be 528 added in the future. 529 ''', 530 ) 531 parser.add_argument( 532 '--prop-order', 533 default=[], 534 type=lambda s: s.split(','), 535 help=''' 536 comma-separated property names defining the order of precedence for object properties; 537 properties not specified here are given lower precedence than those that are, and 538 are kept in their original order from the schema. Required properties are always 539 given precedence over optional properties. 540 ''' 541 ) 542 parser.add_argument( 543 '--allow-fetch', 544 action='store_true', 545 default=False, 546 help='Whether to allow fetching referenced schemas over HTTPS') 547 parser.add_argument( 548 '--dotall', 549 action='store_true', 550 default=False, 551 help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns') 552 parser.add_argument( 553 '--raw-pattern', 554 action='store_true', 555 default=False, 556 help='Treats string patterns as raw patterns w/o quotes (or quote escapes)') 557 558 parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)') 559 args = parser.parse_args(args_in) 560 561 if args.schema.startswith('https://'): 562 url = args.schema 563 import requests 564 schema = requests.get(url).json() 565 elif args.schema == '-': 566 url = 'stdin' 567 schema = json.load(sys.stdin) 568 else: 569 url = f'file://{args.schema}' 570 with open(args.schema) as f: 571 schema = json.load(f) 572 converter = SchemaConverter( 573 prop_order={name: idx for idx, name in enumerate(args.prop_order)}, 574 allow_fetch=args.allow_fetch, 575 dotall=args.dotall, 576 raw_pattern=args.raw_pattern) 577 schema = converter.resolve_refs(schema, url) 578 converter.visit(schema, '') 579 print(converter.format_grammar()) 580 581 582 if __name__ == '__main__': 583 main()