/ src / evidently / legacy / utils / data_preprocessing.py
data_preprocessing.py
  1  import dataclasses
  2  import logging
  3  import warnings
  4  from enum import Enum
  5  from typing import Collection
  6  from typing import Dict
  7  from typing import List
  8  from typing import Optional
  9  from typing import Sequence
 10  from typing import Union
 11  
 12  import pandas as pd
 13  import pandas.api.types
 14  
 15  from evidently._pydantic_compat import BaseModel
 16  from evidently.legacy.core import ColumnType
 17  from evidently.legacy.pipeline.column_mapping import ColumnMapping
 18  from evidently.legacy.pipeline.column_mapping import RecomType
 19  from evidently.legacy.pipeline.column_mapping import TargetNames
 20  from evidently.legacy.pipeline.column_mapping import TaskType
 21  from evidently.pydantic_utils import EnumValueMixin
 22  
 23  
 24  @dataclasses.dataclass
 25  class _InputData:
 26      reference: Optional[pd.DataFrame]
 27      current: pd.DataFrame
 28  
 29  
 30  class ColumnDefinition(BaseModel):
 31      column_name: str
 32      column_type: ColumnType
 33  
 34      def __init__(self, column_name: str, column_type: ColumnType):
 35          super().__init__(column_name=column_name, column_type=column_type)
 36  
 37  
 38  class FeatureDefinition(BaseModel):
 39      feature_name: str
 40      display_name: Optional[str]
 41      feature_type: ColumnType
 42      feature_class: str
 43  
 44  
 45  class PredictionColumns(BaseModel):
 46      predicted_values: Optional[ColumnDefinition] = None
 47      prediction_probas: Optional[List[ColumnDefinition]] = None
 48  
 49      def __init__(
 50          self,
 51          predicted_values: Optional[ColumnDefinition] = None,
 52          prediction_probas: Optional[List[ColumnDefinition]] = None,
 53      ):
 54          super().__init__(predicted_values=predicted_values, prediction_probas=prediction_probas)
 55  
 56      def get_columns_list(self) -> List[ColumnDefinition]:
 57          result = [self.predicted_values]
 58          if self.prediction_probas is not None:
 59              result.extend(self.prediction_probas)
 60          return [col for col in result if col is not None]
 61  
 62  
 63  def _check_filter(
 64      column: ColumnDefinition, utility_columns: List[str], filter_def: ColumnType = None, features_only: bool = False
 65  ) -> bool:
 66      if filter_def is None:
 67          return column.column_name not in utility_columns if features_only else True
 68      if not features_only:
 69          return column.column_type == filter_def
 70  
 71      return column.column_type == filter_def and column.column_name not in utility_columns
 72  
 73  
 74  class DataDefinition(EnumValueMixin):
 75      columns: Dict[str, ColumnDefinition]
 76      target: Optional[ColumnDefinition]
 77      prediction_columns: Optional[PredictionColumns]
 78      id_column: Optional[ColumnDefinition]
 79      datetime_column: Optional[ColumnDefinition]
 80      embeddings: Optional[Dict[str, List[str]]]
 81      user_id: Optional[ColumnDefinition]
 82      item_id: Optional[ColumnDefinition]
 83  
 84      task: Optional[str]
 85      classification_labels: Optional[TargetNames]
 86      reference_present: bool
 87      recommendations_type: Optional[RecomType]
 88  
 89      def get_column(self, column_name: str) -> ColumnDefinition:
 90          return self.columns[column_name]
 91  
 92      def get_columns(self, filter_def: ColumnType = None, features_only: bool = False) -> List[ColumnDefinition]:
 93          if self.prediction_columns is not None:
 94              prediction = self.prediction_columns.get_columns_list()
 95          else:
 96              prediction = []
 97          utility_columns = [
 98              col.column_name
 99              for col in [
100                  self.id_column,
101                  self.datetime_column,
102                  self.target,
103                  self.user_id,
104                  self.item_id,
105                  *prediction,
106              ]
107              if col is not None
108          ]
109          return [
110              column
111              for column in self.columns.values()
112              if _check_filter(column, utility_columns, filter_def, features_only)
113          ]
114  
115      def get_column_names(self, filter_def: ColumnType = None, features_only: bool = False) -> List[str]:
116          return [x.column_name for x in self.get_columns(filter_def, features_only)]
117  
118      def get_target_column(self) -> Optional[ColumnDefinition]:
119          return self.target
120  
121      def get_prediction_columns(self) -> Optional[PredictionColumns]:
122          return self.prediction_columns
123  
124      def get_id_column(self) -> Optional[ColumnDefinition]:
125          return self.id_column
126  
127      def get_user_id_column(self) -> Optional[ColumnDefinition]:
128          return self.user_id
129  
130      def get_item_id_column(self) -> Optional[ColumnDefinition]:
131          return self.item_id
132  
133      def get_datetime_column(self) -> Optional[ColumnDefinition]:
134          return self.datetime_column
135  
136  
137  class DataDefinitionError(ValueError):
138      pass
139  
140  
141  def _is_cardinality_exceeded(
142      column_name: Optional[str],
143      data: _InputData,
144      limit: Optional[int],
145  ) -> bool:
146      cardinality = _get_column_cardinality(column_name, data)
147      if limit and cardinality >= limit:
148          return True
149      return False
150  
151  
152  def _process_column(
153      column_name: Optional[str],
154      data: _InputData,
155      if_partially_present: str = "raise",
156      predefined_type: Optional[ColumnType] = None,
157      mapping: Optional[ColumnMapping] = None,
158      cardinality_limit: Optional[int] = None,
159  ) -> Optional[ColumnDefinition]:
160      if column_name is None:
161          return None
162      presence = _get_column_presence(column_name, data)
163      if presence == ColumnPresenceState.Missing:
164          return None
165      if presence == ColumnPresenceState.Partially:
166          if if_partially_present == "raise":
167              raise ValueError(f"Column ({column_name}) is partially present in data")
168          if if_partially_present == "skip":
169              return None
170          if if_partially_present == "keep":
171              pass
172      column_type = (
173          predefined_type
174          if predefined_type is not None
175          else _get_column_type(column_name, data, mapping, cardinality_limit)
176      )
177      return ColumnDefinition(column_name, column_type)
178  
179  
180  def _prediction_column(
181      prediction: Optional[Union[str, int, Sequence[int], Sequence[str]]],
182      target_type: Optional[ColumnType],
183      task: Optional[str],
184      data: _InputData,
185      mapping: Optional[ColumnMapping] = None,
186  ) -> Optional[PredictionColumns]:
187      if prediction is None:
188          return None
189      if isinstance(prediction, str):
190          prediction_present = _get_column_presence(prediction, data)
191          if prediction_present == ColumnPresenceState.Missing:
192              return None
193          if prediction_present == ColumnPresenceState.Partially:
194              raise ValueError(f"Prediction column ({prediction}) is partially present in data")
195          prediction_type = _get_column_type(prediction, data, mapping)
196          if task == TaskType.CLASSIFICATION_TASK:
197              if prediction_type == ColumnType.Categorical:
198                  return PredictionColumns(predicted_values=ColumnDefinition(prediction, prediction_type))
199              if prediction_type == ColumnType.Numerical:
200                  return PredictionColumns(prediction_probas=[ColumnDefinition(prediction, prediction_type)])
201              raise ValueError(f"Unexpected type for prediction column ({prediction}) (it is {prediction_type})")
202          if task == TaskType.REGRESSION_TASK:
203              if prediction_type == ColumnType.Categorical:
204                  raise ValueError("Prediction type is categorical but task is regression")
205              if prediction_type == ColumnType.Numerical:
206                  return PredictionColumns(predicted_values=ColumnDefinition(prediction, prediction_type))
207          if mapping is not None and mapping.recommendations_type == RecomType.RANK:
208              return PredictionColumns(predicted_values=ColumnDefinition(prediction, prediction_type))
209          if (
210              task == TaskType.RECOMMENDER_SYSTEMS
211              and mapping is not None
212              and mapping.recommendations_type == RecomType.SCORE
213          ):
214              return PredictionColumns(prediction_probas=[ColumnDefinition(prediction, prediction_type)])
215          if task is None:
216              if prediction_type == ColumnType.Numerical and target_type == ColumnType.Categorical:
217                  # probably this is binary with single column of probabilities
218                  return PredictionColumns(prediction_probas=[ColumnDefinition(prediction, prediction_type)])
219              return PredictionColumns(predicted_values=ColumnDefinition(prediction, prediction_type))
220      if isinstance(prediction, list):
221          presence = [_get_column_presence(column, data) for column in prediction]
222          if all([item == ColumnPresenceState.Missing for item in presence]):
223              return None
224          if all([item == ColumnPresenceState.Present for item in presence]):
225              prediction_defs = [ColumnDefinition(column, _get_column_type(column, data)) for column in prediction]
226              if any([item.column_type != ColumnType.Numerical for item in prediction_defs]):
227                  raise ValueError(f"Some prediction columns have incorrect types {prediction_defs}")
228              return PredictionColumns(prediction_probas=prediction_defs)
229      raise ValueError("Unexpected type for prediction field in column_mapping")
230  
231  
232  def _filter_by_type(column: Optional[ColumnDefinition], column_type: ColumnType, exclude: List[str]) -> bool:
233      return column is not None and column.column_type == column_type and column.column_name not in exclude
234  
235  
236  def _column_not_present_in_list(
237      column: Optional[str],
238      columns: Collection[str],
239      handle_error: str,
240      message: str,
241  ) -> Optional[str]:
242      if column is None:
243          return None
244      if column not in columns:
245          return column
246      if handle_error == "error":
247          raise ValueError(message.format(column=column))
248      if handle_error == "warning":
249          logging.warning(message.format(column=column))
250          return None
251      if handle_error == "skip":
252          return None
253      raise ValueError(f"Unknown handle error type {handle_error}")
254  
255  
256  def create_data_definition(
257      reference_data: Optional[pd.DataFrame],
258      current_data: pd.DataFrame,
259      mapping: ColumnMapping,
260      categorical_features_cardinality_limit: Optional[int] = None,
261  ) -> DataDefinition:
262      data = _InputData(reference_data, current_data)
263      embedding_columns = set()
264      embeddings: Optional[Dict[str, List[str]]] = None
265      if mapping.embeddings is not None:
266          embeddings = dict()
267          for embedding_name, columns in mapping.embeddings.items():
268              embeddings[embedding_name] = []
269              for column in columns:
270                  presence = _get_column_presence(column, data)
271                  if presence != ColumnPresenceState.Present:
272                      logging.warning(f"Column {column} isn't present in data. Skipping it.")
273                  else:
274                      embeddings[embedding_name].append(column)
275                      embedding_columns.add(column)
276  
277      id_column = _process_column(
278          _column_not_present_in_list(
279              mapping.id,
280              embedding_columns,
281              "warning",
282              "Column {column} is in embeddings list and as an ID field. Ignoring ID field.",
283          ),
284          data,
285      )
286      user_id = _process_column(
287          _column_not_present_in_list(
288              mapping.user_id,
289              embedding_columns,
290              "warning",
291              "Column {column} is in embeddings list and as an user_id field. Ignoring user_id field.",
292          ),
293          data,
294      )
295      item_id = _process_column(
296          _column_not_present_in_list(
297              mapping.item_id,
298              embedding_columns,
299              "warning",
300              "Column {column} is in embeddings list and as an item_id field. Ignoring item_id field.",
301          ),
302          data,
303      )
304      target_column = _process_column(
305          _column_not_present_in_list(
306              mapping.target,
307              embedding_columns,
308              "warning",
309              "Column {column} is in embeddings list and as a target field. Ignoring target field.",
310          ),
311          data,
312          mapping=mapping,
313      )
314      datetime_column = _process_column(
315          _column_not_present_in_list(
316              mapping.datetime,
317              embedding_columns,
318              "warning",
319              "Column {column} is in embeddings list and as a datetime field. Ignoring datetime field.",
320          ),
321          data,
322      )
323  
324      prediction_columns = _prediction_column(
325          mapping.prediction,
326          target_column.column_type if target_column is not None else None,
327          mapping.task,
328          data,
329          mapping,
330      )
331  
332      prediction_cols = prediction_columns.get_columns_list() if prediction_columns is not None else []
333  
334      all_columns = [
335          id_column,
336          user_id,
337          item_id,
338          datetime_column,
339          target_column,
340          *prediction_cols,
341      ]
342      utility_column_names = [column.column_name for column in all_columns if column is not None]
343      data_columns = set(data.current.columns) | (set(data.reference.columns) if data.reference is not None else set())
344      col_defs = [
345          _process_column(
346              column_name,
347              data,
348              if_partially_present="skip",
349              mapping=mapping,
350              cardinality_limit=categorical_features_cardinality_limit,
351          )
352          for column_name in data_columns
353      ]
354  
355      if mapping.numerical_features is None:
356          num = [
357              column
358              for column in col_defs
359              if column is not None
360              and _filter_by_type(column, ColumnType.Numerical, utility_column_names)
361              and _column_not_present_in_list(column.column_name, embedding_columns, "skip", "")
362          ]
363          all_columns.extend(num)
364      else:
365          all_columns.extend(
366              [
367                  _process_column(
368                      column_name,
369                      data,
370                      predefined_type=ColumnType.Numerical,
371                      mapping=mapping,
372                  )
373                  for column_name in mapping.numerical_features
374                  if column_name not in utility_column_names
375                  and _column_not_present_in_list(
376                      column_name,
377                      embedding_columns,
378                      "warning",
379                      "Column {column} is in embedding list and in numerical features list."
380                      " Ignoring it in a features list.",
381                  )
382              ]
383          )
384  
385      if mapping.categorical_features is None:
386          cat = [
387              column
388              for column in col_defs
389              if column is not None
390              and _filter_by_type(column, ColumnType.Categorical, utility_column_names)
391              and _column_not_present_in_list(column.column_name, embedding_columns, "skip", "")
392          ]
393          all_columns.extend(cat)
394      else:
395          categorical_features = [
396              _process_column(
397                  column_name,
398                  data,
399                  predefined_type=ColumnType.Categorical,
400                  mapping=mapping,
401                  cardinality_limit=categorical_features_cardinality_limit,
402              )
403              for column_name in mapping.categorical_features
404              if column_name not in utility_column_names
405              and _column_not_present_in_list(
406                  column_name,
407                  embedding_columns,
408                  "warning",
409                  "Column {column} is in embedding list and in categorical features list."
410                  " Ignoring it in a features list.",
411              )
412          ]
413          all_columns.extend(categorical_features)
414  
415      if mapping.datetime_features is None:
416          dt = [
417              column
418              for column in col_defs
419              if column is not None
420              and _filter_by_type(column, ColumnType.Datetime, utility_column_names)
421              and _column_not_present_in_list(column.column_name, embedding_columns, "skip", "")
422          ]
423          all_columns.extend(dt)
424      else:
425          all_columns.extend(
426              [
427                  _process_column(
428                      column_name,
429                      data,
430                      predefined_type=ColumnType.Datetime,
431                      mapping=mapping,
432                  )
433                  for column_name in mapping.datetime_features
434                  if column_name not in utility_column_names
435                  and _column_not_present_in_list(
436                      column_name,
437                      embedding_columns,
438                      "warning",
439                      "Column {column} is in embedding list and in datetime features list."
440                      " Ignoring it in a features list.",
441                  )
442              ]
443          )
444  
445      if mapping.text_features is not None:
446          all_columns.extend(
447              [
448                  _process_column(column_name, data, predefined_type=ColumnType.Text, mapping=mapping)
449                  for column_name in mapping.text_features
450                  if column_name not in utility_column_names
451                  and _column_not_present_in_list(
452                      column_name,
453                      embedding_columns,
454                      "warning",
455                      "Column {column} is in embedding list and in text features list."
456                      " Ignoring it in a features list.",
457                  )
458              ]
459          )
460      task = mapping.task
461      if task is None:
462          if target_column is None:
463              task = None
464          elif target_column.column_type == ColumnType.Categorical:
465              task = TaskType.CLASSIFICATION_TASK
466          elif target_column.column_type == ColumnType.Numerical:
467              task = TaskType.REGRESSION_TASK
468          else:
469              task = None
470  
471      labels = None
472      if target_column is not None:
473          labels = list(data.current[target_column.column_name].unique())
474          if data.reference is not None:
475              labels = list(set(labels) | set(data.reference[target_column.column_name].unique()))
476          if None in labels:
477              warnings.warn(
478                  f"Target column '{target_column.column_name}' contains 'None' values, which is not supported as label value"
479              )
480              labels = [v for v in labels if v is not None]
481      recommendations_type = mapping.recommendations_type or RecomType.SCORE
482  
483      classification_labels = mapping.target_names or labels
484      return DataDefinition(
485          columns={col.column_name: col for col in all_columns if col is not None},
486          id_column=id_column,
487          user_id=user_id,
488          item_id=item_id,
489          datetime_column=datetime_column,
490          target=target_column,
491          prediction_columns=prediction_columns,
492          task=task,
493          classification_labels=classification_labels,
494          embeddings=embeddings,
495          reference_present=reference_data is not None,
496          recommendations_type=recommendations_type,
497      )
498  
499  
500  def get_column_name_or_none(column: Optional[ColumnDefinition]) -> Optional[str]:
501      if column is None:
502          return None
503      return column.column_name
504  
505  
506  def create_column_mapping(data_definition: DataDefinition) -> ColumnMapping:
507      prediction = None
508      prediction_columns = data_definition.get_prediction_columns()
509      if prediction_columns and prediction_columns.predicted_values:
510          prediction = prediction_columns.predicted_values.column_name
511  
512      column_mapping = ColumnMapping(
513          target=get_column_name_or_none(data_definition.get_target_column()),
514          prediction=prediction,
515          datetime=get_column_name_or_none(data_definition.get_datetime_column()),
516          id=get_column_name_or_none(data_definition.get_id_column()),
517          numerical_features=data_definition.get_column_names(ColumnType.Numerical, features_only=True),
518          categorical_features=data_definition.get_column_names(ColumnType.Categorical, features_only=True),
519          datetime_features=data_definition.get_column_names(ColumnType.Datetime, features_only=True),
520          text_features=data_definition.get_column_names(ColumnType.Text, features_only=True),
521          target_names=data_definition.classification_labels,
522          task=data_definition.task,
523          embeddings=data_definition.embeddings,
524          user_id=get_column_name_or_none(data_definition.get_user_id_column()),
525          item_id=get_column_name_or_none(data_definition.get_item_id_column()),
526          recommendations_type=RecomType(data_definition.recommendations_type),
527      )
528      return column_mapping
529  
530  
531  class ColumnPresenceState(Enum):
532      Present = 0
533      Partially = 1
534      Missing = 2
535  
536  
537  def _get_column_presence(column_name: str, data: _InputData) -> ColumnPresenceState:
538      if column_name in data.current.columns:
539          if data.reference is None or column_name in data.reference.columns:
540              return ColumnPresenceState.Present
541          return ColumnPresenceState.Partially
542      if data.reference is None or column_name not in data.reference.columns:
543          return ColumnPresenceState.Missing
544      return ColumnPresenceState.Partially
545  
546  
547  def _get_column_cardinality(column_name: Optional[str], data: _InputData) -> float:
548      if column_name in data.current.columns:
549          try:
550              return data.current[column_name].nunique()
551          except TypeError:
552              return data.current[column_name].count()
553      return 0
554  
555  
556  NUMBER_UNIQUE_AS_CATEGORICAL = 5
557  
558  
559  def _get_column_type(
560      column_name: str, data: _InputData, mapping: Optional[ColumnMapping] = None, cardinality_limit: Optional[int] = None
561  ) -> ColumnType:
562      if mapping is not None:
563          if mapping.categorical_features and column_name in mapping.categorical_features:
564              if cardinality_limit and _is_cardinality_exceeded(column_name, data, cardinality_limit):
565                  raise DataDefinitionError(f"The cardinality of column ({column_name}) has been exceeded")
566              return ColumnType.Categorical
567          if mapping.numerical_features and column_name in mapping.numerical_features:
568              return ColumnType.Numerical
569          if mapping.datetime_features and column_name in mapping.datetime_features:
570              return ColumnType.Datetime
571          if mapping.text_features and column_name in mapping.text_features:
572              return ColumnType.Text
573      ref_type = None
574      ref_unique = None
575      if data.reference is not None and column_name in data.reference.columns:
576          ref_type = data.reference[column_name].dtype
577          try:
578              ref_unique = data.reference[column_name].nunique()
579          except TypeError:
580              ref_unique = None
581      cur_type = None
582      cur_unique = None
583      if column_name in data.current.columns:
584          cur_type = data.current[column_name].dtype
585          try:
586              cur_unique = data.current[column_name].nunique()
587          except TypeError:
588              cur_unique = None
589      if ref_type is not None and cur_type is not None:
590          if ref_type != cur_type:
591              available_set = ["i", "u", "f", "c", "m", "M"]
592              if ref_type.kind not in available_set or cur_type.kind not in available_set:
593                  logging.warning(
594                      f"Column {column_name} have different types in reference {ref_type} and current {cur_type}."
595                      f" Returning type from reference"
596                  )
597                  cur_type = ref_type
598              # TODO: add proper type check
599              if pandas.api.types.is_dtype_equal(cur_type, ref_type):
600                  logging.warning(
601                      f"Column {column_name} have different types in reference {ref_type} and current {cur_type}."
602                      f" Returning type from reference"
603                  )
604                  cur_type = ref_type
605      nunique = ref_unique or cur_unique
606      # special case: target
607      column_dtype = cur_type if cur_type is not None else ref_type
608      if mapping is not None and (column_name == mapping.target or (mapping.target is None and column_name == "target")):
609          reg_condition = mapping.task == "regression" or (
610              pd.api.types.is_numeric_dtype(column_dtype)
611              and mapping.task != "classification"
612              and (nunique is not None and nunique > NUMBER_UNIQUE_AS_CATEGORICAL)
613          )
614          if reg_condition:
615              return ColumnType.Numerical
616          else:
617              return ColumnType.Categorical
618  
619      if mapping is not None and (
620          (isinstance(mapping.prediction, str) and column_name == mapping.prediction)
621          or (mapping.prediction is None and column_name == "prediction")
622      ):
623          if (
624              pd.api.types.is_string_dtype(column_dtype)
625              or (
626                  pd.api.types.is_integer_dtype(column_dtype)
627                  and mapping.task != "regression"
628                  and (nunique is not None and nunique <= NUMBER_UNIQUE_AS_CATEGORICAL)
629              )
630              or (
631                  pd.api.types.is_numeric_dtype(column_dtype)
632                  and mapping.task != "regression"
633                  and (nunique is not None and nunique <= NUMBER_UNIQUE_AS_CATEGORICAL)
634                  and (data.current[column_name].max() > 1 or data.current[column_name].min() < 0)
635              )
636              or (
637                  pd.api.types.is_numeric_dtype(column_dtype)
638                  and mapping.task == "classification"
639                  and (data.current[column_name].max() > 1 or data.current[column_name].min() < 0)
640              )
641          ):
642              return ColumnType.Categorical
643          else:
644              return ColumnType.Numerical
645  
646      # all other features
647      if pd.api.types.is_integer_dtype(column_dtype):
648          nunique = ref_unique or cur_unique
649          if nunique is not None and nunique <= NUMBER_UNIQUE_AS_CATEGORICAL:
650              return ColumnType.Categorical
651          return ColumnType.Numerical
652      if pd.api.types.is_numeric_dtype(column_dtype):
653          if column_dtype == bool:
654              return ColumnType.Categorical
655          return ColumnType.Numerical
656      if pd.api.types.is_datetime64_dtype(column_dtype):
657          return ColumnType.Datetime
658      if _is_cardinality_exceeded(column_name, data, cardinality_limit):
659          return ColumnType.Unknown
660      return ColumnType.Categorical