_types.py
1 # SPDX-License-Identifier: MIT 2 # 3 # Copyright (c) 2021 The Anvil Extras project team members listed at 4 # https://github.com/anvilistas/anvil-extras/graphs/contributors 5 # 6 # This software is published at https://github.com/anvilistas/anvil-extras 7 8 from datetime import date as _date 9 from datetime import datetime as _datetime 10 11 from anvil import is_server_side 12 13 from ._zod_error import ZodError, ZodIssueCode 14 from .helpers import ZodParsedType, get_parsed_type, regex, util 15 from .helpers.dict_util import getitem, merge_shapes 16 from .helpers.parse_util import ( 17 ABORTED, 18 DIRTY, 19 INVALID, 20 MISSING, 21 OK, 22 VALID, 23 Common, 24 ErrorMapContext, 25 ParseContext, 26 ParseInput, 27 ParseResult, 28 ParseReturn, 29 ParseStatus, 30 add_issue_to_context, 31 is_valid, 32 ) 33 34 __version__ = "3.1.0" 35 36 any_ = any 37 isinstance_ = isinstance 38 float_ = float 39 list_ = list 40 41 42 class ParseInputLazyPath: 43 def __init__(self, parent, value, path, key): 44 self.parent = parent 45 self.data = value 46 self._path = path 47 self._key = key 48 49 @property 50 def path(self): 51 return self._path + [self._key] 52 53 54 def handle_result(ctx, result): 55 if is_valid(result): 56 return ParseResult(success=True, data=result.value, error=None) 57 else: 58 if not ctx.common.issues: 59 raise Exception("Validation failed but no issues detected") 60 error = ZodError(ctx.common.issues) 61 return ParseResult(success=False, data=None, error=error) 62 63 64 def _check_error_cb(rv): 65 assert ( 66 type(rv) is dict and "message" in rv 67 ), f"bad return type from error callback, expected str or {{'message': str}}, got {rv!r}" 68 return rv 69 70 71 def process_params(error_map=None, invalid_type_error=False, required_error=False): 72 if not any_([error_map, invalid_type_error, required_error]): 73 return {} 74 if error_map and (invalid_type_error or required_error): 75 raise Exception( 76 'Can\'t use "invalid_type_error" or "required_error" in conjunction with custom error' 77 ) 78 79 if error_map: 80 81 def _error_map(issue, ctx): 82 rv = error_map(issue, ctx) or ctx.default_error 83 if type(rv) is str: 84 return {"message": rv} 85 86 _check_error_cb(rv) 87 return rv 88 89 return {"error_map": _error_map} 90 91 def custom_map(issue, ctx: ErrorMapContext): 92 if issue["code"] != "invalid_type": 93 return {"message": ctx.default_error} 94 if ctx.data is MISSING: 95 return {"message": required_error or ctx.default_error} 96 return {"message": invalid_type_error or ctx.default_error} 97 98 return {"error_map": custom_map} 99 100 101 class ZodType: 102 _type = None 103 _type_name = None 104 105 @classmethod 106 def _create(cls, **params): 107 return cls(process_params(**params)) 108 109 def __init__(self, _def: dict): 110 self._def = _def 111 112 def _check_invalid_type(self, input): 113 parsed_type = self._get_type(input) 114 115 types = self._type if type(self._type) is list_ else [self._type] 116 117 if parsed_type not in types: 118 ctx = self._get_or_return_ctx(input) 119 add_issue_to_context( 120 ctx, 121 code=ZodIssueCode.invalid_type, 122 expected=self._type_name, 123 received=ctx.parsed_type, 124 ) 125 return True 126 127 def _parse(self, input): 128 raise NotImplementedError("should be implemented by subclass") 129 130 def _get_type(self, input): 131 return get_parsed_type(input.data) 132 133 def _get_or_return_ctx(self, input: ParseInput, ctx=None): 134 return ctx or ParseContext( 135 common=input.parent.common, 136 data=input.data, 137 parsed_type=get_parsed_type(input.data), 138 schema_error_map=self._def.get("error_map"), 139 path=input.path, 140 parent=input.parent, 141 ) 142 143 def _process_input_params(self, input: ParseInput): 144 return ParseStatus(), self._get_or_return_ctx(input) 145 146 def parse(self, data, **params): 147 result = self.safe_parse(data, **params) 148 if result.success: 149 return result.data 150 raise result.error 151 152 def safe_parse(self, data, **params): 153 ctx = ParseContext( 154 common=Common(issues=[], contextual_error_map=params.get("error_map")), 155 path=params.get("path", []), 156 schema_error_map=self._def.get("error_map"), 157 parent=None, 158 data=data, 159 parsed_type=get_parsed_type(data), 160 ) 161 input = ParseInput(data, path=ctx.path, parent=ctx) 162 result = self._parse(input) 163 return handle_result(ctx, result) 164 165 def list(self): 166 return ZodList._create(self) 167 168 array = list 169 170 def not_required(self): 171 "might not exist" 172 return ZodNotRequired._create(self) 173 174 def optional(self): 175 "can be None" 176 return ZodOptional._create(self) 177 178 def default(self, value): 179 "replace missing values with" 180 return ZodDefault._create(self, value) 181 182 def catch(self, value): 183 "if the parse fails - replace with value" 184 return ZodCatch._create(self, value) 185 186 def union(self, other): 187 "equivalent to z.union([a, b])" 188 return ZodUnion._create([self, other]) 189 190 or_ = union 191 192 def super_refine(self, refinement): 193 return ZodEffects._create( 194 schema=self, effect={"type": "refinment", "refinement": refinement} 195 ) 196 197 def refine(self, check_fn, message="", fatal=False, **issue_params): 198 """add a custom check_fn(val) -> bool, message: can be a str or a callable: (val) -> str 199 if fatal is True, no further checks for this schema will be considered 200 """ 201 202 def get_issue_props(val): 203 rv = message 204 if callable(rv): 205 rv = rv(val) 206 207 if type(rv) is str: 208 return {"fatal": fatal, **issue_params, "message": rv} 209 210 rv = _check_error_cb(rv) 211 return {"fatal": fatal, **issue_params, **rv} 212 213 def _refinement(val, ctx: CheckContext): 214 if check_fn(val): 215 return True 216 else: 217 ctx.add_issue(**get_issue_props(val)) 218 return False 219 220 return self.super_refine(_refinement) 221 222 def super_transform(self, transform_fn): 223 return ZodEffects._create( 224 schema=self, effect={"type": "transform", "transform": transform_fn} 225 ) 226 227 def transform(self, transform_fn): 228 "transform the input with a custom transform function, to prevent any further checks/transforms return z.NEVER" 229 230 def _transform(val, ctx: CheckContext): 231 rv = transform_fn(val) 232 if rv is INVALID: 233 ctx.add_issue(fatal=True) 234 return rv 235 236 return self.super_transform(_transform) 237 238 def pipe(self, target): 239 if not isinstance_(target, ZodType): 240 raise TypeError("expected a zod schema") 241 return ZodPipeline._create(self, target) 242 243 244 class ZodString(ZodType): 245 _type = ZodParsedType.string 246 _type_name = _type 247 248 def _parse(self, input: ParseInput): 249 if self._def["coerce"]: 250 input.data = str(input.data) 251 252 if self._check_invalid_type(input): 253 return INVALID 254 255 status = ParseStatus() 256 ctx = None 257 for check in self._def["checks"]: 258 kind = check["kind"] 259 260 if kind == "min": 261 if len(input.data) < check["value"]: 262 ctx = self._get_or_return_ctx(input, ctx) 263 add_issue_to_context( 264 ctx, 265 code=ZodIssueCode.too_small, 266 minimum=check["value"], 267 type="string", 268 inclusive=True, 269 message=check["message"], 270 ) 271 status.dirty() 272 273 elif kind == "max": 274 if len(input.data) > check["value"]: 275 ctx = self._get_or_return_ctx(input, ctx) 276 add_issue_to_context( 277 ctx, 278 code=ZodIssueCode.too_big, 279 maximum=check["value"], 280 type="string", 281 inclusive=True, 282 message=check["message"], 283 ) 284 status.dirty() 285 286 elif kind == "email": 287 if not regex.EMAIL.match(input.data): 288 ctx = self._get_or_return_ctx(input, ctx) 289 add_issue_to_context( 290 ctx, 291 code=ZodIssueCode.invalid_string, 292 validation="email", 293 message=check["message"], 294 ) 295 status.dirty() 296 297 elif kind == "uuid": 298 if not regex.UUID.match(input.data): 299 ctx = self._get_or_return_ctx(input, ctx) 300 add_issue_to_context( 301 ctx, 302 code=ZodIssueCode.invalid_string, 303 validation="uuid", 304 message=check["message"], 305 ) 306 status.dirty() 307 308 elif kind == "url": 309 url_valid = True 310 if not is_server_side(): 311 from anvil.js.window import URL 312 313 try: 314 URL(input.data) 315 except Exception: 316 url_valid = False 317 else: 318 url_valid = regex.URL.match(input.data) 319 if not url_valid: 320 ctx = self._get_or_return_ctx(input, ctx) 321 add_issue_to_context( 322 ctx, 323 code=ZodIssueCode.invalid_string, 324 validation="url", 325 message=check["message"], 326 ) 327 status.dirty() 328 329 elif kind == "regex": 330 match = check["regex"].match(input.data) 331 if match is None: 332 ctx = self._get_or_return_ctx(input, ctx) 333 add_issue_to_context( 334 ctx, 335 code=ZodIssueCode.invalid_string, 336 validation="regex", 337 message=check["message"], 338 ) 339 status.dirty() 340 341 elif kind == "strip": 342 input.data = input.data.strip() 343 elif kind == "lower": 344 input.data = input.data.lower() 345 elif kind == "upper": 346 input.data = input.data.upper() 347 348 elif kind == "startswith": 349 if not input.data.startswith(check["value"]): 350 ctx = self._get_or_return_ctx(input, ctx) 351 add_issue_to_context( 352 ctx, 353 code=ZodIssueCode.invalid_string, 354 validation={"startswith": check["value"]}, 355 message=check["message"], 356 ) 357 status.dirty() 358 359 elif kind == "endswith": 360 if not input.data.endswith(check["value"]): 361 ctx = self._get_or_return_ctx(input, ctx) 362 add_issue_to_context( 363 ctx, 364 code=ZodIssueCode.invalid_string, 365 validation={"endswith": check["value"]}, 366 message=check["message"], 367 ) 368 status.dirty() 369 370 elif kind == "datetime" or kind == "date": 371 format = check["format"] 372 try: 373 if format is not None: 374 _datetime.strptime(input.data, format) 375 elif kind == "datetime": 376 _datetime.fromisoformat(input.data) 377 else: 378 _date.fromisoformat(input.data) 379 except Exception: 380 ctx = self._get_or_return_ctx(input, ctx) 381 add_issue_to_context( 382 ctx, 383 code=ZodIssueCode.invalid_string, 384 validation=kind, 385 message=check["message"], 386 ) 387 status.dirty() 388 389 else: 390 assert False 391 392 return ParseReturn(status=status.value, value=input.data) 393 394 def _add_check(self, **check): 395 return ZodString({**self._def, "checks": [*self._def["checks"], check]}) 396 397 def email(self, message=""): 398 return self._add_check(kind="email", message=message) 399 400 def url(self, message=""): 401 return self._add_check(kind="url", message=message) 402 403 def uuid(self, message=""): 404 return self._add_check(kind="uuid", message=message) 405 406 def datetime(self, format=None, message=""): 407 "default format is isoformat" 408 return self._add_check(kind="datetime", format=format, message=message) 409 410 def date(self, format=None, message=""): 411 "default format is isoformat" 412 return self._add_check(kind="date", format=format, message=message) 413 414 def regex(self, regex, message=""): 415 return self._add_check(kind="regex", regex=regex, message=message) 416 417 def startswith(self, value, message=""): 418 return self._add_check(kind="startswith", value=value, message=message) 419 420 def endswith(self, value, message=""): 421 return self._add_check(kind="endswith", value=value, message=message) 422 423 def min(self, min_length: int, message=""): 424 return self._add_check(kind="min", value=min_length, message=message) 425 426 def max(self, min_length: int, message=""): 427 return self._add_check(kind="max", value=min_length, message=message) 428 429 def len(self, len: int, message=""): 430 return self.min(len, message).max(len, message) 431 432 def nonempty(self, message=""): 433 return self.min(1, message) 434 435 def strip(self): 436 "similar to z.string().transform(str.strip)" 437 return self._add_check(kind="strip") 438 439 def lower(self): 440 "similar to z.string().transform(str.lower)" 441 return self._add_check(kind="lower") 442 443 def upper(self): 444 "similar to z.string().transform(str.upper)" 445 return self._add_check(kind="upper") 446 447 @classmethod 448 def _create(cls, *, coerce=False, **params): 449 return cls(dict(checks=[], coerce=coerce, **process_params(**params))) 450 451 452 class ZodAbstractNumber(ZodType): 453 _type = ZodParsedType.number 454 _type_name = _type 455 456 def _parse(self, input): 457 if self._def["coerce"]: 458 try: 459 if self._type == ZodParsedType.integer: 460 input.data = int(input.data) 461 elif self._type == ZodParsedType.float: 462 input.data = float_(input.data) 463 except Exception: 464 pass 465 466 if self._check_invalid_type(input): 467 return INVALID 468 469 status = ParseStatus() 470 ctx = None 471 472 for check in self._def["checks"]: 473 kind = check["kind"] 474 475 if kind == "int": 476 data = input.data 477 if type(data) is float_ and not data.is_integer(): 478 ctx = self._get_or_return_ctx(input, ctx) 479 add_issue_to_context( 480 ctx, 481 code=ZodIssueCode.invalid_type, 482 expected="integer", 483 received="float", 484 message=check["message"], 485 ) 486 status.dirty() 487 488 elif kind == "min": 489 value = check["value"] 490 inclusive = check["inclusive"] 491 too_small = input.data < value if inclusive else input.data <= value 492 if too_small: 493 ctx = self._get_or_return_ctx(input, ctx) 494 add_issue_to_context( 495 ctx, 496 code=ZodIssueCode.too_small, 497 minimum=value, 498 type=self._type_name, 499 inclusive=inclusive, 500 message=check["message"], 501 ) 502 status.dirty() 503 504 elif kind == "max": 505 value = check["value"] 506 inclusive = check["inclusive"] 507 too_big = input.data > value if inclusive else input.data >= value 508 if too_big: 509 ctx = self._get_or_return_ctx(input, ctx) 510 add_issue_to_context( 511 ctx, 512 code=ZodIssueCode.too_big, 513 maximum=value, 514 type=self._type_name, 515 inclusive=inclusive, 516 message=check["message"], 517 ) 518 status.dirty() 519 520 else: 521 assert False 522 523 return ParseReturn(status=status.value, value=input.data) 524 525 @classmethod 526 def _create(cls, **params): 527 return cls(dict(checks=[], coerce=False, **process_params(**params))) 528 529 530 class ZodInteger(ZodAbstractNumber): 531 _type = ZodParsedType.integer 532 _type_name = _type 533 534 def _add_check(self, **check): 535 return ZodInteger({**self._def, "checks": [*self._def["checks"], check]}) 536 537 def set_limit(self, kind, value, inclusive, message=""): 538 return self._add_check( 539 kind=kind, value=value, inclusive=inclusive, message=message 540 ) 541 542 def ge(self, value, message=""): 543 return self.set_limit("min", value, True, message) 544 545 def min(self, value, message=""): 546 return self.set_limit("min", value, True, message) 547 548 def gt(self, value, message=""): 549 return self.set_limit("min", value, False, message) 550 551 def le(self, value, message=""): 552 return self.set_limit("max", value, True, message) 553 554 def max(self, value, message=""): 555 return self.set_limit("max", value, True, message) 556 557 def lt(self, value, message=""): 558 return self.set_limit("max", value, False, message) 559 560 def positive(self, message=""): 561 return self.set_limit("min", 0, False, message) 562 563 def negative(self, message=""): 564 return self.set_limit("max", 0, False, message) 565 566 def nonpositive(self, message=""): 567 return self.set_limit("max", 0, True, message) 568 569 def nonnegative(self, message=""): 570 return self.set_limit("min", 0, True, message) 571 572 @classmethod 573 def _create(cls, *, coerce=False, **params): 574 return cls(dict(checks=[], coerce=coerce, **process_params(**params))) 575 576 577 class ZodFloat(ZodAbstractNumber): 578 _type = ZodParsedType.float 579 _type_name = _type 580 581 def _add_check(self, **check): 582 return ZodFloat({**self._def, "checks": [*self._def["checks"], check]}) 583 584 def set_limit(self, kind, value, inclusive, message=""): 585 return self._add_check( 586 kind=kind, value=value, inclusive=inclusive, message=message 587 ) 588 589 def int(self, message=""): 590 return self._add_check(kind="int", message=message) 591 592 def ge(self, value, message=""): 593 return self.set_limit("min", value, True, message) 594 595 def min(self, value, message=""): 596 return self.set_limit("min", value, True, message) 597 598 def gt(self, value, message=""): 599 return self.set_limit("min", value, False, message) 600 601 def le(self, value, message=""): 602 return self.set_limit("max", value, True, message) 603 604 def max(self, value, message=""): 605 return self.set_limit("max", value, True, message) 606 607 def lt(self, value, message=""): 608 return self.set_limit("max", value, False, message) 609 610 def positive(self, message=""): 611 return self.set_limit("min", 0, False, message) 612 613 def negative(self, message=""): 614 return self.set_limit("max", 0, False, message) 615 616 def nonpositive(self, message=""): 617 return self.set_limit("max", 0, True, message) 618 619 def nonnegative(self, message=""): 620 return self.set_limit("min", 0, True, message) 621 622 @classmethod 623 def _create(cls, *, coerce=False, **params): 624 return cls(dict(checks=[], coerce=coerce, **process_params(**params))) 625 626 627 class ZodNumber(ZodAbstractNumber): 628 _type = [ZodParsedType.integer, ZodParsedType.float] 629 _type_name = ZodParsedType.number 630 631 def _add_check(self, **check): 632 return ZodNumber({**self._def, "checks": [*self._def["checks"], check]}) 633 634 def set_limit(self, kind, value, inclusive, message=""): 635 return self._add_check( 636 kind=kind, value=value, inclusive=inclusive, message=message 637 ) 638 639 def int(self, message=""): 640 return self._add_check(kind="int", message=message) 641 642 def ge(self, value, message=""): 643 return self.set_limit("min", value, True, message) 644 645 def min(self, value, message=""): 646 return self.set_limit("min", value, True, message) 647 648 def gt(self, value, message=""): 649 return self.set_limit("min", value, False, message) 650 651 def le(self, value, message=""): 652 return self.set_limit("max", value, True, message) 653 654 def max(self, value, message=""): 655 return self.set_limit("max", value, True, message) 656 657 def lt(self, value, message=""): 658 return self.set_limit("max", value, False, message) 659 660 def positive(self, message=""): 661 return self.set_limit("min", 0, False, message) 662 663 def negative(self, message=""): 664 return self.set_limit("max", 0, False, message) 665 666 def nonpositive(self, message=""): 667 return self.set_limit("max", 0, True, message) 668 669 def nonnegative(self, message=""): 670 return self.set_limit("min", 0, True, message) 671 672 @classmethod 673 def _create(cls, **params): 674 return cls(dict(checks=[], coerce=False, **process_params(**params))) 675 676 677 class ZodDateTime(ZodType): 678 _type = ZodParsedType.datetime 679 _type_name = _type 680 681 def _parse(self, input): 682 if self._check_invalid_type(input): 683 return INVALID 684 685 status = ParseStatus() 686 ctx = None 687 for check in self._def["checks"]: 688 kind = check["kind"] 689 690 if kind == "min": 691 if input.data < check["value"]: 692 ctx = self._get_or_return_ctx(input, ctx) 693 add_issue_to_context( 694 ctx, 695 code=ZodIssueCode.too_small, 696 minimum=check["value"].isoformat(), 697 type=self._type_name, 698 inclusive=True, 699 message=check["message"], 700 ) 701 status.dirty() 702 703 elif kind == "max": 704 if input.data > check["value"]: 705 ctx = self._get_or_return_ctx(input, ctx) 706 add_issue_to_context( 707 ctx, 708 code=ZodIssueCode.too_big, 709 maximum=check["value"].isoformat(), 710 type=self._type_name, 711 inclusive=True, 712 message=check["message"], 713 ) 714 status.dirty() 715 716 else: 717 assert False 718 719 return ParseReturn(status=status.value, value=input.data) 720 721 def _add_check(self, **check): 722 return ZodDateTime({**self._def, "checks": [*self._def["checks"], check]}) 723 724 def min(self, min_date: _datetime, message=""): 725 return self._add_check(kind="min", value=min_date, message=message) 726 727 def max(self, max_date: _datetime, message=""): 728 return self._add_check(kind="max", value=max_date, message=message) 729 730 @classmethod 731 def _create(cls, **params): 732 return cls(dict(checks=[], **process_params(**params))) 733 734 735 class ZodDate(ZodDateTime): 736 _type = ZodParsedType.date 737 _type_name = _type 738 739 def _add_check(self, **check): 740 return ZodDate({**self._def, "checks": [*self._def["checks"], check]}) 741 742 def min(self, min_date: _date, message=""): 743 return self._add_check(kind="min", value=min_date, message=message) 744 745 def max(self, max_date: _date, message=""): 746 return self._add_check(kind="max", value=max_date, message=message) 747 748 749 class ZodBoolean(ZodType): 750 _type = ZodParsedType.boolean 751 _type_name = _type 752 753 def _parse(self, input: ParseInput): 754 if self._def["coerce"]: 755 input.data = bool(input.data) 756 if self._check_invalid_type(input): 757 return INVALID 758 return OK(input.data) 759 760 @classmethod 761 def _create(cls, *, coerce=False, **params): 762 return cls(dict(coerce=coerce, **process_params(**params))) 763 764 765 class ZodNone(ZodType): 766 _type = ZodParsedType.none 767 _type_name = _type 768 769 def _parse(self, input: ParseInput): 770 if self._check_invalid_type(input): 771 return INVALID 772 return OK(input.data) 773 774 775 class ZodAny(ZodType): 776 def _parse(self, input): 777 return OK(input.data) 778 779 780 class ZodUnknown(ZodType): 781 _type = ZodParsedType.unknown 782 _type_name = _type 783 _unknown = True 784 785 def _parse(self, input): 786 return OK(input.data) 787 788 789 class ZodNever(ZodType): 790 _type = ZodParsedType.never 791 _type_name = _type 792 793 def _parse(self, input): 794 ctx = self._get_or_return_ctx(input) 795 add_issue_to_context( 796 ctx, 797 code=ZodIssueCode.invalid_type, 798 expected=self._type, 799 received=ctx.parsed_type, 800 ) 801 return INVALID 802 803 804 class ZodList(ZodType): 805 _type = [ZodParsedType.list, ZodParsedType.tuple] 806 _type_name = "list" 807 808 def _parse(self, input): 809 status, ctx = self._process_input_params(input) 810 811 if self._check_invalid_type(input): 812 return INVALID 813 814 for check in self._def["checks"]: 815 kind = check["kind"] 816 817 if kind == "min": 818 if len(ctx.data) < check["value"]: 819 add_issue_to_context( 820 ctx, 821 code=ZodIssueCode.too_small, 822 minimum=check["value"], 823 type="list", 824 inclusive=True, 825 message=check["message"], 826 ) 827 status.dirty() 828 829 elif kind == "max": 830 if len(ctx.data) > check["value"]: 831 add_issue_to_context( 832 ctx, 833 code=ZodIssueCode.too_big, 834 maximum=check["value"], 835 type="list", 836 inclusive=True, 837 message=check["message"], 838 ) 839 status.dirty() 840 841 type_schema = self._def["type"] 842 843 results = [ 844 type_schema._parse(ParseInputLazyPath(ctx, item, ctx.path, i)) 845 for i, item in enumerate(ctx.data) 846 ] 847 848 return ParseStatus.merge_list(status, results) 849 850 @property 851 def element(self): 852 return self._def["type"] 853 854 def _add_check(self, **check): 855 return ZodList({**self._def, "checks": [*self._def["checks"], check]}) 856 857 def min(self, min_length, message=""): 858 return self._add_check(kind="min", value=min_length, message=message) 859 860 def max(self, max_length, message=""): 861 return self._add_check(kind="max", value=max_length, message=message) 862 863 def len(self, len, message=""): 864 return self.min(len, message).max(len, message) 865 866 def nonempty(self, message=""): 867 return self.min(1, message) 868 869 @classmethod 870 def _create(cls, schema, **params): 871 return cls(dict(type=schema, checks=[], **process_params(**params))) 872 873 874 class ZodEnum(ZodType): 875 def _parse(self, input): 876 values = self._def["values"] 877 if input.data not in values: 878 ctx = self._get_or_return_ctx(input) 879 add_issue_to_context( 880 ctx, 881 code=ZodIssueCode.invalid_type, 882 expected=" | ".join(repr(a) for a in values), 883 received=ctx.parsed_type, 884 ) 885 return INVALID 886 return OK(input.data) 887 888 @property 889 def options(self): 890 return self._def["values"] 891 892 @property 893 def enum(self): 894 return util.enum("ENUM", self.options) 895 896 @classmethod 897 def _create(cls, options, **params): 898 return cls(dict(values=list_(options), **process_params(**params))) 899 900 901 def deep_partialify(schema): 902 t = type(schema) 903 if t is ZodTypedDict: 904 new_shape = { 905 k: ZodNotRequired._create(deep_partialify(v)) 906 for k, v in schema.shape.items() 907 } 908 return ZodTypedDict({**schema._def, "shape": lambda: new_shape}) 909 if t is ZodList: 910 return ZodList._create(deep_partialify(schema.element)) 911 if t is ZodNotRequired: 912 return ZodNotRequired._create(deep_partialify(schema.unwrap())) 913 if t is ZodOptional: 914 return ZodOptional._create(deep_partialify(schema.unwrap())) 915 if t is ZodTuple: 916 return ZodTuple._create([deep_partialify(item) for item in schema.items]) 917 return schema 918 919 920 class ZodTypedDict(ZodType): 921 _type = ZodParsedType.mapping 922 _type_name = _type 923 924 def __init__(self, _def): 925 super().__init__(_def) 926 self._cached = None 927 928 def _parse(self, input): 929 if self._check_invalid_type(input): 930 return INVALID 931 932 status, ctx = self._process_input_params(input) 933 shape = self.shape 934 shape_keys = shape.keys() 935 extra_keys = set() 936 937 if not ( 938 type(self._def["catchall"]) is ZodNever 939 and self._def["unknown_keys"] == "strip" 940 ): 941 for key in ctx.data.keys(): 942 if key not in shape_keys: 943 extra_keys.add(key) 944 945 pairs = [] 946 947 for key in shape_keys: 948 key_validator = shape[key] 949 value = getitem(ctx.data, key, MISSING) 950 pairs.append( 951 ( 952 ParseReturn(VALID, key), 953 key_validator._parse(ParseInputLazyPath(ctx, value, ctx.path, key)), 954 key in ctx.data, 955 ) 956 ) 957 958 if type(self._def["catchall"]) is ZodNever: 959 unknown_keys = self._def["unknown_keys"] 960 if unknown_keys == "passthrough": 961 for key in extra_keys: 962 pairs.append( 963 ( 964 ParseReturn(VALID, key), 965 ParseReturn(VALID, ctx.data[key]), 966 False, 967 ) 968 ) 969 elif unknown_keys == "strict": 970 if extra_keys: 971 add_issue_to_context( 972 ctx, code=ZodIssueCode.unrecognized_keys, keys=extra_keys 973 ) 974 status.dirty() 975 elif unknown_keys == "strip": 976 pass 977 else: 978 assert False, "invalid unknown_keys value" 979 else: 980 # run cachall validation 981 catchall = self._def["catchall"] 982 983 for key in extra_keys: 984 value = ctx.data[key] 985 pairs.append( 986 ( 987 ParseReturn(VALID, key), 988 catchall._parse( 989 ParseInputLazyPath(ctx, value, ctx.path, key), 990 ), 991 key in ctx.data, 992 ) 993 ) 994 995 return ParseStatus.merge_dict(status, pairs) 996 997 @property 998 def shape(self): 999 return self._def["shape"]() 1000 1001 def strict(self, message=""): 1002 "reject if theere are extra keys" 1003 _def = {**self._def, "unknown_keys": "strict"} 1004 if message: 1005 1006 def error_map(issue, ctx): 1007 try: 1008 default_error = self._def["error_map"](issue, ctx)["message"] 1009 except TypeError: 1010 default_error = ctx.default_error 1011 if issue.code == "unrecognized_keys": 1012 return {"message": message or default_error} 1013 return {"message": default_error} 1014 1015 _def["error_map"] = error_map 1016 return ZodTypedDict(_def) 1017 1018 def strip(self): 1019 "return the data without additional keys" 1020 return ZodTypedDict({**self._def, "unknown_keys": "strip"}) 1021 1022 def passthrough(self): 1023 "ignore additional keys" 1024 return ZodTypedDict({**self._def, "unknown_keys": "passthrough"}) 1025 1026 nonstrict = passthrough 1027 1028 def extend(self, shape): 1029 "create a new schema extending the current shape" 1030 return ZodTypedDict( 1031 {**self._def, "shape": lambda: merge_shapes(self.shape, shape)} 1032 ) 1033 1034 def set_key(self, key, schema): 1035 "returns a new schema with the additional key" 1036 return self.extend({key: schema}) 1037 1038 def merge(self, merge_with): 1039 "merge two mapping schemas" 1040 assert type(merge_with) is ZodTypedDict, "expected a zod mapping schema" 1041 merged = { 1042 "unknown_keys": merge_with._def["unknown_keys"], 1043 "catchall": merge_with._def["catchall"], 1044 "shape": lambda: merge_shapes( 1045 self._def["shape"](), merge_with._def["shape"]() 1046 ), 1047 } 1048 return ZodTypedDict(merged) 1049 1050 def catchall(self, index): 1051 return ZodTypedDict({**self._def, "catchall": index}) 1052 1053 def pick(self, mask): 1054 "mask should be an iterable of keys, retuns a new schema with only those keys" 1055 this_shape = self.shape 1056 shape = {k: this_shape[k] for k in mask if k in this_shape} 1057 return ZodTypedDict({**self._def, "shape": lambda: shape}) 1058 1059 def omit(self, mask): 1060 "mask should be an iterable of keys, retuns a new schema without those keys" 1061 this_shape = self.shape 1062 shape = {k: v for k, v in this_shape.items() if k not in mask} 1063 return ZodTypedDict({**self._def, "shape": lambda: shape}) 1064 1065 def partial(self, mask=None): 1066 "returns a new schema where values are not required. If a mask is provided, only those keys will become not required" 1067 if mask: 1068 shape = { 1069 k: (v.not_required() if k in mask else v) for k, v in self.shape.items() 1070 } 1071 else: 1072 shape = {k: v.not_required() for k, v in self.shape.items()} 1073 return ZodTypedDict({**self._def, "shape": lambda: shape}) 1074 1075 def deep_partial(self): 1076 return deep_partialify(self) 1077 1078 def required(self, mask=None): 1079 "returns a new schema where values are required. If a mask is provided, only those keys will become required" 1080 1081 def unwrap(field): 1082 while isinstance_(field, ZodNotRequired): 1083 field = field._def["inner_type"] 1084 return field 1085 1086 if mask: 1087 shape = {k: (unwrap(v) if k in mask else v) for k, v in self.shape.items()} 1088 else: 1089 shape = {k: unwrap(v) for k, v in self.shape.items()} 1090 return ZodTypedDict({**self._def, "shape": lambda: shape}) 1091 1092 def keyof(self): 1093 "get the keys of this mapping schema as an enum schema" 1094 return ZodEnum._create(self.shape.keys()) 1095 1096 @classmethod 1097 def _create(cls, shape, **params): 1098 return cls( 1099 dict( 1100 shape=lambda: shape, 1101 unknown_keys="strip", 1102 catchall=never(), 1103 **process_params(**params), 1104 ) 1105 ) 1106 1107 1108 class ZodTuple(ZodType): 1109 _type = [ZodParsedType.list, ZodParsedType.tuple] 1110 _type_name = _type 1111 1112 def _parse(self, input): 1113 status, ctx = self._process_input_params(input) 1114 if self._check_invalid_type(input): 1115 return INVALID 1116 1117 items = self._def["items"] 1118 rest = self._def["rest"] 1119 1120 if len(ctx.data) < len(items): 1121 add_issue_to_context( 1122 ctx, 1123 code=ZodIssueCode.too_small, 1124 minimum=len(items), 1125 inclusive=True, 1126 type="list", 1127 ) 1128 return INVALID 1129 1130 if not rest and len(ctx.data) > len(items): 1131 add_issue_to_context( 1132 ctx, 1133 code=ZodIssueCode.too_big, 1134 maximum=len(items), 1135 inclusive=True, 1136 type="list", 1137 ) 1138 return INVALID 1139 1140 from itertools import zip_longest 1141 1142 results = [ 1143 schema._parse(ParseInputLazyPath(ctx, item, ctx.path, i)) 1144 for i, (item, schema) in enumerate( 1145 zip_longest(ctx.data, items, fillvalue=rest) 1146 ) 1147 ] 1148 1149 return ParseStatus.merge_list(status, results) 1150 1151 @property 1152 def items(self): 1153 return self._def["items"] 1154 1155 def rest(self, rest): 1156 return ZodTuple({**self._def, "rest": rest}) 1157 1158 @classmethod 1159 def _create(cls, schemas, **params): 1160 return cls(dict(items=schemas, rest=None, **process_params(**params))) 1161 1162 1163 class ZodMapping(ZodType): 1164 _type = ZodParsedType.mapping 1165 _type_name = _type 1166 1167 def _parse(self, input): 1168 status, ctx = self._process_input_params(input) 1169 if self._check_invalid_type(input): 1170 return INVALID 1171 1172 key_type = self._def["key_type"] 1173 value_type = self._def["value_type"] 1174 1175 pairs = [ 1176 ( 1177 key_type._parse(ParseInputLazyPath(ctx, key, ctx.path, key)), 1178 value_type._parse( 1179 ParseInputLazyPath( 1180 ctx, getitem(ctx.data, key, MISSING), ctx.path, key 1181 ) 1182 ), 1183 False, 1184 ) 1185 for key in ctx.data 1186 ] 1187 1188 return ParseStatus.merge_dict(status, pairs) 1189 1190 @property 1191 def key_schema(self): 1192 return self._def["key_type"] 1193 1194 @property 1195 def value_schema(self): 1196 return self._def["value_type"] 1197 1198 element = value_schema 1199 1200 @classmethod 1201 def _create(cls, keys, vals, **params): 1202 assert isinstance_(keys, ZodType) and isinstance_( 1203 keys, ZodType 1204 ), "expected schemas" 1205 return cls(dict(key_type=keys, value_type=vals, **process_params(**params))) 1206 1207 1208 class ZodLazy(ZodType): 1209 def _parse(self, input): 1210 ctx = self._get_or_return_ctx(input) 1211 return self.schema._parse(ParseInput(data=ctx.data, path=ctx.path, parent=ctx)) 1212 1213 @property 1214 def schema(self): 1215 return self._def["getter"]() 1216 1217 @classmethod 1218 def _create(cls, getter, **params): 1219 return cls(dict(getter=getter, **process_params(**params))) 1220 1221 1222 class ZodLiteral(ZodType): 1223 def _parse(self, input): 1224 value = self._def["value"] 1225 data = input.data 1226 if value is data or (type(value) is type(data) and value == data): 1227 return ParseReturn(status=VALID, value=data) 1228 else: 1229 ctx = self._get_or_return_ctx(input) 1230 add_issue_to_context(ctx, code=ZodIssueCode.invalid_literal, expected=value) 1231 return INVALID 1232 1233 @property 1234 def value(self): 1235 return self._def["value"] 1236 1237 @classmethod 1238 def _create(cls, value, **params): 1239 return cls(dict(value=value, **process_params(**params))) 1240 1241 1242 class CheckContext: 1243 def __init__(self, status: ParseStatus, ctx: ParseContext): 1244 self.status = status 1245 self.ctx = ctx 1246 1247 def add_issue( 1248 self, code=ZodIssueCode.custom, fatal=False, message="", **issue_data 1249 ): 1250 add_issue_to_context( 1251 self.ctx, code=code, fatal=fatal, message=message, **issue_data 1252 ) 1253 if fatal: 1254 self.status.abort() 1255 else: 1256 self.status.dirty() 1257 1258 @property 1259 def path(self): 1260 return self.ctx.path 1261 1262 1263 class ZodEffects(ZodType): 1264 def _parse(self, input): 1265 status, ctx = self._process_input_params(input) 1266 1267 effect = self._def["effect"] 1268 check_ctx = CheckContext(status, ctx) 1269 effect_type = effect["type"] 1270 1271 if effect_type == "preprocess": 1272 processed = effect["transform"](ctx.data) 1273 return self._def["schema"]._parse(ParseInput(processed, ctx.path, ctx)) 1274 1275 if effect_type == "refinment": 1276 inner = self._def["schema"]._parse(ParseInput(ctx.data, ctx.path, ctx)) 1277 if inner.status is ABORTED: 1278 return INVALID 1279 elif inner.status is DIRTY: 1280 status.dirty() 1281 effect["refinement"](inner.value, check_ctx) 1282 return ParseReturn(status.value, inner.value) 1283 1284 if effect_type == "transform": 1285 base = self._def["schema"]._parse(ParseInput(ctx.data, ctx.path, ctx)) 1286 if not is_valid(base): 1287 return base 1288 1289 result = effect["transform"](base.value, check_ctx) 1290 return ParseReturn(status.value, result) 1291 1292 assert False, "unnkown effect" 1293 1294 @classmethod 1295 def _create(cls, schema, effect, **params): 1296 return cls(dict(schema=schema, effect=effect, **process_params(**params))) 1297 1298 @classmethod 1299 def _preprocess(cls, preprocess, schema, **params): 1300 "transform the data before parsing it" 1301 return cls( 1302 dict( 1303 schema=schema, 1304 effect={"type": "preprocess", "transform": preprocess}, 1305 **process_params(**params), 1306 ) 1307 ) 1308 1309 1310 class ZodWraps(ZodType): 1311 _wraps = None 1312 _type = None 1313 1314 def _parse(self, input): 1315 parse_type = self._get_type(input) 1316 if parse_type is self._type: 1317 return OK(self._wraps) 1318 return self._def["inner_type"]._parse(input) 1319 1320 def unwrap(self): 1321 return self._def["inner_type"] 1322 1323 @classmethod 1324 def _create(cls, type, **params): 1325 return cls(dict(inner_type=type, **process_params(**params))) 1326 1327 1328 class ZodNotRequired(ZodWraps): 1329 _wraps = MISSING 1330 _type = ZodParsedType.missing 1331 _type_name = _type 1332 1333 1334 class ZodOptional(ZodWraps): 1335 _wraps = None 1336 _type = ZodParsedType.none 1337 _type_name = _type 1338 1339 1340 class ZodDefaultAbstract(ZodType): 1341 def remove_default(self): 1342 return self._def["inner_type"] 1343 1344 @classmethod 1345 def _create(cls, type, default, **params): 1346 default_ = default 1347 if not callable(default): 1348 default_ = lambda: default # noqa E731 1349 return cls(dict(inner_type=type, default=default_, **process_params(**params))) 1350 1351 1352 class ZodDefault(ZodDefaultAbstract): 1353 def _parse(self, input): 1354 ctx = self._get_or_return_ctx(input) 1355 data = ctx.data 1356 if ctx.parsed_type is ZodParsedType.missing: 1357 data = self._def["default"]() 1358 return self._def["inner_type"]._parse( 1359 ParseInput(data, path=ctx.path, parent=ctx) 1360 ) 1361 1362 1363 class ZodCatch(ZodDefaultAbstract): 1364 def _parse(self, input): 1365 ctx = self._get_or_return_ctx(input) 1366 result = self._def["inner_type"]._parse(ParseInput(ctx.data, ctx.path, ctx)) 1367 value = result.value if result.status is VALID else self._def["default"]() 1368 return ParseReturn(VALID, value) 1369 1370 1371 class ZodUnion(ZodType): 1372 def _parse(self, input): 1373 ctx = self._get_or_return_ctx(input) 1374 options = self._def["options"] 1375 dirty = None 1376 issues = [] 1377 1378 for option in options: 1379 # child_ctx = ... 1380 child_ctx = ParseContext( 1381 **{ 1382 **ctx, 1383 "common": Common(**{**ctx.common, "issues": []}), 1384 "parent": None, 1385 } 1386 ) 1387 1388 result = option._parse( 1389 ParseInput(data=ctx.data, path=ctx.path, parent=child_ctx) 1390 ) 1391 1392 if result.status is VALID: 1393 return result 1394 elif result.status is DIRTY and not dirty: 1395 dirty = {"result": result, "ctx": child_ctx} 1396 1397 if child_ctx.common.issues: 1398 issues.append(child_ctx.common.issues) # should this be extend? 1399 1400 if dirty: 1401 ctx.common.issues.extend(dirty["ctx"].common.issues) 1402 return dirty["result"] 1403 1404 add_issue_to_context(ctx, code=ZodIssueCode.invalid_union, union_issues=issues) 1405 return INVALID 1406 1407 @property 1408 def options(self): 1409 return self._def["options"] 1410 1411 @classmethod 1412 def _create(cls, types, **params): 1413 return cls(dict(options=types, **process_params(**params))) 1414 1415 1416 class ZodPipeline(ZodType): 1417 def _parse(self, input): 1418 status, ctx = self._process_input_params(input) 1419 in_result = self._def["in"]._parse( 1420 ParseInput(data=ctx.data, path=ctx.path, parent=ctx) 1421 ) 1422 if in_result.status is ABORTED: 1423 return INVALID 1424 if in_result.status is DIRTY: 1425 status.dirty() 1426 return ParseReturn(status=status.value, value=input.data) 1427 return self._def["out"]._parse( 1428 ParseInput(data=in_result.value, path=ctx.path, parent=ctx) 1429 ) 1430 1431 @classmethod 1432 def _create(cls, a: ZodType, b: ZodType, **params): 1433 return cls(dict({"in": a, "out": b}, **process_params(**params))) 1434 1435 1436 def custom(check=None, fatal=False, **params): 1437 if check is not None: 1438 1439 def custom_check(data, ctx: CheckContext): 1440 if not check(data): 1441 ctx.add_issue(fatal=fatal, **params) 1442 1443 return ZodAny._create().super_refine(custom_check) 1444 return ZodAny._create() 1445 1446 1447 def isinstance(cls, message=""): 1448 message = message or f"Input not instance of {cls.__name__}" 1449 return custom(lambda data: isinstance_(data, cls), fatal=True, message=message) 1450 1451 1452 NEVER = INVALID 1453 1454 any = ZodAny._create 1455 array = ZodList._create 1456 boolean = ZodBoolean._create 1457 date = ZodDate._create 1458 datetime = ZodDateTime._create 1459 enum = ZodEnum._create 1460 float = ZodFloat._create 1461 integer = ZodInteger._create 1462 lazy = ZodLazy._create 1463 list = ZodList._create 1464 literal = ZodLiteral._create 1465 mapping = ZodMapping._create 1466 never = ZodNever._create 1467 none = ZodNone._create 1468 not_required = ZodNotRequired._create 1469 number = ZodNumber._create 1470 object = ZodTypedDict._create 1471 optional = ZodOptional._create 1472 preprocess = ZodEffects._preprocess 1473 record = ZodMapping._create 1474 string = ZodString._create 1475 tuple = ZodTuple._create 1476 typed_dict = ZodTypedDict._create 1477 unknown = ZodUnknown._create 1478 union = ZodUnion._create 1479 1480 1481 class ZodCoercion: 1482 @staticmethod 1483 def string(**params): 1484 return string(coerce=True, **params) 1485 1486 @staticmethod 1487 def integer(**params): 1488 return integer(coerce=True, **params) 1489 1490 @staticmethod 1491 def float(**params): 1492 return float(coerce=True, **params) 1493 1494 @staticmethod 1495 def boolean(**params): 1496 return boolean(coerce=True, **params) 1497 1498 1499 coerce = ZodCoercion()