base_metric.py
1 import abc 2 import logging 3 import warnings 4 from dataclasses import dataclass 5 from enum import Enum 6 from typing import TYPE_CHECKING 7 from typing import Any 8 from typing import ClassVar 9 from typing import Dict 10 from typing import Generic 11 from typing import List 12 from typing import Optional 13 from typing import Tuple 14 from typing import Type 15 from typing import TypeVar 16 from typing import Union 17 18 import pandas as pd 19 import typing_inspect 20 21 from evidently._pydantic_compat import Field 22 from evidently._pydantic_compat import ModelMetaclass 23 from evidently._pydantic_compat import PrivateAttr 24 from evidently.legacy.core import BaseResult 25 from evidently.legacy.core import ColumnType 26 from evidently.legacy.core import IncludeTags 27 from evidently.legacy.options.base import AnyOptions 28 from evidently.legacy.options.base import Options 29 from evidently.legacy.pipeline.column_mapping import ColumnMapping 30 from evidently.legacy.utils.data_preprocessing import DataDefinition 31 from evidently.pydantic_utils import EnumValueMixin 32 from evidently.pydantic_utils import EvidentlyBaseModel 33 from evidently.pydantic_utils import FieldPath 34 from evidently.pydantic_utils import FingerprintPart 35 from evidently.pydantic_utils import FrozenBaseMeta 36 from evidently.pydantic_utils import PolymorphicModel 37 from evidently.pydantic_utils import WithTestAndMetricDependencies 38 from evidently.pydantic_utils import autoregister 39 from evidently.pydantic_utils import get_value_fingerprint 40 41 if TYPE_CHECKING: 42 from evidently.legacy.features.generated_features import GeneratedFeatures 43 from evidently.legacy.suite.base_suite import Context 44 45 46 class WithFieldsPathMetaclass(ModelMetaclass): 47 @property 48 def fields(cls) -> FieldPath: 49 return FieldPath([], cls) 50 51 52 class MetricResult(PolymorphicModel, BaseResult, metaclass=WithFieldsPathMetaclass): # type: ignore[misc] # pydantic Config 53 class Config: 54 type_alias = "evidently:metric_result:MetricResult" 55 field_tags = {"type": {IncludeTags.TypeField}} 56 is_base_type = True 57 alias_required = True 58 59 60 class ErrorResult(BaseResult): 61 class Config: 62 underscore_attrs_are_private = True 63 64 _exception: Optional[BaseException] = None # todo: fix serialization of exceptions 65 66 def __init__(self, exception: Optional[BaseException]): 67 super().__init__() 68 self._exception = exception 69 70 @property 71 def exception(self): 72 return self._exception 73 74 75 class DatasetType(Enum): 76 MAIN = "main" 77 ADDITIONAL = "additional" 78 79 80 DisplayName = str 81 82 83 @autoregister 84 class ColumnName(EnumValueMixin, EvidentlyBaseModel): 85 class Config: 86 type_alias = "evidently:base:ColumnName" 87 88 name: str 89 display_name: DisplayName 90 dataset: DatasetType 91 _feature_class: Optional["GeneratedFeatures"] = PrivateAttr(None) 92 93 def __init__( 94 self, name: str, display_name: str, dataset: DatasetType, feature_class: Optional["GeneratedFeatures"] = None 95 ): 96 self._feature_class = feature_class 97 super().__init__(name=name, display_name=display_name, dataset=dataset) 98 99 def is_main_dataset(self): 100 return self.dataset == DatasetType.MAIN 101 102 @staticmethod 103 def main_dataset(name: str): 104 return ColumnName(name, name, DatasetType.MAIN, None) 105 106 def __str__(self): 107 return self.display_name 108 109 @classmethod 110 def from_any(cls, column_name: Union[str, "ColumnName"]): 111 return column_name if not isinstance(column_name, str) else ColumnName.main_dataset(column_name) 112 113 @property 114 def feature_class(self) -> Optional["GeneratedFeatures"]: 115 return self._feature_class 116 117 def get_fingerprint_parts(self) -> Tuple[FingerprintPart, ...]: 118 return tuple( 119 (name, self.get_field_fingerprint(name)) 120 for name, field in sorted(self.__fields__.items()) 121 if field.required or getattr(self, name) != field.get_default() and field.name != "display_name" 122 ) 123 124 125 class ColumnNotFound(BaseException): 126 def __init__(self, column_name: str): 127 self.column_name = column_name 128 129 130 TEngineDataType = TypeVar("TEngineDataType") 131 132 133 @dataclass 134 class GenericInputData(Generic[TEngineDataType]): 135 reference_data: Optional[TEngineDataType] 136 current_data: TEngineDataType 137 column_mapping: ColumnMapping 138 data_definition: DataDefinition 139 additional_data: Dict[str, Any] 140 reference_additional_features: Optional[TEngineDataType] = None 141 current_additional_features: Optional[TEngineDataType] = None 142 143 def get_datasets(self) -> Tuple[Optional[TEngineDataType], TEngineDataType]: 144 raise NotImplementedError 145 146 147 class InputData(GenericInputData[pd.DataFrame]): 148 @staticmethod 149 def _get_by_column_name(dataset: pd.DataFrame, additional: Optional[pd.DataFrame], column: ColumnName) -> pd.Series: 150 if column.dataset == DatasetType.MAIN: 151 if column.name not in dataset.columns: 152 raise ColumnNotFound(column.name) 153 return dataset[column.name] 154 if column.dataset == DatasetType.ADDITIONAL: 155 if additional is None: 156 raise ValueError("no additional dataset is provided, but field requested") 157 return additional[column.name] 158 raise ValueError("unknown column data") 159 160 def get_current_column(self, column: Union[str, ColumnName]) -> pd.Series: 161 _column = self._str_to_column_name(column) 162 return self._get_by_column_name(self.current_data, self.current_additional_features, _column) 163 164 def get_reference_column(self, column: Union[str, ColumnName]) -> Optional[pd.Series]: 165 if self.reference_data is None: 166 return None 167 _column = self._str_to_column_name(column) 168 if self.reference_additional_features is None and _column.dataset == DatasetType.ADDITIONAL: 169 return None 170 return self._get_by_column_name(self.reference_data, self.reference_additional_features, _column) 171 172 def get_data(self, column: Union[str, ColumnName]) -> Tuple[ColumnType, pd.Series, Optional[pd.Series]]: 173 ref_data = None 174 if self.reference_data is not None: 175 ref_data = self.get_reference_column(column) 176 return self._determine_type(column), self.get_current_column(column), ref_data 177 178 def _determine_type(self, column: Union[str, ColumnName]) -> ColumnType: 179 if isinstance(column, ColumnName) and column.feature_class is not None: 180 column_type = column.feature_class.get_type(column.name) 181 else: 182 if isinstance(column, ColumnName): 183 column_name = column.name 184 else: 185 column_name = column 186 column_type = self.data_definition.get_column(column_name).column_type 187 return column_type 188 189 def has_column(self, column_name: Union[str, ColumnName]): 190 column = self._str_to_column_name(column_name) 191 if column.dataset == DatasetType.MAIN: 192 return column.name in [definition.column_name for definition in self.data_definition.get_columns()] 193 if self.current_additional_features is not None: 194 return column.name in self.current_additional_features.columns 195 return False 196 197 def _str_to_column_name(self, column: Union[str, ColumnName]) -> ColumnName: 198 if isinstance(column, str): 199 _column = ColumnName(column, column, DatasetType.MAIN, None) 200 else: 201 _column = column 202 return _column 203 204 def get_datasets(self) -> Tuple[Optional[pd.DataFrame], pd.DataFrame]: 205 current = self.current_data 206 if self.current_additional_features is not None: 207 current = self.current_data.join(self.current_additional_features) 208 reference = self.reference_data 209 if self.reference_data is not None and self.reference_additional_features is not None: 210 reference = self.reference_data.join(self.reference_additional_features) 211 return reference, current 212 213 214 TResult = TypeVar("TResult", bound=MetricResult) 215 216 217 class FieldsDescriptor: 218 def __get__(self, instance: Optional["Metric"], type: Type["Metric"]) -> FieldPath: 219 if instance is not None: 220 try: 221 return FieldPath([], instance.get_result()) 222 except ValueError: 223 warnings.warn("Metric is not calculated yet, using generic fields list") 224 return FieldPath([], type.result_type()) 225 226 227 class WithResultFieldPathMetaclass(FrozenBaseMeta): 228 def result_type(cls) -> Type[MetricResult]: 229 return typing_inspect.get_args( 230 next(b for b in cls.__orig_bases__ if typing_inspect.is_generic_type(b)) # type: ignore[attr-defined] 231 )[0] 232 233 234 class BasePreset(EvidentlyBaseModel): 235 class Config: 236 type_alias = "evidently:base:BasePreset" 237 transitive_aliases = True 238 is_base_type = True 239 240 241 class Metric(WithTestAndMetricDependencies, Generic[TResult], metaclass=WithResultFieldPathMetaclass): 242 class Config: 243 is_base_type = True 244 245 _context: Optional["Context"] = None 246 247 options: Optional[Options] = Field(default=None) 248 249 fields: ClassVar[FieldsDescriptor] = FieldsDescriptor() 250 # resulting options will be determined via 251 # options = global_option.override(display_options).override(metric_options) 252 253 def __init__(self, options: AnyOptions = None, **data): 254 self.options = Options.from_any_options(options) 255 super().__init__(**data) 256 257 @classmethod 258 def get_id(cls) -> str: 259 return cls.__name__ 260 261 @classmethod 262 def get_group(cls) -> str: 263 if cls.__module__.startswith("evidently.legacy.metrics."): 264 return cls.__module__.split(".")[2] 265 return "" 266 267 @abc.abstractmethod 268 def calculate(self, data: InputData) -> TResult: 269 raise NotImplementedError() 270 271 def set_context(self, context): 272 self._context = context 273 274 def get_result(self) -> TResult: 275 if not hasattr(self, "_context") or self._context is None: 276 raise ValueError("No context is set") 277 result = self._context.metric_results.get(self, None) 278 if isinstance(result, ErrorResult): 279 raise result.exception 280 if result is None: 281 raise ValueError(f"No result found for metric {self} of type {type(self).__name__}") 282 return result # type: ignore[return-value] 283 284 def get_parameters(self) -> Optional[tuple]: 285 attributes = [] 286 for field, value in sorted(self.__dict__.items(), key=lambda x: x[0]): 287 if field in ["_context"]: 288 continue 289 if isinstance(value, list): 290 attributes.append(tuple(value)) 291 else: 292 attributes.append(value) 293 params = tuple(attributes) 294 try: 295 hash(params) 296 except TypeError: 297 logging.warning(f"unhashable params for {type(self)}. Fallback to unique.") 298 return None 299 return params 300 301 def required_features(self, data_definition: DataDefinition) -> List["GeneratedFeatures"]: 302 required_features = [] 303 for field, value in sorted(self.__dict__.items(), key=lambda x: x[0]): 304 if field in ["context"]: 305 continue 306 if isinstance(value, ColumnName) and value.feature_class is not None: 307 required_features.append(value.feature_class) 308 return required_features 309 310 def get_options(self): 311 options = self.options if hasattr(self, "options") else Options() 312 if self._context is not None: 313 options = self._context.options.override(options) 314 return options 315 316 def get_field_fingerprint(self, field: str) -> FingerprintPart: 317 if field == "options": 318 return self.get_options_fingerprint() 319 return super().get_field_fingerprint(field) 320 321 def get_options_fingerprint(self) -> FingerprintPart: 322 return None 323 324 325 class UsesRawDataMixin: 326 options: Options 327 328 def get_options_fingerprint(self) -> FingerprintPart: 329 return get_value_fingerprint(self.options.render_options.raw_data) 330 331 332 class ColumnMetricResult(MetricResult): 333 class Config: 334 type_alias = "evidently:metric_result:ColumnMetricResult" 335 field_tags = { 336 "column_name": {IncludeTags.Parameter}, 337 "column_type": {IncludeTags.Parameter}, 338 } 339 340 column_name: str 341 # todo: use enum 342 column_type: str 343 344 def get_pandas(self) -> pd.DataFrame: 345 df = pd.DataFrame.from_dict({self.column_name: self.collect_pandas_columns()}, orient="index") 346 df.index.name = "column_name" 347 return df 348 349 350 ColumnTResult = TypeVar("ColumnTResult", bound=ColumnMetricResult) 351 352 353 class ColumnMetric(Metric[ColumnTResult], Generic[ColumnTResult], abc.ABC): 354 column_name: ColumnName 355 356 def __init__(self, column_name: Union[ColumnName, str], options: AnyOptions = None): 357 self.column_name = ColumnName.from_any(column_name) 358 super().__init__(options)