numerical_tests.py
1 import abc 2 from typing import ClassVar 3 from typing import Union 4 5 from evidently.core.metric_types import DatasetType 6 from evidently.core.metric_types import MetricCalculationBase 7 from evidently.core.metric_types import MetricTest 8 from evidently.core.metric_types import MetricTestResult 9 from evidently.core.metric_types import MetricValueLocation 10 from evidently.core.metric_types import SingleValue 11 from evidently.core.metric_types import SingleValueTest 12 from evidently.core.metric_types import TestStatus 13 from evidently.core.metric_types import Value 14 from evidently.core.report import Context 15 from evidently.core.tests import ApproxValue 16 from evidently.core.tests import Reference 17 from evidently.core.tests import ThresholdType 18 from evidently.core.tests import ThresholdValue 19 20 21 class ComparisonTest(MetricTest): 22 threshold: ThresholdType 23 __short_name__: ClassVar[str] 24 __full_name__: ClassVar[str] 25 __reference_relation__: ClassVar[str] 26 27 @abc.abstractmethod 28 def check(self, value: Value, threshold: ThresholdValue) -> bool: 29 raise NotImplementedError 30 31 def to_test(self) -> SingleValueTest: 32 def func(context: Context, metric: MetricCalculationBase, value: SingleValue): 33 threshold = self.get_threshold(context, value.get_metric_value_location()) 34 title_threshold = f"{threshold:0.3f}" 35 if isinstance(self.threshold, Reference): 36 if isinstance(threshold, ApproxValue): 37 title_threshold += f"Reference {threshold:0.3f} ± {threshold.tolerance:0.3f}" 38 else: 39 title_threshold = f"Reference {threshold:0.3f}" 40 return MetricTestResult( 41 id=self.__short_name__, 42 name=f"{value.display_name}: {self.__full_name__} {title_threshold}", 43 description=f"Actual value {value.value:0.3f} {'<' if value.value < threshold else '>='} {threshold:0.3f}", 44 status=TestStatus.SUCCESS if self.check(value.value, threshold) else TestStatus.FAIL, 45 metric_config=metric.to_metric_config(), 46 test_config=self.dict(), 47 ) 48 49 return func 50 51 def get_threshold(self, context: Context, metric_location: MetricValueLocation) -> Union[float, int, ApproxValue]: 52 if isinstance(self.threshold, Reference): 53 if context._input_data[1] is None: 54 raise ValueError("No Reference dataset provided, but tests contains Reference thresholds") 55 value = metric_location.value(context, DatasetType.Reference).value 56 return ApproxValue(value, self.threshold.relative, self.threshold.absolute) 57 return self.threshold 58 59 60 class LessOrEqualMetricTest(ComparisonTest): 61 __short_name__: ClassVar[str] = "le" 62 __full_name__: ClassVar[str] = "Less or Equal" 63 __reference_relation__ = "less" 64 65 def check(self, value: Value, threshold: ThresholdValue) -> bool: 66 return value <= threshold 67 68 69 class GreaterOrEqualMetricTest(ComparisonTest): 70 __short_name__: ClassVar[str] = "ge" 71 __full_name__: ClassVar[str] = "Greater or Equal" 72 __reference_relation__: ClassVar[str] = "greater" 73 74 def check(self, value: Value, threshold: ThresholdValue): 75 return value >= threshold 76 77 78 class GreaterThanMetricTest(ComparisonTest): 79 __short_name__: ClassVar[str] = "gt" 80 __full_name__: ClassVar[str] = "Greater" 81 __reference_relation__: ClassVar[str] = "greater" 82 83 def check(self, value: Value, threshold: ThresholdValue): 84 return value > threshold 85 86 87 class LessThanMetricTest(ComparisonTest): 88 __short_name__: ClassVar[str] = "lt" 89 __full_name__: ClassVar[str] = "Less" 90 __reference_relation__ = "less" 91 92 def check(self, value: Value, threshold: ThresholdValue): 93 return value < threshold 94 95 96 class EqualMetricTestBase(MetricTest, abc.ABC): 97 expected: ThresholdType 98 99 def is_equal(self, context: Context, value: SingleValue): 100 expected: Union[float, int, ApproxValue] 101 if isinstance(self.expected, Reference): 102 result = value.get_metric_value_location().value(context, DatasetType.Reference) 103 expected = ApproxValue(result.value, self.expected.relative, self.expected.absolute) 104 else: 105 expected = self.expected 106 title_expected = f"{expected:0.3f}" 107 return expected, title_expected, expected == value.value 108 109 110 class EqualMetricTest(EqualMetricTestBase): 111 expected: ThresholdType 112 113 def to_test(self) -> SingleValueTest: 114 def func(context: Context, metric: MetricCalculationBase, value: SingleValue): 115 expected, title_expected, is_equal = self.is_equal(context, value) 116 return MetricTestResult( 117 id="eq", 118 name=f"{metric.display_name()}: Equal {title_expected}", 119 description=f"Actual value {value.value:0.3f}" 120 f" {f', but expected {expected:0.3f}' if not is_equal else f' expected {expected:0.3f}'}", 121 status=TestStatus.SUCCESS if is_equal else TestStatus.FAIL, 122 metric_config=metric.to_metric_config(), 123 test_config=self.dict(), 124 ) 125 126 return func 127 128 129 class NotEqualMetricTest(EqualMetricTestBase): 130 def to_test(self) -> SingleValueTest: 131 def func(context: Context, metric: MetricCalculationBase, value: SingleValue): 132 expected, title_expected, is_equal = self.is_equal(context, value) 133 return MetricTestResult( 134 id="not_eq", 135 name=f"{metric.display_name()}: Not equal {title_expected}", 136 description=f"Actual value {value.value}" 137 f" {f', but expected not {expected:0.3f}' if is_equal else f' not equal to {expected:0.3f}'}", 138 status=TestStatus.SUCCESS if not is_equal else TestStatus.FAIL, 139 metric_config=metric.to_metric_config(), 140 test_config=self.dict(), 141 ) 142 143 return func