test_pydantic_utils.py
1 from abc import ABC 2 from typing import Dict 3 from typing import Optional 4 from typing import Union 5 6 import pytest 7 8 from evidently._pydantic_compat import ValidationError 9 from evidently._pydantic_compat import parse_obj_as 10 from evidently.legacy.base_metric import Metric 11 from evidently.legacy.base_metric import MetricResult 12 from evidently.legacy.core import IncludeTags 13 from evidently.legacy.core import get_all_fields_tags 14 from evidently.pydantic_utils import ALLOWED_TYPE_PREFIXES 15 from evidently.pydantic_utils import EvidentlyBaseModel 16 from evidently.pydantic_utils import FieldPath 17 from evidently.pydantic_utils import PolymorphicModel 18 19 20 class MockMetricResultField(MetricResult): 21 class Config: 22 alias_required = False 23 24 nested_field: str 25 26 27 class ExtendedMockMetricResultField(MockMetricResultField): 28 class Config: 29 alias_required = False 30 31 additional_field: str 32 33 34 class MockMetricResult(MetricResult): 35 class Config: 36 alias_required = False 37 38 field1: MockMetricResultField 39 field2: int 40 41 42 def _metric_with_result(result: MetricResult): 43 class MockMetric(Metric): 44 class Config: 45 alias_required = False 46 47 def get_result(self): 48 return result 49 50 def calculate(self, data): 51 pass 52 53 return MockMetric() 54 55 56 def test_field_path(): 57 assert MockMetricResult.fields.list_fields() == ["type", "field1", "field2"] 58 assert MockMetricResult.fields.field1.list_fields() == ["type", "nested_field"] 59 assert MockMetricResult.fields.list_nested_fields() == ["type", "field1.type", "field1.nested_field", "field2"] 60 61 with pytest.raises(AttributeError): 62 _ = MockMetricResult.fields.field3 63 64 metric_result = MockMetricResult(field1=MockMetricResultField(nested_field="1"), field2=1) 65 metric = _metric_with_result(metric_result) 66 67 assert metric.fields.list_fields() == ["type", "field1", "field2"] 68 assert metric.fields.field1.list_fields() == ["type", "nested_field"] 69 assert metric.fields.list_nested_fields() == ["type", "field1.type", "field1.nested_field", "field2"] 70 71 metric_result = MockMetricResult( 72 field1=ExtendedMockMetricResultField(nested_field="1", additional_field="2"), field2=1 73 ) 74 metric = _metric_with_result(metric_result) 75 76 assert metric.fields.list_fields() == ["type", "field1", "field2"] 77 assert metric.fields.field1.list_fields() == ["type", "nested_field", "additional_field"] 78 assert metric.fields.list_nested_fields() == [ 79 "type", 80 "field1.type", 81 "field1.nested_field", 82 "field1.additional_field", 83 "field2", 84 ] 85 86 87 class MockMetricResultWithDict(MetricResult): 88 class Config: 89 alias_required = False 90 91 d: Dict[str, MockMetricResultField] 92 93 94 def test_field_path_with_dict(): 95 assert MockMetricResultWithDict.fields.list_fields() == ["type", "d"] 96 assert MockMetricResultWithDict.fields.list_nested_fields() == ["type", "d.*.type", "d.*.nested_field"] 97 assert MockMetricResultWithDict.fields.d.lol.list_fields() == ["type", "nested_field"] 98 assert str(MockMetricResultWithDict.fields.d.lol.nested_field) == "d.lol.nested_field" 99 100 metric_result = MockMetricResultWithDict(d={"a": MockMetricResultField(nested_field="1")}) 101 metric = _metric_with_result(metric_result) 102 103 assert metric.fields.list_fields() == ["type", "d"] 104 assert metric.fields.list_nested_fields() == ["type", "d.a.type", "d.a.nested_field"] 105 assert metric.fields.d.a.list_fields() == ["type", "nested_field"] 106 assert metric.fields.d.list_fields() == ["a"] 107 108 metric_result = MockMetricResultWithDict( 109 d={ 110 "a": MockMetricResultField(nested_field="1"), 111 "b": ExtendedMockMetricResultField(nested_field="1", additional_field="2"), 112 } 113 ) 114 metric = _metric_with_result(metric_result) 115 116 assert metric.fields.list_fields() == ["type", "d"] 117 assert metric.fields.list_nested_fields() == [ 118 "type", 119 "d.a.type", 120 "d.a.nested_field", 121 "d.b.type", 122 "d.b.nested_field", 123 "d.b.additional_field", 124 ] 125 assert metric.fields.d.a.list_fields() == ["type", "nested_field"] 126 assert metric.fields.d.list_fields() == ["a", "b"] 127 128 129 def test_not_allowed_prefix(): 130 class SomeModel(PolymorphicModel): 131 class Config: 132 alias_required = False 133 134 with pytest.raises(ValueError): 135 parse_obj_as(SomeModel, {"type": "external.Class"}) 136 137 138 def test_type_alias(): 139 class SomeModel(PolymorphicModel): 140 class Config: 141 type_alias = "somemodel" 142 alias_required = False 143 144 class SomeModelSubclass(SomeModel): 145 pass 146 147 class SomeOtherSubclass(SomeModel): 148 class Config: 149 type_alias = "othersubclass" 150 151 obj = parse_obj_as(SomeModel, {"type": "somemodel"}) 152 assert obj.__class__ == SomeModel 153 154 obj = parse_obj_as(SomeModel, {"type": SomeModelSubclass.__get_type__()}) 155 assert obj.__class__ == SomeModelSubclass 156 157 obj = parse_obj_as(SomeModel, {"type": "othersubclass"}) 158 assert obj.__class__ == SomeOtherSubclass 159 160 161 def test_include_exclude(): 162 class SomeModel(MetricResult): 163 class Config: 164 field_tags = {"f1": {IncludeTags.Render}} 165 alias_required = False 166 167 f1: str 168 f2: str 169 170 assert SomeModel.fields.list_nested_fields(exclude={IncludeTags.Render, IncludeTags.TypeField}) == ["f2"] 171 172 # assert SomeModel.fields.list_nested_fields(include={IncludeTags.Render}) == ["f1"] 173 174 class SomeNestedModel(MetricResult): 175 class Config: 176 tags = {IncludeTags.Render} 177 alias_required = False 178 179 f1: str 180 181 class SomeOtherModel(MetricResult): 182 class Config: 183 alias_required = False 184 185 f1: str 186 f2: SomeNestedModel 187 f3: SomeModel 188 189 assert SomeOtherModel.fields.list_nested_fields(exclude={IncludeTags.Render, IncludeTags.TypeField}) == [ 190 "f1", 191 "f3.f2", 192 ] 193 # assert SomeOtherModel.fields.list_nested_fields(include={IncludeTags.Render}) == ["f2.f1", "f3.f1"] 194 195 196 def test_get_field_tags(): 197 class SomeModel(MetricResult): 198 class Config: 199 field_tags = {"f1": {IncludeTags.Render}} 200 alias_required = False 201 202 f1: str 203 f2: str 204 205 assert SomeModel.fields.get_field_tags(["type"]) == {IncludeTags.TypeField} 206 assert SomeModel.fields.get_field_tags(["f1"]) == {IncludeTags.Render} 207 assert SomeModel.fields.get_field_tags(["f2"]) == set() 208 209 class SomeNestedModel(MetricResult): 210 class Config: 211 tags = {IncludeTags.Render} 212 alias_required = False 213 214 f1: str 215 216 class SomeOtherModel(MetricResult): 217 class Config: 218 alias_required = False 219 220 f1: str 221 f2: SomeNestedModel 222 f3: SomeModel 223 224 assert SomeOtherModel.fields.get_field_tags(["type"]) == {IncludeTags.TypeField} 225 assert SomeOtherModel.fields.get_field_tags(["f1"]) == set() 226 assert SomeOtherModel.fields.get_field_tags(["f2"]) == {IncludeTags.Render} 227 assert SomeOtherModel.fields.get_field_tags(["f2", "f1"]) == {IncludeTags.Render} 228 assert SomeOtherModel.fields.get_field_tags(["f3"]) == set() 229 assert SomeOtherModel.fields.get_field_tags(["f3", "f1"]) == {IncludeTags.Render} 230 assert SomeOtherModel.fields.get_field_tags(["f3", "f2"]) == set() 231 232 233 def test_list_with_tags(): 234 class SomeModel(MetricResult): 235 class Config: 236 field_tags = {"f1": {IncludeTags.Render}} 237 alias_required = False 238 239 f1: str 240 f2: str 241 242 assert SomeModel.fields.list_nested_fields_with_tags() == [ 243 ("type", {IncludeTags.TypeField}), 244 ("f1", {IncludeTags.Render}), 245 ("f2", set()), 246 ] 247 248 class SomeNestedModel(MetricResult): 249 class Config: 250 tags = {IncludeTags.Render} 251 alias_required = False 252 253 f1: str 254 255 class SomeOtherModel(MetricResult): 256 class Config: 257 alias_required = False 258 259 f1: str 260 f2: SomeNestedModel 261 f3: SomeModel 262 263 assert SomeOtherModel.fields.list_nested_fields_with_tags() == [ 264 ("type", {IncludeTags.TypeField}), 265 ("f1", set()), 266 ("f2.type", {IncludeTags.Render, IncludeTags.TypeField}), 267 ("f2.f1", {IncludeTags.Render}), 268 ("f3.type", {IncludeTags.TypeField}), 269 ("f3.f1", {IncludeTags.Render}), 270 ("f3.f2", set()), 271 ] 272 273 274 def test_list_with_tags_with_union(): 275 class A(MetricResult): 276 class Config: 277 tags = {IncludeTags.Render} 278 alias_required = False 279 280 f1: str 281 282 class B(MetricResult): 283 class Config: 284 tags = {IncludeTags.Render} 285 alias_required = False 286 287 f1: str 288 289 fp = FieldPath([], Union[A, B]) 290 assert not fp.has_instance 291 assert fp._cls == A 292 293 class SomeModel(MetricResult): 294 class Config: 295 alias_required = False 296 297 f2: Union[A, B] 298 f1: str 299 300 assert list(sorted(SomeModel.fields.list_nested_fields_with_tags())) == [ 301 ("f1", set()), 302 ("f2.f1", {IncludeTags.Render}), 303 ("f2.type", {IncludeTags.Render, IncludeTags.TypeField}), 304 ("type", {IncludeTags.TypeField}), 305 ] 306 307 308 def test_get_field_tags_no_overwrite(): 309 class A(MetricResult): 310 class Config: 311 field_tags = {"f": {IncludeTags.Current}} 312 alias_required = False 313 314 f: str 315 316 class B(A): 317 class Config: 318 tags = {IncludeTags.Reference} 319 320 class C(MetricResult): 321 class Config: 322 field_tags = {"f": {IncludeTags.Reference}} 323 alias_required = False 324 325 f: A 326 327 assert A.fields.get_field_tags("f") == {IncludeTags.Current} 328 assert B.fields.get_field_tags("f") == {IncludeTags.Current, IncludeTags.Reference} 329 assert C.fields.get_field_tags(["f", "f"]) == {IncludeTags.Current, IncludeTags.Reference} 330 B.fields.list_nested_fields_with_tags() 331 C.fields.list_nested_fields_with_tags() 332 get_all_fields_tags(B) 333 get_all_fields_tags(C) 334 assert A.fields.get_field_tags("f") == {IncludeTags.Current} 335 336 337 def test_fingerprint_add_new_default_field(): 338 class A(EvidentlyBaseModel): 339 class Config: 340 alias_required = False 341 342 field1: str 343 344 f1 = A(field1="123").get_fingerprint() 345 346 class A(EvidentlyBaseModel): 347 class Config: 348 alias_required = False 349 350 field1: str 351 field2: str = "321" 352 353 f2 = A(field1="123").get_fingerprint() 354 355 assert f2 == f1 356 assert A(field1="123", field2="123").get_fingerprint() != f1 357 358 359 def test_fingerprint_reorder_fields(): 360 class A(EvidentlyBaseModel): 361 class Config: 362 alias_required = False 363 364 field1: str 365 field2: str 366 367 f1 = A(field1="123", field2="321").get_fingerprint() 368 369 class A(EvidentlyBaseModel): 370 class Config: 371 alias_required = False 372 373 field2: str 374 field1: str 375 376 f2 = A(field1="123", field2="321").get_fingerprint() 377 378 assert f2 == f1 379 assert A(field1="123", field2="123").get_fingerprint() != f1 380 381 382 def test_fingerprint_default_collision(): 383 class A(EvidentlyBaseModel): 384 class Config: 385 alias_required = False 386 387 field1: Optional[str] = None 388 field2: Optional[str] = None 389 390 assert A(field1="a").get_fingerprint() != A(field2="a").get_fingerprint() 391 392 393 def test_wrong_classpath(): 394 class WrongClassPath(EvidentlyBaseModel): 395 class Config: 396 alias_required = False 397 398 f: str 399 400 ALLOWED_TYPE_PREFIXES.append("tests.") 401 a = WrongClassPath(f="asd") 402 assert parse_obj_as(WrongClassPath, a.dict()) == a 403 d = a.dict() 404 d["type"] += "_" 405 with pytest.raises(ValidationError): 406 parse_obj_as(WrongClassPath, d) 407 408 409 def test_alias_requied(): 410 class RequiredAlias(PolymorphicModel, ABC): 411 class Config: 412 alias_required = True 413 414 with pytest.raises(ValueError): 415 416 class NoAlias(RequiredAlias): 417 pass 418 419 class Alias(RequiredAlias): 420 class Config: 421 type_alias = "alias"