/ tests / future / descriptors / test_conditions.py
test_conditions.py
  1  import json
  2  from inspect import isabstract
  3  from typing import List
  4  from typing import Tuple
  5  
  6  import pandas as pd
  7  import pytest
  8  
  9  from evidently._pydantic_compat import parse_obj_as
 10  from evidently.core.datasets import ColumnCondition
 11  from evidently.core.datasets import ColumnTest
 12  from evidently.core.datasets import Dataset
 13  from evidently.tests.descriptors import EqualsColumnCondition
 14  from evidently.tests.descriptors import GreaterColumnCondition
 15  from evidently.tests.descriptors import GreaterEqualColumnCondition
 16  from evidently.tests.descriptors import IsInColumnCondition
 17  from evidently.tests.descriptors import IsNotInColumnCondition
 18  from evidently.tests.descriptors import LessColumnCondition
 19  from evidently.tests.descriptors import LessEqualColumnCondition
 20  from evidently.tests.descriptors import NotEqualsColumnCondition
 21  from tests.conftest import load_all_subtypes
 22  
 23  all_conditions: List[Tuple[ColumnCondition, pd.Series, str, pd.Series]] = [
 24      (
 25          GreaterEqualColumnCondition(threshold=2),
 26          pd.Series([1, 2, 3], name="input"),
 27          "input: greater or equal to 2.0",
 28          pd.Series([False, True, True]),
 29      ),
 30      (
 31          GreaterColumnCondition(threshold=2),
 32          pd.Series([1, 2, 3], name="input"),
 33          "input greater than 2.0",
 34          pd.Series([False, False, True]),
 35      ),
 36      (
 37          LessEqualColumnCondition(threshold=2),
 38          pd.Series([1, 2, 3], name="input"),
 39          "input: less or equal to 2.0",
 40          pd.Series([True, True, False]),
 41      ),
 42      (
 43          LessColumnCondition(threshold=2),
 44          pd.Series([1, 2, 3], name="input"),
 45          "input: less than 2.0",
 46          pd.Series([True, False, False]),
 47      ),
 48      (
 49          EqualsColumnCondition(expected=2),
 50          pd.Series([1, 2, 3], name="input"),
 51          "input: equals 2",
 52          pd.Series([False, True, False]),
 53      ),
 54      (
 55          NotEqualsColumnCondition(expected=2),
 56          pd.Series([1, 2, 3], name="input"),
 57          "input not equals 2",
 58          pd.Series([True, False, True]),
 59      ),
 60      (
 61          IsNotInColumnCondition(values={1}),
 62          pd.Series([1, 2, 3], name="input"),
 63          "input not in list {1}",
 64          pd.Series([False, True, True]),
 65      ),
 66      (
 67          IsInColumnCondition(values={2, 3}),
 68          pd.Series([1, 2, 3], name="input"),
 69          "input in list {2, 3}",
 70          pd.Series([False, True, True]),
 71      ),
 72  ]
 73  
 74  
 75  def test_all_conditions_tested():
 76      tested_cond_set = {type(p) for p, _, _, _ in all_conditions}
 77      load_all_subtypes(ColumnCondition)
 78      all_cond_types = set(s for s in ColumnCondition.__subclasses__() if not isabstract(s))
 79      assert tested_cond_set == all_cond_types, "Missing tests for conditions " + ", ".join(
 80          f'({t.__name__}(), pd.Series([], name="input"), "input_test", pd.Series([]))'
 81          for t in all_cond_types - tested_cond_set
 82      )
 83  
 84  
 85  @pytest.mark.parametrize("condition,input,column_name,result", all_conditions)
 86  def test_conditions(condition: ColumnCondition, input: pd.Series, column_name: str, result: pd.Series):
 87      df = pd.DataFrame(input)
 88  
 89      dataset = Dataset.from_pandas(df)
 90      dataset.add_descriptor(ColumnTest(str(input.name), condition))
 91      res_df = dataset.as_dataframe()
 92  
 93      assert column_name in res_df.columns, f"Wrong column name {column_name}, actual: {res_df.columns}"
 94      values = res_df[column_name].tolist()
 95      assert values == result.tolist()
 96  
 97      payload = json.loads(condition.json())
 98      condition2 = parse_obj_as(ColumnCondition, payload)
 99  
100      assert condition2 == condition