test_conditions.py
1 import json 2 from inspect import isabstract 3 from typing import List 4 from typing import Tuple 5 6 import pandas as pd 7 import pytest 8 9 from evidently._pydantic_compat import parse_obj_as 10 from evidently.core.datasets import ColumnCondition 11 from evidently.core.datasets import ColumnTest 12 from evidently.core.datasets import Dataset 13 from evidently.tests.descriptors import EqualsColumnCondition 14 from evidently.tests.descriptors import GreaterColumnCondition 15 from evidently.tests.descriptors import GreaterEqualColumnCondition 16 from evidently.tests.descriptors import IsInColumnCondition 17 from evidently.tests.descriptors import IsNotInColumnCondition 18 from evidently.tests.descriptors import LessColumnCondition 19 from evidently.tests.descriptors import LessEqualColumnCondition 20 from evidently.tests.descriptors import NotEqualsColumnCondition 21 from tests.conftest import load_all_subtypes 22 23 all_conditions: List[Tuple[ColumnCondition, pd.Series, str, pd.Series]] = [ 24 ( 25 GreaterEqualColumnCondition(threshold=2), 26 pd.Series([1, 2, 3], name="input"), 27 "input: greater or equal to 2.0", 28 pd.Series([False, True, True]), 29 ), 30 ( 31 GreaterColumnCondition(threshold=2), 32 pd.Series([1, 2, 3], name="input"), 33 "input greater than 2.0", 34 pd.Series([False, False, True]), 35 ), 36 ( 37 LessEqualColumnCondition(threshold=2), 38 pd.Series([1, 2, 3], name="input"), 39 "input: less or equal to 2.0", 40 pd.Series([True, True, False]), 41 ), 42 ( 43 LessColumnCondition(threshold=2), 44 pd.Series([1, 2, 3], name="input"), 45 "input: less than 2.0", 46 pd.Series([True, False, False]), 47 ), 48 ( 49 EqualsColumnCondition(expected=2), 50 pd.Series([1, 2, 3], name="input"), 51 "input: equals 2", 52 pd.Series([False, True, False]), 53 ), 54 ( 55 NotEqualsColumnCondition(expected=2), 56 pd.Series([1, 2, 3], name="input"), 57 "input not equals 2", 58 pd.Series([True, False, True]), 59 ), 60 ( 61 IsNotInColumnCondition(values={1}), 62 pd.Series([1, 2, 3], name="input"), 63 "input not in list {1}", 64 pd.Series([False, True, True]), 65 ), 66 ( 67 IsInColumnCondition(values={2, 3}), 68 pd.Series([1, 2, 3], name="input"), 69 "input in list {2, 3}", 70 pd.Series([False, True, True]), 71 ), 72 ] 73 74 75 def test_all_conditions_tested(): 76 tested_cond_set = {type(p) for p, _, _, _ in all_conditions} 77 load_all_subtypes(ColumnCondition) 78 all_cond_types = set(s for s in ColumnCondition.__subclasses__() if not isabstract(s)) 79 assert tested_cond_set == all_cond_types, "Missing tests for conditions " + ", ".join( 80 f'({t.__name__}(), pd.Series([], name="input"), "input_test", pd.Series([]))' 81 for t in all_cond_types - tested_cond_set 82 ) 83 84 85 @pytest.mark.parametrize("condition,input,column_name,result", all_conditions) 86 def test_conditions(condition: ColumnCondition, input: pd.Series, column_name: str, result: pd.Series): 87 df = pd.DataFrame(input) 88 89 dataset = Dataset.from_pandas(df) 90 dataset.add_descriptor(ColumnTest(str(input.name), condition)) 91 res_df = dataset.as_dataframe() 92 93 assert column_name in res_df.columns, f"Wrong column name {column_name}, actual: {res_df.columns}" 94 values = res_df[column_name].tolist() 95 assert values == result.tolist() 96 97 payload = json.loads(condition.json()) 98 condition2 = parse_obj_as(ColumnCondition, payload) 99 100 assert condition2 == condition