/ src / evidently / tests / numerical_tests.py
numerical_tests.py
  1  import abc
  2  from typing import ClassVar
  3  from typing import Union
  4  
  5  from evidently.core.metric_types import DatasetType
  6  from evidently.core.metric_types import MetricCalculationBase
  7  from evidently.core.metric_types import MetricTest
  8  from evidently.core.metric_types import MetricTestResult
  9  from evidently.core.metric_types import MetricValueLocation
 10  from evidently.core.metric_types import SingleValue
 11  from evidently.core.metric_types import SingleValueTest
 12  from evidently.core.metric_types import TestStatus
 13  from evidently.core.metric_types import Value
 14  from evidently.core.report import Context
 15  from evidently.core.tests import ApproxValue
 16  from evidently.core.tests import Reference
 17  from evidently.core.tests import ThresholdType
 18  from evidently.core.tests import ThresholdValue
 19  
 20  
 21  class ComparisonTest(MetricTest):
 22      threshold: ThresholdType
 23      __short_name__: ClassVar[str]
 24      __full_name__: ClassVar[str]
 25      __reference_relation__: ClassVar[str]
 26  
 27      @abc.abstractmethod
 28      def check(self, value: Value, threshold: ThresholdValue) -> bool:
 29          raise NotImplementedError
 30  
 31      def to_test(self) -> SingleValueTest:
 32          def func(context: Context, metric: MetricCalculationBase, value: SingleValue):
 33              threshold = self.get_threshold(context, value.get_metric_value_location())
 34              title_threshold = f"{threshold:0.3f}"
 35              if isinstance(self.threshold, Reference):
 36                  if isinstance(threshold, ApproxValue):
 37                      title_threshold += f"Reference {threshold:0.3f} ± {threshold.tolerance:0.3f}"
 38                  else:
 39                      title_threshold = f"Reference {threshold:0.3f}"
 40              return MetricTestResult(
 41                  id=self.__short_name__,
 42                  name=f"{value.display_name}: {self.__full_name__} {title_threshold}",
 43                  description=f"Actual value {value.value:0.3f} {'<' if value.value < threshold else '>='} {threshold:0.3f}",
 44                  status=TestStatus.SUCCESS if self.check(value.value, threshold) else TestStatus.FAIL,
 45                  metric_config=metric.to_metric_config(),
 46                  test_config=self.dict(),
 47              )
 48  
 49          return func
 50  
 51      def get_threshold(self, context: Context, metric_location: MetricValueLocation) -> Union[float, int, ApproxValue]:
 52          if isinstance(self.threshold, Reference):
 53              if context._input_data[1] is None:
 54                  raise ValueError("No Reference dataset provided, but tests contains Reference thresholds")
 55              value = metric_location.value(context, DatasetType.Reference).value
 56              return ApproxValue(value, self.threshold.relative, self.threshold.absolute)
 57          return self.threshold
 58  
 59  
 60  class LessOrEqualMetricTest(ComparisonTest):
 61      __short_name__: ClassVar[str] = "le"
 62      __full_name__: ClassVar[str] = "Less or Equal"
 63      __reference_relation__ = "less"
 64  
 65      def check(self, value: Value, threshold: ThresholdValue) -> bool:
 66          return value <= threshold
 67  
 68  
 69  class GreaterOrEqualMetricTest(ComparisonTest):
 70      __short_name__: ClassVar[str] = "ge"
 71      __full_name__: ClassVar[str] = "Greater or Equal"
 72      __reference_relation__: ClassVar[str] = "greater"
 73  
 74      def check(self, value: Value, threshold: ThresholdValue):
 75          return value >= threshold
 76  
 77  
 78  class GreaterThanMetricTest(ComparisonTest):
 79      __short_name__: ClassVar[str] = "gt"
 80      __full_name__: ClassVar[str] = "Greater"
 81      __reference_relation__: ClassVar[str] = "greater"
 82  
 83      def check(self, value: Value, threshold: ThresholdValue):
 84          return value > threshold
 85  
 86  
 87  class LessThanMetricTest(ComparisonTest):
 88      __short_name__: ClassVar[str] = "lt"
 89      __full_name__: ClassVar[str] = "Less"
 90      __reference_relation__ = "less"
 91  
 92      def check(self, value: Value, threshold: ThresholdValue):
 93          return value < threshold
 94  
 95  
 96  class EqualMetricTestBase(MetricTest, abc.ABC):
 97      expected: ThresholdType
 98  
 99      def is_equal(self, context: Context, value: SingleValue):
100          expected: Union[float, int, ApproxValue]
101          if isinstance(self.expected, Reference):
102              result = value.get_metric_value_location().value(context, DatasetType.Reference)
103              expected = ApproxValue(result.value, self.expected.relative, self.expected.absolute)
104          else:
105              expected = self.expected
106          title_expected = f"{expected:0.3f}"
107          return expected, title_expected, expected == value.value
108  
109  
110  class EqualMetricTest(EqualMetricTestBase):
111      expected: ThresholdType
112  
113      def to_test(self) -> SingleValueTest:
114          def func(context: Context, metric: MetricCalculationBase, value: SingleValue):
115              expected, title_expected, is_equal = self.is_equal(context, value)
116              return MetricTestResult(
117                  id="eq",
118                  name=f"{metric.display_name()}: Equal {title_expected}",
119                  description=f"Actual value {value.value:0.3f}"
120                  f" {f', but expected {expected:0.3f}' if not is_equal else f' expected {expected:0.3f}'}",
121                  status=TestStatus.SUCCESS if is_equal else TestStatus.FAIL,
122                  metric_config=metric.to_metric_config(),
123                  test_config=self.dict(),
124              )
125  
126          return func
127  
128  
129  class NotEqualMetricTest(EqualMetricTestBase):
130      def to_test(self) -> SingleValueTest:
131          def func(context: Context, metric: MetricCalculationBase, value: SingleValue):
132              expected, title_expected, is_equal = self.is_equal(context, value)
133              return MetricTestResult(
134                  id="not_eq",
135                  name=f"{metric.display_name()}: Not equal {title_expected}",
136                  description=f"Actual value {value.value}"
137                  f" {f', but expected not {expected:0.3f}' if is_equal else f' not equal to {expected:0.3f}'}",
138                  status=TestStatus.SUCCESS if not is_equal else TestStatus.FAIL,
139                  metric_config=metric.to_metric_config(),
140                  test_config=self.dict(),
141              )
142  
143          return func