test_data_drift.py
1 import pytest 2 from numpy.ma.testutils import approx 3 4 from evidently.legacy.options.data_drift import DataDriftOptions 5 6 7 @pytest.mark.parametrize( 8 "confidence,expected", 9 [ 10 (0.1, {"feature1": 0.9, "feature2": 0.9}), 11 ({"feature1": 0.1}, {"feature1": 0.9, "feature2": None}), 12 ({"feature2": 0.1}, {"feature1": None, "feature2": 0.9}), 13 ({}, {"feature1": None, "feature2": None}), 14 ], 15 ) 16 def test_confidence_threshold_valid(confidence, expected): 17 options = DataDriftOptions(confidence=confidence) 18 for feature, expected_threshold in expected.items(): 19 assert approx(options.get_threshold(feature, "num"), expected_threshold) 20 21 22 @pytest.mark.parametrize( 23 "threshold,expected", 24 [ 25 (0.1, {"feature1": 0.1, "feature2": 0.1}), 26 ({"feature1": 0.1}, {"feature1": 0.1, "feature2": None}), 27 ({"feature2": 0.1}, {"feature1": None, "feature2": 0.1}), 28 ({}, {"feature1": None, "feature2": None}), 29 ], 30 ) 31 def test_threshold_valid(threshold, expected): 32 options = DataDriftOptions(threshold=threshold) 33 for feature, expected_threshold in expected.items(): 34 assert approx(options.get_threshold(feature, "num"), expected_threshold) 35 36 37 def test_threshold_default(): 38 options = DataDriftOptions() 39 assert options.get_threshold("feature", "num") is None 40 41 42 def _default_stattest(): 43 pass 44 45 46 def _custom_stattest(): 47 pass 48 49 50 def _another_stattest(): 51 pass 52 53 54 @pytest.mark.parametrize( 55 "feature_func,expected", 56 [ 57 (None, {"feature1": None, "feature2": None}), 58 ("st1", {"feature1": "st1", "feature2": "st1"}), 59 ({"feature1": _custom_stattest}, {"feature1": _custom_stattest, "feature2": None}), 60 ({"feature2": _custom_stattest}, {"feature1": None, "feature2": _custom_stattest}), 61 ( 62 {"feature1": _another_stattest, "feature2": _custom_stattest}, 63 {"feature1": _another_stattest, "feature2": _custom_stattest}, 64 ), 65 ({"feature1": "st1"}, {"feature1": "st1", "feature2": None}), 66 ({"feature2": "st2"}, {"feature1": None, "feature2": "st2"}), 67 ({"feature1": "st1", "feature2": "st2"}, {"feature1": "st1", "feature2": "st2"}), 68 ({"feature1": _another_stattest, "feature2": "st2"}, {"feature1": _another_stattest, "feature2": "st2"}), 69 ], 70 ) 71 def test_stattest_function_valid(feature_func, expected): 72 options = DataDriftOptions(feature_stattest_func=feature_func) 73 for feature, expected_func in expected.items(): 74 assert options.get_feature_stattest_func(feature, "cat") == expected_func 75 76 77 @pytest.mark.parametrize( 78 "global_st,cat_st,num_st,per_feature_st,expected", 79 [ 80 (None, None, None, None, {"cat1": None, "cat2": None, "num1": None, "num2": None}), 81 ("st1", None, None, None, {"cat1": "st1", "cat2": "st1", "num1": "st1", "num2": "st1"}), 82 (None, None, None, {"cat1": "st1"}, {"cat1": "st1", "cat2": None, "num1": None, "num2": None}), 83 ( 84 None, 85 None, 86 None, 87 {"cat2": _custom_stattest}, 88 {"cat1": None, "cat2": _custom_stattest, "num1": None, "num2": None}, 89 ), 90 ( 91 None, 92 None, 93 None, 94 {"cat1": _custom_stattest, "num1": _another_stattest}, 95 {"cat1": _custom_stattest, "cat2": None, "num1": _another_stattest, "num2": None}, 96 ), 97 (None, "st1", None, None, {"cat1": "st1", "cat2": "st1", "num1": None, "num2": None}), 98 ( 99 None, 100 _custom_stattest, 101 None, 102 None, 103 {"cat1": _custom_stattest, "cat2": _custom_stattest, "num1": None, "num2": None}, 104 ), 105 ( 106 None, 107 None, 108 _custom_stattest, 109 None, 110 {"cat1": None, "cat2": None, "num1": _custom_stattest, "num2": _custom_stattest}, 111 ), 112 ("st1", "st2", None, None, {"cat1": "st2", "cat2": "st2", "num1": "st1", "num2": "st1"}), 113 ("st1", None, "st2", None, {"cat1": "st1", "cat2": "st1", "num1": "st2", "num2": "st2"}), 114 ( 115 "st1", 116 None, 117 None, 118 {"cat2": "st2", "num2": "st2"}, 119 {"cat1": "st1", "cat2": "st2", "num1": "st1", "num2": "st2"}, 120 ), 121 ( 122 "st1", 123 "st2", 124 "st3", 125 {"cat2": "st4", "num2": "st5"}, 126 {"cat1": "st2", "cat2": "st4", "num1": "st3", "num2": "st5"}, 127 ), 128 ], 129 ) 130 def test_stattest_function_valid_v2(global_st, cat_st, num_st, per_feature_st, expected): 131 features_with_types = {"cat1": "cat", "cat2": "cat", "num1": "num", "num2": "num"} 132 options = DataDriftOptions( 133 all_features_stattest=global_st, 134 cat_features_stattest=cat_st, 135 num_features_stattest=num_st, 136 per_feature_stattest=per_feature_st, 137 ) 138 for feature, expected_func in expected.items(): 139 assert options.get_feature_stattest_func(feature, features_with_types[feature]) == expected_func 140 141 142 @pytest.mark.parametrize( 143 "feature_st,global_st,cat_st,num_st,per_feature_st", 144 ( 145 [ 146 ("st1", "st2", None, None, None), 147 ("st1", None, "st2", None, None), 148 ("st1", None, None, "st2", None), 149 ("st1", None, None, None, {"f1": "st2"}), 150 ] 151 ), 152 ) 153 def test_stattest_function_deprecated(feature_st, global_st, cat_st, num_st, per_feature_st): 154 options = DataDriftOptions( 155 feature_stattest_func=feature_st, 156 all_features_stattest=global_st, 157 cat_features_stattest=cat_st, 158 num_features_stattest=num_st, 159 per_feature_stattest=per_feature_st, 160 ) 161 with pytest.raises(ValueError): 162 options.get_feature_stattest_func("f1", "cat") 163 164 165 @pytest.mark.parametrize( 166 "nbinsx,expected", 167 [ 168 (20, {"feature1": 20, "feature2": 20}), 169 ({"feature1": 15}, {"feature1": 15, "feature2": DataDriftOptions.DEFAULT_NBINSX}), 170 ({"feature2": 11}, {"feature1": DataDriftOptions.DEFAULT_NBINSX, "feature2": 11}), 171 ({"feature1": 25, "feature2": 35}, {"feature1": 25, "feature2": 35}), 172 ], 173 ) 174 def test_nbinsx_valid(nbinsx, expected): 175 options = DataDriftOptions(nbinsx=nbinsx) 176 for feature, expected_nbinsx in expected.items(): 177 assert options.get_nbinsx(feature) == expected_nbinsx