base_test.py
1 import abc 2 from abc import ABC 3 from enum import Enum 4 from typing import TYPE_CHECKING 5 from typing import Any 6 from typing import ClassVar 7 from typing import Dict 8 from typing import Generic 9 from typing import List 10 from typing import Optional 11 from typing import Type 12 from typing import TypeVar 13 from typing import Union 14 15 from evidently._pydantic_compat import BaseModel 16 from evidently._pydantic_compat import Field 17 from evidently.legacy.base_metric import BaseResult 18 from evidently.legacy.base_metric import Metric 19 from evidently.legacy.base_metric import MetricResult 20 from evidently.legacy.core import IncludeTags 21 from evidently.legacy.utils.generators import BaseGenerator 22 from evidently.legacy.utils.generators import make_generator_by_columns 23 from evidently.legacy.utils.types import ApproxValue 24 from evidently.legacy.utils.types import Numeric 25 from evidently.legacy.utils.types import NumericApprox 26 from evidently.pydantic_utils import EnumValueMixin 27 from evidently.pydantic_utils import EvidentlyBaseModel 28 from evidently.pydantic_utils import ExcludeNoneMixin 29 from evidently.pydantic_utils import WithTestAndMetricDependencies 30 31 if TYPE_CHECKING: 32 from evidently.legacy.suite.base_suite import Context 33 34 35 class GroupData(BaseModel): 36 id: str 37 title: str 38 description: str 39 sort_index: int = 0 40 severity: Optional[str] = None 41 42 43 class GroupTypeData(BaseModel): 44 id: str 45 title: str 46 # possible values with description, if empty will use simple view (no severity, description and sorting). 47 values: List[GroupData] = Field(default_factory=list) 48 49 def add_value(self, data: GroupData): 50 self.values.append(data) 51 52 53 class GroupingTypes: 54 ByFeature = GroupTypeData( 55 id="by_feature", 56 title="By feature", 57 values=[ 58 GroupData( 59 id="no group", 60 title="Dataset-level tests", 61 description="Some tests cannot be grouped by feature", 62 ) 63 ], 64 ) 65 ByClass = GroupTypeData(id="by_class", title="By class", values=[]) 66 TestGroup = GroupTypeData( 67 id="test_group", 68 title="By test group", 69 values=[ 70 GroupData( 71 id="no group", 72 title="Ungrouped", 73 description="Some tests don’t belong to any group under the selected condition", 74 ) 75 ], 76 ) 77 TestType = GroupTypeData(id="test_type", title="By test type", values=[]) 78 79 80 DEFAULT_GROUP = [ 81 GroupingTypes.ByFeature, 82 GroupingTypes.TestGroup, 83 GroupingTypes.TestType, 84 GroupingTypes.ByClass, 85 ] 86 87 88 class TestStatus(str, Enum): 89 # Constants for test result status 90 SUCCESS = "SUCCESS" # the test was passed 91 FAIL = "FAIL" # success pass for the test 92 WARNING = "WARNING" # the test was passed, but we have some issues during the execution 93 ERROR = "ERROR" # cannot calculate the test result, no data 94 SKIPPED = "SKIPPED" # the test was skipped 95 96 97 class TestParameters(EvidentlyBaseModel, BaseResult): # type: ignore[misc] # pydantic Config 98 class Config: 99 type_alias = "evidently:test_parameters:TestParameters" 100 field_tags = {"type": {IncludeTags.TypeField}} 101 is_base_type = True 102 103 104 class TestResult(EnumValueMixin, MetricResult): # todo: create common base class 105 # short name/title from the test class 106 class Config: 107 type_alias = "evidently:metric_result:TestResult" 108 109 name: str 110 # what was checked, what threshold (current value 13 is not ok with condition less than 5) 111 description: str 112 # status of the test result 113 status: TestStatus 114 # grouping parameters 115 group: str 116 parameters: Optional[TestParameters] 117 _exception: Optional[BaseException] = None 118 119 @property 120 def exception(self): 121 return self._exception 122 123 def set_status(self, status: TestStatus, description: Optional[str] = None) -> None: 124 self.status = status 125 126 if description is not None: 127 self.description = description 128 129 def mark_as_fail(self, description: Optional[str] = None): 130 self.set_status(TestStatus.FAIL, description=description) 131 132 def mark_as_error(self, description: Optional[str] = None): 133 self.set_status(TestStatus.ERROR, description=description) 134 135 def mark_as_success(self, description: Optional[str] = None): 136 self.set_status(TestStatus.SUCCESS, description=description) 137 138 def mark_as_warning(self, description: Optional[str] = None): 139 self.set_status(TestStatus.WARNING, description=description) 140 141 def is_passed(self): 142 return self.status in [TestStatus.SUCCESS, TestStatus.WARNING] 143 144 145 class Test(WithTestAndMetricDependencies): 146 class Config: 147 is_base_type = True 148 149 """ 150 all fields in test class with type that is subclass of Metric would be used as dependencies of test. 151 """ 152 153 name: ClassVar[str] 154 group: ClassVar[str] 155 is_critical: bool = True 156 _context: Optional["Context"] = None 157 158 @abc.abstractmethod 159 def check(self) -> TestResult: 160 raise NotImplementedError 161 162 def set_context(self, context: "Context"): 163 self._context = context 164 165 def get_result(self) -> TestResult: 166 if self._context is None: 167 raise ValueError("No context is set") 168 result = self._context.test_results.get(self, None) 169 if result is None: 170 raise ValueError(f"No result found for metric {self} of type {type(self).__name__}") 171 return result # type: ignore[return-value] 172 173 def get_id(self) -> str: 174 return self.__class__.__name__ 175 176 @abc.abstractmethod 177 def groups(self) -> Dict[str, str]: 178 raise NotImplementedError 179 180 def get_groups(self) -> Dict[str, str]: 181 groups = self.groups() 182 groups.update( 183 { 184 GroupingTypes.TestGroup.id: self.group, 185 GroupingTypes.TestType.id: self.name, 186 } 187 ) 188 return groups 189 190 191 class TestValueCondition(ExcludeNoneMixin): 192 """ 193 Class for processing a value conditions - should it be less, greater than, equals and so on. 194 195 An object of the class stores specified conditions and can be used for checking a value by them. 196 """ 197 198 class Config: 199 arbitrary_types_allowed = True 200 use_enum_values = True 201 smart_union = True 202 203 eq: Optional[NumericApprox] = None 204 gt: Optional[NumericApprox] = None 205 gte: Optional[NumericApprox] = None 206 is_in: Optional[List[Union[Numeric, str, bool]]] = None 207 lt: Optional[NumericApprox] = None 208 lte: Optional[NumericApprox] = None 209 not_eq: Optional[Numeric] = None 210 not_in: Optional[List[Union[Numeric, str, bool]]] = None 211 212 def has_condition(self) -> bool: 213 """ 214 Checks if we have a condition in the object and returns True in this case. 215 216 If we have no conditions - returns False. 217 """ 218 return any( 219 value is not None 220 for value in ( 221 self.eq, 222 self.gt, 223 self.gte, 224 self.is_in, 225 self.lt, 226 self.lte, 227 self.not_in, 228 self.not_eq, 229 ) 230 ) 231 232 def check_value(self, value: Numeric) -> bool: 233 result = True 234 235 if self.eq is not None and result: 236 result = value == self.eq 237 238 if self.gt is not None and result: 239 result = value > self.gt 240 241 if self.gte is not None and result: 242 result = value >= self.gte 243 244 if self.is_in is not None and result: 245 result = value in self.is_in 246 247 if self.lt is not None and result: 248 result = value < self.lt 249 250 if self.lte is not None and result: 251 result = value <= self.lte 252 253 if self.not_eq is not None and result: 254 result = value != self.not_eq 255 256 if self.not_in is not None and result: 257 result = value not in self.not_in 258 259 return result 260 261 def __str__(self) -> str: 262 conditions = [] 263 operations = ["eq", "gt", "gte", "lt", "lte", "not_eq", "is_in", "not_in"] 264 265 for op in operations: 266 value = getattr(self, op) 267 268 if value is None: 269 continue 270 271 if isinstance(value, (float, ApproxValue)): 272 conditions.append(f"{op}={value:.3g}") 273 274 else: 275 conditions.append(f"{op}={value}") 276 277 return f"{' and '.join(conditions)}" 278 279 280 class ConditionTestParameters(TestParameters): 281 class Config: 282 type_alias = "evidently:test_parameters:ConditionTestParameters" 283 284 condition: TestValueCondition 285 286 287 class BaseConditionsTest(Test, TestValueCondition, ABC): 288 """ 289 Base class for all tests with a condition 290 """ 291 292 class Config: 293 arbitrary_types_allowed = True 294 use_enum_values = True 295 smart_union = True 296 underscore_attrs_are_private = True 297 298 # condition: TestValueCondition 299 300 @property 301 def condition(self) -> TestValueCondition: 302 return TestValueCondition( 303 eq=self.eq, 304 gt=self.gt, 305 gte=self.gte, 306 is_in=self.is_in, 307 lt=self.lt, 308 lte=self.lte, 309 not_eq=self.not_eq, 310 not_in=self.not_in, 311 ) 312 313 314 class CheckValueParameters(ConditionTestParameters): 315 class Config: 316 type_alias = "evidently:test_parameters:CheckValueParameters" 317 318 value: Optional[Numeric] 319 320 321 class ColumnCheckValueParameters(CheckValueParameters): 322 class Config: 323 type_alias = "evidently:test_parameters:ColumnCheckValueParameters" 324 325 column_name: str 326 327 328 class BaseCheckValueTest(BaseConditionsTest): 329 """ 330 Base class for all tests with checking a value condition 331 """ 332 333 _value: Numeric 334 335 @abc.abstractmethod 336 def calculate_value_for_test(self) -> Optional[Any]: 337 """Method for getting the checking value. 338 339 Define it in a child class""" 340 raise NotImplementedError() 341 342 @abc.abstractmethod 343 def get_description(self, value: Numeric) -> str: 344 """Method for getting a description that we can use. 345 The description can use the checked value. 346 347 Define it in a child class""" 348 raise NotImplementedError() 349 350 def get_condition(self) -> TestValueCondition: 351 return self.condition 352 353 def groups(self) -> Dict[str, str]: 354 return {} 355 356 def get_parameters(self) -> CheckValueParameters: 357 return CheckValueParameters(condition=self.get_condition(), value=self._value) 358 359 def check(self): 360 result = TestResult( 361 name=self.name, 362 description="The test was not launched", 363 status=TestStatus.SKIPPED, 364 group=self.group, 365 parameters=None, 366 ) 367 value = self.calculate_value_for_test() 368 self._value = value 369 result.description = self.get_description(value) 370 result.parameters = self.get_parameters() 371 372 try: 373 if value is None: 374 result.mark_as_error() 375 376 else: 377 condition = self.get_condition() 378 379 if condition is None: 380 raise ValueError 381 382 condition_check_result = condition.check_value(value) 383 384 if condition_check_result: 385 result.mark_as_success() 386 387 else: 388 result.mark_as_fail() 389 390 except ValueError: 391 result.mark_as_error("Cannot calculate the condition") 392 393 return result 394 395 396 T = TypeVar("T", bound=MetricResult) 397 398 399 class ConditionFromReferenceMixin(BaseCheckValueTest, Generic[T], ABC): 400 reference_field: ClassVar[str] = "reference" 401 _metric: Metric 402 403 def get_condition_from_reference(self, reference: Optional[T]) -> TestValueCondition: 404 raise NotImplementedError 405 406 def get_condition(self) -> TestValueCondition: 407 if self.condition.has_condition(): 408 return self.condition 409 410 reference_stats = getattr(self.metric.get_result(), self.reference_field) 411 412 return self.get_condition_from_reference(reference_stats) 413 414 @property 415 def metric(self): 416 return self._metric 417 418 419 def generate_column_tests( 420 test_class: Type[Test], columns: Optional[Union[str, list]] = None, parameters: Optional[Dict] = None 421 ) -> BaseGenerator: 422 """Function for generating tests for columns""" 423 return make_generator_by_columns( 424 base_class=test_class, 425 columns=columns, 426 parameters=parameters, 427 )