/ client_code / zod / _types.py
_types.py
   1  # SPDX-License-Identifier: MIT
   2  #
   3  # Copyright (c) 2021 The Anvil Extras project team members listed at
   4  # https://github.com/anvilistas/anvil-extras/graphs/contributors
   5  #
   6  # This software is published at https://github.com/anvilistas/anvil-extras
   7  
   8  from datetime import date as _date
   9  from datetime import datetime as _datetime
  10  
  11  from anvil import is_server_side
  12  
  13  from ._zod_error import ZodError, ZodIssueCode
  14  from .helpers import ZodParsedType, get_parsed_type, regex, util
  15  from .helpers.dict_util import getitem, merge_shapes
  16  from .helpers.parse_util import (
  17      ABORTED,
  18      DIRTY,
  19      INVALID,
  20      MISSING,
  21      OK,
  22      VALID,
  23      Common,
  24      ErrorMapContext,
  25      ParseContext,
  26      ParseInput,
  27      ParseResult,
  28      ParseReturn,
  29      ParseStatus,
  30      add_issue_to_context,
  31      is_valid,
  32  )
  33  
  34  __version__ = "3.1.0"
  35  
  36  any_ = any
  37  isinstance_ = isinstance
  38  float_ = float
  39  list_ = list
  40  
  41  
  42  class ParseInputLazyPath:
  43      def __init__(self, parent, value, path, key):
  44          self.parent = parent
  45          self.data = value
  46          self._path = path
  47          self._key = key
  48  
  49      @property
  50      def path(self):
  51          return self._path + [self._key]
  52  
  53  
  54  def handle_result(ctx, result):
  55      if is_valid(result):
  56          return ParseResult(success=True, data=result.value, error=None)
  57      else:
  58          if not ctx.common.issues:
  59              raise Exception("Validation failed but no issues detected")
  60          error = ZodError(ctx.common.issues)
  61          return ParseResult(success=False, data=None, error=error)
  62  
  63  
  64  def _check_error_cb(rv):
  65      assert (
  66          type(rv) is dict and "message" in rv
  67      ), f"bad return type from error callback, expected str or {{'message': str}}, got {rv!r}"
  68      return rv
  69  
  70  
  71  def process_params(error_map=None, invalid_type_error=False, required_error=False):
  72      if not any_([error_map, invalid_type_error, required_error]):
  73          return {}
  74      if error_map and (invalid_type_error or required_error):
  75          raise Exception(
  76              'Can\'t use "invalid_type_error" or "required_error" in conjunction with custom error'
  77          )
  78  
  79      if error_map:
  80  
  81          def _error_map(issue, ctx):
  82              rv = error_map(issue, ctx) or ctx.default_error
  83              if type(rv) is str:
  84                  return {"message": rv}
  85  
  86              _check_error_cb(rv)
  87              return rv
  88  
  89          return {"error_map": _error_map}
  90  
  91      def custom_map(issue, ctx: ErrorMapContext):
  92          if issue["code"] != "invalid_type":
  93              return {"message": ctx.default_error}
  94          if ctx.data is MISSING:
  95              return {"message": required_error or ctx.default_error}
  96          return {"message": invalid_type_error or ctx.default_error}
  97  
  98      return {"error_map": custom_map}
  99  
 100  
 101  class ZodType:
 102      _type = None
 103      _type_name = None
 104  
 105      @classmethod
 106      def _create(cls, **params):
 107          return cls(process_params(**params))
 108  
 109      def __init__(self, _def: dict):
 110          self._def = _def
 111  
 112      def _check_invalid_type(self, input):
 113          parsed_type = self._get_type(input)
 114  
 115          types = self._type if type(self._type) is list_ else [self._type]
 116  
 117          if parsed_type not in types:
 118              ctx = self._get_or_return_ctx(input)
 119              add_issue_to_context(
 120                  ctx,
 121                  code=ZodIssueCode.invalid_type,
 122                  expected=self._type_name,
 123                  received=ctx.parsed_type,
 124              )
 125              return True
 126  
 127      def _parse(self, input):
 128          raise NotImplementedError("should be implemented by subclass")
 129  
 130      def _get_type(self, input):
 131          return get_parsed_type(input.data)
 132  
 133      def _get_or_return_ctx(self, input: ParseInput, ctx=None):
 134          return ctx or ParseContext(
 135              common=input.parent.common,
 136              data=input.data,
 137              parsed_type=get_parsed_type(input.data),
 138              schema_error_map=self._def.get("error_map"),
 139              path=input.path,
 140              parent=input.parent,
 141          )
 142  
 143      def _process_input_params(self, input: ParseInput):
 144          return ParseStatus(), self._get_or_return_ctx(input)
 145  
 146      def parse(self, data, **params):
 147          result = self.safe_parse(data, **params)
 148          if result.success:
 149              return result.data
 150          raise result.error
 151  
 152      def safe_parse(self, data, **params):
 153          ctx = ParseContext(
 154              common=Common(issues=[], contextual_error_map=params.get("error_map")),
 155              path=params.get("path", []),
 156              schema_error_map=self._def.get("error_map"),
 157              parent=None,
 158              data=data,
 159              parsed_type=get_parsed_type(data),
 160          )
 161          input = ParseInput(data, path=ctx.path, parent=ctx)
 162          result = self._parse(input)
 163          return handle_result(ctx, result)
 164  
 165      def list(self):
 166          return ZodList._create(self)
 167  
 168      array = list
 169  
 170      def not_required(self):
 171          "might not exist"
 172          return ZodNotRequired._create(self)
 173  
 174      def optional(self):
 175          "can be None"
 176          return ZodOptional._create(self)
 177  
 178      def default(self, value):
 179          "replace missing values with"
 180          return ZodDefault._create(self, value)
 181  
 182      def catch(self, value):
 183          "if the parse fails - replace with value"
 184          return ZodCatch._create(self, value)
 185  
 186      def union(self, other):
 187          "equivalent to z.union([a, b])"
 188          return ZodUnion._create([self, other])
 189  
 190      or_ = union
 191  
 192      def super_refine(self, refinement):
 193          return ZodEffects._create(
 194              schema=self, effect={"type": "refinment", "refinement": refinement}
 195          )
 196  
 197      def refine(self, check_fn, message="", fatal=False, **issue_params):
 198          """add a custom check_fn(val) -> bool, message: can be a str or a callable: (val) -> str
 199          if fatal is True, no further checks for this schema will be considered
 200          """
 201  
 202          def get_issue_props(val):
 203              rv = message
 204              if callable(rv):
 205                  rv = rv(val)
 206  
 207              if type(rv) is str:
 208                  return {"fatal": fatal, **issue_params, "message": rv}
 209  
 210              rv = _check_error_cb(rv)
 211              return {"fatal": fatal, **issue_params, **rv}
 212  
 213          def _refinement(val, ctx: CheckContext):
 214              if check_fn(val):
 215                  return True
 216              else:
 217                  ctx.add_issue(**get_issue_props(val))
 218                  return False
 219  
 220          return self.super_refine(_refinement)
 221  
 222      def super_transform(self, transform_fn):
 223          return ZodEffects._create(
 224              schema=self, effect={"type": "transform", "transform": transform_fn}
 225          )
 226  
 227      def transform(self, transform_fn):
 228          "transform the input with a custom transform function, to prevent any further checks/transforms return z.NEVER"
 229  
 230          def _transform(val, ctx: CheckContext):
 231              rv = transform_fn(val)
 232              if rv is INVALID:
 233                  ctx.add_issue(fatal=True)
 234              return rv
 235  
 236          return self.super_transform(_transform)
 237  
 238      def pipe(self, target):
 239          if not isinstance_(target, ZodType):
 240              raise TypeError("expected a zod schema")
 241          return ZodPipeline._create(self, target)
 242  
 243  
 244  class ZodString(ZodType):
 245      _type = ZodParsedType.string
 246      _type_name = _type
 247  
 248      def _parse(self, input: ParseInput):
 249          if self._def["coerce"]:
 250              input.data = str(input.data)
 251  
 252          if self._check_invalid_type(input):
 253              return INVALID
 254  
 255          status = ParseStatus()
 256          ctx = None
 257          for check in self._def["checks"]:
 258              kind = check["kind"]
 259  
 260              if kind == "min":
 261                  if len(input.data) < check["value"]:
 262                      ctx = self._get_or_return_ctx(input, ctx)
 263                      add_issue_to_context(
 264                          ctx,
 265                          code=ZodIssueCode.too_small,
 266                          minimum=check["value"],
 267                          type="string",
 268                          inclusive=True,
 269                          message=check["message"],
 270                      )
 271                      status.dirty()
 272  
 273              elif kind == "max":
 274                  if len(input.data) > check["value"]:
 275                      ctx = self._get_or_return_ctx(input, ctx)
 276                      add_issue_to_context(
 277                          ctx,
 278                          code=ZodIssueCode.too_big,
 279                          maximum=check["value"],
 280                          type="string",
 281                          inclusive=True,
 282                          message=check["message"],
 283                      )
 284                      status.dirty()
 285  
 286              elif kind == "email":
 287                  if not regex.EMAIL.match(input.data):
 288                      ctx = self._get_or_return_ctx(input, ctx)
 289                      add_issue_to_context(
 290                          ctx,
 291                          code=ZodIssueCode.invalid_string,
 292                          validation="email",
 293                          message=check["message"],
 294                      )
 295                      status.dirty()
 296  
 297              elif kind == "uuid":
 298                  if not regex.UUID.match(input.data):
 299                      ctx = self._get_or_return_ctx(input, ctx)
 300                      add_issue_to_context(
 301                          ctx,
 302                          code=ZodIssueCode.invalid_string,
 303                          validation="uuid",
 304                          message=check["message"],
 305                      )
 306                      status.dirty()
 307  
 308              elif kind == "url":
 309                  url_valid = True
 310                  if not is_server_side():
 311                      from anvil.js.window import URL
 312  
 313                      try:
 314                          URL(input.data)
 315                      except Exception:
 316                          url_valid = False
 317                  else:
 318                      url_valid = regex.URL.match(input.data)
 319                  if not url_valid:
 320                      ctx = self._get_or_return_ctx(input, ctx)
 321                      add_issue_to_context(
 322                          ctx,
 323                          code=ZodIssueCode.invalid_string,
 324                          validation="url",
 325                          message=check["message"],
 326                      )
 327                      status.dirty()
 328  
 329              elif kind == "regex":
 330                  match = check["regex"].match(input.data)
 331                  if match is None:
 332                      ctx = self._get_or_return_ctx(input, ctx)
 333                      add_issue_to_context(
 334                          ctx,
 335                          code=ZodIssueCode.invalid_string,
 336                          validation="regex",
 337                          message=check["message"],
 338                      )
 339                      status.dirty()
 340  
 341              elif kind == "strip":
 342                  input.data = input.data.strip()
 343              elif kind == "lower":
 344                  input.data = input.data.lower()
 345              elif kind == "upper":
 346                  input.data = input.data.upper()
 347  
 348              elif kind == "startswith":
 349                  if not input.data.startswith(check["value"]):
 350                      ctx = self._get_or_return_ctx(input, ctx)
 351                      add_issue_to_context(
 352                          ctx,
 353                          code=ZodIssueCode.invalid_string,
 354                          validation={"startswith": check["value"]},
 355                          message=check["message"],
 356                      )
 357                      status.dirty()
 358  
 359              elif kind == "endswith":
 360                  if not input.data.endswith(check["value"]):
 361                      ctx = self._get_or_return_ctx(input, ctx)
 362                      add_issue_to_context(
 363                          ctx,
 364                          code=ZodIssueCode.invalid_string,
 365                          validation={"endswith": check["value"]},
 366                          message=check["message"],
 367                      )
 368                      status.dirty()
 369  
 370              elif kind == "datetime" or kind == "date":
 371                  format = check["format"]
 372                  try:
 373                      if format is not None:
 374                          _datetime.strptime(input.data, format)
 375                      elif kind == "datetime":
 376                          _datetime.fromisoformat(input.data)
 377                      else:
 378                          _date.fromisoformat(input.data)
 379                  except Exception:
 380                      ctx = self._get_or_return_ctx(input, ctx)
 381                      add_issue_to_context(
 382                          ctx,
 383                          code=ZodIssueCode.invalid_string,
 384                          validation=kind,
 385                          message=check["message"],
 386                      )
 387                      status.dirty()
 388  
 389              else:
 390                  assert False
 391  
 392          return ParseReturn(status=status.value, value=input.data)
 393  
 394      def _add_check(self, **check):
 395          return ZodString({**self._def, "checks": [*self._def["checks"], check]})
 396  
 397      def email(self, message=""):
 398          return self._add_check(kind="email", message=message)
 399  
 400      def url(self, message=""):
 401          return self._add_check(kind="url", message=message)
 402  
 403      def uuid(self, message=""):
 404          return self._add_check(kind="uuid", message=message)
 405  
 406      def datetime(self, format=None, message=""):
 407          "default format is isoformat"
 408          return self._add_check(kind="datetime", format=format, message=message)
 409  
 410      def date(self, format=None, message=""):
 411          "default format is isoformat"
 412          return self._add_check(kind="date", format=format, message=message)
 413  
 414      def regex(self, regex, message=""):
 415          return self._add_check(kind="regex", regex=regex, message=message)
 416  
 417      def startswith(self, value, message=""):
 418          return self._add_check(kind="startswith", value=value, message=message)
 419  
 420      def endswith(self, value, message=""):
 421          return self._add_check(kind="endswith", value=value, message=message)
 422  
 423      def min(self, min_length: int, message=""):
 424          return self._add_check(kind="min", value=min_length, message=message)
 425  
 426      def max(self, min_length: int, message=""):
 427          return self._add_check(kind="max", value=min_length, message=message)
 428  
 429      def len(self, len: int, message=""):
 430          return self.min(len, message).max(len, message)
 431  
 432      def nonempty(self, message=""):
 433          return self.min(1, message)
 434  
 435      def strip(self):
 436          "similar to z.string().transform(str.strip)"
 437          return self._add_check(kind="strip")
 438  
 439      def lower(self):
 440          "similar to z.string().transform(str.lower)"
 441          return self._add_check(kind="lower")
 442  
 443      def upper(self):
 444          "similar to z.string().transform(str.upper)"
 445          return self._add_check(kind="upper")
 446  
 447      @classmethod
 448      def _create(cls, *, coerce=False, **params):
 449          return cls(dict(checks=[], coerce=coerce, **process_params(**params)))
 450  
 451  
 452  class ZodAbstractNumber(ZodType):
 453      _type = ZodParsedType.number
 454      _type_name = _type
 455  
 456      def _parse(self, input):
 457          if self._def["coerce"]:
 458              try:
 459                  if self._type == ZodParsedType.integer:
 460                      input.data = int(input.data)
 461                  elif self._type == ZodParsedType.float:
 462                      input.data = float_(input.data)
 463              except Exception:
 464                  pass
 465  
 466          if self._check_invalid_type(input):
 467              return INVALID
 468  
 469          status = ParseStatus()
 470          ctx = None
 471  
 472          for check in self._def["checks"]:
 473              kind = check["kind"]
 474  
 475              if kind == "int":
 476                  data = input.data
 477                  if type(data) is float_ and not data.is_integer():
 478                      ctx = self._get_or_return_ctx(input, ctx)
 479                      add_issue_to_context(
 480                          ctx,
 481                          code=ZodIssueCode.invalid_type,
 482                          expected="integer",
 483                          received="float",
 484                          message=check["message"],
 485                      )
 486                      status.dirty()
 487  
 488              elif kind == "min":
 489                  value = check["value"]
 490                  inclusive = check["inclusive"]
 491                  too_small = input.data < value if inclusive else input.data <= value
 492                  if too_small:
 493                      ctx = self._get_or_return_ctx(input, ctx)
 494                      add_issue_to_context(
 495                          ctx,
 496                          code=ZodIssueCode.too_small,
 497                          minimum=value,
 498                          type=self._type_name,
 499                          inclusive=inclusive,
 500                          message=check["message"],
 501                      )
 502                      status.dirty()
 503  
 504              elif kind == "max":
 505                  value = check["value"]
 506                  inclusive = check["inclusive"]
 507                  too_big = input.data > value if inclusive else input.data >= value
 508                  if too_big:
 509                      ctx = self._get_or_return_ctx(input, ctx)
 510                      add_issue_to_context(
 511                          ctx,
 512                          code=ZodIssueCode.too_big,
 513                          maximum=value,
 514                          type=self._type_name,
 515                          inclusive=inclusive,
 516                          message=check["message"],
 517                      )
 518                      status.dirty()
 519  
 520              else:
 521                  assert False
 522  
 523          return ParseReturn(status=status.value, value=input.data)
 524  
 525      @classmethod
 526      def _create(cls, **params):
 527          return cls(dict(checks=[], coerce=False, **process_params(**params)))
 528  
 529  
 530  class ZodInteger(ZodAbstractNumber):
 531      _type = ZodParsedType.integer
 532      _type_name = _type
 533  
 534      def _add_check(self, **check):
 535          return ZodInteger({**self._def, "checks": [*self._def["checks"], check]})
 536  
 537      def set_limit(self, kind, value, inclusive, message=""):
 538          return self._add_check(
 539              kind=kind, value=value, inclusive=inclusive, message=message
 540          )
 541  
 542      def ge(self, value, message=""):
 543          return self.set_limit("min", value, True, message)
 544  
 545      def min(self, value, message=""):
 546          return self.set_limit("min", value, True, message)
 547  
 548      def gt(self, value, message=""):
 549          return self.set_limit("min", value, False, message)
 550  
 551      def le(self, value, message=""):
 552          return self.set_limit("max", value, True, message)
 553  
 554      def max(self, value, message=""):
 555          return self.set_limit("max", value, True, message)
 556  
 557      def lt(self, value, message=""):
 558          return self.set_limit("max", value, False, message)
 559  
 560      def positive(self, message=""):
 561          return self.set_limit("min", 0, False, message)
 562  
 563      def negative(self, message=""):
 564          return self.set_limit("max", 0, False, message)
 565  
 566      def nonpositive(self, message=""):
 567          return self.set_limit("max", 0, True, message)
 568  
 569      def nonnegative(self, message=""):
 570          return self.set_limit("min", 0, True, message)
 571  
 572      @classmethod
 573      def _create(cls, *, coerce=False, **params):
 574          return cls(dict(checks=[], coerce=coerce, **process_params(**params)))
 575  
 576  
 577  class ZodFloat(ZodAbstractNumber):
 578      _type = ZodParsedType.float
 579      _type_name = _type
 580  
 581      def _add_check(self, **check):
 582          return ZodFloat({**self._def, "checks": [*self._def["checks"], check]})
 583  
 584      def set_limit(self, kind, value, inclusive, message=""):
 585          return self._add_check(
 586              kind=kind, value=value, inclusive=inclusive, message=message
 587          )
 588  
 589      def int(self, message=""):
 590          return self._add_check(kind="int", message=message)
 591  
 592      def ge(self, value, message=""):
 593          return self.set_limit("min", value, True, message)
 594  
 595      def min(self, value, message=""):
 596          return self.set_limit("min", value, True, message)
 597  
 598      def gt(self, value, message=""):
 599          return self.set_limit("min", value, False, message)
 600  
 601      def le(self, value, message=""):
 602          return self.set_limit("max", value, True, message)
 603  
 604      def max(self, value, message=""):
 605          return self.set_limit("max", value, True, message)
 606  
 607      def lt(self, value, message=""):
 608          return self.set_limit("max", value, False, message)
 609  
 610      def positive(self, message=""):
 611          return self.set_limit("min", 0, False, message)
 612  
 613      def negative(self, message=""):
 614          return self.set_limit("max", 0, False, message)
 615  
 616      def nonpositive(self, message=""):
 617          return self.set_limit("max", 0, True, message)
 618  
 619      def nonnegative(self, message=""):
 620          return self.set_limit("min", 0, True, message)
 621  
 622      @classmethod
 623      def _create(cls, *, coerce=False, **params):
 624          return cls(dict(checks=[], coerce=coerce, **process_params(**params)))
 625  
 626  
 627  class ZodNumber(ZodAbstractNumber):
 628      _type = [ZodParsedType.integer, ZodParsedType.float]
 629      _type_name = ZodParsedType.number
 630  
 631      def _add_check(self, **check):
 632          return ZodNumber({**self._def, "checks": [*self._def["checks"], check]})
 633  
 634      def set_limit(self, kind, value, inclusive, message=""):
 635          return self._add_check(
 636              kind=kind, value=value, inclusive=inclusive, message=message
 637          )
 638  
 639      def int(self, message=""):
 640          return self._add_check(kind="int", message=message)
 641  
 642      def ge(self, value, message=""):
 643          return self.set_limit("min", value, True, message)
 644  
 645      def min(self, value, message=""):
 646          return self.set_limit("min", value, True, message)
 647  
 648      def gt(self, value, message=""):
 649          return self.set_limit("min", value, False, message)
 650  
 651      def le(self, value, message=""):
 652          return self.set_limit("max", value, True, message)
 653  
 654      def max(self, value, message=""):
 655          return self.set_limit("max", value, True, message)
 656  
 657      def lt(self, value, message=""):
 658          return self.set_limit("max", value, False, message)
 659  
 660      def positive(self, message=""):
 661          return self.set_limit("min", 0, False, message)
 662  
 663      def negative(self, message=""):
 664          return self.set_limit("max", 0, False, message)
 665  
 666      def nonpositive(self, message=""):
 667          return self.set_limit("max", 0, True, message)
 668  
 669      def nonnegative(self, message=""):
 670          return self.set_limit("min", 0, True, message)
 671  
 672      @classmethod
 673      def _create(cls, **params):
 674          return cls(dict(checks=[], coerce=False, **process_params(**params)))
 675  
 676  
 677  class ZodDateTime(ZodType):
 678      _type = ZodParsedType.datetime
 679      _type_name = _type
 680  
 681      def _parse(self, input):
 682          if self._check_invalid_type(input):
 683              return INVALID
 684  
 685          status = ParseStatus()
 686          ctx = None
 687          for check in self._def["checks"]:
 688              kind = check["kind"]
 689  
 690              if kind == "min":
 691                  if input.data < check["value"]:
 692                      ctx = self._get_or_return_ctx(input, ctx)
 693                      add_issue_to_context(
 694                          ctx,
 695                          code=ZodIssueCode.too_small,
 696                          minimum=check["value"].isoformat(),
 697                          type=self._type_name,
 698                          inclusive=True,
 699                          message=check["message"],
 700                      )
 701                      status.dirty()
 702  
 703              elif kind == "max":
 704                  if input.data > check["value"]:
 705                      ctx = self._get_or_return_ctx(input, ctx)
 706                      add_issue_to_context(
 707                          ctx,
 708                          code=ZodIssueCode.too_big,
 709                          maximum=check["value"].isoformat(),
 710                          type=self._type_name,
 711                          inclusive=True,
 712                          message=check["message"],
 713                      )
 714                      status.dirty()
 715  
 716              else:
 717                  assert False
 718  
 719          return ParseReturn(status=status.value, value=input.data)
 720  
 721      def _add_check(self, **check):
 722          return ZodDateTime({**self._def, "checks": [*self._def["checks"], check]})
 723  
 724      def min(self, min_date: _datetime, message=""):
 725          return self._add_check(kind="min", value=min_date, message=message)
 726  
 727      def max(self, max_date: _datetime, message=""):
 728          return self._add_check(kind="max", value=max_date, message=message)
 729  
 730      @classmethod
 731      def _create(cls, **params):
 732          return cls(dict(checks=[], **process_params(**params)))
 733  
 734  
 735  class ZodDate(ZodDateTime):
 736      _type = ZodParsedType.date
 737      _type_name = _type
 738  
 739      def _add_check(self, **check):
 740          return ZodDate({**self._def, "checks": [*self._def["checks"], check]})
 741  
 742      def min(self, min_date: _date, message=""):
 743          return self._add_check(kind="min", value=min_date, message=message)
 744  
 745      def max(self, max_date: _date, message=""):
 746          return self._add_check(kind="max", value=max_date, message=message)
 747  
 748  
 749  class ZodBoolean(ZodType):
 750      _type = ZodParsedType.boolean
 751      _type_name = _type
 752  
 753      def _parse(self, input: ParseInput):
 754          if self._def["coerce"]:
 755              input.data = bool(input.data)
 756          if self._check_invalid_type(input):
 757              return INVALID
 758          return OK(input.data)
 759  
 760      @classmethod
 761      def _create(cls, *, coerce=False, **params):
 762          return cls(dict(coerce=coerce, **process_params(**params)))
 763  
 764  
 765  class ZodNone(ZodType):
 766      _type = ZodParsedType.none
 767      _type_name = _type
 768  
 769      def _parse(self, input: ParseInput):
 770          if self._check_invalid_type(input):
 771              return INVALID
 772          return OK(input.data)
 773  
 774  
 775  class ZodAny(ZodType):
 776      def _parse(self, input):
 777          return OK(input.data)
 778  
 779  
 780  class ZodUnknown(ZodType):
 781      _type = ZodParsedType.unknown
 782      _type_name = _type
 783      _unknown = True
 784  
 785      def _parse(self, input):
 786          return OK(input.data)
 787  
 788  
 789  class ZodNever(ZodType):
 790      _type = ZodParsedType.never
 791      _type_name = _type
 792  
 793      def _parse(self, input):
 794          ctx = self._get_or_return_ctx(input)
 795          add_issue_to_context(
 796              ctx,
 797              code=ZodIssueCode.invalid_type,
 798              expected=self._type,
 799              received=ctx.parsed_type,
 800          )
 801          return INVALID
 802  
 803  
 804  class ZodList(ZodType):
 805      _type = [ZodParsedType.list, ZodParsedType.tuple]
 806      _type_name = "list"
 807  
 808      def _parse(self, input):
 809          status, ctx = self._process_input_params(input)
 810  
 811          if self._check_invalid_type(input):
 812              return INVALID
 813  
 814          for check in self._def["checks"]:
 815              kind = check["kind"]
 816  
 817              if kind == "min":
 818                  if len(ctx.data) < check["value"]:
 819                      add_issue_to_context(
 820                          ctx,
 821                          code=ZodIssueCode.too_small,
 822                          minimum=check["value"],
 823                          type="list",
 824                          inclusive=True,
 825                          message=check["message"],
 826                      )
 827                      status.dirty()
 828  
 829              elif kind == "max":
 830                  if len(ctx.data) > check["value"]:
 831                      add_issue_to_context(
 832                          ctx,
 833                          code=ZodIssueCode.too_big,
 834                          maximum=check["value"],
 835                          type="list",
 836                          inclusive=True,
 837                          message=check["message"],
 838                      )
 839                      status.dirty()
 840  
 841          type_schema = self._def["type"]
 842  
 843          results = [
 844              type_schema._parse(ParseInputLazyPath(ctx, item, ctx.path, i))
 845              for i, item in enumerate(ctx.data)
 846          ]
 847  
 848          return ParseStatus.merge_list(status, results)
 849  
 850      @property
 851      def element(self):
 852          return self._def["type"]
 853  
 854      def _add_check(self, **check):
 855          return ZodList({**self._def, "checks": [*self._def["checks"], check]})
 856  
 857      def min(self, min_length, message=""):
 858          return self._add_check(kind="min", value=min_length, message=message)
 859  
 860      def max(self, max_length, message=""):
 861          return self._add_check(kind="max", value=max_length, message=message)
 862  
 863      def len(self, len, message=""):
 864          return self.min(len, message).max(len, message)
 865  
 866      def nonempty(self, message=""):
 867          return self.min(1, message)
 868  
 869      @classmethod
 870      def _create(cls, schema, **params):
 871          return cls(dict(type=schema, checks=[], **process_params(**params)))
 872  
 873  
 874  class ZodEnum(ZodType):
 875      def _parse(self, input):
 876          values = self._def["values"]
 877          if input.data not in values:
 878              ctx = self._get_or_return_ctx(input)
 879              add_issue_to_context(
 880                  ctx,
 881                  code=ZodIssueCode.invalid_type,
 882                  expected=" | ".join(repr(a) for a in values),
 883                  received=ctx.parsed_type,
 884              )
 885              return INVALID
 886          return OK(input.data)
 887  
 888      @property
 889      def options(self):
 890          return self._def["values"]
 891  
 892      @property
 893      def enum(self):
 894          return util.enum("ENUM", self.options)
 895  
 896      @classmethod
 897      def _create(cls, options, **params):
 898          return cls(dict(values=list_(options), **process_params(**params)))
 899  
 900  
 901  def deep_partialify(schema):
 902      t = type(schema)
 903      if t is ZodTypedDict:
 904          new_shape = {
 905              k: ZodNotRequired._create(deep_partialify(v))
 906              for k, v in schema.shape.items()
 907          }
 908          return ZodTypedDict({**schema._def, "shape": lambda: new_shape})
 909      if t is ZodList:
 910          return ZodList._create(deep_partialify(schema.element))
 911      if t is ZodNotRequired:
 912          return ZodNotRequired._create(deep_partialify(schema.unwrap()))
 913      if t is ZodOptional:
 914          return ZodOptional._create(deep_partialify(schema.unwrap()))
 915      if t is ZodTuple:
 916          return ZodTuple._create([deep_partialify(item) for item in schema.items])
 917      return schema
 918  
 919  
 920  class ZodTypedDict(ZodType):
 921      _type = ZodParsedType.mapping
 922      _type_name = _type
 923  
 924      def __init__(self, _def):
 925          super().__init__(_def)
 926          self._cached = None
 927  
 928      def _parse(self, input):
 929          if self._check_invalid_type(input):
 930              return INVALID
 931  
 932          status, ctx = self._process_input_params(input)
 933          shape = self.shape
 934          shape_keys = shape.keys()
 935          extra_keys = set()
 936  
 937          if not (
 938              type(self._def["catchall"]) is ZodNever
 939              and self._def["unknown_keys"] == "strip"
 940          ):
 941              for key in ctx.data.keys():
 942                  if key not in shape_keys:
 943                      extra_keys.add(key)
 944  
 945          pairs = []
 946  
 947          for key in shape_keys:
 948              key_validator = shape[key]
 949              value = getitem(ctx.data, key, MISSING)
 950              pairs.append(
 951                  (
 952                      ParseReturn(VALID, key),
 953                      key_validator._parse(ParseInputLazyPath(ctx, value, ctx.path, key)),
 954                      key in ctx.data,
 955                  )
 956              )
 957  
 958          if type(self._def["catchall"]) is ZodNever:
 959              unknown_keys = self._def["unknown_keys"]
 960              if unknown_keys == "passthrough":
 961                  for key in extra_keys:
 962                      pairs.append(
 963                          (
 964                              ParseReturn(VALID, key),
 965                              ParseReturn(VALID, ctx.data[key]),
 966                              False,
 967                          )
 968                      )
 969              elif unknown_keys == "strict":
 970                  if extra_keys:
 971                      add_issue_to_context(
 972                          ctx, code=ZodIssueCode.unrecognized_keys, keys=extra_keys
 973                      )
 974                      status.dirty()
 975              elif unknown_keys == "strip":
 976                  pass
 977              else:
 978                  assert False, "invalid unknown_keys value"
 979          else:
 980              # run cachall validation
 981              catchall = self._def["catchall"]
 982  
 983              for key in extra_keys:
 984                  value = ctx.data[key]
 985                  pairs.append(
 986                      (
 987                          ParseReturn(VALID, key),
 988                          catchall._parse(
 989                              ParseInputLazyPath(ctx, value, ctx.path, key),
 990                          ),
 991                          key in ctx.data,
 992                      )
 993                  )
 994  
 995          return ParseStatus.merge_dict(status, pairs)
 996  
 997      @property
 998      def shape(self):
 999          return self._def["shape"]()
1000  
1001      def strict(self, message=""):
1002          "reject if theere are extra keys"
1003          _def = {**self._def, "unknown_keys": "strict"}
1004          if message:
1005  
1006              def error_map(issue, ctx):
1007                  try:
1008                      default_error = self._def["error_map"](issue, ctx)["message"]
1009                  except TypeError:
1010                      default_error = ctx.default_error
1011                  if issue.code == "unrecognized_keys":
1012                      return {"message": message or default_error}
1013                  return {"message": default_error}
1014  
1015              _def["error_map"] = error_map
1016          return ZodTypedDict(_def)
1017  
1018      def strip(self):
1019          "return the data without additional keys"
1020          return ZodTypedDict({**self._def, "unknown_keys": "strip"})
1021  
1022      def passthrough(self):
1023          "ignore additional keys"
1024          return ZodTypedDict({**self._def, "unknown_keys": "passthrough"})
1025  
1026      nonstrict = passthrough
1027  
1028      def extend(self, shape):
1029          "create a new schema extending the current shape"
1030          return ZodTypedDict(
1031              {**self._def, "shape": lambda: merge_shapes(self.shape, shape)}
1032          )
1033  
1034      def set_key(self, key, schema):
1035          "returns a new schema with the additional key"
1036          return self.extend({key: schema})
1037  
1038      def merge(self, merge_with):
1039          "merge two mapping schemas"
1040          assert type(merge_with) is ZodTypedDict, "expected a zod mapping schema"
1041          merged = {
1042              "unknown_keys": merge_with._def["unknown_keys"],
1043              "catchall": merge_with._def["catchall"],
1044              "shape": lambda: merge_shapes(
1045                  self._def["shape"](), merge_with._def["shape"]()
1046              ),
1047          }
1048          return ZodTypedDict(merged)
1049  
1050      def catchall(self, index):
1051          return ZodTypedDict({**self._def, "catchall": index})
1052  
1053      def pick(self, mask):
1054          "mask should be an iterable of keys, retuns a new schema with only those keys"
1055          this_shape = self.shape
1056          shape = {k: this_shape[k] for k in mask if k in this_shape}
1057          return ZodTypedDict({**self._def, "shape": lambda: shape})
1058  
1059      def omit(self, mask):
1060          "mask should be an iterable of keys, retuns a new schema without those keys"
1061          this_shape = self.shape
1062          shape = {k: v for k, v in this_shape.items() if k not in mask}
1063          return ZodTypedDict({**self._def, "shape": lambda: shape})
1064  
1065      def partial(self, mask=None):
1066          "returns a new schema where values are not required. If a mask is provided, only those keys will become not required"
1067          if mask:
1068              shape = {
1069                  k: (v.not_required() if k in mask else v) for k, v in self.shape.items()
1070              }
1071          else:
1072              shape = {k: v.not_required() for k, v in self.shape.items()}
1073          return ZodTypedDict({**self._def, "shape": lambda: shape})
1074  
1075      def deep_partial(self):
1076          return deep_partialify(self)
1077  
1078      def required(self, mask=None):
1079          "returns a new schema where values are required. If a mask is provided, only those keys will become required"
1080  
1081          def unwrap(field):
1082              while isinstance_(field, ZodNotRequired):
1083                  field = field._def["inner_type"]
1084              return field
1085  
1086          if mask:
1087              shape = {k: (unwrap(v) if k in mask else v) for k, v in self.shape.items()}
1088          else:
1089              shape = {k: unwrap(v) for k, v in self.shape.items()}
1090          return ZodTypedDict({**self._def, "shape": lambda: shape})
1091  
1092      def keyof(self):
1093          "get the keys of this mapping schema as an enum schema"
1094          return ZodEnum._create(self.shape.keys())
1095  
1096      @classmethod
1097      def _create(cls, shape, **params):
1098          return cls(
1099              dict(
1100                  shape=lambda: shape,
1101                  unknown_keys="strip",
1102                  catchall=never(),
1103                  **process_params(**params),
1104              )
1105          )
1106  
1107  
1108  class ZodTuple(ZodType):
1109      _type = [ZodParsedType.list, ZodParsedType.tuple]
1110      _type_name = _type
1111  
1112      def _parse(self, input):
1113          status, ctx = self._process_input_params(input)
1114          if self._check_invalid_type(input):
1115              return INVALID
1116  
1117          items = self._def["items"]
1118          rest = self._def["rest"]
1119  
1120          if len(ctx.data) < len(items):
1121              add_issue_to_context(
1122                  ctx,
1123                  code=ZodIssueCode.too_small,
1124                  minimum=len(items),
1125                  inclusive=True,
1126                  type="list",
1127              )
1128              return INVALID
1129  
1130          if not rest and len(ctx.data) > len(items):
1131              add_issue_to_context(
1132                  ctx,
1133                  code=ZodIssueCode.too_big,
1134                  maximum=len(items),
1135                  inclusive=True,
1136                  type="list",
1137              )
1138              return INVALID
1139  
1140          from itertools import zip_longest
1141  
1142          results = [
1143              schema._parse(ParseInputLazyPath(ctx, item, ctx.path, i))
1144              for i, (item, schema) in enumerate(
1145                  zip_longest(ctx.data, items, fillvalue=rest)
1146              )
1147          ]
1148  
1149          return ParseStatus.merge_list(status, results)
1150  
1151      @property
1152      def items(self):
1153          return self._def["items"]
1154  
1155      def rest(self, rest):
1156          return ZodTuple({**self._def, "rest": rest})
1157  
1158      @classmethod
1159      def _create(cls, schemas, **params):
1160          return cls(dict(items=schemas, rest=None, **process_params(**params)))
1161  
1162  
1163  class ZodMapping(ZodType):
1164      _type = ZodParsedType.mapping
1165      _type_name = _type
1166  
1167      def _parse(self, input):
1168          status, ctx = self._process_input_params(input)
1169          if self._check_invalid_type(input):
1170              return INVALID
1171  
1172          key_type = self._def["key_type"]
1173          value_type = self._def["value_type"]
1174  
1175          pairs = [
1176              (
1177                  key_type._parse(ParseInputLazyPath(ctx, key, ctx.path, key)),
1178                  value_type._parse(
1179                      ParseInputLazyPath(
1180                          ctx, getitem(ctx.data, key, MISSING), ctx.path, key
1181                      )
1182                  ),
1183                  False,
1184              )
1185              for key in ctx.data
1186          ]
1187  
1188          return ParseStatus.merge_dict(status, pairs)
1189  
1190      @property
1191      def key_schema(self):
1192          return self._def["key_type"]
1193  
1194      @property
1195      def value_schema(self):
1196          return self._def["value_type"]
1197  
1198      element = value_schema
1199  
1200      @classmethod
1201      def _create(cls, keys, vals, **params):
1202          assert isinstance_(keys, ZodType) and isinstance_(
1203              keys, ZodType
1204          ), "expected schemas"
1205          return cls(dict(key_type=keys, value_type=vals, **process_params(**params)))
1206  
1207  
1208  class ZodLazy(ZodType):
1209      def _parse(self, input):
1210          ctx = self._get_or_return_ctx(input)
1211          return self.schema._parse(ParseInput(data=ctx.data, path=ctx.path, parent=ctx))
1212  
1213      @property
1214      def schema(self):
1215          return self._def["getter"]()
1216  
1217      @classmethod
1218      def _create(cls, getter, **params):
1219          return cls(dict(getter=getter, **process_params(**params)))
1220  
1221  
1222  class ZodLiteral(ZodType):
1223      def _parse(self, input):
1224          value = self._def["value"]
1225          data = input.data
1226          if value is data or (type(value) is type(data) and value == data):
1227              return ParseReturn(status=VALID, value=data)
1228          else:
1229              ctx = self._get_or_return_ctx(input)
1230              add_issue_to_context(ctx, code=ZodIssueCode.invalid_literal, expected=value)
1231              return INVALID
1232  
1233      @property
1234      def value(self):
1235          return self._def["value"]
1236  
1237      @classmethod
1238      def _create(cls, value, **params):
1239          return cls(dict(value=value, **process_params(**params)))
1240  
1241  
1242  class CheckContext:
1243      def __init__(self, status: ParseStatus, ctx: ParseContext):
1244          self.status = status
1245          self.ctx = ctx
1246  
1247      def add_issue(
1248          self, code=ZodIssueCode.custom, fatal=False, message="", **issue_data
1249      ):
1250          add_issue_to_context(
1251              self.ctx, code=code, fatal=fatal, message=message, **issue_data
1252          )
1253          if fatal:
1254              self.status.abort()
1255          else:
1256              self.status.dirty()
1257  
1258      @property
1259      def path(self):
1260          return self.ctx.path
1261  
1262  
1263  class ZodEffects(ZodType):
1264      def _parse(self, input):
1265          status, ctx = self._process_input_params(input)
1266  
1267          effect = self._def["effect"]
1268          check_ctx = CheckContext(status, ctx)
1269          effect_type = effect["type"]
1270  
1271          if effect_type == "preprocess":
1272              processed = effect["transform"](ctx.data)
1273              return self._def["schema"]._parse(ParseInput(processed, ctx.path, ctx))
1274  
1275          if effect_type == "refinment":
1276              inner = self._def["schema"]._parse(ParseInput(ctx.data, ctx.path, ctx))
1277              if inner.status is ABORTED:
1278                  return INVALID
1279              elif inner.status is DIRTY:
1280                  status.dirty()
1281              effect["refinement"](inner.value, check_ctx)
1282              return ParseReturn(status.value, inner.value)
1283  
1284          if effect_type == "transform":
1285              base = self._def["schema"]._parse(ParseInput(ctx.data, ctx.path, ctx))
1286              if not is_valid(base):
1287                  return base
1288  
1289              result = effect["transform"](base.value, check_ctx)
1290              return ParseReturn(status.value, result)
1291  
1292          assert False, "unnkown effect"
1293  
1294      @classmethod
1295      def _create(cls, schema, effect, **params):
1296          return cls(dict(schema=schema, effect=effect, **process_params(**params)))
1297  
1298      @classmethod
1299      def _preprocess(cls, preprocess, schema, **params):
1300          "transform the data before parsing it"
1301          return cls(
1302              dict(
1303                  schema=schema,
1304                  effect={"type": "preprocess", "transform": preprocess},
1305                  **process_params(**params),
1306              )
1307          )
1308  
1309  
1310  class ZodWraps(ZodType):
1311      _wraps = None
1312      _type = None
1313  
1314      def _parse(self, input):
1315          parse_type = self._get_type(input)
1316          if parse_type is self._type:
1317              return OK(self._wraps)
1318          return self._def["inner_type"]._parse(input)
1319  
1320      def unwrap(self):
1321          return self._def["inner_type"]
1322  
1323      @classmethod
1324      def _create(cls, type, **params):
1325          return cls(dict(inner_type=type, **process_params(**params)))
1326  
1327  
1328  class ZodNotRequired(ZodWraps):
1329      _wraps = MISSING
1330      _type = ZodParsedType.missing
1331      _type_name = _type
1332  
1333  
1334  class ZodOptional(ZodWraps):
1335      _wraps = None
1336      _type = ZodParsedType.none
1337      _type_name = _type
1338  
1339  
1340  class ZodDefaultAbstract(ZodType):
1341      def remove_default(self):
1342          return self._def["inner_type"]
1343  
1344      @classmethod
1345      def _create(cls, type, default, **params):
1346          default_ = default
1347          if not callable(default):
1348              default_ = lambda: default  # noqa E731
1349          return cls(dict(inner_type=type, default=default_, **process_params(**params)))
1350  
1351  
1352  class ZodDefault(ZodDefaultAbstract):
1353      def _parse(self, input):
1354          ctx = self._get_or_return_ctx(input)
1355          data = ctx.data
1356          if ctx.parsed_type is ZodParsedType.missing:
1357              data = self._def["default"]()
1358          return self._def["inner_type"]._parse(
1359              ParseInput(data, path=ctx.path, parent=ctx)
1360          )
1361  
1362  
1363  class ZodCatch(ZodDefaultAbstract):
1364      def _parse(self, input):
1365          ctx = self._get_or_return_ctx(input)
1366          result = self._def["inner_type"]._parse(ParseInput(ctx.data, ctx.path, ctx))
1367          value = result.value if result.status is VALID else self._def["default"]()
1368          return ParseReturn(VALID, value)
1369  
1370  
1371  class ZodUnion(ZodType):
1372      def _parse(self, input):
1373          ctx = self._get_or_return_ctx(input)
1374          options = self._def["options"]
1375          dirty = None
1376          issues = []
1377  
1378          for option in options:
1379              # child_ctx = ...
1380              child_ctx = ParseContext(
1381                  **{
1382                      **ctx,
1383                      "common": Common(**{**ctx.common, "issues": []}),
1384                      "parent": None,
1385                  }
1386              )
1387  
1388              result = option._parse(
1389                  ParseInput(data=ctx.data, path=ctx.path, parent=child_ctx)
1390              )
1391  
1392              if result.status is VALID:
1393                  return result
1394              elif result.status is DIRTY and not dirty:
1395                  dirty = {"result": result, "ctx": child_ctx}
1396  
1397              if child_ctx.common.issues:
1398                  issues.append(child_ctx.common.issues)  # should this be extend?
1399  
1400          if dirty:
1401              ctx.common.issues.extend(dirty["ctx"].common.issues)
1402              return dirty["result"]
1403  
1404          add_issue_to_context(ctx, code=ZodIssueCode.invalid_union, union_issues=issues)
1405          return INVALID
1406  
1407      @property
1408      def options(self):
1409          return self._def["options"]
1410  
1411      @classmethod
1412      def _create(cls, types, **params):
1413          return cls(dict(options=types, **process_params(**params)))
1414  
1415  
1416  class ZodPipeline(ZodType):
1417      def _parse(self, input):
1418          status, ctx = self._process_input_params(input)
1419          in_result = self._def["in"]._parse(
1420              ParseInput(data=ctx.data, path=ctx.path, parent=ctx)
1421          )
1422          if in_result.status is ABORTED:
1423              return INVALID
1424          if in_result.status is DIRTY:
1425              status.dirty()
1426              return ParseReturn(status=status.value, value=input.data)
1427          return self._def["out"]._parse(
1428              ParseInput(data=in_result.value, path=ctx.path, parent=ctx)
1429          )
1430  
1431      @classmethod
1432      def _create(cls, a: ZodType, b: ZodType, **params):
1433          return cls(dict({"in": a, "out": b}, **process_params(**params)))
1434  
1435  
1436  def custom(check=None, fatal=False, **params):
1437      if check is not None:
1438  
1439          def custom_check(data, ctx: CheckContext):
1440              if not check(data):
1441                  ctx.add_issue(fatal=fatal, **params)
1442  
1443          return ZodAny._create().super_refine(custom_check)
1444      return ZodAny._create()
1445  
1446  
1447  def isinstance(cls, message=""):
1448      message = message or f"Input not instance of {cls.__name__}"
1449      return custom(lambda data: isinstance_(data, cls), fatal=True, message=message)
1450  
1451  
1452  NEVER = INVALID
1453  
1454  any = ZodAny._create
1455  array = ZodList._create
1456  boolean = ZodBoolean._create
1457  date = ZodDate._create
1458  datetime = ZodDateTime._create
1459  enum = ZodEnum._create
1460  float = ZodFloat._create
1461  integer = ZodInteger._create
1462  lazy = ZodLazy._create
1463  list = ZodList._create
1464  literal = ZodLiteral._create
1465  mapping = ZodMapping._create
1466  never = ZodNever._create
1467  none = ZodNone._create
1468  not_required = ZodNotRequired._create
1469  number = ZodNumber._create
1470  object = ZodTypedDict._create
1471  optional = ZodOptional._create
1472  preprocess = ZodEffects._preprocess
1473  record = ZodMapping._create
1474  string = ZodString._create
1475  tuple = ZodTuple._create
1476  typed_dict = ZodTypedDict._create
1477  unknown = ZodUnknown._create
1478  union = ZodUnion._create
1479  
1480  
1481  class ZodCoercion:
1482      @staticmethod
1483      def string(**params):
1484          return string(coerce=True, **params)
1485  
1486      @staticmethod
1487      def integer(**params):
1488          return integer(coerce=True, **params)
1489  
1490      @staticmethod
1491      def float(**params):
1492          return float(coerce=True, **params)
1493  
1494      @staticmethod
1495      def boolean(**params):
1496          return boolean(coerce=True, **params)
1497  
1498  
1499  coerce = ZodCoercion()