/ src / evidently / legacy / base_metric.py
base_metric.py
  1  import abc
  2  import logging
  3  import warnings
  4  from dataclasses import dataclass
  5  from enum import Enum
  6  from typing import TYPE_CHECKING
  7  from typing import Any
  8  from typing import ClassVar
  9  from typing import Dict
 10  from typing import Generic
 11  from typing import List
 12  from typing import Optional
 13  from typing import Tuple
 14  from typing import Type
 15  from typing import TypeVar
 16  from typing import Union
 17  
 18  import pandas as pd
 19  import typing_inspect
 20  
 21  from evidently._pydantic_compat import Field
 22  from evidently._pydantic_compat import ModelMetaclass
 23  from evidently._pydantic_compat import PrivateAttr
 24  from evidently.legacy.core import BaseResult
 25  from evidently.legacy.core import ColumnType
 26  from evidently.legacy.core import IncludeTags
 27  from evidently.legacy.options.base import AnyOptions
 28  from evidently.legacy.options.base import Options
 29  from evidently.legacy.pipeline.column_mapping import ColumnMapping
 30  from evidently.legacy.utils.data_preprocessing import DataDefinition
 31  from evidently.pydantic_utils import EnumValueMixin
 32  from evidently.pydantic_utils import EvidentlyBaseModel
 33  from evidently.pydantic_utils import FieldPath
 34  from evidently.pydantic_utils import FingerprintPart
 35  from evidently.pydantic_utils import FrozenBaseMeta
 36  from evidently.pydantic_utils import PolymorphicModel
 37  from evidently.pydantic_utils import WithTestAndMetricDependencies
 38  from evidently.pydantic_utils import autoregister
 39  from evidently.pydantic_utils import get_value_fingerprint
 40  
 41  if TYPE_CHECKING:
 42      from evidently.legacy.features.generated_features import GeneratedFeatures
 43      from evidently.legacy.suite.base_suite import Context
 44  
 45  
 46  class WithFieldsPathMetaclass(ModelMetaclass):
 47      @property
 48      def fields(cls) -> FieldPath:
 49          return FieldPath([], cls)
 50  
 51  
 52  class MetricResult(PolymorphicModel, BaseResult, metaclass=WithFieldsPathMetaclass):  # type: ignore[misc] # pydantic Config
 53      class Config:
 54          type_alias = "evidently:metric_result:MetricResult"
 55          field_tags = {"type": {IncludeTags.TypeField}}
 56          is_base_type = True
 57          alias_required = True
 58  
 59  
 60  class ErrorResult(BaseResult):
 61      class Config:
 62          underscore_attrs_are_private = True
 63  
 64      _exception: Optional[BaseException] = None  # todo: fix serialization of exceptions
 65  
 66      def __init__(self, exception: Optional[BaseException]):
 67          super().__init__()
 68          self._exception = exception
 69  
 70      @property
 71      def exception(self):
 72          return self._exception
 73  
 74  
 75  class DatasetType(Enum):
 76      MAIN = "main"
 77      ADDITIONAL = "additional"
 78  
 79  
 80  DisplayName = str
 81  
 82  
 83  @autoregister
 84  class ColumnName(EnumValueMixin, EvidentlyBaseModel):
 85      class Config:
 86          type_alias = "evidently:base:ColumnName"
 87  
 88      name: str
 89      display_name: DisplayName
 90      dataset: DatasetType
 91      _feature_class: Optional["GeneratedFeatures"] = PrivateAttr(None)
 92  
 93      def __init__(
 94          self, name: str, display_name: str, dataset: DatasetType, feature_class: Optional["GeneratedFeatures"] = None
 95      ):
 96          self._feature_class = feature_class
 97          super().__init__(name=name, display_name=display_name, dataset=dataset)
 98  
 99      def is_main_dataset(self):
100          return self.dataset == DatasetType.MAIN
101  
102      @staticmethod
103      def main_dataset(name: str):
104          return ColumnName(name, name, DatasetType.MAIN, None)
105  
106      def __str__(self):
107          return self.display_name
108  
109      @classmethod
110      def from_any(cls, column_name: Union[str, "ColumnName"]):
111          return column_name if not isinstance(column_name, str) else ColumnName.main_dataset(column_name)
112  
113      @property
114      def feature_class(self) -> Optional["GeneratedFeatures"]:
115          return self._feature_class
116  
117      def get_fingerprint_parts(self) -> Tuple[FingerprintPart, ...]:
118          return tuple(
119              (name, self.get_field_fingerprint(name))
120              for name, field in sorted(self.__fields__.items())
121              if field.required or getattr(self, name) != field.get_default() and field.name != "display_name"
122          )
123  
124  
125  class ColumnNotFound(BaseException):
126      def __init__(self, column_name: str):
127          self.column_name = column_name
128  
129  
130  TEngineDataType = TypeVar("TEngineDataType")
131  
132  
133  @dataclass
134  class GenericInputData(Generic[TEngineDataType]):
135      reference_data: Optional[TEngineDataType]
136      current_data: TEngineDataType
137      column_mapping: ColumnMapping
138      data_definition: DataDefinition
139      additional_data: Dict[str, Any]
140      reference_additional_features: Optional[TEngineDataType] = None
141      current_additional_features: Optional[TEngineDataType] = None
142  
143      def get_datasets(self) -> Tuple[Optional[TEngineDataType], TEngineDataType]:
144          raise NotImplementedError
145  
146  
147  class InputData(GenericInputData[pd.DataFrame]):
148      @staticmethod
149      def _get_by_column_name(dataset: pd.DataFrame, additional: Optional[pd.DataFrame], column: ColumnName) -> pd.Series:
150          if column.dataset == DatasetType.MAIN:
151              if column.name not in dataset.columns:
152                  raise ColumnNotFound(column.name)
153              return dataset[column.name]
154          if column.dataset == DatasetType.ADDITIONAL:
155              if additional is None:
156                  raise ValueError("no additional dataset is provided, but field requested")
157              return additional[column.name]
158          raise ValueError("unknown column data")
159  
160      def get_current_column(self, column: Union[str, ColumnName]) -> pd.Series:
161          _column = self._str_to_column_name(column)
162          return self._get_by_column_name(self.current_data, self.current_additional_features, _column)
163  
164      def get_reference_column(self, column: Union[str, ColumnName]) -> Optional[pd.Series]:
165          if self.reference_data is None:
166              return None
167          _column = self._str_to_column_name(column)
168          if self.reference_additional_features is None and _column.dataset == DatasetType.ADDITIONAL:
169              return None
170          return self._get_by_column_name(self.reference_data, self.reference_additional_features, _column)
171  
172      def get_data(self, column: Union[str, ColumnName]) -> Tuple[ColumnType, pd.Series, Optional[pd.Series]]:
173          ref_data = None
174          if self.reference_data is not None:
175              ref_data = self.get_reference_column(column)
176          return self._determine_type(column), self.get_current_column(column), ref_data
177  
178      def _determine_type(self, column: Union[str, ColumnName]) -> ColumnType:
179          if isinstance(column, ColumnName) and column.feature_class is not None:
180              column_type = column.feature_class.get_type(column.name)
181          else:
182              if isinstance(column, ColumnName):
183                  column_name = column.name
184              else:
185                  column_name = column
186              column_type = self.data_definition.get_column(column_name).column_type
187          return column_type
188  
189      def has_column(self, column_name: Union[str, ColumnName]):
190          column = self._str_to_column_name(column_name)
191          if column.dataset == DatasetType.MAIN:
192              return column.name in [definition.column_name for definition in self.data_definition.get_columns()]
193          if self.current_additional_features is not None:
194              return column.name in self.current_additional_features.columns
195          return False
196  
197      def _str_to_column_name(self, column: Union[str, ColumnName]) -> ColumnName:
198          if isinstance(column, str):
199              _column = ColumnName(column, column, DatasetType.MAIN, None)
200          else:
201              _column = column
202          return _column
203  
204      def get_datasets(self) -> Tuple[Optional[pd.DataFrame], pd.DataFrame]:
205          current = self.current_data
206          if self.current_additional_features is not None:
207              current = self.current_data.join(self.current_additional_features)
208          reference = self.reference_data
209          if self.reference_data is not None and self.reference_additional_features is not None:
210              reference = self.reference_data.join(self.reference_additional_features)
211          return reference, current
212  
213  
214  TResult = TypeVar("TResult", bound=MetricResult)
215  
216  
217  class FieldsDescriptor:
218      def __get__(self, instance: Optional["Metric"], type: Type["Metric"]) -> FieldPath:
219          if instance is not None:
220              try:
221                  return FieldPath([], instance.get_result())
222              except ValueError:
223                  warnings.warn("Metric is not calculated yet, using generic fields list")
224          return FieldPath([], type.result_type())
225  
226  
227  class WithResultFieldPathMetaclass(FrozenBaseMeta):
228      def result_type(cls) -> Type[MetricResult]:
229          return typing_inspect.get_args(
230              next(b for b in cls.__orig_bases__ if typing_inspect.is_generic_type(b))  # type: ignore[attr-defined]
231          )[0]
232  
233  
234  class BasePreset(EvidentlyBaseModel):
235      class Config:
236          type_alias = "evidently:base:BasePreset"
237          transitive_aliases = True
238          is_base_type = True
239  
240  
241  class Metric(WithTestAndMetricDependencies, Generic[TResult], metaclass=WithResultFieldPathMetaclass):
242      class Config:
243          is_base_type = True
244  
245      _context: Optional["Context"] = None
246  
247      options: Optional[Options] = Field(default=None)
248  
249      fields: ClassVar[FieldsDescriptor] = FieldsDescriptor()
250      # resulting options will be determined via
251      # options = global_option.override(display_options).override(metric_options)
252  
253      def __init__(self, options: AnyOptions = None, **data):
254          self.options = Options.from_any_options(options)
255          super().__init__(**data)
256  
257      @classmethod
258      def get_id(cls) -> str:
259          return cls.__name__
260  
261      @classmethod
262      def get_group(cls) -> str:
263          if cls.__module__.startswith("evidently.legacy.metrics."):
264              return cls.__module__.split(".")[2]
265          return ""
266  
267      @abc.abstractmethod
268      def calculate(self, data: InputData) -> TResult:
269          raise NotImplementedError()
270  
271      def set_context(self, context):
272          self._context = context
273  
274      def get_result(self) -> TResult:
275          if not hasattr(self, "_context") or self._context is None:
276              raise ValueError("No context is set")
277          result = self._context.metric_results.get(self, None)
278          if isinstance(result, ErrorResult):
279              raise result.exception
280          if result is None:
281              raise ValueError(f"No result found for metric {self} of type {type(self).__name__}")
282          return result  # type: ignore[return-value]
283  
284      def get_parameters(self) -> Optional[tuple]:
285          attributes = []
286          for field, value in sorted(self.__dict__.items(), key=lambda x: x[0]):
287              if field in ["_context"]:
288                  continue
289              if isinstance(value, list):
290                  attributes.append(tuple(value))
291              else:
292                  attributes.append(value)
293          params = tuple(attributes)
294          try:
295              hash(params)
296          except TypeError:
297              logging.warning(f"unhashable params for {type(self)}. Fallback to unique.")
298              return None
299          return params
300  
301      def required_features(self, data_definition: DataDefinition) -> List["GeneratedFeatures"]:
302          required_features = []
303          for field, value in sorted(self.__dict__.items(), key=lambda x: x[0]):
304              if field in ["context"]:
305                  continue
306              if isinstance(value, ColumnName) and value.feature_class is not None:
307                  required_features.append(value.feature_class)
308          return required_features
309  
310      def get_options(self):
311          options = self.options if hasattr(self, "options") else Options()
312          if self._context is not None:
313              options = self._context.options.override(options)
314          return options
315  
316      def get_field_fingerprint(self, field: str) -> FingerprintPart:
317          if field == "options":
318              return self.get_options_fingerprint()
319          return super().get_field_fingerprint(field)
320  
321      def get_options_fingerprint(self) -> FingerprintPart:
322          return None
323  
324  
325  class UsesRawDataMixin:
326      options: Options
327  
328      def get_options_fingerprint(self) -> FingerprintPart:
329          return get_value_fingerprint(self.options.render_options.raw_data)
330  
331  
332  class ColumnMetricResult(MetricResult):
333      class Config:
334          type_alias = "evidently:metric_result:ColumnMetricResult"
335          field_tags = {
336              "column_name": {IncludeTags.Parameter},
337              "column_type": {IncludeTags.Parameter},
338          }
339  
340      column_name: str
341      # todo: use enum
342      column_type: str
343  
344      def get_pandas(self) -> pd.DataFrame:
345          df = pd.DataFrame.from_dict({self.column_name: self.collect_pandas_columns()}, orient="index")
346          df.index.name = "column_name"
347          return df
348  
349  
350  ColumnTResult = TypeVar("ColumnTResult", bound=ColumnMetricResult)
351  
352  
353  class ColumnMetric(Metric[ColumnTResult], Generic[ColumnTResult], abc.ABC):
354      column_name: ColumnName
355  
356      def __init__(self, column_name: Union[ColumnName, str], options: AnyOptions = None):
357          self.column_name = ColumnName.from_any(column_name)
358          super().__init__(options)