pydantic_utils.py
1 import dataclasses 2 import hashlib 3 import inspect 4 import itertools 5 import json 6 import os 7 import warnings 8 from abc import ABC 9 from enum import Enum 10 from functools import lru_cache 11 from typing import TYPE_CHECKING 12 from typing import Any 13 from typing import Callable 14 from typing import ClassVar 15 from typing import Dict 16 from typing import FrozenSet 17 from typing import Iterable 18 from typing import List 19 from typing import Literal 20 from typing import Optional 21 from typing import Set 22 from typing import Tuple 23 from typing import Type 24 from typing import TypeVar 25 from typing import Union 26 from typing import get_args 27 28 import numpy as np 29 import yaml 30 from typing_inspect import is_union_type 31 32 from evidently._pydantic_compat import SHAPE_DICT 33 from evidently._pydantic_compat import BaseConfig 34 from evidently._pydantic_compat import BaseModel 35 from evidently._pydantic_compat import Field 36 from evidently._pydantic_compat import ModelMetaclass 37 from evidently._pydantic_compat import import_string 38 from evidently._pydantic_compat import parse_obj_as 39 40 if TYPE_CHECKING: 41 from evidently._pydantic_compat import DictStrAny 42 43 md5_kwargs = {"usedforsecurity": False} 44 45 46 T = TypeVar("T") 47 48 49 def pydantic_type_validator(type_: Type[Any], prioritize: bool = False): 50 def decorator(f): 51 from evidently._pydantic_compat import _VALIDATORS 52 53 for cls, validators in _VALIDATORS: 54 if cls is type_: 55 if prioritize: 56 validators.insert(0, f) 57 else: 58 validators.append(f) 59 return 60 if prioritize: 61 _VALIDATORS.insert(0, (type_, [f])) 62 else: 63 _VALIDATORS.append( 64 (type_, [f]), 65 ) 66 67 return decorator 68 69 70 class FrozenBaseMeta(ModelMetaclass): 71 def __new__(mcs, name, bases, namespace, **kwargs): 72 res = super().__new__(mcs, name, bases, namespace, **kwargs) 73 res.__config__.frozen = True 74 return res 75 76 77 object_setattr = object.__setattr__ 78 object_delattr = object.__delattr__ 79 80 81 class FrozenBaseModel(BaseModel, metaclass=FrozenBaseMeta): 82 class Config: 83 underscore_attrs_are_private = True 84 85 _init_values: Optional[Dict] 86 87 def __init__(self, **data: Any): 88 super().__init__(**self.__init_values__, **data) 89 for private_attr in self.__private_attributes__: 90 if private_attr in self.__init_values__: 91 object_setattr(self, private_attr, self.__init_values__[private_attr]) 92 object_setattr(self, "_init_values", None) 93 94 @property 95 def __init_values__(self): 96 if not hasattr(self, "_init_values"): 97 object_setattr(self, "_init_values", {}) 98 return self._init_values 99 100 def __setattr__(self, key, value): 101 if self.__init_values__ is not None: 102 if key not in self.__fields__ and key not in self.__private_attributes__: 103 raise AttributeError(f"{self.__class__.__name__} has no attribute {key}") 104 self.__init_values__[key] = value 105 return 106 super().__setattr__(key, value) 107 108 def __hash__(self): 109 try: 110 return hash(self.__class__) + hash(tuple(self._field_hash(v) for v in self.__dict__.values())) 111 except TypeError: 112 raise 113 114 @classmethod 115 def _field_hash(cls, value): 116 if isinstance(value, list): 117 return tuple(cls._field_hash(v) for v in value) 118 if isinstance(value, dict): 119 return tuple((k, cls._field_hash(v)) for k, v in value.items()) 120 return value 121 122 123 def all_subclasses(cls: Type[T]) -> Set[Type[T]]: 124 return set(cls.__subclasses__()).union([s for c in cls.__subclasses__() for s in all_subclasses(c)]) 125 126 127 ALLOWED_TYPE_PREFIXES = ["evidently."] 128 129 EVIDENTLY_TYPE_PREFIXES_ENV = "EVIDENTLY_TYPE_PREFIXES" 130 ALLOWED_TYPE_PREFIXES.extend([p for p in os.environ.get(EVIDENTLY_TYPE_PREFIXES_ENV, "").split(",") if p]) 131 132 TYPE_ALIASES: Dict[Tuple[Type["PolymorphicModel"], str], str] = {} 133 LOADED_TYPE_ALIASES: Dict[Tuple[Type["PolymorphicModel"], str], Type["PolymorphicModel"]] = {} 134 135 136 def register_type_alias(base_class: Type["PolymorphicModel"], classpath: str, alias: str): 137 while True: 138 key = (base_class, alias) 139 140 if key in TYPE_ALIASES and TYPE_ALIASES[key] != classpath and "PYTEST_CURRENT_TEST" not in os.environ: 141 warnings.warn(f"Duplicate key {key} in alias map") 142 TYPE_ALIASES[key] = classpath 143 144 if base_class is PolymorphicModel: 145 break 146 base_class = get_base_class(base_class, ensure_parent=True) # type: ignore[arg-type] 147 if not base_class.__config__.transitive_aliases: 148 break 149 150 151 def autoregister(cls: Type["PolymorphicModel"]): 152 """Decorator that automatically registers subclass. 153 Can only be used on subclasses that are defined in the same file as base class 154 (or if the import of this subclass is guaranteed when base class is imported) 155 """ 156 register_type_alias(get_base_class(cls), get_classpath(cls), cls.__get_type__()) # type: ignore[arg-type] 157 return cls 158 159 160 def register_loaded_alias(base_class: Type["PolymorphicModel"], cls: Type["PolymorphicModel"], alias: str): 161 if not issubclass(cls, base_class): 162 raise ValueError(f"Cannot register alias: {cls.__name__} is not subclass of {base_class.__name__}") 163 164 key = (base_class, alias) 165 if key in LOADED_TYPE_ALIASES and LOADED_TYPE_ALIASES[key] != cls and "PYTEST_CURRENT_TEST" not in os.environ: 166 warnings.warn(f"Duplicate key {key} in alias map") 167 LOADED_TYPE_ALIASES[key] = cls 168 169 170 @lru_cache() 171 def get_base_class(cls: Type["PolymorphicModel"], ensure_parent: bool = False) -> Type["PolymorphicModel"]: 172 for cls_ in cls.mro(): 173 if ensure_parent and cls_ is cls: 174 continue 175 if not issubclass(cls_, PolymorphicModel): 176 continue 177 config = cls_.__dict__.get("Config") 178 if config is not None and config.__dict__.get("is_base_type", False): 179 return cls_ 180 return PolymorphicModel 181 182 183 def get_classpath(cls: Type) -> str: 184 return f"{cls.__module__}.{cls.__name__}" 185 186 187 TPM = TypeVar("TPM", bound="PolymorphicModel") 188 189 Fingerprint = str 190 FingerprintPart = Union[None, int, str, float, bool, bytes, Tuple["FingerprintPart", ...]] 191 192 193 def is_not_abstract(cls): 194 return not (inspect.isabstract(cls) or ABC in cls.__bases__) 195 196 197 class PolymorphicModel(BaseModel): 198 class Config(BaseConfig): 199 # value to put into "type" field 200 type_alias: ClassVar[Optional[str]] = None 201 # flag to mark alias required. If not required, classpath is used by default 202 alias_required: ClassVar[bool] = True 203 # flag to register aliaes for grand-parent base type 204 # eg PolymorphicModel -> A -> B -> C, where A and B are base types. only if A has this flag, C can be parsed as both A and B. 205 transitive_aliases: ClassVar[bool] = False 206 # flag to mark type as base. This means it will be possible to parse all subclasses of it as this type 207 is_base_type: ClassVar[bool] = False 208 209 __config__: ClassVar[Type[Config]] = Config 210 211 @classmethod 212 def __get_type__(cls) -> str: 213 config = cls.__dict__.get("Config") 214 if config is not None and config.__dict__.get("type_alias") is not None: 215 return config.type_alias 216 if cls.__config__.alias_required and is_not_abstract(cls): 217 raise ValueError(f"Alias is required for {cls.__name__}") 218 return cls.__get_classpath__() 219 220 @classmethod 221 def __get_classpath__(cls): 222 return get_classpath(cls) 223 224 type: str = Field("") 225 226 def __init_subclass__(cls): 227 super().__init_subclass__() 228 if cls == PolymorphicModel: 229 return 230 231 typename = cls.__get_type__() 232 literal_typename = Literal[typename] 233 234 type_field = cls.__fields__["type"] 235 type_field.default = typename 236 type_field.field_info.default = typename 237 type_field.type_ = type_field.outer_type_ = literal_typename 238 239 base_class = get_base_class(cls) 240 if (base_class, typename) not in LOADED_TYPE_ALIASES: 241 register_loaded_alias(base_class, cls, typename) 242 if base_class != cls: 243 base_typefield = base_class.__fields__["type"] 244 base_typefield_type = base_typefield.type_ 245 if is_union_type(base_typefield_type): 246 subclass_literals = get_args(base_typefield_type) + (literal_typename,) 247 else: 248 subclass_literals = (base_typefield_type, literal_typename) 249 base_typefield.type_ = base_typefield.outer_type_ = Union[subclass_literals] 250 251 @classmethod 252 def __subtypes__(cls: Type[TPM]) -> Tuple[Type["TPM"], ...]: 253 return tuple(all_subclasses(cls)) 254 255 @classmethod 256 def __is_base_type__(cls) -> bool: 257 config = cls.__dict__.get("Config") 258 if config is not None and config.__dict__.get("is_base_type") is not None: 259 return config.is_base_type 260 return False 261 262 @classmethod 263 def validate(cls: Type[TPM], value: Any) -> TPM: 264 if isinstance(value, dict) and "type" in value: 265 typename = value.pop("type") 266 try: 267 subcls = cls.load_alias(typename) 268 return subcls.validate(value) # type: ignore[return-value] 269 finally: 270 value["type"] = typename 271 return super().validate(value) # type: ignore[misc] 272 273 @classmethod 274 def load_alias(cls, typename): 275 key = (get_base_class(cls), typename) # type: ignore[arg-type] 276 if key in LOADED_TYPE_ALIASES: 277 subcls = LOADED_TYPE_ALIASES[key] 278 else: 279 if key in TYPE_ALIASES: 280 classpath = TYPE_ALIASES[key] 281 else: 282 if "." not in typename: 283 raise ValueError(f'Unknown alias "{typename}"') 284 classpath = typename 285 if not any(classpath.startswith(p) for p in ALLOWED_TYPE_PREFIXES): 286 raise ValueError(f"{classpath} does not match any allowed prefixes") 287 try: 288 subcls = import_string(classpath) 289 except ImportError as e: 290 raise ValueError(f"Error importing subclass from '{classpath}' {e.args[0]}") from e 291 return subcls 292 293 294 def get_value_fingerprint(value: Any) -> FingerprintPart: 295 if isinstance(value, EvidentlyBaseModel): 296 return value.get_fingerprint() 297 if isinstance(value, np.int64): 298 return int(value) 299 if isinstance(value, BaseModel): 300 return get_value_fingerprint(value.dict()) 301 if dataclasses.is_dataclass(value): 302 return get_value_fingerprint(dataclasses.asdict(value)) 303 if isinstance(value, Enum): 304 return value.value 305 if isinstance(value, (str, int, float, bool, type(None))): 306 return value 307 if isinstance(value, dict): 308 return tuple((get_value_fingerprint(k), get_value_fingerprint(v)) for k, v in sorted(value.items())) 309 if isinstance(value, (list, tuple)): 310 return tuple(get_value_fingerprint(v) for v in value) 311 if isinstance(value, (set, frozenset)): 312 return tuple(get_value_fingerprint(v) for v in sorted(value, key=str)) 313 if isinstance(value, Callable): # type: ignore 314 return hash(value) 315 raise NotImplementedError( 316 f"Not implemented for value of type {value.__class__.__module__}.{value.__class__.__name__}" 317 ) 318 319 320 EBM = TypeVar("EBM", bound="EvidentlyBaseModel") 321 322 323 def _is_yaml_fmt(path: str, fmt: Literal["yaml", "json", None]) -> bool: 324 if fmt == "yaml": 325 return True 326 if fmt == "json": 327 return False 328 return path.endswith(".yml") or path.endswith(".yaml") 329 330 331 class EvidentlyBaseModel(FrozenBaseModel, PolymorphicModel): 332 class Config: 333 type_alias = "evidently:base:EvidentlyBaseModel" 334 alias_required = True 335 is_base_type = True 336 337 def get_fingerprint(self) -> Fingerprint: 338 classpath = self.__get_classpath__() 339 if ".legacy" in classpath: 340 classpath = classpath.replace(".legacy", "") 341 return hashlib.md5((classpath + str(self.get_fingerprint_parts())).encode("utf8"), **md5_kwargs).hexdigest() 342 343 def get_fingerprint_parts(self) -> Tuple[FingerprintPart, ...]: 344 return tuple( 345 (name, self.get_field_fingerprint(name)) 346 for name, field in sorted(self.__fields__.items()) 347 if field.required or getattr(self, name) != field.get_default() 348 ) 349 350 def get_field_fingerprint(self, field: str) -> FingerprintPart: 351 value = getattr(self, field) 352 return get_value_fingerprint(value) 353 354 def update(self: EBM, **kwargs) -> EBM: 355 data = self.dict() 356 data.update(kwargs) 357 return self.__class__(**data) 358 359 @classmethod 360 def load(cls: Type[EBM], path: str, fmt: Literal["json", "yaml", None] = None) -> EBM: 361 with open(path, "r") as f: 362 if _is_yaml_fmt(path, fmt): 363 data = yaml.safe_load(f) 364 else: 365 data = json.load(f) 366 return parse_obj_as(cls, data) 367 368 def dump(self, path: str, fmt: Literal["json", "yaml", None] = None): 369 with open(path, "w") as f: 370 if _is_yaml_fmt(path, fmt): 371 yaml.safe_dump(json.loads(self.json()), f) 372 else: 373 f.write(self.json(indent=2, ensure_ascii=False)) 374 375 376 @autoregister 377 class WithTestAndMetricDependencies(EvidentlyBaseModel): 378 class Config: 379 type_alias = "evidently:test:WithTestAndMetricDependencies" 380 381 def __evidently_dependencies__(self): 382 from evidently.legacy.base_metric import Metric 383 from evidently.legacy.tests.base_test import Test 384 385 for field_name, field in itertools.chain( 386 self.__dict__.items(), ((pa, getattr(self, pa, None)) for pa in self.__private_attributes__) 387 ): 388 if issubclass(type(field), (Metric, Test)): 389 yield field_name, field 390 391 392 class EnumValueMixin(BaseModel): 393 def _to_enum_value(self, key, value): 394 field = self.__fields__[key] 395 if isinstance(field.type_, type) and not issubclass(field.type_, Enum): 396 return value 397 398 if isinstance(value, list): 399 return [v.value if isinstance(v, Enum) else v for v in value] 400 401 if isinstance(value, frozenset): 402 return frozenset(v.value if isinstance(v, Enum) else v for v in value) 403 404 if isinstance(value, set): 405 return {v.value if isinstance(v, Enum) else v for v in value} 406 return value.value if isinstance(value, Enum) else value 407 408 def dict(self, *args, **kwargs) -> "DictStrAny": 409 res = super().dict(*args, **kwargs) 410 return {k: self._to_enum_value(k, v) for k, v in res.items()} 411 412 413 class ExcludeNoneMixin(BaseModel): 414 def dict(self, *args, **kwargs) -> "DictStrAny": 415 kwargs["exclude_none"] = True 416 return super().dict(*args, **kwargs) 417 418 419 class FieldTags(Enum): 420 Parameter = "parameter" 421 Current = "current" 422 Reference = "reference" 423 Render = "render" 424 TypeField = "type_field" 425 Extra = "extra" 426 427 428 IncludeTags = FieldTags # fixme: tmp for compatibility, remove in separate PR 429 430 431 class FieldInfo(EnumValueMixin): 432 class Config: 433 frozen = True 434 435 path: str 436 tags: FrozenSet[FieldTags] 437 classpath: str 438 439 def __lt__(self, other): 440 return self.path < other.path 441 442 443 def _to_path(path: List[Any]) -> str: 444 return ".".join(str(p) for p in path) 445 446 447 class FieldPath: 448 def __init__(self, path: List[Any], cls_or_instance: Union[Type, Any], is_mapping: bool = False): 449 self._path = path 450 self._cls: Type 451 self._instance: Any 452 if is_union_type(cls_or_instance): 453 cls_or_instance = get_args(cls_or_instance)[0] 454 if isinstance(cls_or_instance, type): 455 self._cls = cls_or_instance 456 self._instance = None 457 else: 458 self._cls = type(cls_or_instance) 459 self._instance = cls_or_instance 460 self._is_mapping = is_mapping 461 462 @property 463 def has_instance(self): 464 return self._instance is not None 465 466 def list_fields(self) -> List[str]: 467 if self.has_instance and self._is_mapping and isinstance(self._instance, dict): 468 return list(self._instance.keys()) 469 if isinstance(self._cls, type) and issubclass(self._cls, BaseModel): 470 return list(self._cls.__fields__) 471 return [] 472 473 def __getattr__(self, item) -> "FieldPath": 474 return self.child(item) 475 476 def child(self, item: str) -> "FieldPath": 477 if self._is_mapping: 478 if self.has_instance and isinstance(self._instance, dict): 479 return FieldPath(self._path + [item], self._instance[item]) 480 return FieldPath(self._path + [item], self._cls) 481 if not issubclass(self._cls, BaseModel): 482 raise AttributeError(f"{self._cls} does not have fields") 483 if item not in self._cls.__fields__: 484 raise AttributeError(f"{self._cls} type does not have '{item}' field") 485 field = self._cls.__fields__[item] 486 field_value = field.type_ 487 is_mapping = field.shape == SHAPE_DICT 488 if self.has_instance: 489 field_value = getattr(self._instance, item) 490 if is_mapping: 491 return FieldPath(self._path + [item], field_value, is_mapping=True) 492 return FieldPath(self._path + [item], field_value, is_mapping=is_mapping) 493 494 def list_nested_fields(self, exclude: Set["IncludeTags"] = None) -> List[str]: 495 if not isinstance(self._cls, type) or not issubclass(self._cls, BaseModel): 496 return [repr(self)] 497 res = [] 498 for name, field in self._cls.__fields__.items(): 499 field_value = field.type_ 500 # todo: do something with recursive imports 501 from evidently.legacy.core import get_field_tags 502 503 field_tags = get_field_tags(self._cls, name) 504 if field_tags is not None and (exclude is not None and any(t in exclude for t in field_tags)): 505 continue 506 is_mapping = field.shape == SHAPE_DICT 507 if self.has_instance: 508 field_value = getattr(self._instance, name) 509 if is_mapping and isinstance(field_value, dict): 510 for key, value in field_value.items(): 511 res.extend(FieldPath(self._path + [name, str(key)], value).list_nested_fields(exclude=exclude)) 512 continue 513 else: 514 if is_mapping: 515 name = f"{name}.*" 516 res.extend(FieldPath(self._path + [name], field_value).list_nested_fields(exclude=exclude)) 517 return res 518 519 def _list_with_tags(self, current_tags: Set["IncludeTags"]) -> List[Tuple[List[Any], Set["IncludeTags"]]]: 520 if not isinstance(self._cls, type) or not issubclass(self._cls, BaseModel): 521 return [(self._path, current_tags)] 522 from evidently.legacy.core import BaseResult 523 524 if issubclass(self._cls, BaseResult) and self._cls.__config__.extract_as_obj: 525 return [(self._path, current_tags)] 526 res = [] 527 from evidently.ui.backport import ByLabelCountValueV1 528 from evidently.ui.backport import ByLabelValueV1 529 530 if issubclass(self._cls, ByLabelValueV1): 531 res.append((self._path + ["values"], current_tags.union({IncludeTags.Render}))) 532 if issubclass(self._cls, ByLabelCountValueV1): 533 res.append((self._path + ["counts"], current_tags.union({IncludeTags.Render}))) 534 res.append((self._path + ["shares"], current_tags.union({IncludeTags.Render}))) 535 for name, field in self._cls.__fields__.items(): 536 field_value = field.type_ 537 538 # todo: do something with recursive imports 539 from evidently.legacy.core import get_field_tags 540 541 field_tags = get_field_tags(self._cls, name) 542 543 is_mapping = field.shape == SHAPE_DICT 544 if self.has_instance: 545 field_value = getattr(self._instance, name) 546 if is_mapping and isinstance(field_value, dict): 547 for key, value in field_value.items(): 548 res.extend( 549 FieldPath(self._path + [name, key], value)._list_with_tags(current_tags.union(field_tags)) 550 ) 551 continue 552 else: 553 if is_mapping: 554 name = f"{name}.*" 555 res.extend(FieldPath(self._path + [name], field_value)._list_with_tags(current_tags.union(field_tags))) 556 return res 557 558 def list_nested_fields_with_tags(self) -> List[Tuple[str, Set["IncludeTags"]]]: 559 return [(_to_path(path), tags) for path, tags in self._list_with_tags(set())] 560 561 def list_nested_field_infos(self) -> List[FieldInfo]: 562 return [ 563 FieldInfo(path=_to_path(path), tags=frozenset(tags), classpath=get_classpath(self._get_field_type(path))) 564 for path, tags in self._list_with_tags(set()) 565 ] 566 567 def _get_field_type(self, path: List[str]) -> Type: 568 if len(path) == 0: 569 raise ValueError("Empty path provided") 570 if len(path) == 1: 571 if isinstance(self._cls, type) and issubclass(self._cls, BaseModel): 572 return self._cls.__fields__[path[0]].outer_type_ 573 if self.has_instance: 574 # fixme: tmp fix 575 # in case of field like f: Dict[str, A] we wont know that value was type annotated with A when we get to it 576 if isinstance(self._instance, dict): 577 return type(self._instance.get(path[0])) 578 raise NotImplementedError(f"Not implemented for {self._cls.__name__}") 579 child, *path = path 580 return self.child(child)._get_field_type(path) 581 582 def __repr__(self): 583 return self.get_path() 584 585 def get_path(self): 586 return ".".join(self._path) 587 588 def __dir__(self) -> Iterable[str]: 589 res: List[str] = [] 590 res.extend(super().__dir__()) 591 res.extend(self.list_fields()) 592 return res 593 594 def get_field_tags(self, path: List[str]) -> Optional[Set["IncludeTags"]]: 595 from evidently.legacy.base_metric import BaseResult 596 597 if not isinstance(self._cls, type) or not issubclass(self._cls, BaseResult): 598 return None 599 self_tags = self._cls.__config__.tags 600 if len(path) == 0: 601 return self_tags 602 field_name, *path = path 603 # todo: do something with recursive imports 604 from evidently.legacy.core import get_field_tags 605 606 field_tags = get_field_tags(self._cls, field_name) 607 return self_tags.union(field_tags).union(self.child(field_name).get_field_tags(path) or tuple()) 608 609 610 @pydantic_type_validator(FieldPath) 611 def series_validator(value): 612 return value.get_path() 613 614 615 def get_object_hash_deprecated(obj: Union[BaseModel, dict]): 616 from evidently.legacy.utils import NumpyEncoder 617 618 if isinstance(obj, BaseModel): 619 obj = obj.dict() 620 return hashlib.md5(json.dumps(obj, cls=NumpyEncoder).encode("utf8"), **md5_kwargs).hexdigest() # nosec: B324 621 622 623 class AutoAliasMixin: 624 __alias_type__: ClassVar[str] 625 626 @classmethod 627 def __get_type__(cls) -> str: 628 config = cls.__dict__.get("Config") 629 if config is not None and config.__dict__.get("type_alias") is not None: 630 return config.type_alias 631 return f"evidently:{cls.__alias_type__}:{cls.__name__}"