/ src / evidently / legacy / tests / regression_performance_tests.py
regression_performance_tests.py
  1  from abc import ABC
  2  from typing import ClassVar
  3  from typing import List
  4  from typing import Optional
  5  from typing import Union
  6  
  7  from evidently.legacy.metrics import RegressionDummyMetric
  8  from evidently.legacy.metrics import RegressionQualityMetric
  9  from evidently.legacy.metrics.regression_performance.visualization import regression_perf_plot
 10  from evidently.legacy.renderers.base_renderer import TestHtmlInfo
 11  from evidently.legacy.renderers.base_renderer import TestRenderer
 12  from evidently.legacy.renderers.base_renderer import default_renderer
 13  from evidently.legacy.renderers.html_widgets import plotly_figure
 14  from evidently.legacy.renderers.render_utils import plot_distr
 15  from evidently.legacy.tests.base_test import BaseCheckValueTest
 16  from evidently.legacy.tests.base_test import GroupData
 17  from evidently.legacy.tests.base_test import GroupingTypes
 18  from evidently.legacy.tests.base_test import TestValueCondition
 19  from evidently.legacy.tests.utils import approx
 20  from evidently.legacy.utils.types import Numeric
 21  from evidently.legacy.utils.visualizations import plot_distr_with_cond_perc_button
 22  
 23  REGRESSION_GROUP = GroupData(id="regression", title="Regression", description="")
 24  GroupingTypes.TestGroup.add_value(REGRESSION_GROUP)
 25  
 26  
 27  class BaseRegressionPerformanceMetricsTest(BaseCheckValueTest, ABC):
 28      group: ClassVar = REGRESSION_GROUP.id
 29      _metric: RegressionQualityMetric
 30      _dummy_metric: RegressionDummyMetric
 31  
 32      def __init__(
 33          self,
 34          eq: Optional[Numeric] = None,
 35          gt: Optional[Numeric] = None,
 36          gte: Optional[Numeric] = None,
 37          is_in: Optional[List[Union[Numeric, str, bool]]] = None,
 38          lt: Optional[Numeric] = None,
 39          lte: Optional[Numeric] = None,
 40          not_eq: Optional[Numeric] = None,
 41          not_in: Optional[List[Union[Numeric, str, bool]]] = None,
 42          is_critical: bool = True,
 43      ):
 44          super().__init__(
 45              eq=eq,
 46              gt=gt,
 47              gte=gte,
 48              is_in=is_in,
 49              lt=lt,
 50              lte=lte,
 51              not_eq=not_eq,
 52              not_in=not_in,
 53              is_critical=is_critical,
 54          )
 55          self._metric = RegressionQualityMetric()
 56          self._dummy_metric = RegressionDummyMetric()
 57  
 58      @property
 59      def metric(self):
 60          return self._metric
 61  
 62      @property
 63      def dummy_metric(self):
 64          return self._dummy_metric
 65  
 66  
 67  class TestValueMAE(BaseRegressionPerformanceMetricsTest):
 68      class Config:
 69          type_alias = "evidently:test:TestValueMAE"
 70  
 71      name: ClassVar = "Mean Absolute Error (MAE)"
 72  
 73      def get_condition(self) -> TestValueCondition:
 74          if self.condition.has_condition():
 75              return self.condition
 76          metric_result = self.metric.get_result()
 77          ref_mae = metric_result.reference.mean_abs_error if metric_result.reference is not None else None
 78          if ref_mae is not None:
 79              return TestValueCondition(eq=approx(ref_mae, relative=0.1))
 80          return TestValueCondition(lt=self.dummy_metric.get_result().mean_abs_error_default)
 81  
 82      def calculate_value_for_test(self) -> Numeric:
 83          return self.metric.get_result().current.mean_abs_error
 84  
 85      def get_description(self, value: Numeric) -> str:
 86          return f"The MAE is {value:.3}. The test threshold is {self.get_condition()}"
 87  
 88  
 89  @default_renderer(wrap_type=TestValueMAE)
 90  class TestValueMAERenderer(TestRenderer):
 91      def render_html(self, obj: TestValueMAE) -> TestHtmlInfo:
 92          info = super().render_html(obj)
 93          result = obj.metric.get_result()
 94          fig = regression_perf_plot(
 95              val_for_plot=result.vals_for_plots.mean_abs_error,
 96              hist_for_plot=result.hist_for_plot,
 97              name="MAE",
 98              curr_metric=result.current.mean_abs_error,
 99              ref_metric=result.reference.mean_abs_error if result.reference is not None else None,
100              color_options=self.color_options,
101          )
102          info.with_details("MAE", plotly_figure(title="", figure=fig))
103          return info
104  
105  
106  class TestValueMAPE(BaseRegressionPerformanceMetricsTest):
107      class Config:
108          type_alias = "evidently:test:TestValueMAPE"
109  
110      name: ClassVar = "Mean Absolute Percentage Error (MAPE)"
111  
112      def get_condition(self) -> TestValueCondition:
113          if self.condition.has_condition():
114              return self.condition
115          metric_result = self.metric.get_result()
116          ref_mae = metric_result.reference.mean_abs_perc_error if metric_result.reference is not None else None
117          if ref_mae is not None:
118              return TestValueCondition(eq=approx(ref_mae, relative=0.1))
119          return TestValueCondition(lt=self.dummy_metric.get_result().mean_abs_perc_error_default)
120  
121      def calculate_value_for_test(self) -> Numeric:
122          return self.metric.get_result().current.mean_abs_perc_error
123  
124      def get_description(self, value: Numeric) -> str:
125          return f"The MAPE is {value:.3}. The test threshold is {self.get_condition()}."
126  
127  
128  @default_renderer(wrap_type=TestValueMAPE)
129  class TestValueMAPERenderer(TestRenderer):
130      def render_html(self, obj: TestValueMAPE) -> TestHtmlInfo:
131          info = super().render_html(obj)
132          result = obj.metric.get_result()
133          val_for_plot = result.vals_for_plots.mean_abs_perc_error
134          val_for_plot = val_for_plot * 100
135          fig = regression_perf_plot(
136              val_for_plot=val_for_plot,
137              hist_for_plot=result.hist_for_plot,
138              name="MAPE",
139              curr_metric=result.current.mean_abs_perc_error,
140              ref_metric=result.reference.mean_abs_perc_error if result.reference is not None else None,
141              color_options=self.color_options,
142          )
143          info.with_details("MAPE", plotly_figure(title="", figure=fig))
144          return info
145  
146  
147  class TestValueRMSE(BaseRegressionPerformanceMetricsTest):
148      class Config:
149          type_alias = "evidently:test:TestValueRMSE"
150  
151      name: ClassVar = "Root Mean Square Error (RMSE)"
152  
153      def get_condition(self) -> TestValueCondition:
154          if self.condition.has_condition():
155              return self.condition
156          metric_result = self.metric.get_result()
157          rmse_ref = metric_result.reference.rmse if metric_result.reference is not None else None
158          if rmse_ref is not None:
159              return TestValueCondition(eq=approx(rmse_ref, relative=0.1))
160          return TestValueCondition(lt=self.dummy_metric.get_result().rmse_default)
161  
162      def calculate_value_for_test(self) -> Numeric:
163          return self.metric.get_result().current.rmse
164  
165      def get_description(self, value: Numeric) -> str:
166          return f"The RMSE is {value:.3}. The test threshold is {self.get_condition()}."
167  
168  
169  @default_renderer(wrap_type=TestValueRMSE)
170  class TestValueRMSERenderer(TestRenderer):
171      def render_html(self, obj: TestValueRMSE) -> TestHtmlInfo:
172          info = super().render_html(obj)
173          result = obj.metric.get_result()
174          fig = regression_perf_plot(
175              val_for_plot=result.vals_for_plots.rmse,
176              hist_for_plot=result.hist_for_plot,
177              name="RMSE",
178              curr_metric=result.current.rmse,
179              ref_metric=result.reference.rmse if result.reference is not None else None,
180              color_options=self.color_options,
181          )
182          info.with_details("RMSE", plotly_figure(title="", figure=fig))
183          return info
184  
185  
186  class TestValueMeanError(BaseRegressionPerformanceMetricsTest):
187      class Config:
188          type_alias = "evidently:test:TestValueMeanError"
189  
190      name: ClassVar = "Mean Error (ME)"
191  
192      def get_condition(self) -> TestValueCondition:
193          if self.condition.has_condition():
194              return self.condition
195          return TestValueCondition(eq=approx(0, absolute=0.1 * self.metric.get_result().me_default_sigma))
196  
197      def calculate_value_for_test(self) -> Numeric:
198          return self.metric.get_result().current.mean_error
199  
200      def get_description(self, value: Numeric) -> str:
201          return f"The ME is {value:.3}. The test threshold is {self.get_condition()}."
202  
203  
204  @default_renderer(wrap_type=TestValueMeanError)
205  class TestValueMeanErrorRenderer(TestRenderer):
206      def render_html(self, obj: TestValueMeanError) -> TestHtmlInfo:
207          info = super().render_html(obj)
208          metric_result = obj.metric.get_result()
209          me_hist_for_plot = metric_result.me_hist_for_plot
210          hist_curr = me_hist_for_plot.current
211          hist_ref = me_hist_for_plot.reference
212  
213          fig = plot_distr_with_cond_perc_button(
214              hist_curr=hist_curr,
215              hist_ref=hist_ref,
216              xaxis_name="",
217              yaxis_name="count",
218              yaxis_name_perc="percent",
219              color_options=self.color_options,
220              to_json=False,
221              condition=obj.get_condition(),
222              value=metric_result.current.mean_error,
223              value_name="current mean error",
224          )
225          # fig = plot_distr(hist_curr=hist_curr, hist_ref=hist_ref, color_options=self.color_options)
226          # fig = plot_check(fig, obj.get_condition(), color_options=self.color_options)
227          # fig = plot_metric_value(fig, metric_result.current.mean_error, "current mean error")
228          info.with_details("", plotly_figure(title="", figure=fig))
229          return info
230  
231  
232  class TestValueAbsMaxError(BaseRegressionPerformanceMetricsTest):
233      class Config:
234          type_alias = "evidently:test:TestValueAbsMaxError"
235  
236      name: ClassVar = "Max Absolute Error"
237  
238      def get_condition(self) -> TestValueCondition:
239          if self.condition.has_condition():
240              return self.condition
241          metric_result = self.metric.get_result()
242          abs_error_max_ref = metric_result.reference.abs_error_max if metric_result.reference is not None else None
243          if abs_error_max_ref is not None:
244              return TestValueCondition(lte=approx(abs_error_max_ref, relative=0.1))
245          return TestValueCondition(lte=self.dummy_metric.get_result().abs_error_max_default)
246  
247      def calculate_value_for_test(self) -> Numeric:
248          return self.metric.get_result().current.abs_error_max
249  
250      def get_description(self, value: Numeric) -> str:
251          return f"The Max Absolute Error is {value:.3}. The test threshold is {self.get_condition()}."
252  
253  
254  @default_renderer(wrap_type=TestValueAbsMaxError)
255  class TestValueAbsMaxErrorRenderer(TestRenderer):
256      def render_html(self, obj: TestValueAbsMaxError) -> TestHtmlInfo:
257          info = super().render_html(obj)
258          me_hist_for_plot = obj.metric.get_result().me_hist_for_plot
259          hist_curr = me_hist_for_plot.current
260          hist_ref = me_hist_for_plot.reference
261  
262          fig = plot_distr(hist_curr=hist_curr, hist_ref=hist_ref, color_options=self.color_options)
263          info.with_details("", plotly_figure(title="", figure=fig))
264          return info
265  
266  
267  class TestValueR2Score(BaseRegressionPerformanceMetricsTest):
268      class Config:
269          type_alias = "evidently:test:TestValueR2Score"
270  
271      name: ClassVar = "R2 Score"
272  
273      def get_condition(self) -> TestValueCondition:
274          if self.condition.has_condition():
275              return self.condition
276          result = self.metric.get_result()
277          r2_score_ref = result.reference.r2_score if result.reference is not None else None
278          if r2_score_ref is not None:
279              return TestValueCondition(eq=approx(r2_score_ref, relative=0.1))
280          return TestValueCondition(gt=0)
281  
282      def calculate_value_for_test(self) -> Numeric:
283          return self.metric.get_result().current.r2_score
284  
285      def get_description(self, value: Numeric) -> str:
286          return f"The R2 score is {value:.3}. The test threshold is {self.get_condition()}."
287  
288  
289  @default_renderer(wrap_type=TestValueR2Score)
290  class TestValueR2ScoreRenderer(TestRenderer):
291      def render_html(self, obj: TestValueR2Score) -> TestHtmlInfo:
292          info = super().render_html(obj)
293          result = obj.metric.get_result()
294  
295          fig = regression_perf_plot(
296              val_for_plot=result.vals_for_plots.r2_score,
297              hist_for_plot=result.hist_for_plot,
298              name="R2_score",
299              curr_metric=result.current.r2_score,
300              ref_metric=result.reference.r2_score if result.reference is not None else None,
301              color_options=self.color_options,
302          )
303          info.with_details("R2 Score", plotly_figure(title="", figure=fig))
304          return info