test_generic_tests_creation.py
1 from typing import Optional 2 from typing import Sequence 3 from typing import Type 4 5 import pytest 6 7 from evidently.core.datasets import Dataset 8 from evidently.core.metric_types import BoundTest 9 from evidently.core.metric_types import Metric 10 from evidently.core.metric_types import SingleValueCalculation 11 from evidently.core.metric_types import SingleValueMetric 12 from evidently.core.metric_types import TestStatus 13 from evidently.core.metric_types import TMetricResult 14 from evidently.core.report import Context 15 from evidently.core.report import Report 16 from evidently.core.tests import GenericTest 17 from evidently.tests import eq 18 from evidently.tests import gt 19 from evidently.tests import lt 20 from evidently.tests import not_eq 21 22 23 class StubMetric(SingleValueMetric): 24 def get_bound_tests(self, context: "Context") -> Sequence[BoundTest]: 25 return [] 26 27 28 class StubMetricCalculation(SingleValueCalculation[StubMetric]): 29 def calculate(self, context: "Context", current_data: Dataset, reference_data: Optional[Dataset]) -> TMetricResult: 30 pass 31 32 def display_name(self) -> str: 33 return "Stub metric" 34 35 def to_metric(self) -> "Metric": 36 return StubMetric() 37 38 39 @pytest.mark.parametrize( 40 "test,value,expected_metric,expected_descriptor", 41 [ 42 (eq(1), 1, TestStatus.SUCCESS, True), 43 (not_eq(1), 1, TestStatus.FAIL, False), 44 (gt(1), 2, TestStatus.SUCCESS, True), 45 (lt(1), 0, TestStatus.SUCCESS, True), 46 (eq("a"), "a", None, True), 47 (not_eq("a"), "a", None, False), 48 ], 49 ) 50 def test_instances(test: GenericTest, value, expected_metric, expected_descriptor): 51 if expected_metric is None: 52 assert test.metric is None 53 else: 54 calculation = StubMetricCalculation("stub_metric", StubMetric()) 55 assert test.metric.run(Context(Report([])), calculation, calculation.result(value)).status == expected_metric 56 if expected_descriptor is None: 57 assert test.descriptor is None 58 else: 59 assert test.descriptor.condition.check(value) == expected_descriptor 60 61 62 @pytest.mark.parametrize("test,args", [(lt, ("a",))]) 63 def test_failed_instances(test: Type, args): 64 with pytest.raises(ValueError): 65 test(*args)