/ tests / options / test_data_drift.py
test_data_drift.py
  1  import pytest
  2  from numpy.ma.testutils import approx
  3  
  4  from evidently.legacy.options.data_drift import DataDriftOptions
  5  
  6  
  7  @pytest.mark.parametrize(
  8      "confidence,expected",
  9      [
 10          (0.1, {"feature1": 0.9, "feature2": 0.9}),
 11          ({"feature1": 0.1}, {"feature1": 0.9, "feature2": None}),
 12          ({"feature2": 0.1}, {"feature1": None, "feature2": 0.9}),
 13          ({}, {"feature1": None, "feature2": None}),
 14      ],
 15  )
 16  def test_confidence_threshold_valid(confidence, expected):
 17      options = DataDriftOptions(confidence=confidence)
 18      for feature, expected_threshold in expected.items():
 19          assert approx(options.get_threshold(feature, "num"), expected_threshold)
 20  
 21  
 22  @pytest.mark.parametrize(
 23      "threshold,expected",
 24      [
 25          (0.1, {"feature1": 0.1, "feature2": 0.1}),
 26          ({"feature1": 0.1}, {"feature1": 0.1, "feature2": None}),
 27          ({"feature2": 0.1}, {"feature1": None, "feature2": 0.1}),
 28          ({}, {"feature1": None, "feature2": None}),
 29      ],
 30  )
 31  def test_threshold_valid(threshold, expected):
 32      options = DataDriftOptions(threshold=threshold)
 33      for feature, expected_threshold in expected.items():
 34          assert approx(options.get_threshold(feature, "num"), expected_threshold)
 35  
 36  
 37  def test_threshold_default():
 38      options = DataDriftOptions()
 39      assert options.get_threshold("feature", "num") is None
 40  
 41  
 42  def _default_stattest():
 43      pass
 44  
 45  
 46  def _custom_stattest():
 47      pass
 48  
 49  
 50  def _another_stattest():
 51      pass
 52  
 53  
 54  @pytest.mark.parametrize(
 55      "feature_func,expected",
 56      [
 57          (None, {"feature1": None, "feature2": None}),
 58          ("st1", {"feature1": "st1", "feature2": "st1"}),
 59          ({"feature1": _custom_stattest}, {"feature1": _custom_stattest, "feature2": None}),
 60          ({"feature2": _custom_stattest}, {"feature1": None, "feature2": _custom_stattest}),
 61          (
 62              {"feature1": _another_stattest, "feature2": _custom_stattest},
 63              {"feature1": _another_stattest, "feature2": _custom_stattest},
 64          ),
 65          ({"feature1": "st1"}, {"feature1": "st1", "feature2": None}),
 66          ({"feature2": "st2"}, {"feature1": None, "feature2": "st2"}),
 67          ({"feature1": "st1", "feature2": "st2"}, {"feature1": "st1", "feature2": "st2"}),
 68          ({"feature1": _another_stattest, "feature2": "st2"}, {"feature1": _another_stattest, "feature2": "st2"}),
 69      ],
 70  )
 71  def test_stattest_function_valid(feature_func, expected):
 72      options = DataDriftOptions(feature_stattest_func=feature_func)
 73      for feature, expected_func in expected.items():
 74          assert options.get_feature_stattest_func(feature, "cat") == expected_func
 75  
 76  
 77  @pytest.mark.parametrize(
 78      "global_st,cat_st,num_st,per_feature_st,expected",
 79      [
 80          (None, None, None, None, {"cat1": None, "cat2": None, "num1": None, "num2": None}),
 81          ("st1", None, None, None, {"cat1": "st1", "cat2": "st1", "num1": "st1", "num2": "st1"}),
 82          (None, None, None, {"cat1": "st1"}, {"cat1": "st1", "cat2": None, "num1": None, "num2": None}),
 83          (
 84              None,
 85              None,
 86              None,
 87              {"cat2": _custom_stattest},
 88              {"cat1": None, "cat2": _custom_stattest, "num1": None, "num2": None},
 89          ),
 90          (
 91              None,
 92              None,
 93              None,
 94              {"cat1": _custom_stattest, "num1": _another_stattest},
 95              {"cat1": _custom_stattest, "cat2": None, "num1": _another_stattest, "num2": None},
 96          ),
 97          (None, "st1", None, None, {"cat1": "st1", "cat2": "st1", "num1": None, "num2": None}),
 98          (
 99              None,
100              _custom_stattest,
101              None,
102              None,
103              {"cat1": _custom_stattest, "cat2": _custom_stattest, "num1": None, "num2": None},
104          ),
105          (
106              None,
107              None,
108              _custom_stattest,
109              None,
110              {"cat1": None, "cat2": None, "num1": _custom_stattest, "num2": _custom_stattest},
111          ),
112          ("st1", "st2", None, None, {"cat1": "st2", "cat2": "st2", "num1": "st1", "num2": "st1"}),
113          ("st1", None, "st2", None, {"cat1": "st1", "cat2": "st1", "num1": "st2", "num2": "st2"}),
114          (
115              "st1",
116              None,
117              None,
118              {"cat2": "st2", "num2": "st2"},
119              {"cat1": "st1", "cat2": "st2", "num1": "st1", "num2": "st2"},
120          ),
121          (
122              "st1",
123              "st2",
124              "st3",
125              {"cat2": "st4", "num2": "st5"},
126              {"cat1": "st2", "cat2": "st4", "num1": "st3", "num2": "st5"},
127          ),
128      ],
129  )
130  def test_stattest_function_valid_v2(global_st, cat_st, num_st, per_feature_st, expected):
131      features_with_types = {"cat1": "cat", "cat2": "cat", "num1": "num", "num2": "num"}
132      options = DataDriftOptions(
133          all_features_stattest=global_st,
134          cat_features_stattest=cat_st,
135          num_features_stattest=num_st,
136          per_feature_stattest=per_feature_st,
137      )
138      for feature, expected_func in expected.items():
139          assert options.get_feature_stattest_func(feature, features_with_types[feature]) == expected_func
140  
141  
142  @pytest.mark.parametrize(
143      "feature_st,global_st,cat_st,num_st,per_feature_st",
144      (
145          [
146              ("st1", "st2", None, None, None),
147              ("st1", None, "st2", None, None),
148              ("st1", None, None, "st2", None),
149              ("st1", None, None, None, {"f1": "st2"}),
150          ]
151      ),
152  )
153  def test_stattest_function_deprecated(feature_st, global_st, cat_st, num_st, per_feature_st):
154      options = DataDriftOptions(
155          feature_stattest_func=feature_st,
156          all_features_stattest=global_st,
157          cat_features_stattest=cat_st,
158          num_features_stattest=num_st,
159          per_feature_stattest=per_feature_st,
160      )
161      with pytest.raises(ValueError):
162          options.get_feature_stattest_func("f1", "cat")
163  
164  
165  @pytest.mark.parametrize(
166      "nbinsx,expected",
167      [
168          (20, {"feature1": 20, "feature2": 20}),
169          ({"feature1": 15}, {"feature1": 15, "feature2": DataDriftOptions.DEFAULT_NBINSX}),
170          ({"feature2": 11}, {"feature1": DataDriftOptions.DEFAULT_NBINSX, "feature2": 11}),
171          ({"feature1": 25, "feature2": 35}, {"feature1": 25, "feature2": 35}),
172      ],
173  )
174  def test_nbinsx_valid(nbinsx, expected):
175      options = DataDriftOptions(nbinsx=nbinsx)
176      for feature, expected_nbinsx in expected.items():
177          assert options.get_nbinsx(feature) == expected_nbinsx