engine.py
  1  import abc
  2  import dataclasses
  3  import functools
  4  import logging
  5  from typing import TYPE_CHECKING
  6  from typing import Dict
  7  from typing import Generic
  8  from typing import List
  9  from typing import Optional
 10  from typing import Tuple
 11  from typing import Type
 12  from typing import TypeVar
 13  from typing import Union
 14  
 15  from evidently.legacy.base_metric import ErrorResult
 16  from evidently.legacy.base_metric import GenericInputData
 17  from evidently.legacy.base_metric import Metric
 18  from evidently.legacy.base_metric import MetricResult
 19  from evidently.legacy.base_metric import TEngineDataType
 20  from evidently.legacy.calculation_engine.metric_implementation import MetricImplementation
 21  from evidently.legacy.features.generated_features import FeatureResult
 22  from evidently.legacy.features.generated_features import GeneratedFeatures
 23  from evidently.legacy.options.base import Options
 24  from evidently.legacy.pipeline.column_mapping import ColumnMapping
 25  from evidently.legacy.utils.data_preprocessing import DataDefinition
 26  from evidently.pydantic_utils import Fingerprint
 27  
 28  if TYPE_CHECKING:
 29      from evidently.legacy.suite.base_suite import Context
 30  
 31  TMetricImplementation = TypeVar("TMetricImplementation", bound=MetricImplementation)
 32  TInputData = TypeVar("TInputData", bound=GenericInputData)
 33  
 34  
 35  # EngineDatasets = Tuple[Optional[TEngineDataType], Optional[TEngineDataType]]
 36  
 37  
 38  @dataclasses.dataclass
 39  class EngineDatasets(Generic[TEngineDataType]):
 40      current: Optional[TEngineDataType]
 41      reference: Optional[TEngineDataType]
 42  
 43      def __iter__(self):
 44          yield self.current
 45          yield self.reference
 46  
 47  
 48  class Engine(Generic[TMetricImplementation, TInputData, TEngineDataType]):
 49      def __init__(self):
 50          self.metrics = []
 51          self.tests = []
 52  
 53      def set_metrics(self, metrics):
 54          self.metrics = metrics
 55  
 56      def set_tests(self, tests):
 57          self.tests = tests
 58  
 59      def execute_metrics(self, context: "Context", data: GenericInputData):
 60          calculations: Dict[Metric, Union[ErrorResult, MetricResult]] = {}
 61          converted_data = self.convert_input_data(data)
 62  
 63          features_list = self.get_additional_features(converted_data.data_definition)
 64          features = self.calculate_additional_features(converted_data, features_list, context.options)
 65          context.set_features(features)
 66          self.inject_additional_features(converted_data, features)
 67          context.data = converted_data
 68          for metric, calculation in self.get_metric_execution_iterator():
 69              if calculation not in calculations:
 70                  logging.debug(f"Executing {type(calculation)}...")
 71                  try:
 72                      calculations[metric] = calculation.calculate(context, converted_data)
 73                  except BaseException as ex:
 74                      calculations[metric] = ErrorResult(exception=ex)
 75              else:
 76                  logging.debug(f"Using cached result for {type(calculation)}")
 77              context.metric_results[metric] = calculations[metric]
 78  
 79      @abc.abstractmethod
 80      def convert_input_data(self, data: GenericInputData) -> TInputData:
 81          raise NotImplementedError
 82  
 83      @abc.abstractmethod
 84      def get_data_definition(
 85          self,
 86          current_data: TEngineDataType,
 87          reference_data: TEngineDataType,
 88          column_mapping: ColumnMapping,
 89          categorical_features_cardinality: Optional[int] = None,
 90      ):
 91          raise NotImplementedError
 92  
 93      @abc.abstractmethod
 94      def calculate_additional_features(
 95          self, data: TInputData, features: List[GeneratedFeatures], options: Options
 96      ) -> Dict[GeneratedFeatures, FeatureResult[TEngineDataType]]:
 97          raise NotImplementedError
 98  
 99      @abc.abstractmethod
100      def merge_additional_features(
101          self, features: Dict[GeneratedFeatures, FeatureResult[TEngineDataType]]
102      ) -> EngineDatasets[TEngineDataType]:
103          raise NotImplementedError
104  
105      def inject_additional_features(self, data: TInputData, features: Dict[GeneratedFeatures, FeatureResult]):
106          current, reference = self.merge_additional_features(features)
107          data.current_additional_features = current
108          data.reference_additional_features = reference
109  
110      def get_additional_features(self, data_definition: DataDefinition) -> List[GeneratedFeatures]:
111          features: Dict[Fingerprint, GeneratedFeatures] = {}
112          for metric, calculation in self.get_metric_execution_iterator():
113              try:
114                  required_features: List[GeneratedFeatures] = metric.required_features(data_definition)
115              except Exception as e:
116                  logging.error(f"failed to get features for {type(metric)}: {e}", exc_info=e)
117                  continue
118              for feature in required_features:
119                  fp = feature.get_fingerprint()
120                  features[fp] = feature
121          return list(features.values())
122  
123      def get_metric_implementation(self, metric):
124          """
125          Get engine specific metric implementation.
126          """
127          impl = _ImplRegistry.get(type(self), {}).get(type(metric))
128          if impl is None:
129              return None
130          return impl(self, metric)
131  
132      def get_metric_execution_iterator(self) -> List[Tuple[Metric, TMetricImplementation]]:
133          aggregated: Dict[Type[Metric], List[Metric]] = functools.reduce(_aggregate_metrics, self.metrics, {})
134          metric_to_calculations = {}
135          for metric_type, metrics in aggregated.items():
136              metrics_by_parameters: Dict[tuple, List[Metric]] = functools.reduce(_aggregate_by_parameters, metrics, {})
137  
138              for metric in metrics:
139                  parameters = metric.get_parameters()
140                  if parameters is None:
141                      metric_to_calculations[metric] = metric
142                  else:
143                      metric_to_calculations[metric] = metrics_by_parameters[parameters][0]
144  
145          return [(metric, self.get_metric_implementation(metric_to_calculations[metric])) for metric in self.metrics]
146  
147      def form_datasets(
148          self,
149          data: Optional[TInputData],
150          features: List[GeneratedFeatures],
151          data_definition: DataDefinition,
152      ) -> EngineDatasets[TEngineDataType]:
153          raise NotImplementedError
154  
155  
156  def _aggregate_metrics(agg, item):
157      agg[type(item)] = agg.get(type(item), []) + [item]
158      return agg
159  
160  
161  def _aggregate_by_parameters(agg: dict, metric: Metric) -> dict:
162      agg[metric.get_parameters()] = agg.get(metric.get_parameters(), []) + [metric]
163      return agg
164  
165  
166  _ImplRegistry: Dict[Type, Dict[Type, Type]] = dict()
167  
168  
169  def metric_implementation(metric_cls):
170      """
171      Decorate metric implementation class, as a implementation for specific metric.
172      """
173  
174      def wrapper(cls: Type[MetricImplementation]):
175          _add_implementation(metric_cls, cls)
176          return cls
177  
178      return wrapper
179  
180  
181  def _add_implementation(metric_cls, cls):
182      engines = cls.supported_engines()
183      for engine in engines:
184          engine_impls = _ImplRegistry.get(engine, {})
185          if metric_cls in engine_impls:
186              raise ValueError(
187                  f"Multiple impls of metric {metric_cls}: {engine_impls[metric_cls]}"
188                  f" already set, but trying to set {cls}"
189              )
190          engine_impls[metric_cls] = cls
191          _ImplRegistry[engine] = engine_impls
192      return cls