/ src / evidently / pydantic_utils.py
pydantic_utils.py
  1  import dataclasses
  2  import hashlib
  3  import inspect
  4  import itertools
  5  import json
  6  import os
  7  import warnings
  8  from abc import ABC
  9  from enum import Enum
 10  from functools import lru_cache
 11  from typing import TYPE_CHECKING
 12  from typing import Any
 13  from typing import Callable
 14  from typing import ClassVar
 15  from typing import Dict
 16  from typing import FrozenSet
 17  from typing import Iterable
 18  from typing import List
 19  from typing import Literal
 20  from typing import Optional
 21  from typing import Set
 22  from typing import Tuple
 23  from typing import Type
 24  from typing import TypeVar
 25  from typing import Union
 26  from typing import get_args
 27  
 28  import numpy as np
 29  import yaml
 30  from typing_inspect import is_union_type
 31  
 32  from evidently._pydantic_compat import SHAPE_DICT
 33  from evidently._pydantic_compat import BaseConfig
 34  from evidently._pydantic_compat import BaseModel
 35  from evidently._pydantic_compat import Field
 36  from evidently._pydantic_compat import ModelMetaclass
 37  from evidently._pydantic_compat import import_string
 38  from evidently._pydantic_compat import parse_obj_as
 39  
 40  if TYPE_CHECKING:
 41      from evidently._pydantic_compat import DictStrAny
 42  
 43  md5_kwargs = {"usedforsecurity": False}
 44  
 45  
 46  T = TypeVar("T")
 47  
 48  
 49  def pydantic_type_validator(type_: Type[Any], prioritize: bool = False):
 50      def decorator(f):
 51          from evidently._pydantic_compat import _VALIDATORS
 52  
 53          for cls, validators in _VALIDATORS:
 54              if cls is type_:
 55                  if prioritize:
 56                      validators.insert(0, f)
 57                  else:
 58                      validators.append(f)
 59                  return
 60          if prioritize:
 61              _VALIDATORS.insert(0, (type_, [f]))
 62          else:
 63              _VALIDATORS.append(
 64                  (type_, [f]),
 65              )
 66  
 67      return decorator
 68  
 69  
 70  class FrozenBaseMeta(ModelMetaclass):
 71      def __new__(mcs, name, bases, namespace, **kwargs):
 72          res = super().__new__(mcs, name, bases, namespace, **kwargs)
 73          res.__config__.frozen = True
 74          return res
 75  
 76  
 77  object_setattr = object.__setattr__
 78  object_delattr = object.__delattr__
 79  
 80  
 81  class FrozenBaseModel(BaseModel, metaclass=FrozenBaseMeta):
 82      class Config:
 83          underscore_attrs_are_private = True
 84  
 85      _init_values: Optional[Dict]
 86  
 87      def __init__(self, **data: Any):
 88          super().__init__(**self.__init_values__, **data)
 89          for private_attr in self.__private_attributes__:
 90              if private_attr in self.__init_values__:
 91                  object_setattr(self, private_attr, self.__init_values__[private_attr])
 92          object_setattr(self, "_init_values", None)
 93  
 94      @property
 95      def __init_values__(self):
 96          if not hasattr(self, "_init_values"):
 97              object_setattr(self, "_init_values", {})
 98          return self._init_values
 99  
100      def __setattr__(self, key, value):
101          if self.__init_values__ is not None:
102              if key not in self.__fields__ and key not in self.__private_attributes__:
103                  raise AttributeError(f"{self.__class__.__name__} has no attribute {key}")
104              self.__init_values__[key] = value
105              return
106          super().__setattr__(key, value)
107  
108      def __hash__(self):
109          try:
110              return hash(self.__class__) + hash(tuple(self._field_hash(v) for v in self.__dict__.values()))
111          except TypeError:
112              raise
113  
114      @classmethod
115      def _field_hash(cls, value):
116          if isinstance(value, list):
117              return tuple(cls._field_hash(v) for v in value)
118          if isinstance(value, dict):
119              return tuple((k, cls._field_hash(v)) for k, v in value.items())
120          return value
121  
122  
123  def all_subclasses(cls: Type[T]) -> Set[Type[T]]:
124      return set(cls.__subclasses__()).union([s for c in cls.__subclasses__() for s in all_subclasses(c)])
125  
126  
127  ALLOWED_TYPE_PREFIXES = ["evidently."]
128  
129  EVIDENTLY_TYPE_PREFIXES_ENV = "EVIDENTLY_TYPE_PREFIXES"
130  ALLOWED_TYPE_PREFIXES.extend([p for p in os.environ.get(EVIDENTLY_TYPE_PREFIXES_ENV, "").split(",") if p])
131  
132  TYPE_ALIASES: Dict[Tuple[Type["PolymorphicModel"], str], str] = {}
133  LOADED_TYPE_ALIASES: Dict[Tuple[Type["PolymorphicModel"], str], Type["PolymorphicModel"]] = {}
134  
135  
136  def register_type_alias(base_class: Type["PolymorphicModel"], classpath: str, alias: str):
137      while True:
138          key = (base_class, alias)
139  
140          if key in TYPE_ALIASES and TYPE_ALIASES[key] != classpath and "PYTEST_CURRENT_TEST" not in os.environ:
141              warnings.warn(f"Duplicate key {key} in alias map")
142          TYPE_ALIASES[key] = classpath
143  
144          if base_class is PolymorphicModel:
145              break
146          base_class = get_base_class(base_class, ensure_parent=True)  # type: ignore[arg-type]
147          if not base_class.__config__.transitive_aliases:
148              break
149  
150  
151  def autoregister(cls: Type["PolymorphicModel"]):
152      """Decorator that automatically registers subclass.
153      Can only be used on subclasses that are defined in the same file as base class
154      (or if the import of this subclass is guaranteed when base class is imported)
155      """
156      register_type_alias(get_base_class(cls), get_classpath(cls), cls.__get_type__())  # type: ignore[arg-type]
157      return cls
158  
159  
160  def register_loaded_alias(base_class: Type["PolymorphicModel"], cls: Type["PolymorphicModel"], alias: str):
161      if not issubclass(cls, base_class):
162          raise ValueError(f"Cannot register alias: {cls.__name__} is not subclass of {base_class.__name__}")
163  
164      key = (base_class, alias)
165      if key in LOADED_TYPE_ALIASES and LOADED_TYPE_ALIASES[key] != cls and "PYTEST_CURRENT_TEST" not in os.environ:
166          warnings.warn(f"Duplicate key {key} in alias map")
167      LOADED_TYPE_ALIASES[key] = cls
168  
169  
170  @lru_cache()
171  def get_base_class(cls: Type["PolymorphicModel"], ensure_parent: bool = False) -> Type["PolymorphicModel"]:
172      for cls_ in cls.mro():
173          if ensure_parent and cls_ is cls:
174              continue
175          if not issubclass(cls_, PolymorphicModel):
176              continue
177          config = cls_.__dict__.get("Config")
178          if config is not None and config.__dict__.get("is_base_type", False):
179              return cls_
180      return PolymorphicModel
181  
182  
183  def get_classpath(cls: Type) -> str:
184      return f"{cls.__module__}.{cls.__name__}"
185  
186  
187  TPM = TypeVar("TPM", bound="PolymorphicModel")
188  
189  Fingerprint = str
190  FingerprintPart = Union[None, int, str, float, bool, bytes, Tuple["FingerprintPart", ...]]
191  
192  
193  def is_not_abstract(cls):
194      return not (inspect.isabstract(cls) or ABC in cls.__bases__)
195  
196  
197  class PolymorphicModel(BaseModel):
198      class Config(BaseConfig):
199          # value to put into "type" field
200          type_alias: ClassVar[Optional[str]] = None
201          # flag to mark alias required. If not required, classpath is used by default
202          alias_required: ClassVar[bool] = True
203          # flag to register aliaes for grand-parent base type
204          # eg PolymorphicModel -> A -> B -> C, where A and B are base types. only if A has this flag, C can be parsed as both A and B.
205          transitive_aliases: ClassVar[bool] = False
206          # flag to mark type as base. This means it will be possible to parse all subclasses of it as this type
207          is_base_type: ClassVar[bool] = False
208  
209      __config__: ClassVar[Type[Config]] = Config
210  
211      @classmethod
212      def __get_type__(cls) -> str:
213          config = cls.__dict__.get("Config")
214          if config is not None and config.__dict__.get("type_alias") is not None:
215              return config.type_alias
216          if cls.__config__.alias_required and is_not_abstract(cls):
217              raise ValueError(f"Alias is required for {cls.__name__}")
218          return cls.__get_classpath__()
219  
220      @classmethod
221      def __get_classpath__(cls):
222          return get_classpath(cls)
223  
224      type: str = Field("")
225  
226      def __init_subclass__(cls):
227          super().__init_subclass__()
228          if cls == PolymorphicModel:
229              return
230  
231          typename = cls.__get_type__()
232          literal_typename = Literal[typename]
233  
234          type_field = cls.__fields__["type"]
235          type_field.default = typename
236          type_field.field_info.default = typename
237          type_field.type_ = type_field.outer_type_ = literal_typename
238  
239          base_class = get_base_class(cls)
240          if (base_class, typename) not in LOADED_TYPE_ALIASES:
241              register_loaded_alias(base_class, cls, typename)
242          if base_class != cls:
243              base_typefield = base_class.__fields__["type"]
244              base_typefield_type = base_typefield.type_
245              if is_union_type(base_typefield_type):
246                  subclass_literals = get_args(base_typefield_type) + (literal_typename,)
247              else:
248                  subclass_literals = (base_typefield_type, literal_typename)
249              base_typefield.type_ = base_typefield.outer_type_ = Union[subclass_literals]
250  
251      @classmethod
252      def __subtypes__(cls: Type[TPM]) -> Tuple[Type["TPM"], ...]:
253          return tuple(all_subclasses(cls))
254  
255      @classmethod
256      def __is_base_type__(cls) -> bool:
257          config = cls.__dict__.get("Config")
258          if config is not None and config.__dict__.get("is_base_type") is not None:
259              return config.is_base_type
260          return False
261  
262      @classmethod
263      def validate(cls: Type[TPM], value: Any) -> TPM:
264          if isinstance(value, dict) and "type" in value:
265              typename = value.pop("type")
266              try:
267                  subcls = cls.load_alias(typename)
268                  return subcls.validate(value)  # type: ignore[return-value]
269              finally:
270                  value["type"] = typename
271          return super().validate(value)  # type: ignore[misc]
272  
273      @classmethod
274      def load_alias(cls, typename):
275          key = (get_base_class(cls), typename)  # type: ignore[arg-type]
276          if key in LOADED_TYPE_ALIASES:
277              subcls = LOADED_TYPE_ALIASES[key]
278          else:
279              if key in TYPE_ALIASES:
280                  classpath = TYPE_ALIASES[key]
281              else:
282                  if "." not in typename:
283                      raise ValueError(f'Unknown alias "{typename}"')
284                  classpath = typename
285              if not any(classpath.startswith(p) for p in ALLOWED_TYPE_PREFIXES):
286                  raise ValueError(f"{classpath} does not match any allowed prefixes")
287              try:
288                  subcls = import_string(classpath)
289              except ImportError as e:
290                  raise ValueError(f"Error importing subclass from '{classpath}' {e.args[0]}") from e
291          return subcls
292  
293  
294  def get_value_fingerprint(value: Any) -> FingerprintPart:
295      if isinstance(value, EvidentlyBaseModel):
296          return value.get_fingerprint()
297      if isinstance(value, np.int64):
298          return int(value)
299      if isinstance(value, BaseModel):
300          return get_value_fingerprint(value.dict())
301      if dataclasses.is_dataclass(value):
302          return get_value_fingerprint(dataclasses.asdict(value))
303      if isinstance(value, Enum):
304          return value.value
305      if isinstance(value, (str, int, float, bool, type(None))):
306          return value
307      if isinstance(value, dict):
308          return tuple((get_value_fingerprint(k), get_value_fingerprint(v)) for k, v in sorted(value.items()))
309      if isinstance(value, (list, tuple)):
310          return tuple(get_value_fingerprint(v) for v in value)
311      if isinstance(value, (set, frozenset)):
312          return tuple(get_value_fingerprint(v) for v in sorted(value, key=str))
313      if isinstance(value, Callable):  # type: ignore
314          return hash(value)
315      raise NotImplementedError(
316          f"Not implemented for value of type {value.__class__.__module__}.{value.__class__.__name__}"
317      )
318  
319  
320  EBM = TypeVar("EBM", bound="EvidentlyBaseModel")
321  
322  
323  def _is_yaml_fmt(path: str, fmt: Literal["yaml", "json", None]) -> bool:
324      if fmt == "yaml":
325          return True
326      if fmt == "json":
327          return False
328      return path.endswith(".yml") or path.endswith(".yaml")
329  
330  
331  class EvidentlyBaseModel(FrozenBaseModel, PolymorphicModel):
332      class Config:
333          type_alias = "evidently:base:EvidentlyBaseModel"
334          alias_required = True
335          is_base_type = True
336  
337      def get_fingerprint(self) -> Fingerprint:
338          classpath = self.__get_classpath__()
339          if ".legacy" in classpath:
340              classpath = classpath.replace(".legacy", "")
341          return hashlib.md5((classpath + str(self.get_fingerprint_parts())).encode("utf8"), **md5_kwargs).hexdigest()
342  
343      def get_fingerprint_parts(self) -> Tuple[FingerprintPart, ...]:
344          return tuple(
345              (name, self.get_field_fingerprint(name))
346              for name, field in sorted(self.__fields__.items())
347              if field.required or getattr(self, name) != field.get_default()
348          )
349  
350      def get_field_fingerprint(self, field: str) -> FingerprintPart:
351          value = getattr(self, field)
352          return get_value_fingerprint(value)
353  
354      def update(self: EBM, **kwargs) -> EBM:
355          data = self.dict()
356          data.update(kwargs)
357          return self.__class__(**data)
358  
359      @classmethod
360      def load(cls: Type[EBM], path: str, fmt: Literal["json", "yaml", None] = None) -> EBM:
361          with open(path, "r") as f:
362              if _is_yaml_fmt(path, fmt):
363                  data = yaml.safe_load(f)
364              else:
365                  data = json.load(f)
366              return parse_obj_as(cls, data)
367  
368      def dump(self, path: str, fmt: Literal["json", "yaml", None] = None):
369          with open(path, "w") as f:
370              if _is_yaml_fmt(path, fmt):
371                  yaml.safe_dump(json.loads(self.json()), f)
372              else:
373                  f.write(self.json(indent=2, ensure_ascii=False))
374  
375  
376  @autoregister
377  class WithTestAndMetricDependencies(EvidentlyBaseModel):
378      class Config:
379          type_alias = "evidently:test:WithTestAndMetricDependencies"
380  
381      def __evidently_dependencies__(self):
382          from evidently.legacy.base_metric import Metric
383          from evidently.legacy.tests.base_test import Test
384  
385          for field_name, field in itertools.chain(
386              self.__dict__.items(), ((pa, getattr(self, pa, None)) for pa in self.__private_attributes__)
387          ):
388              if issubclass(type(field), (Metric, Test)):
389                  yield field_name, field
390  
391  
392  class EnumValueMixin(BaseModel):
393      def _to_enum_value(self, key, value):
394          field = self.__fields__[key]
395          if isinstance(field.type_, type) and not issubclass(field.type_, Enum):
396              return value
397  
398          if isinstance(value, list):
399              return [v.value if isinstance(v, Enum) else v for v in value]
400  
401          if isinstance(value, frozenset):
402              return frozenset(v.value if isinstance(v, Enum) else v for v in value)
403  
404          if isinstance(value, set):
405              return {v.value if isinstance(v, Enum) else v for v in value}
406          return value.value if isinstance(value, Enum) else value
407  
408      def dict(self, *args, **kwargs) -> "DictStrAny":
409          res = super().dict(*args, **kwargs)
410          return {k: self._to_enum_value(k, v) for k, v in res.items()}
411  
412  
413  class ExcludeNoneMixin(BaseModel):
414      def dict(self, *args, **kwargs) -> "DictStrAny":
415          kwargs["exclude_none"] = True
416          return super().dict(*args, **kwargs)
417  
418  
419  class FieldTags(Enum):
420      Parameter = "parameter"
421      Current = "current"
422      Reference = "reference"
423      Render = "render"
424      TypeField = "type_field"
425      Extra = "extra"
426  
427  
428  IncludeTags = FieldTags  # fixme: tmp for compatibility, remove in separate PR
429  
430  
431  class FieldInfo(EnumValueMixin):
432      class Config:
433          frozen = True
434  
435      path: str
436      tags: FrozenSet[FieldTags]
437      classpath: str
438  
439      def __lt__(self, other):
440          return self.path < other.path
441  
442  
443  def _to_path(path: List[Any]) -> str:
444      return ".".join(str(p) for p in path)
445  
446  
447  class FieldPath:
448      def __init__(self, path: List[Any], cls_or_instance: Union[Type, Any], is_mapping: bool = False):
449          self._path = path
450          self._cls: Type
451          self._instance: Any
452          if is_union_type(cls_or_instance):
453              cls_or_instance = get_args(cls_or_instance)[0]
454          if isinstance(cls_or_instance, type):
455              self._cls = cls_or_instance
456              self._instance = None
457          else:
458              self._cls = type(cls_or_instance)
459              self._instance = cls_or_instance
460          self._is_mapping = is_mapping
461  
462      @property
463      def has_instance(self):
464          return self._instance is not None
465  
466      def list_fields(self) -> List[str]:
467          if self.has_instance and self._is_mapping and isinstance(self._instance, dict):
468              return list(self._instance.keys())
469          if isinstance(self._cls, type) and issubclass(self._cls, BaseModel):
470              return list(self._cls.__fields__)
471          return []
472  
473      def __getattr__(self, item) -> "FieldPath":
474          return self.child(item)
475  
476      def child(self, item: str) -> "FieldPath":
477          if self._is_mapping:
478              if self.has_instance and isinstance(self._instance, dict):
479                  return FieldPath(self._path + [item], self._instance[item])
480              return FieldPath(self._path + [item], self._cls)
481          if not issubclass(self._cls, BaseModel):
482              raise AttributeError(f"{self._cls} does not have fields")
483          if item not in self._cls.__fields__:
484              raise AttributeError(f"{self._cls} type does not have '{item}' field")
485          field = self._cls.__fields__[item]
486          field_value = field.type_
487          is_mapping = field.shape == SHAPE_DICT
488          if self.has_instance:
489              field_value = getattr(self._instance, item)
490              if is_mapping:
491                  return FieldPath(self._path + [item], field_value, is_mapping=True)
492          return FieldPath(self._path + [item], field_value, is_mapping=is_mapping)
493  
494      def list_nested_fields(self, exclude: Set["IncludeTags"] = None) -> List[str]:
495          if not isinstance(self._cls, type) or not issubclass(self._cls, BaseModel):
496              return [repr(self)]
497          res = []
498          for name, field in self._cls.__fields__.items():
499              field_value = field.type_
500              # todo: do something with recursive imports
501              from evidently.legacy.core import get_field_tags
502  
503              field_tags = get_field_tags(self._cls, name)
504              if field_tags is not None and (exclude is not None and any(t in exclude for t in field_tags)):
505                  continue
506              is_mapping = field.shape == SHAPE_DICT
507              if self.has_instance:
508                  field_value = getattr(self._instance, name)
509                  if is_mapping and isinstance(field_value, dict):
510                      for key, value in field_value.items():
511                          res.extend(FieldPath(self._path + [name, str(key)], value).list_nested_fields(exclude=exclude))
512                      continue
513              else:
514                  if is_mapping:
515                      name = f"{name}.*"
516              res.extend(FieldPath(self._path + [name], field_value).list_nested_fields(exclude=exclude))
517          return res
518  
519      def _list_with_tags(self, current_tags: Set["IncludeTags"]) -> List[Tuple[List[Any], Set["IncludeTags"]]]:
520          if not isinstance(self._cls, type) or not issubclass(self._cls, BaseModel):
521              return [(self._path, current_tags)]
522          from evidently.legacy.core import BaseResult
523  
524          if issubclass(self._cls, BaseResult) and self._cls.__config__.extract_as_obj:
525              return [(self._path, current_tags)]
526          res = []
527          from evidently.ui.backport import ByLabelCountValueV1
528          from evidently.ui.backport import ByLabelValueV1
529  
530          if issubclass(self._cls, ByLabelValueV1):
531              res.append((self._path + ["values"], current_tags.union({IncludeTags.Render})))
532          if issubclass(self._cls, ByLabelCountValueV1):
533              res.append((self._path + ["counts"], current_tags.union({IncludeTags.Render})))
534              res.append((self._path + ["shares"], current_tags.union({IncludeTags.Render})))
535          for name, field in self._cls.__fields__.items():
536              field_value = field.type_
537  
538              # todo: do something with recursive imports
539              from evidently.legacy.core import get_field_tags
540  
541              field_tags = get_field_tags(self._cls, name)
542  
543              is_mapping = field.shape == SHAPE_DICT
544              if self.has_instance:
545                  field_value = getattr(self._instance, name)
546                  if is_mapping and isinstance(field_value, dict):
547                      for key, value in field_value.items():
548                          res.extend(
549                              FieldPath(self._path + [name, key], value)._list_with_tags(current_tags.union(field_tags))
550                          )
551                      continue
552              else:
553                  if is_mapping:
554                      name = f"{name}.*"
555              res.extend(FieldPath(self._path + [name], field_value)._list_with_tags(current_tags.union(field_tags)))
556          return res
557  
558      def list_nested_fields_with_tags(self) -> List[Tuple[str, Set["IncludeTags"]]]:
559          return [(_to_path(path), tags) for path, tags in self._list_with_tags(set())]
560  
561      def list_nested_field_infos(self) -> List[FieldInfo]:
562          return [
563              FieldInfo(path=_to_path(path), tags=frozenset(tags), classpath=get_classpath(self._get_field_type(path)))
564              for path, tags in self._list_with_tags(set())
565          ]
566  
567      def _get_field_type(self, path: List[str]) -> Type:
568          if len(path) == 0:
569              raise ValueError("Empty path provided")
570          if len(path) == 1:
571              if isinstance(self._cls, type) and issubclass(self._cls, BaseModel):
572                  return self._cls.__fields__[path[0]].outer_type_
573              if self.has_instance:
574                  # fixme: tmp fix
575                  # in case of field like f: Dict[str, A] we wont know that value was type annotated with A when we get to it
576                  if isinstance(self._instance, dict):
577                      return type(self._instance.get(path[0]))
578              raise NotImplementedError(f"Not implemented for {self._cls.__name__}")
579          child, *path = path
580          return self.child(child)._get_field_type(path)
581  
582      def __repr__(self):
583          return self.get_path()
584  
585      def get_path(self):
586          return ".".join(self._path)
587  
588      def __dir__(self) -> Iterable[str]:
589          res: List[str] = []
590          res.extend(super().__dir__())
591          res.extend(self.list_fields())
592          return res
593  
594      def get_field_tags(self, path: List[str]) -> Optional[Set["IncludeTags"]]:
595          from evidently.legacy.base_metric import BaseResult
596  
597          if not isinstance(self._cls, type) or not issubclass(self._cls, BaseResult):
598              return None
599          self_tags = self._cls.__config__.tags
600          if len(path) == 0:
601              return self_tags
602          field_name, *path = path
603          # todo: do something with recursive imports
604          from evidently.legacy.core import get_field_tags
605  
606          field_tags = get_field_tags(self._cls, field_name)
607          return self_tags.union(field_tags).union(self.child(field_name).get_field_tags(path) or tuple())
608  
609  
610  @pydantic_type_validator(FieldPath)
611  def series_validator(value):
612      return value.get_path()
613  
614  
615  def get_object_hash_deprecated(obj: Union[BaseModel, dict]):
616      from evidently.legacy.utils import NumpyEncoder
617  
618      if isinstance(obj, BaseModel):
619          obj = obj.dict()
620      return hashlib.md5(json.dumps(obj, cls=NumpyEncoder).encode("utf8"), **md5_kwargs).hexdigest()  # nosec: B324
621  
622  
623  class AutoAliasMixin:
624      __alias_type__: ClassVar[str]
625  
626      @classmethod
627      def __get_type__(cls) -> str:
628          config = cls.__dict__.get("Config")
629          if config is not None and config.__dict__.get("type_alias") is not None:
630              return config.type_alias
631          return f"evidently:{cls.__alias_type__}:{cls.__name__}"