/ src / evidently / legacy / options / data_drift.py
data_drift.py
  1  import warnings
  2  from typing import Any
  3  from typing import ClassVar
  4  from typing import Dict
  5  from typing import Optional
  6  from typing import Union
  7  
  8  from evidently._pydantic_compat import BaseModel
  9  from evidently.legacy.calculations.stattests import PossibleStatTestType
 10  from evidently.legacy.calculations.stattests import StatTest
 11  from evidently.legacy.utils.data_drift_utils import resolve_stattest_threshold
 12  
 13  
 14  class DataDriftOptions(BaseModel):
 15      """Configuration for Data Drift calculations.
 16  
 17      Args:
 18          confidence: Defines the confidence level for statistical tests.
 19                      Applies to all features (if passed as float) or certain features (if passed as dictionary).
 20                      (Deprecated) Use `threshold` to define confidence level for statistical
 21                      tests as more universal solution.
 22          threshold: Defines thresholds for statistical tests.
 23                     Applies to all features (if passed as float) or certain features (if passed as dictionary).
 24          drift_share: Sets the share of drifting features as a condition for Dataset Drift in the Data Drift report.
 25          nbinsx: Defines the number of bins in a histogram.
 26                  Applies to all features (if passed as int) or certain features (if passed as dictionary).
 27          xbins: Defines the boundaries for the size of a specific bin in a histogram.
 28          feature_stattest_func: Defines a custom statistical test for drift detection in the Data Drift report.
 29                                 Applies to all features (if passed as a function) or individual features (if a dict).
 30                                 (Deprecated) Use `all_features_stattest` or `per_feature_stattest`.
 31          all_features_stattest: Defines a custom statistical test for drift detection in the Data Drift report
 32                                 for all features.
 33          cat_features_stattest: Defines a custom statistical test for drift detection in the Data Drift report
 34                                 for categorical features only.
 35          num_features_stattest: Defines a custom statistical test for drift detection in the Data Drift report
 36                                 for numerical features only.
 37          per_feature_stattest: Defines a custom statistical test for drift detection in the Data Drift report
 38                                per feature.
 39          cat_target_stattest_func: Defines a custom statistical test to detect target drift in category target.
 40          num_target_stattest_func: Defines a custom statistical test to detect target drift in numeric target.
 41      """
 42  
 43      DEFAULT_NBINSX: ClassVar = 10
 44  
 45      confidence: Optional[Union[float, Dict[str, float]]] = None
 46      threshold: Optional[Union[float, Dict[str, float]]] = None
 47      drift_share: float = 0.5
 48      nbinsx: Optional[Union[int, Dict[str, int]]] = None
 49      xbins: Optional[Dict[str, int]] = None
 50  
 51      feature_stattest_func: Optional[Union[PossibleStatTestType, Dict[str, PossibleStatTestType]]] = None
 52  
 53      all_features_stattest: Optional[PossibleStatTestType] = None
 54      cat_features_stattest: Optional[PossibleStatTestType] = None
 55      num_features_stattest: Optional[PossibleStatTestType] = None
 56      text_features_stattest: Optional[PossibleStatTestType] = None
 57      per_feature_stattest: Optional[Dict[str, PossibleStatTestType]] = None
 58  
 59      all_features_threshold: Optional[float] = None
 60      cat_features_threshold: Optional[float] = None
 61      num_features_threshold: Optional[float] = None
 62      text_features_threshold: Optional[float] = None
 63      per_feature_threshold: Optional[Dict[str, float]] = None
 64  
 65      cat_target_threshold: Optional[float] = None
 66      num_target_threshold: Optional[float] = None
 67  
 68      cat_target_stattest_func: Optional[PossibleStatTestType] = None
 69      num_target_stattest_func: Optional[PossibleStatTestType] = None
 70  
 71      def as_dict(self) -> Dict[str, Any]:
 72          return {
 73              "confidence": self.confidence,
 74              "drift_share": self.drift_share,
 75              "nbinsx": self.nbinsx,
 76              "xbins": self.xbins,
 77          }
 78  
 79      def _calculate_threshold(self, feature_name: str, feature_type: str) -> Optional[float]:
 80          if self.threshold is not None:
 81              if isinstance(self.threshold, float):
 82                  return self.threshold
 83  
 84              if isinstance(self.threshold, dict):
 85                  return self.threshold.get(feature_name)
 86  
 87              raise ValueError(f"DataDriftOptions.threshold is incorrect type {type(self.threshold)}")
 88  
 89          _, threshold = resolve_stattest_threshold(
 90              feature_name,
 91              feature_type,
 92              self.all_features_stattest,
 93              self.cat_features_stattest,
 94              self.num_features_stattest,
 95              self.text_features_stattest,
 96              self.per_feature_stattest,
 97              self.all_features_threshold,
 98              self.cat_features_threshold,
 99              self.num_features_threshold,
100              self.text_features_threshold,
101              self.per_feature_threshold,
102          )
103          return threshold
104  
105      def get_threshold(self, feature_name: str, feature_type: str) -> Optional[float]:
106          threshold = self._calculate_threshold(feature_name, feature_type)
107  
108          if self.confidence is not None and threshold is not None:
109              raise ValueError("Only DataDriftOptions.confidence or DataDriftOptions.threshold can be set")
110  
111          if self.confidence is not None:
112              warnings.warn("DataDriftOptions.confidence is deprecated, use DataDriftOptions.threshold instead.")
113  
114              if isinstance(self.confidence, float):
115                  return 1.0 - self.confidence
116  
117              if isinstance(self.confidence, dict):
118                  override = self.confidence.get(feature_name)
119                  return None if override is None else 1.0 - override
120  
121              raise ValueError(f"DataDriftOptions.confidence is incorrect type {type(self.confidence)}")
122  
123          return threshold
124  
125      def get_nbinsx_or_none(self, feature_name: str) -> Optional[int]:
126          if self.nbinsx is None:
127              return None
128          if isinstance(self.nbinsx, int):
129              return self.nbinsx
130          if isinstance(self.nbinsx, dict):
131              return self.nbinsx.get(feature_name)
132          raise ValueError(f"DataDriftOptions.nbinsx is incorrect type {type(self.nbinsx)}")
133  
134      def get_nbinsx(self, feature_name: str) -> int:
135          if self.nbinsx is None:
136              return DataDriftOptions.DEFAULT_NBINSX
137          if isinstance(self.nbinsx, int):
138              return self.nbinsx
139          if isinstance(self.nbinsx, dict):
140              return self.nbinsx.get(feature_name, DataDriftOptions.DEFAULT_NBINSX)
141          raise ValueError(f"DataDriftOptions.nbinsx is incorrect type {type(self.nbinsx)}")
142  
143      def get_feature_stattest_func(self, feature_name: str, feature_type: str) -> Optional[PossibleStatTestType]:
144          if self.feature_stattest_func is not None and any(
145              [
146                  self.all_features_stattest,
147                  self.cat_features_stattest,
148                  self.num_features_stattest,
149                  self.text_features_stattest,
150                  self.per_feature_stattest,
151              ]
152          ):
153              raise ValueError(
154                  "Cannot use DataDriftOptions.feature_stattest_func along with any "
155                  "of DataDriftOptions.cat_stattest_func,"
156                  " DataDriftOptions.num_stattest_func,"
157                  " DataDriftOptions.text_stattest_func,"
158                  " DataDriftOptions.per_feature_stattest_func."
159              )
160          if self.feature_stattest_func is not None:
161              warnings.warn(
162                  "DataDriftOptions.feature_stattest_func is deprecated use DataDriftOptions.stattest_func"
163                  " or DataDriftOptions.per_feature_stattest_func."
164              )
165              if callable(self.feature_stattest_func) or isinstance(self.feature_stattest_func, (StatTest, str)):
166                  return self.feature_stattest_func
167              if isinstance(self.feature_stattest_func, dict):
168                  return self.feature_stattest_func.get(feature_name)
169              return None
170          stattest, _ = resolve_stattest_threshold(
171              feature_name,
172              feature_type,
173              self.all_features_stattest,
174              self.cat_features_stattest,
175              self.num_features_stattest,
176              self.text_features_stattest,
177              self.per_feature_stattest,
178              self.all_features_threshold,
179              self.cat_features_threshold,
180              self.num_features_threshold,
181              self.text_features_threshold,
182              self.per_feature_threshold,
183          )
184          return stattest
185  
186      def __hash__(self) -> int:
187          """Calculate hash for data drift options - for using in metrics deduplication via dicts."""
188          return str(self.as_dict()).__hash__()