drift.py
1 from typing import Dict 2 from typing import List 3 from typing import Optional 4 from typing import Sequence 5 from typing import Tuple 6 7 from evidently.core.container import MetricContainer 8 from evidently.core.container import MetricOrContainer 9 from evidently.core.metric_types import MetricId 10 from evidently.core.report import Context 11 from evidently.core.report import _default_input_data_generator 12 from evidently.legacy.calculations.stattests import PossibleStatTestType 13 from evidently.legacy.calculations.stattests import StatTest 14 from evidently.legacy.core import ColumnType 15 from evidently.legacy.metrics import DataDriftTable 16 from evidently.legacy.metrics import DatasetDriftMetric 17 from evidently.legacy.metrics.data_drift.embedding_drift_methods import DriftMethod 18 from evidently.legacy.model.widget import BaseWidgetInfo 19 from evidently.legacy.options.data_drift import DataDriftOptions 20 from evidently.metrics import ValueDrift 21 from evidently.metrics.column_statistics import DriftedColumnsCount 22 23 24 class DataDriftPreset(MetricContainer): 25 """Large preset for detecting data drift across columns. 26 27 Calculates data drift for all or specified columns using various drift detection methods. 28 Generates `DriftedColumnsCount` and `ValueDrift` metrics for each column. Requires reference data. 29 30 31 See Also: 32 * [Drift Methods Documentation](https://docs.evidentlyai.com/metrics/customize_data_drift) for available methods. 33 """ 34 35 columns: Optional[List[str]] = None 36 """Optional list of column names to analyze (None = all columns).""" 37 embeddings: Optional[List[str]] = None 38 """Optional list of embedding column names.""" 39 embeddings_drift_method: Optional[Dict[str, DriftMethod]] = None 40 """Optional dictionary mapping embedding columns to drift methods.""" 41 drift_share: float = 0.5 42 """Threshold for drift share (0.5 = 50% of columns).""" 43 method: Optional[PossibleStatTestType] = None 44 """Optional drift detection method for all columns (auto-selected if None).""" 45 cat_method: Optional[PossibleStatTestType] = None 46 """Optional method for categorical columns.""" 47 num_method: Optional[PossibleStatTestType] = None 48 """Optional method for numerical columns.""" 49 text_method: Optional[PossibleStatTestType] = None 50 """Optional method for text columns.""" 51 per_column_method: Optional[Dict[str, PossibleStatTestType]] = None 52 """Optional dictionary mapping column names to methods.""" 53 threshold: Optional[float] = None 54 """Optional drift threshold for all columns (uses method default if None).""" 55 cat_threshold: Optional[float] = None 56 """Optional threshold for categorical columns.""" 57 num_threshold: Optional[float] = None 58 """Optional threshold for numerical columns.""" 59 text_threshold: Optional[float] = None 60 """Optional threshold for text columns.""" 61 per_column_threshold: Optional[Dict[str, float]] = None 62 """Optional dictionary mapping column names to thresholds.""" 63 64 def __init__( 65 self, 66 columns: Optional[List[str]] = None, 67 embeddings: Optional[List[str]] = None, 68 embeddings_drift_method: Optional[Dict[str, DriftMethod]] = None, 69 drift_share: float = 0.5, 70 method: Optional[PossibleStatTestType] = None, 71 cat_method: Optional[PossibleStatTestType] = None, 72 num_method: Optional[PossibleStatTestType] = None, 73 text_method: Optional[PossibleStatTestType] = None, 74 per_column_method: Optional[Dict[str, PossibleStatTestType]] = None, 75 threshold: Optional[float] = None, 76 cat_threshold: Optional[float] = None, 77 num_threshold: Optional[float] = None, 78 text_threshold: Optional[float] = None, 79 per_column_threshold: Optional[Dict[str, float]] = None, 80 include_tests: bool = True, 81 ): 82 self.per_column_threshold = per_column_threshold 83 self.text_threshold = text_threshold 84 self.num_threshold = num_threshold 85 self.cat_threshold = cat_threshold 86 self.threshold = threshold 87 self.per_column_method = per_column_method 88 self.text_method = text_method 89 self.num_method = num_method 90 self.cat_method = cat_method 91 self.method = method 92 self.drift_share = drift_share 93 self.embeddings_drift_method = embeddings_drift_method 94 self.embeddings = embeddings 95 self.columns = columns 96 super().__init__(include_tests=include_tests) 97 98 def generate_metrics(self, context: Context) -> Sequence[MetricOrContainer]: 99 types = [ColumnType.Numerical, ColumnType.Categorical, ColumnType.Text] 100 options = DataDriftOptions( 101 drift_share=self.drift_share, 102 all_features_stattest=self.method, 103 cat_features_stattest=self.cat_method, 104 num_features_stattest=self.num_method, 105 text_features_stattest=self.text_method, 106 per_feature_stattest=self.per_column_method, 107 all_features_threshold=self.threshold, 108 cat_features_threshold=self.cat_threshold, 109 num_features_threshold=self.num_threshold, 110 text_features_threshold=self.text_threshold, 111 per_feature_threshold=self.per_column_threshold, 112 ) 113 return [ 114 DriftedColumnsCount( 115 columns=self.columns, 116 drift_share=self.drift_share, 117 method=self.method, 118 cat_method=self.cat_method, 119 num_method=self.num_method, 120 text_method=self.text_method, 121 per_column_method=self.per_column_method, 122 threshold=self.threshold, 123 cat_threshold=self.cat_threshold, 124 num_threshold=self.num_threshold, 125 text_threshold=self.text_threshold, 126 per_column_threshold=self.per_column_threshold, 127 ), 128 ] + [ 129 ValueDrift( 130 column=column, 131 method=self._get_drift_stattest( 132 column, 133 False, 134 context.data_definition.get_column_type(column), 135 options, 136 ), 137 threshold=options.get_threshold(column, context.data_definition.get_column_type(column).value), 138 ) 139 for column in (self.columns if self.columns is not None else context.data_definition.get_columns(types)) 140 ] 141 142 def render( 143 self, 144 context: "Context", 145 child_widgets: Optional[List[Tuple[Optional[MetricId], List[BaseWidgetInfo]]]] = None, 146 ) -> List[BaseWidgetInfo]: 147 dataset_drift = context.get_legacy_metric( 148 DatasetDriftMetric( 149 columns=self.columns, 150 drift_share=self.drift_share, 151 stattest=self.method, 152 cat_stattest=self.cat_method, 153 num_stattest=self.num_method, 154 text_stattest=self.text_method, 155 per_column_stattest=self.per_column_method, 156 stattest_threshold=self.threshold, 157 cat_stattest_threshold=self.cat_threshold, 158 num_stattest_threshold=self.num_threshold, 159 text_stattest_threshold=self.text_threshold, 160 per_column_stattest_threshold=self.per_column_threshold, 161 ), 162 _default_input_data_generator, 163 None, # TODO: parametrize task name 164 )[1] 165 table = context.get_legacy_metric( 166 DataDriftTable( 167 columns=self.columns, 168 stattest=self.method, 169 cat_stattest=self.cat_method, 170 num_stattest=self.num_method, 171 text_stattest=self.text_method, 172 per_column_stattest=self.per_column_method, 173 stattest_threshold=self.threshold, 174 cat_stattest_threshold=self.cat_threshold, 175 num_stattest_threshold=self.num_threshold, 176 text_stattest_threshold=self.text_threshold, 177 per_column_stattest_threshold=self.per_column_threshold, 178 ), 179 _default_input_data_generator, 180 None, # TODO: parametrize task name 181 )[1] 182 return dataset_drift + table 183 184 def _get_drift_stattest( 185 self, 186 column_name: str, 187 is_target: bool, 188 column_type: ColumnType, 189 options: DataDriftOptions, 190 ): 191 stattest = None 192 193 if is_target and column_type == ColumnType.Numerical: 194 stattest = options.num_target_stattest_func 195 196 elif is_target and column_type == ColumnType.Categorical: 197 stattest = options.cat_target_stattest_func 198 199 if not stattest: 200 stattest = options.get_feature_stattest_func(column_name, column_type.value) 201 if stattest: 202 if isinstance(stattest, str): 203 return stattest 204 if isinstance(stattest, StatTest): 205 return stattest.name 206 return stattest 207 return None