engine.py
1 import abc 2 import dataclasses 3 import functools 4 import logging 5 from typing import TYPE_CHECKING 6 from typing import Dict 7 from typing import Generic 8 from typing import List 9 from typing import Optional 10 from typing import Tuple 11 from typing import Type 12 from typing import TypeVar 13 from typing import Union 14 15 from evidently.legacy.base_metric import ErrorResult 16 from evidently.legacy.base_metric import GenericInputData 17 from evidently.legacy.base_metric import Metric 18 from evidently.legacy.base_metric import MetricResult 19 from evidently.legacy.base_metric import TEngineDataType 20 from evidently.legacy.calculation_engine.metric_implementation import MetricImplementation 21 from evidently.legacy.features.generated_features import FeatureResult 22 from evidently.legacy.features.generated_features import GeneratedFeatures 23 from evidently.legacy.options.base import Options 24 from evidently.legacy.pipeline.column_mapping import ColumnMapping 25 from evidently.legacy.utils.data_preprocessing import DataDefinition 26 from evidently.pydantic_utils import Fingerprint 27 28 if TYPE_CHECKING: 29 from evidently.legacy.suite.base_suite import Context 30 31 TMetricImplementation = TypeVar("TMetricImplementation", bound=MetricImplementation) 32 TInputData = TypeVar("TInputData", bound=GenericInputData) 33 34 35 # EngineDatasets = Tuple[Optional[TEngineDataType], Optional[TEngineDataType]] 36 37 38 @dataclasses.dataclass 39 class EngineDatasets(Generic[TEngineDataType]): 40 current: Optional[TEngineDataType] 41 reference: Optional[TEngineDataType] 42 43 def __iter__(self): 44 yield self.current 45 yield self.reference 46 47 48 class Engine(Generic[TMetricImplementation, TInputData, TEngineDataType]): 49 def __init__(self): 50 self.metrics = [] 51 self.tests = [] 52 53 def set_metrics(self, metrics): 54 self.metrics = metrics 55 56 def set_tests(self, tests): 57 self.tests = tests 58 59 def execute_metrics(self, context: "Context", data: GenericInputData): 60 calculations: Dict[Metric, Union[ErrorResult, MetricResult]] = {} 61 converted_data = self.convert_input_data(data) 62 63 features_list = self.get_additional_features(converted_data.data_definition) 64 features = self.calculate_additional_features(converted_data, features_list, context.options) 65 context.set_features(features) 66 self.inject_additional_features(converted_data, features) 67 context.data = converted_data 68 for metric, calculation in self.get_metric_execution_iterator(): 69 if calculation not in calculations: 70 logging.debug(f"Executing {type(calculation)}...") 71 try: 72 calculations[metric] = calculation.calculate(context, converted_data) 73 except BaseException as ex: 74 calculations[metric] = ErrorResult(exception=ex) 75 else: 76 logging.debug(f"Using cached result for {type(calculation)}") 77 context.metric_results[metric] = calculations[metric] 78 79 @abc.abstractmethod 80 def convert_input_data(self, data: GenericInputData) -> TInputData: 81 raise NotImplementedError 82 83 @abc.abstractmethod 84 def get_data_definition( 85 self, 86 current_data: TEngineDataType, 87 reference_data: TEngineDataType, 88 column_mapping: ColumnMapping, 89 categorical_features_cardinality: Optional[int] = None, 90 ): 91 raise NotImplementedError 92 93 @abc.abstractmethod 94 def calculate_additional_features( 95 self, data: TInputData, features: List[GeneratedFeatures], options: Options 96 ) -> Dict[GeneratedFeatures, FeatureResult[TEngineDataType]]: 97 raise NotImplementedError 98 99 @abc.abstractmethod 100 def merge_additional_features( 101 self, features: Dict[GeneratedFeatures, FeatureResult[TEngineDataType]] 102 ) -> EngineDatasets[TEngineDataType]: 103 raise NotImplementedError 104 105 def inject_additional_features(self, data: TInputData, features: Dict[GeneratedFeatures, FeatureResult]): 106 current, reference = self.merge_additional_features(features) 107 data.current_additional_features = current 108 data.reference_additional_features = reference 109 110 def get_additional_features(self, data_definition: DataDefinition) -> List[GeneratedFeatures]: 111 features: Dict[Fingerprint, GeneratedFeatures] = {} 112 for metric, calculation in self.get_metric_execution_iterator(): 113 try: 114 required_features: List[GeneratedFeatures] = metric.required_features(data_definition) 115 except Exception as e: 116 logging.error(f"failed to get features for {type(metric)}: {e}", exc_info=e) 117 continue 118 for feature in required_features: 119 fp = feature.get_fingerprint() 120 features[fp] = feature 121 return list(features.values()) 122 123 def get_metric_implementation(self, metric): 124 """ 125 Get engine specific metric implementation. 126 """ 127 impl = _ImplRegistry.get(type(self), {}).get(type(metric)) 128 if impl is None: 129 return None 130 return impl(self, metric) 131 132 def get_metric_execution_iterator(self) -> List[Tuple[Metric, TMetricImplementation]]: 133 aggregated: Dict[Type[Metric], List[Metric]] = functools.reduce(_aggregate_metrics, self.metrics, {}) 134 metric_to_calculations = {} 135 for metric_type, metrics in aggregated.items(): 136 metrics_by_parameters: Dict[tuple, List[Metric]] = functools.reduce(_aggregate_by_parameters, metrics, {}) 137 138 for metric in metrics: 139 parameters = metric.get_parameters() 140 if parameters is None: 141 metric_to_calculations[metric] = metric 142 else: 143 metric_to_calculations[metric] = metrics_by_parameters[parameters][0] 144 145 return [(metric, self.get_metric_implementation(metric_to_calculations[metric])) for metric in self.metrics] 146 147 def form_datasets( 148 self, 149 data: Optional[TInputData], 150 features: List[GeneratedFeatures], 151 data_definition: DataDefinition, 152 ) -> EngineDatasets[TEngineDataType]: 153 raise NotImplementedError 154 155 156 def _aggregate_metrics(agg, item): 157 agg[type(item)] = agg.get(type(item), []) + [item] 158 return agg 159 160 161 def _aggregate_by_parameters(agg: dict, metric: Metric) -> dict: 162 agg[metric.get_parameters()] = agg.get(metric.get_parameters(), []) + [metric] 163 return agg 164 165 166 _ImplRegistry: Dict[Type, Dict[Type, Type]] = dict() 167 168 169 def metric_implementation(metric_cls): 170 """ 171 Decorate metric implementation class, as a implementation for specific metric. 172 """ 173 174 def wrapper(cls: Type[MetricImplementation]): 175 _add_implementation(metric_cls, cls) 176 return cls 177 178 return wrapper 179 180 181 def _add_implementation(metric_cls, cls): 182 engines = cls.supported_engines() 183 for engine in engines: 184 engine_impls = _ImplRegistry.get(engine, {}) 185 if metric_cls in engine_impls: 186 raise ValueError( 187 f"Multiple impls of metric {metric_cls}: {engine_impls[metric_cls]}" 188 f" already set, but trying to set {cls}" 189 ) 190 engine_impls[metric_cls] = cls 191 _ImplRegistry[engine] = engine_impls 192 return cls