/ tests / utils / test_pydantic_utils.py
test_pydantic_utils.py
  1  from abc import ABC
  2  from typing import Dict
  3  from typing import Optional
  4  from typing import Union
  5  
  6  import pytest
  7  
  8  from evidently._pydantic_compat import ValidationError
  9  from evidently._pydantic_compat import parse_obj_as
 10  from evidently.legacy.base_metric import Metric
 11  from evidently.legacy.base_metric import MetricResult
 12  from evidently.legacy.core import IncludeTags
 13  from evidently.legacy.core import get_all_fields_tags
 14  from evidently.pydantic_utils import ALLOWED_TYPE_PREFIXES
 15  from evidently.pydantic_utils import EvidentlyBaseModel
 16  from evidently.pydantic_utils import FieldPath
 17  from evidently.pydantic_utils import PolymorphicModel
 18  
 19  
 20  class MockMetricResultField(MetricResult):
 21      class Config:
 22          alias_required = False
 23  
 24      nested_field: str
 25  
 26  
 27  class ExtendedMockMetricResultField(MockMetricResultField):
 28      class Config:
 29          alias_required = False
 30  
 31      additional_field: str
 32  
 33  
 34  class MockMetricResult(MetricResult):
 35      class Config:
 36          alias_required = False
 37  
 38      field1: MockMetricResultField
 39      field2: int
 40  
 41  
 42  def _metric_with_result(result: MetricResult):
 43      class MockMetric(Metric):
 44          class Config:
 45              alias_required = False
 46  
 47          def get_result(self):
 48              return result
 49  
 50          def calculate(self, data):
 51              pass
 52  
 53      return MockMetric()
 54  
 55  
 56  def test_field_path():
 57      assert MockMetricResult.fields.list_fields() == ["type", "field1", "field2"]
 58      assert MockMetricResult.fields.field1.list_fields() == ["type", "nested_field"]
 59      assert MockMetricResult.fields.list_nested_fields() == ["type", "field1.type", "field1.nested_field", "field2"]
 60  
 61      with pytest.raises(AttributeError):
 62          _ = MockMetricResult.fields.field3
 63  
 64      metric_result = MockMetricResult(field1=MockMetricResultField(nested_field="1"), field2=1)
 65      metric = _metric_with_result(metric_result)
 66  
 67      assert metric.fields.list_fields() == ["type", "field1", "field2"]
 68      assert metric.fields.field1.list_fields() == ["type", "nested_field"]
 69      assert metric.fields.list_nested_fields() == ["type", "field1.type", "field1.nested_field", "field2"]
 70  
 71      metric_result = MockMetricResult(
 72          field1=ExtendedMockMetricResultField(nested_field="1", additional_field="2"), field2=1
 73      )
 74      metric = _metric_with_result(metric_result)
 75  
 76      assert metric.fields.list_fields() == ["type", "field1", "field2"]
 77      assert metric.fields.field1.list_fields() == ["type", "nested_field", "additional_field"]
 78      assert metric.fields.list_nested_fields() == [
 79          "type",
 80          "field1.type",
 81          "field1.nested_field",
 82          "field1.additional_field",
 83          "field2",
 84      ]
 85  
 86  
 87  class MockMetricResultWithDict(MetricResult):
 88      class Config:
 89          alias_required = False
 90  
 91      d: Dict[str, MockMetricResultField]
 92  
 93  
 94  def test_field_path_with_dict():
 95      assert MockMetricResultWithDict.fields.list_fields() == ["type", "d"]
 96      assert MockMetricResultWithDict.fields.list_nested_fields() == ["type", "d.*.type", "d.*.nested_field"]
 97      assert MockMetricResultWithDict.fields.d.lol.list_fields() == ["type", "nested_field"]
 98      assert str(MockMetricResultWithDict.fields.d.lol.nested_field) == "d.lol.nested_field"
 99  
100      metric_result = MockMetricResultWithDict(d={"a": MockMetricResultField(nested_field="1")})
101      metric = _metric_with_result(metric_result)
102  
103      assert metric.fields.list_fields() == ["type", "d"]
104      assert metric.fields.list_nested_fields() == ["type", "d.a.type", "d.a.nested_field"]
105      assert metric.fields.d.a.list_fields() == ["type", "nested_field"]
106      assert metric.fields.d.list_fields() == ["a"]
107  
108      metric_result = MockMetricResultWithDict(
109          d={
110              "a": MockMetricResultField(nested_field="1"),
111              "b": ExtendedMockMetricResultField(nested_field="1", additional_field="2"),
112          }
113      )
114      metric = _metric_with_result(metric_result)
115  
116      assert metric.fields.list_fields() == ["type", "d"]
117      assert metric.fields.list_nested_fields() == [
118          "type",
119          "d.a.type",
120          "d.a.nested_field",
121          "d.b.type",
122          "d.b.nested_field",
123          "d.b.additional_field",
124      ]
125      assert metric.fields.d.a.list_fields() == ["type", "nested_field"]
126      assert metric.fields.d.list_fields() == ["a", "b"]
127  
128  
129  def test_not_allowed_prefix():
130      class SomeModel(PolymorphicModel):
131          class Config:
132              alias_required = False
133  
134      with pytest.raises(ValueError):
135          parse_obj_as(SomeModel, {"type": "external.Class"})
136  
137  
138  def test_type_alias():
139      class SomeModel(PolymorphicModel):
140          class Config:
141              type_alias = "somemodel"
142              alias_required = False
143  
144      class SomeModelSubclass(SomeModel):
145          pass
146  
147      class SomeOtherSubclass(SomeModel):
148          class Config:
149              type_alias = "othersubclass"
150  
151      obj = parse_obj_as(SomeModel, {"type": "somemodel"})
152      assert obj.__class__ == SomeModel
153  
154      obj = parse_obj_as(SomeModel, {"type": SomeModelSubclass.__get_type__()})
155      assert obj.__class__ == SomeModelSubclass
156  
157      obj = parse_obj_as(SomeModel, {"type": "othersubclass"})
158      assert obj.__class__ == SomeOtherSubclass
159  
160  
161  def test_include_exclude():
162      class SomeModel(MetricResult):
163          class Config:
164              field_tags = {"f1": {IncludeTags.Render}}
165              alias_required = False
166  
167          f1: str
168          f2: str
169  
170      assert SomeModel.fields.list_nested_fields(exclude={IncludeTags.Render, IncludeTags.TypeField}) == ["f2"]
171  
172      # assert SomeModel.fields.list_nested_fields(include={IncludeTags.Render}) == ["f1"]
173  
174      class SomeNestedModel(MetricResult):
175          class Config:
176              tags = {IncludeTags.Render}
177              alias_required = False
178  
179          f1: str
180  
181      class SomeOtherModel(MetricResult):
182          class Config:
183              alias_required = False
184  
185          f1: str
186          f2: SomeNestedModel
187          f3: SomeModel
188  
189      assert SomeOtherModel.fields.list_nested_fields(exclude={IncludeTags.Render, IncludeTags.TypeField}) == [
190          "f1",
191          "f3.f2",
192      ]
193      # assert SomeOtherModel.fields.list_nested_fields(include={IncludeTags.Render}) == ["f2.f1", "f3.f1"]
194  
195  
196  def test_get_field_tags():
197      class SomeModel(MetricResult):
198          class Config:
199              field_tags = {"f1": {IncludeTags.Render}}
200              alias_required = False
201  
202          f1: str
203          f2: str
204  
205      assert SomeModel.fields.get_field_tags(["type"]) == {IncludeTags.TypeField}
206      assert SomeModel.fields.get_field_tags(["f1"]) == {IncludeTags.Render}
207      assert SomeModel.fields.get_field_tags(["f2"]) == set()
208  
209      class SomeNestedModel(MetricResult):
210          class Config:
211              tags = {IncludeTags.Render}
212              alias_required = False
213  
214          f1: str
215  
216      class SomeOtherModel(MetricResult):
217          class Config:
218              alias_required = False
219  
220          f1: str
221          f2: SomeNestedModel
222          f3: SomeModel
223  
224      assert SomeOtherModel.fields.get_field_tags(["type"]) == {IncludeTags.TypeField}
225      assert SomeOtherModel.fields.get_field_tags(["f1"]) == set()
226      assert SomeOtherModel.fields.get_field_tags(["f2"]) == {IncludeTags.Render}
227      assert SomeOtherModel.fields.get_field_tags(["f2", "f1"]) == {IncludeTags.Render}
228      assert SomeOtherModel.fields.get_field_tags(["f3"]) == set()
229      assert SomeOtherModel.fields.get_field_tags(["f3", "f1"]) == {IncludeTags.Render}
230      assert SomeOtherModel.fields.get_field_tags(["f3", "f2"]) == set()
231  
232  
233  def test_list_with_tags():
234      class SomeModel(MetricResult):
235          class Config:
236              field_tags = {"f1": {IncludeTags.Render}}
237              alias_required = False
238  
239          f1: str
240          f2: str
241  
242      assert SomeModel.fields.list_nested_fields_with_tags() == [
243          ("type", {IncludeTags.TypeField}),
244          ("f1", {IncludeTags.Render}),
245          ("f2", set()),
246      ]
247  
248      class SomeNestedModel(MetricResult):
249          class Config:
250              tags = {IncludeTags.Render}
251              alias_required = False
252  
253          f1: str
254  
255      class SomeOtherModel(MetricResult):
256          class Config:
257              alias_required = False
258  
259          f1: str
260          f2: SomeNestedModel
261          f3: SomeModel
262  
263      assert SomeOtherModel.fields.list_nested_fields_with_tags() == [
264          ("type", {IncludeTags.TypeField}),
265          ("f1", set()),
266          ("f2.type", {IncludeTags.Render, IncludeTags.TypeField}),
267          ("f2.f1", {IncludeTags.Render}),
268          ("f3.type", {IncludeTags.TypeField}),
269          ("f3.f1", {IncludeTags.Render}),
270          ("f3.f2", set()),
271      ]
272  
273  
274  def test_list_with_tags_with_union():
275      class A(MetricResult):
276          class Config:
277              tags = {IncludeTags.Render}
278              alias_required = False
279  
280          f1: str
281  
282      class B(MetricResult):
283          class Config:
284              tags = {IncludeTags.Render}
285              alias_required = False
286  
287          f1: str
288  
289      fp = FieldPath([], Union[A, B])
290      assert not fp.has_instance
291      assert fp._cls == A
292  
293      class SomeModel(MetricResult):
294          class Config:
295              alias_required = False
296  
297          f2: Union[A, B]
298          f1: str
299  
300      assert list(sorted(SomeModel.fields.list_nested_fields_with_tags())) == [
301          ("f1", set()),
302          ("f2.f1", {IncludeTags.Render}),
303          ("f2.type", {IncludeTags.Render, IncludeTags.TypeField}),
304          ("type", {IncludeTags.TypeField}),
305      ]
306  
307  
308  def test_get_field_tags_no_overwrite():
309      class A(MetricResult):
310          class Config:
311              field_tags = {"f": {IncludeTags.Current}}
312              alias_required = False
313  
314          f: str
315  
316      class B(A):
317          class Config:
318              tags = {IncludeTags.Reference}
319  
320      class C(MetricResult):
321          class Config:
322              field_tags = {"f": {IncludeTags.Reference}}
323              alias_required = False
324  
325          f: A
326  
327      assert A.fields.get_field_tags("f") == {IncludeTags.Current}
328      assert B.fields.get_field_tags("f") == {IncludeTags.Current, IncludeTags.Reference}
329      assert C.fields.get_field_tags(["f", "f"]) == {IncludeTags.Current, IncludeTags.Reference}
330      B.fields.list_nested_fields_with_tags()
331      C.fields.list_nested_fields_with_tags()
332      get_all_fields_tags(B)
333      get_all_fields_tags(C)
334      assert A.fields.get_field_tags("f") == {IncludeTags.Current}
335  
336  
337  def test_fingerprint_add_new_default_field():
338      class A(EvidentlyBaseModel):
339          class Config:
340              alias_required = False
341  
342          field1: str
343  
344      f1 = A(field1="123").get_fingerprint()
345  
346      class A(EvidentlyBaseModel):
347          class Config:
348              alias_required = False
349  
350          field1: str
351          field2: str = "321"
352  
353      f2 = A(field1="123").get_fingerprint()
354  
355      assert f2 == f1
356      assert A(field1="123", field2="123").get_fingerprint() != f1
357  
358  
359  def test_fingerprint_reorder_fields():
360      class A(EvidentlyBaseModel):
361          class Config:
362              alias_required = False
363  
364          field1: str
365          field2: str
366  
367      f1 = A(field1="123", field2="321").get_fingerprint()
368  
369      class A(EvidentlyBaseModel):
370          class Config:
371              alias_required = False
372  
373          field2: str
374          field1: str
375  
376      f2 = A(field1="123", field2="321").get_fingerprint()
377  
378      assert f2 == f1
379      assert A(field1="123", field2="123").get_fingerprint() != f1
380  
381  
382  def test_fingerprint_default_collision():
383      class A(EvidentlyBaseModel):
384          class Config:
385              alias_required = False
386  
387          field1: Optional[str] = None
388          field2: Optional[str] = None
389  
390      assert A(field1="a").get_fingerprint() != A(field2="a").get_fingerprint()
391  
392  
393  def test_wrong_classpath():
394      class WrongClassPath(EvidentlyBaseModel):
395          class Config:
396              alias_required = False
397  
398          f: str
399  
400      ALLOWED_TYPE_PREFIXES.append("tests.")
401      a = WrongClassPath(f="asd")
402      assert parse_obj_as(WrongClassPath, a.dict()) == a
403      d = a.dict()
404      d["type"] += "_"
405      with pytest.raises(ValidationError):
406          parse_obj_as(WrongClassPath, d)
407  
408  
409  def test_alias_requied():
410      class RequiredAlias(PolymorphicModel, ABC):
411          class Config:
412              alias_required = True
413  
414      with pytest.raises(ValueError):
415  
416          class NoAlias(RequiredAlias):
417              pass
418  
419      class Alias(RequiredAlias):
420          class Config:
421              type_alias = "alias"