/ src / evidently / legacy / tests / base_test.py
base_test.py
  1  import abc
  2  from abc import ABC
  3  from enum import Enum
  4  from typing import TYPE_CHECKING
  5  from typing import Any
  6  from typing import ClassVar
  7  from typing import Dict
  8  from typing import Generic
  9  from typing import List
 10  from typing import Optional
 11  from typing import Type
 12  from typing import TypeVar
 13  from typing import Union
 14  
 15  from evidently._pydantic_compat import BaseModel
 16  from evidently._pydantic_compat import Field
 17  from evidently.legacy.base_metric import BaseResult
 18  from evidently.legacy.base_metric import Metric
 19  from evidently.legacy.base_metric import MetricResult
 20  from evidently.legacy.core import IncludeTags
 21  from evidently.legacy.utils.generators import BaseGenerator
 22  from evidently.legacy.utils.generators import make_generator_by_columns
 23  from evidently.legacy.utils.types import ApproxValue
 24  from evidently.legacy.utils.types import Numeric
 25  from evidently.legacy.utils.types import NumericApprox
 26  from evidently.pydantic_utils import EnumValueMixin
 27  from evidently.pydantic_utils import EvidentlyBaseModel
 28  from evidently.pydantic_utils import ExcludeNoneMixin
 29  from evidently.pydantic_utils import WithTestAndMetricDependencies
 30  
 31  if TYPE_CHECKING:
 32      from evidently.legacy.suite.base_suite import Context
 33  
 34  
 35  class GroupData(BaseModel):
 36      id: str
 37      title: str
 38      description: str
 39      sort_index: int = 0
 40      severity: Optional[str] = None
 41  
 42  
 43  class GroupTypeData(BaseModel):
 44      id: str
 45      title: str
 46      # possible values with description, if empty will use simple view (no severity, description and sorting).
 47      values: List[GroupData] = Field(default_factory=list)
 48  
 49      def add_value(self, data: GroupData):
 50          self.values.append(data)
 51  
 52  
 53  class GroupingTypes:
 54      ByFeature = GroupTypeData(
 55          id="by_feature",
 56          title="By feature",
 57          values=[
 58              GroupData(
 59                  id="no group",
 60                  title="Dataset-level tests",
 61                  description="Some tests cannot be grouped by feature",
 62              )
 63          ],
 64      )
 65      ByClass = GroupTypeData(id="by_class", title="By class", values=[])
 66      TestGroup = GroupTypeData(
 67          id="test_group",
 68          title="By test group",
 69          values=[
 70              GroupData(
 71                  id="no group",
 72                  title="Ungrouped",
 73                  description="Some tests don’t belong to any group under the selected condition",
 74              )
 75          ],
 76      )
 77      TestType = GroupTypeData(id="test_type", title="By test type", values=[])
 78  
 79  
 80  DEFAULT_GROUP = [
 81      GroupingTypes.ByFeature,
 82      GroupingTypes.TestGroup,
 83      GroupingTypes.TestType,
 84      GroupingTypes.ByClass,
 85  ]
 86  
 87  
 88  class TestStatus(str, Enum):
 89      # Constants for test result status
 90      SUCCESS = "SUCCESS"  # the test was passed
 91      FAIL = "FAIL"  # success pass for the test
 92      WARNING = "WARNING"  # the test was passed, but we have some issues during the execution
 93      ERROR = "ERROR"  # cannot calculate the test result, no data
 94      SKIPPED = "SKIPPED"  # the test was skipped
 95  
 96  
 97  class TestParameters(EvidentlyBaseModel, BaseResult):  # type: ignore[misc] # pydantic Config
 98      class Config:
 99          type_alias = "evidently:test_parameters:TestParameters"
100          field_tags = {"type": {IncludeTags.TypeField}}
101          is_base_type = True
102  
103  
104  class TestResult(EnumValueMixin, MetricResult):  # todo: create common base class
105      # short name/title from the test class
106      class Config:
107          type_alias = "evidently:metric_result:TestResult"
108  
109      name: str
110      # what was checked, what threshold (current value 13 is not ok with condition less than 5)
111      description: str
112      # status of the test result
113      status: TestStatus
114      # grouping parameters
115      group: str
116      parameters: Optional[TestParameters]
117      _exception: Optional[BaseException] = None
118  
119      @property
120      def exception(self):
121          return self._exception
122  
123      def set_status(self, status: TestStatus, description: Optional[str] = None) -> None:
124          self.status = status
125  
126          if description is not None:
127              self.description = description
128  
129      def mark_as_fail(self, description: Optional[str] = None):
130          self.set_status(TestStatus.FAIL, description=description)
131  
132      def mark_as_error(self, description: Optional[str] = None):
133          self.set_status(TestStatus.ERROR, description=description)
134  
135      def mark_as_success(self, description: Optional[str] = None):
136          self.set_status(TestStatus.SUCCESS, description=description)
137  
138      def mark_as_warning(self, description: Optional[str] = None):
139          self.set_status(TestStatus.WARNING, description=description)
140  
141      def is_passed(self):
142          return self.status in [TestStatus.SUCCESS, TestStatus.WARNING]
143  
144  
145  class Test(WithTestAndMetricDependencies):
146      class Config:
147          is_base_type = True
148  
149      """
150      all fields in test class with type that is subclass of Metric would be used as dependencies of test.
151      """
152  
153      name: ClassVar[str]
154      group: ClassVar[str]
155      is_critical: bool = True
156      _context: Optional["Context"] = None
157  
158      @abc.abstractmethod
159      def check(self) -> TestResult:
160          raise NotImplementedError
161  
162      def set_context(self, context: "Context"):
163          self._context = context
164  
165      def get_result(self) -> TestResult:
166          if self._context is None:
167              raise ValueError("No context is set")
168          result = self._context.test_results.get(self, None)
169          if result is None:
170              raise ValueError(f"No result found for metric {self} of type {type(self).__name__}")
171          return result  # type: ignore[return-value]
172  
173      def get_id(self) -> str:
174          return self.__class__.__name__
175  
176      @abc.abstractmethod
177      def groups(self) -> Dict[str, str]:
178          raise NotImplementedError
179  
180      def get_groups(self) -> Dict[str, str]:
181          groups = self.groups()
182          groups.update(
183              {
184                  GroupingTypes.TestGroup.id: self.group,
185                  GroupingTypes.TestType.id: self.name,
186              }
187          )
188          return groups
189  
190  
191  class TestValueCondition(ExcludeNoneMixin):
192      """
193      Class for processing a value conditions - should it be less, greater than, equals and so on.
194  
195      An object of the class stores specified conditions and can be used for checking a value by them.
196      """
197  
198      class Config:
199          arbitrary_types_allowed = True
200          use_enum_values = True
201          smart_union = True
202  
203      eq: Optional[NumericApprox] = None
204      gt: Optional[NumericApprox] = None
205      gte: Optional[NumericApprox] = None
206      is_in: Optional[List[Union[Numeric, str, bool]]] = None
207      lt: Optional[NumericApprox] = None
208      lte: Optional[NumericApprox] = None
209      not_eq: Optional[Numeric] = None
210      not_in: Optional[List[Union[Numeric, str, bool]]] = None
211  
212      def has_condition(self) -> bool:
213          """
214          Checks if we have a condition in the object and returns True in this case.
215  
216          If we have no conditions - returns False.
217          """
218          return any(
219              value is not None
220              for value in (
221                  self.eq,
222                  self.gt,
223                  self.gte,
224                  self.is_in,
225                  self.lt,
226                  self.lte,
227                  self.not_in,
228                  self.not_eq,
229              )
230          )
231  
232      def check_value(self, value: Numeric) -> bool:
233          result = True
234  
235          if self.eq is not None and result:
236              result = value == self.eq
237  
238          if self.gt is not None and result:
239              result = value > self.gt
240  
241          if self.gte is not None and result:
242              result = value >= self.gte
243  
244          if self.is_in is not None and result:
245              result = value in self.is_in
246  
247          if self.lt is not None and result:
248              result = value < self.lt
249  
250          if self.lte is not None and result:
251              result = value <= self.lte
252  
253          if self.not_eq is not None and result:
254              result = value != self.not_eq
255  
256          if self.not_in is not None and result:
257              result = value not in self.not_in
258  
259          return result
260  
261      def __str__(self) -> str:
262          conditions = []
263          operations = ["eq", "gt", "gte", "lt", "lte", "not_eq", "is_in", "not_in"]
264  
265          for op in operations:
266              value = getattr(self, op)
267  
268              if value is None:
269                  continue
270  
271              if isinstance(value, (float, ApproxValue)):
272                  conditions.append(f"{op}={value:.3g}")
273  
274              else:
275                  conditions.append(f"{op}={value}")
276  
277          return f"{' and '.join(conditions)}"
278  
279  
280  class ConditionTestParameters(TestParameters):
281      class Config:
282          type_alias = "evidently:test_parameters:ConditionTestParameters"
283  
284      condition: TestValueCondition
285  
286  
287  class BaseConditionsTest(Test, TestValueCondition, ABC):
288      """
289      Base class for all tests with a condition
290      """
291  
292      class Config:
293          arbitrary_types_allowed = True
294          use_enum_values = True
295          smart_union = True
296          underscore_attrs_are_private = True
297  
298      # condition: TestValueCondition
299  
300      @property
301      def condition(self) -> TestValueCondition:
302          return TestValueCondition(
303              eq=self.eq,
304              gt=self.gt,
305              gte=self.gte,
306              is_in=self.is_in,
307              lt=self.lt,
308              lte=self.lte,
309              not_eq=self.not_eq,
310              not_in=self.not_in,
311          )
312  
313  
314  class CheckValueParameters(ConditionTestParameters):
315      class Config:
316          type_alias = "evidently:test_parameters:CheckValueParameters"
317  
318      value: Optional[Numeric]
319  
320  
321  class ColumnCheckValueParameters(CheckValueParameters):
322      class Config:
323          type_alias = "evidently:test_parameters:ColumnCheckValueParameters"
324  
325      column_name: str
326  
327  
328  class BaseCheckValueTest(BaseConditionsTest):
329      """
330      Base class for all tests with checking a value condition
331      """
332  
333      _value: Numeric
334  
335      @abc.abstractmethod
336      def calculate_value_for_test(self) -> Optional[Any]:
337          """Method for getting the checking value.
338  
339          Define it in a child class"""
340          raise NotImplementedError()
341  
342      @abc.abstractmethod
343      def get_description(self, value: Numeric) -> str:
344          """Method for getting a description that we can use.
345          The description can use the checked value.
346  
347          Define it in a child class"""
348          raise NotImplementedError()
349  
350      def get_condition(self) -> TestValueCondition:
351          return self.condition
352  
353      def groups(self) -> Dict[str, str]:
354          return {}
355  
356      def get_parameters(self) -> CheckValueParameters:
357          return CheckValueParameters(condition=self.get_condition(), value=self._value)
358  
359      def check(self):
360          result = TestResult(
361              name=self.name,
362              description="The test was not launched",
363              status=TestStatus.SKIPPED,
364              group=self.group,
365              parameters=None,
366          )
367          value = self.calculate_value_for_test()
368          self._value = value
369          result.description = self.get_description(value)
370          result.parameters = self.get_parameters()
371  
372          try:
373              if value is None:
374                  result.mark_as_error()
375  
376              else:
377                  condition = self.get_condition()
378  
379                  if condition is None:
380                      raise ValueError
381  
382                  condition_check_result = condition.check_value(value)
383  
384                  if condition_check_result:
385                      result.mark_as_success()
386  
387                  else:
388                      result.mark_as_fail()
389  
390          except ValueError:
391              result.mark_as_error("Cannot calculate the condition")
392  
393          return result
394  
395  
396  T = TypeVar("T", bound=MetricResult)
397  
398  
399  class ConditionFromReferenceMixin(BaseCheckValueTest, Generic[T], ABC):
400      reference_field: ClassVar[str] = "reference"
401      _metric: Metric
402  
403      def get_condition_from_reference(self, reference: Optional[T]) -> TestValueCondition:
404          raise NotImplementedError
405  
406      def get_condition(self) -> TestValueCondition:
407          if self.condition.has_condition():
408              return self.condition
409  
410          reference_stats = getattr(self.metric.get_result(), self.reference_field)
411  
412          return self.get_condition_from_reference(reference_stats)
413  
414      @property
415      def metric(self):
416          return self._metric
417  
418  
419  def generate_column_tests(
420      test_class: Type[Test], columns: Optional[Union[str, list]] = None, parameters: Optional[Dict] = None
421  ) -> BaseGenerator:
422      """Function for generating tests for columns"""
423      return make_generator_by_columns(
424          base_class=test_class,
425          columns=columns,
426          parameters=parameters,
427      )