data_drift.py
1 import warnings 2 from typing import Any 3 from typing import ClassVar 4 from typing import Dict 5 from typing import Optional 6 from typing import Union 7 8 from evidently._pydantic_compat import BaseModel 9 from evidently.legacy.calculations.stattests import PossibleStatTestType 10 from evidently.legacy.calculations.stattests import StatTest 11 from evidently.legacy.utils.data_drift_utils import resolve_stattest_threshold 12 13 14 class DataDriftOptions(BaseModel): 15 """Configuration for Data Drift calculations. 16 17 Args: 18 confidence: Defines the confidence level for statistical tests. 19 Applies to all features (if passed as float) or certain features (if passed as dictionary). 20 (Deprecated) Use `threshold` to define confidence level for statistical 21 tests as more universal solution. 22 threshold: Defines thresholds for statistical tests. 23 Applies to all features (if passed as float) or certain features (if passed as dictionary). 24 drift_share: Sets the share of drifting features as a condition for Dataset Drift in the Data Drift report. 25 nbinsx: Defines the number of bins in a histogram. 26 Applies to all features (if passed as int) or certain features (if passed as dictionary). 27 xbins: Defines the boundaries for the size of a specific bin in a histogram. 28 feature_stattest_func: Defines a custom statistical test for drift detection in the Data Drift report. 29 Applies to all features (if passed as a function) or individual features (if a dict). 30 (Deprecated) Use `all_features_stattest` or `per_feature_stattest`. 31 all_features_stattest: Defines a custom statistical test for drift detection in the Data Drift report 32 for all features. 33 cat_features_stattest: Defines a custom statistical test for drift detection in the Data Drift report 34 for categorical features only. 35 num_features_stattest: Defines a custom statistical test for drift detection in the Data Drift report 36 for numerical features only. 37 per_feature_stattest: Defines a custom statistical test for drift detection in the Data Drift report 38 per feature. 39 cat_target_stattest_func: Defines a custom statistical test to detect target drift in category target. 40 num_target_stattest_func: Defines a custom statistical test to detect target drift in numeric target. 41 """ 42 43 DEFAULT_NBINSX: ClassVar = 10 44 45 confidence: Optional[Union[float, Dict[str, float]]] = None 46 threshold: Optional[Union[float, Dict[str, float]]] = None 47 drift_share: float = 0.5 48 nbinsx: Optional[Union[int, Dict[str, int]]] = None 49 xbins: Optional[Dict[str, int]] = None 50 51 feature_stattest_func: Optional[Union[PossibleStatTestType, Dict[str, PossibleStatTestType]]] = None 52 53 all_features_stattest: Optional[PossibleStatTestType] = None 54 cat_features_stattest: Optional[PossibleStatTestType] = None 55 num_features_stattest: Optional[PossibleStatTestType] = None 56 text_features_stattest: Optional[PossibleStatTestType] = None 57 per_feature_stattest: Optional[Dict[str, PossibleStatTestType]] = None 58 59 all_features_threshold: Optional[float] = None 60 cat_features_threshold: Optional[float] = None 61 num_features_threshold: Optional[float] = None 62 text_features_threshold: Optional[float] = None 63 per_feature_threshold: Optional[Dict[str, float]] = None 64 65 cat_target_threshold: Optional[float] = None 66 num_target_threshold: Optional[float] = None 67 68 cat_target_stattest_func: Optional[PossibleStatTestType] = None 69 num_target_stattest_func: Optional[PossibleStatTestType] = None 70 71 def as_dict(self) -> Dict[str, Any]: 72 return { 73 "confidence": self.confidence, 74 "drift_share": self.drift_share, 75 "nbinsx": self.nbinsx, 76 "xbins": self.xbins, 77 } 78 79 def _calculate_threshold(self, feature_name: str, feature_type: str) -> Optional[float]: 80 if self.threshold is not None: 81 if isinstance(self.threshold, float): 82 return self.threshold 83 84 if isinstance(self.threshold, dict): 85 return self.threshold.get(feature_name) 86 87 raise ValueError(f"DataDriftOptions.threshold is incorrect type {type(self.threshold)}") 88 89 _, threshold = resolve_stattest_threshold( 90 feature_name, 91 feature_type, 92 self.all_features_stattest, 93 self.cat_features_stattest, 94 self.num_features_stattest, 95 self.text_features_stattest, 96 self.per_feature_stattest, 97 self.all_features_threshold, 98 self.cat_features_threshold, 99 self.num_features_threshold, 100 self.text_features_threshold, 101 self.per_feature_threshold, 102 ) 103 return threshold 104 105 def get_threshold(self, feature_name: str, feature_type: str) -> Optional[float]: 106 threshold = self._calculate_threshold(feature_name, feature_type) 107 108 if self.confidence is not None and threshold is not None: 109 raise ValueError("Only DataDriftOptions.confidence or DataDriftOptions.threshold can be set") 110 111 if self.confidence is not None: 112 warnings.warn("DataDriftOptions.confidence is deprecated, use DataDriftOptions.threshold instead.") 113 114 if isinstance(self.confidence, float): 115 return 1.0 - self.confidence 116 117 if isinstance(self.confidence, dict): 118 override = self.confidence.get(feature_name) 119 return None if override is None else 1.0 - override 120 121 raise ValueError(f"DataDriftOptions.confidence is incorrect type {type(self.confidence)}") 122 123 return threshold 124 125 def get_nbinsx_or_none(self, feature_name: str) -> Optional[int]: 126 if self.nbinsx is None: 127 return None 128 if isinstance(self.nbinsx, int): 129 return self.nbinsx 130 if isinstance(self.nbinsx, dict): 131 return self.nbinsx.get(feature_name) 132 raise ValueError(f"DataDriftOptions.nbinsx is incorrect type {type(self.nbinsx)}") 133 134 def get_nbinsx(self, feature_name: str) -> int: 135 if self.nbinsx is None: 136 return DataDriftOptions.DEFAULT_NBINSX 137 if isinstance(self.nbinsx, int): 138 return self.nbinsx 139 if isinstance(self.nbinsx, dict): 140 return self.nbinsx.get(feature_name, DataDriftOptions.DEFAULT_NBINSX) 141 raise ValueError(f"DataDriftOptions.nbinsx is incorrect type {type(self.nbinsx)}") 142 143 def get_feature_stattest_func(self, feature_name: str, feature_type: str) -> Optional[PossibleStatTestType]: 144 if self.feature_stattest_func is not None and any( 145 [ 146 self.all_features_stattest, 147 self.cat_features_stattest, 148 self.num_features_stattest, 149 self.text_features_stattest, 150 self.per_feature_stattest, 151 ] 152 ): 153 raise ValueError( 154 "Cannot use DataDriftOptions.feature_stattest_func along with any " 155 "of DataDriftOptions.cat_stattest_func," 156 " DataDriftOptions.num_stattest_func," 157 " DataDriftOptions.text_stattest_func," 158 " DataDriftOptions.per_feature_stattest_func." 159 ) 160 if self.feature_stattest_func is not None: 161 warnings.warn( 162 "DataDriftOptions.feature_stattest_func is deprecated use DataDriftOptions.stattest_func" 163 " or DataDriftOptions.per_feature_stattest_func." 164 ) 165 if callable(self.feature_stattest_func) or isinstance(self.feature_stattest_func, (StatTest, str)): 166 return self.feature_stattest_func 167 if isinstance(self.feature_stattest_func, dict): 168 return self.feature_stattest_func.get(feature_name) 169 return None 170 stattest, _ = resolve_stattest_threshold( 171 feature_name, 172 feature_type, 173 self.all_features_stattest, 174 self.cat_features_stattest, 175 self.num_features_stattest, 176 self.text_features_stattest, 177 self.per_feature_stattest, 178 self.all_features_threshold, 179 self.cat_features_threshold, 180 self.num_features_threshold, 181 self.text_features_threshold, 182 self.per_feature_threshold, 183 ) 184 return stattest 185 186 def __hash__(self) -> int: 187 """Calculate hash for data drift options - for using in metrics deduplication via dicts.""" 188 return str(self.as_dict()).__hash__()