/ src / evidently / presets / drift.py
drift.py
  1  from typing import Dict
  2  from typing import List
  3  from typing import Optional
  4  from typing import Sequence
  5  from typing import Tuple
  6  
  7  from evidently.core.container import MetricContainer
  8  from evidently.core.container import MetricOrContainer
  9  from evidently.core.metric_types import MetricId
 10  from evidently.core.report import Context
 11  from evidently.core.report import _default_input_data_generator
 12  from evidently.legacy.calculations.stattests import PossibleStatTestType
 13  from evidently.legacy.calculations.stattests import StatTest
 14  from evidently.legacy.core import ColumnType
 15  from evidently.legacy.metrics import DataDriftTable
 16  from evidently.legacy.metrics import DatasetDriftMetric
 17  from evidently.legacy.metrics.data_drift.embedding_drift_methods import DriftMethod
 18  from evidently.legacy.model.widget import BaseWidgetInfo
 19  from evidently.legacy.options.data_drift import DataDriftOptions
 20  from evidently.metrics import ValueDrift
 21  from evidently.metrics.column_statistics import DriftedColumnsCount
 22  
 23  
 24  class DataDriftPreset(MetricContainer):
 25      """Large preset for detecting data drift across columns.
 26  
 27      Calculates data drift for all or specified columns using various drift detection methods.
 28      Generates `DriftedColumnsCount` and `ValueDrift` metrics for each column. Requires reference data.
 29  
 30  
 31      See Also:
 32      * [Drift Methods Documentation](https://docs.evidentlyai.com/metrics/customize_data_drift) for available methods.
 33      """
 34  
 35      columns: Optional[List[str]] = None
 36      """Optional list of column names to analyze (None = all columns)."""
 37      embeddings: Optional[List[str]] = None
 38      """Optional list of embedding column names."""
 39      embeddings_drift_method: Optional[Dict[str, DriftMethod]] = None
 40      """Optional dictionary mapping embedding columns to drift methods."""
 41      drift_share: float = 0.5
 42      """Threshold for drift share (0.5 = 50% of columns)."""
 43      method: Optional[PossibleStatTestType] = None
 44      """Optional drift detection method for all columns (auto-selected if None)."""
 45      cat_method: Optional[PossibleStatTestType] = None
 46      """Optional method for categorical columns."""
 47      num_method: Optional[PossibleStatTestType] = None
 48      """Optional method for numerical columns."""
 49      text_method: Optional[PossibleStatTestType] = None
 50      """Optional method for text columns."""
 51      per_column_method: Optional[Dict[str, PossibleStatTestType]] = None
 52      """Optional dictionary mapping column names to methods."""
 53      threshold: Optional[float] = None
 54      """Optional drift threshold for all columns (uses method default if None)."""
 55      cat_threshold: Optional[float] = None
 56      """Optional threshold for categorical columns."""
 57      num_threshold: Optional[float] = None
 58      """Optional threshold for numerical columns."""
 59      text_threshold: Optional[float] = None
 60      """Optional threshold for text columns."""
 61      per_column_threshold: Optional[Dict[str, float]] = None
 62      """Optional dictionary mapping column names to thresholds."""
 63  
 64      def __init__(
 65          self,
 66          columns: Optional[List[str]] = None,
 67          embeddings: Optional[List[str]] = None,
 68          embeddings_drift_method: Optional[Dict[str, DriftMethod]] = None,
 69          drift_share: float = 0.5,
 70          method: Optional[PossibleStatTestType] = None,
 71          cat_method: Optional[PossibleStatTestType] = None,
 72          num_method: Optional[PossibleStatTestType] = None,
 73          text_method: Optional[PossibleStatTestType] = None,
 74          per_column_method: Optional[Dict[str, PossibleStatTestType]] = None,
 75          threshold: Optional[float] = None,
 76          cat_threshold: Optional[float] = None,
 77          num_threshold: Optional[float] = None,
 78          text_threshold: Optional[float] = None,
 79          per_column_threshold: Optional[Dict[str, float]] = None,
 80          include_tests: bool = True,
 81      ):
 82          self.per_column_threshold = per_column_threshold
 83          self.text_threshold = text_threshold
 84          self.num_threshold = num_threshold
 85          self.cat_threshold = cat_threshold
 86          self.threshold = threshold
 87          self.per_column_method = per_column_method
 88          self.text_method = text_method
 89          self.num_method = num_method
 90          self.cat_method = cat_method
 91          self.method = method
 92          self.drift_share = drift_share
 93          self.embeddings_drift_method = embeddings_drift_method
 94          self.embeddings = embeddings
 95          self.columns = columns
 96          super().__init__(include_tests=include_tests)
 97  
 98      def generate_metrics(self, context: Context) -> Sequence[MetricOrContainer]:
 99          types = [ColumnType.Numerical, ColumnType.Categorical, ColumnType.Text]
100          options = DataDriftOptions(
101              drift_share=self.drift_share,
102              all_features_stattest=self.method,
103              cat_features_stattest=self.cat_method,
104              num_features_stattest=self.num_method,
105              text_features_stattest=self.text_method,
106              per_feature_stattest=self.per_column_method,
107              all_features_threshold=self.threshold,
108              cat_features_threshold=self.cat_threshold,
109              num_features_threshold=self.num_threshold,
110              text_features_threshold=self.text_threshold,
111              per_feature_threshold=self.per_column_threshold,
112          )
113          return [
114              DriftedColumnsCount(
115                  columns=self.columns,
116                  drift_share=self.drift_share,
117                  method=self.method,
118                  cat_method=self.cat_method,
119                  num_method=self.num_method,
120                  text_method=self.text_method,
121                  per_column_method=self.per_column_method,
122                  threshold=self.threshold,
123                  cat_threshold=self.cat_threshold,
124                  num_threshold=self.num_threshold,
125                  text_threshold=self.text_threshold,
126                  per_column_threshold=self.per_column_threshold,
127              ),
128          ] + [
129              ValueDrift(
130                  column=column,
131                  method=self._get_drift_stattest(
132                      column,
133                      False,
134                      context.data_definition.get_column_type(column),
135                      options,
136                  ),
137                  threshold=options.get_threshold(column, context.data_definition.get_column_type(column).value),
138              )
139              for column in (self.columns if self.columns is not None else context.data_definition.get_columns(types))
140          ]
141  
142      def render(
143          self,
144          context: "Context",
145          child_widgets: Optional[List[Tuple[Optional[MetricId], List[BaseWidgetInfo]]]] = None,
146      ) -> List[BaseWidgetInfo]:
147          dataset_drift = context.get_legacy_metric(
148              DatasetDriftMetric(
149                  columns=self.columns,
150                  drift_share=self.drift_share,
151                  stattest=self.method,
152                  cat_stattest=self.cat_method,
153                  num_stattest=self.num_method,
154                  text_stattest=self.text_method,
155                  per_column_stattest=self.per_column_method,
156                  stattest_threshold=self.threshold,
157                  cat_stattest_threshold=self.cat_threshold,
158                  num_stattest_threshold=self.num_threshold,
159                  text_stattest_threshold=self.text_threshold,
160                  per_column_stattest_threshold=self.per_column_threshold,
161              ),
162              _default_input_data_generator,
163              None,  # TODO: parametrize task name
164          )[1]
165          table = context.get_legacy_metric(
166              DataDriftTable(
167                  columns=self.columns,
168                  stattest=self.method,
169                  cat_stattest=self.cat_method,
170                  num_stattest=self.num_method,
171                  text_stattest=self.text_method,
172                  per_column_stattest=self.per_column_method,
173                  stattest_threshold=self.threshold,
174                  cat_stattest_threshold=self.cat_threshold,
175                  num_stattest_threshold=self.num_threshold,
176                  text_stattest_threshold=self.text_threshold,
177                  per_column_stattest_threshold=self.per_column_threshold,
178              ),
179              _default_input_data_generator,
180              None,  # TODO: parametrize task name
181          )[1]
182          return dataset_drift + table
183  
184      def _get_drift_stattest(
185          self,
186          column_name: str,
187          is_target: bool,
188          column_type: ColumnType,
189          options: DataDriftOptions,
190      ):
191          stattest = None
192  
193          if is_target and column_type == ColumnType.Numerical:
194              stattest = options.num_target_stattest_func
195  
196          elif is_target and column_type == ColumnType.Categorical:
197              stattest = options.cat_target_stattest_func
198  
199          if not stattest:
200              stattest = options.get_feature_stattest_func(column_name, column_type.value)
201          if stattest:
202              if isinstance(stattest, str):
203                  return stattest
204              if isinstance(stattest, StatTest):
205                  return stattest.name
206              return stattest
207          return None