regression_performance_tests.py
1 from abc import ABC 2 from typing import ClassVar 3 from typing import List 4 from typing import Optional 5 from typing import Union 6 7 from evidently.legacy.metrics import RegressionDummyMetric 8 from evidently.legacy.metrics import RegressionQualityMetric 9 from evidently.legacy.metrics.regression_performance.visualization import regression_perf_plot 10 from evidently.legacy.renderers.base_renderer import TestHtmlInfo 11 from evidently.legacy.renderers.base_renderer import TestRenderer 12 from evidently.legacy.renderers.base_renderer import default_renderer 13 from evidently.legacy.renderers.html_widgets import plotly_figure 14 from evidently.legacy.renderers.render_utils import plot_distr 15 from evidently.legacy.tests.base_test import BaseCheckValueTest 16 from evidently.legacy.tests.base_test import GroupData 17 from evidently.legacy.tests.base_test import GroupingTypes 18 from evidently.legacy.tests.base_test import TestValueCondition 19 from evidently.legacy.tests.utils import approx 20 from evidently.legacy.utils.types import Numeric 21 from evidently.legacy.utils.visualizations import plot_distr_with_cond_perc_button 22 23 REGRESSION_GROUP = GroupData(id="regression", title="Regression", description="") 24 GroupingTypes.TestGroup.add_value(REGRESSION_GROUP) 25 26 27 class BaseRegressionPerformanceMetricsTest(BaseCheckValueTest, ABC): 28 group: ClassVar = REGRESSION_GROUP.id 29 _metric: RegressionQualityMetric 30 _dummy_metric: RegressionDummyMetric 31 32 def __init__( 33 self, 34 eq: Optional[Numeric] = None, 35 gt: Optional[Numeric] = None, 36 gte: Optional[Numeric] = None, 37 is_in: Optional[List[Union[Numeric, str, bool]]] = None, 38 lt: Optional[Numeric] = None, 39 lte: Optional[Numeric] = None, 40 not_eq: Optional[Numeric] = None, 41 not_in: Optional[List[Union[Numeric, str, bool]]] = None, 42 is_critical: bool = True, 43 ): 44 super().__init__( 45 eq=eq, 46 gt=gt, 47 gte=gte, 48 is_in=is_in, 49 lt=lt, 50 lte=lte, 51 not_eq=not_eq, 52 not_in=not_in, 53 is_critical=is_critical, 54 ) 55 self._metric = RegressionQualityMetric() 56 self._dummy_metric = RegressionDummyMetric() 57 58 @property 59 def metric(self): 60 return self._metric 61 62 @property 63 def dummy_metric(self): 64 return self._dummy_metric 65 66 67 class TestValueMAE(BaseRegressionPerformanceMetricsTest): 68 class Config: 69 type_alias = "evidently:test:TestValueMAE" 70 71 name: ClassVar = "Mean Absolute Error (MAE)" 72 73 def get_condition(self) -> TestValueCondition: 74 if self.condition.has_condition(): 75 return self.condition 76 metric_result = self.metric.get_result() 77 ref_mae = metric_result.reference.mean_abs_error if metric_result.reference is not None else None 78 if ref_mae is not None: 79 return TestValueCondition(eq=approx(ref_mae, relative=0.1)) 80 return TestValueCondition(lt=self.dummy_metric.get_result().mean_abs_error_default) 81 82 def calculate_value_for_test(self) -> Numeric: 83 return self.metric.get_result().current.mean_abs_error 84 85 def get_description(self, value: Numeric) -> str: 86 return f"The MAE is {value:.3}. The test threshold is {self.get_condition()}" 87 88 89 @default_renderer(wrap_type=TestValueMAE) 90 class TestValueMAERenderer(TestRenderer): 91 def render_html(self, obj: TestValueMAE) -> TestHtmlInfo: 92 info = super().render_html(obj) 93 result = obj.metric.get_result() 94 fig = regression_perf_plot( 95 val_for_plot=result.vals_for_plots.mean_abs_error, 96 hist_for_plot=result.hist_for_plot, 97 name="MAE", 98 curr_metric=result.current.mean_abs_error, 99 ref_metric=result.reference.mean_abs_error if result.reference is not None else None, 100 color_options=self.color_options, 101 ) 102 info.with_details("MAE", plotly_figure(title="", figure=fig)) 103 return info 104 105 106 class TestValueMAPE(BaseRegressionPerformanceMetricsTest): 107 class Config: 108 type_alias = "evidently:test:TestValueMAPE" 109 110 name: ClassVar = "Mean Absolute Percentage Error (MAPE)" 111 112 def get_condition(self) -> TestValueCondition: 113 if self.condition.has_condition(): 114 return self.condition 115 metric_result = self.metric.get_result() 116 ref_mae = metric_result.reference.mean_abs_perc_error if metric_result.reference is not None else None 117 if ref_mae is not None: 118 return TestValueCondition(eq=approx(ref_mae, relative=0.1)) 119 return TestValueCondition(lt=self.dummy_metric.get_result().mean_abs_perc_error_default) 120 121 def calculate_value_for_test(self) -> Numeric: 122 return self.metric.get_result().current.mean_abs_perc_error 123 124 def get_description(self, value: Numeric) -> str: 125 return f"The MAPE is {value:.3}. The test threshold is {self.get_condition()}." 126 127 128 @default_renderer(wrap_type=TestValueMAPE) 129 class TestValueMAPERenderer(TestRenderer): 130 def render_html(self, obj: TestValueMAPE) -> TestHtmlInfo: 131 info = super().render_html(obj) 132 result = obj.metric.get_result() 133 val_for_plot = result.vals_for_plots.mean_abs_perc_error 134 val_for_plot = val_for_plot * 100 135 fig = regression_perf_plot( 136 val_for_plot=val_for_plot, 137 hist_for_plot=result.hist_for_plot, 138 name="MAPE", 139 curr_metric=result.current.mean_abs_perc_error, 140 ref_metric=result.reference.mean_abs_perc_error if result.reference is not None else None, 141 color_options=self.color_options, 142 ) 143 info.with_details("MAPE", plotly_figure(title="", figure=fig)) 144 return info 145 146 147 class TestValueRMSE(BaseRegressionPerformanceMetricsTest): 148 class Config: 149 type_alias = "evidently:test:TestValueRMSE" 150 151 name: ClassVar = "Root Mean Square Error (RMSE)" 152 153 def get_condition(self) -> TestValueCondition: 154 if self.condition.has_condition(): 155 return self.condition 156 metric_result = self.metric.get_result() 157 rmse_ref = metric_result.reference.rmse if metric_result.reference is not None else None 158 if rmse_ref is not None: 159 return TestValueCondition(eq=approx(rmse_ref, relative=0.1)) 160 return TestValueCondition(lt=self.dummy_metric.get_result().rmse_default) 161 162 def calculate_value_for_test(self) -> Numeric: 163 return self.metric.get_result().current.rmse 164 165 def get_description(self, value: Numeric) -> str: 166 return f"The RMSE is {value:.3}. The test threshold is {self.get_condition()}." 167 168 169 @default_renderer(wrap_type=TestValueRMSE) 170 class TestValueRMSERenderer(TestRenderer): 171 def render_html(self, obj: TestValueRMSE) -> TestHtmlInfo: 172 info = super().render_html(obj) 173 result = obj.metric.get_result() 174 fig = regression_perf_plot( 175 val_for_plot=result.vals_for_plots.rmse, 176 hist_for_plot=result.hist_for_plot, 177 name="RMSE", 178 curr_metric=result.current.rmse, 179 ref_metric=result.reference.rmse if result.reference is not None else None, 180 color_options=self.color_options, 181 ) 182 info.with_details("RMSE", plotly_figure(title="", figure=fig)) 183 return info 184 185 186 class TestValueMeanError(BaseRegressionPerformanceMetricsTest): 187 class Config: 188 type_alias = "evidently:test:TestValueMeanError" 189 190 name: ClassVar = "Mean Error (ME)" 191 192 def get_condition(self) -> TestValueCondition: 193 if self.condition.has_condition(): 194 return self.condition 195 return TestValueCondition(eq=approx(0, absolute=0.1 * self.metric.get_result().me_default_sigma)) 196 197 def calculate_value_for_test(self) -> Numeric: 198 return self.metric.get_result().current.mean_error 199 200 def get_description(self, value: Numeric) -> str: 201 return f"The ME is {value:.3}. The test threshold is {self.get_condition()}." 202 203 204 @default_renderer(wrap_type=TestValueMeanError) 205 class TestValueMeanErrorRenderer(TestRenderer): 206 def render_html(self, obj: TestValueMeanError) -> TestHtmlInfo: 207 info = super().render_html(obj) 208 metric_result = obj.metric.get_result() 209 me_hist_for_plot = metric_result.me_hist_for_plot 210 hist_curr = me_hist_for_plot.current 211 hist_ref = me_hist_for_plot.reference 212 213 fig = plot_distr_with_cond_perc_button( 214 hist_curr=hist_curr, 215 hist_ref=hist_ref, 216 xaxis_name="", 217 yaxis_name="count", 218 yaxis_name_perc="percent", 219 color_options=self.color_options, 220 to_json=False, 221 condition=obj.get_condition(), 222 value=metric_result.current.mean_error, 223 value_name="current mean error", 224 ) 225 # fig = plot_distr(hist_curr=hist_curr, hist_ref=hist_ref, color_options=self.color_options) 226 # fig = plot_check(fig, obj.get_condition(), color_options=self.color_options) 227 # fig = plot_metric_value(fig, metric_result.current.mean_error, "current mean error") 228 info.with_details("", plotly_figure(title="", figure=fig)) 229 return info 230 231 232 class TestValueAbsMaxError(BaseRegressionPerformanceMetricsTest): 233 class Config: 234 type_alias = "evidently:test:TestValueAbsMaxError" 235 236 name: ClassVar = "Max Absolute Error" 237 238 def get_condition(self) -> TestValueCondition: 239 if self.condition.has_condition(): 240 return self.condition 241 metric_result = self.metric.get_result() 242 abs_error_max_ref = metric_result.reference.abs_error_max if metric_result.reference is not None else None 243 if abs_error_max_ref is not None: 244 return TestValueCondition(lte=approx(abs_error_max_ref, relative=0.1)) 245 return TestValueCondition(lte=self.dummy_metric.get_result().abs_error_max_default) 246 247 def calculate_value_for_test(self) -> Numeric: 248 return self.metric.get_result().current.abs_error_max 249 250 def get_description(self, value: Numeric) -> str: 251 return f"The Max Absolute Error is {value:.3}. The test threshold is {self.get_condition()}." 252 253 254 @default_renderer(wrap_type=TestValueAbsMaxError) 255 class TestValueAbsMaxErrorRenderer(TestRenderer): 256 def render_html(self, obj: TestValueAbsMaxError) -> TestHtmlInfo: 257 info = super().render_html(obj) 258 me_hist_for_plot = obj.metric.get_result().me_hist_for_plot 259 hist_curr = me_hist_for_plot.current 260 hist_ref = me_hist_for_plot.reference 261 262 fig = plot_distr(hist_curr=hist_curr, hist_ref=hist_ref, color_options=self.color_options) 263 info.with_details("", plotly_figure(title="", figure=fig)) 264 return info 265 266 267 class TestValueR2Score(BaseRegressionPerformanceMetricsTest): 268 class Config: 269 type_alias = "evidently:test:TestValueR2Score" 270 271 name: ClassVar = "R2 Score" 272 273 def get_condition(self) -> TestValueCondition: 274 if self.condition.has_condition(): 275 return self.condition 276 result = self.metric.get_result() 277 r2_score_ref = result.reference.r2_score if result.reference is not None else None 278 if r2_score_ref is not None: 279 return TestValueCondition(eq=approx(r2_score_ref, relative=0.1)) 280 return TestValueCondition(gt=0) 281 282 def calculate_value_for_test(self) -> Numeric: 283 return self.metric.get_result().current.r2_score 284 285 def get_description(self, value: Numeric) -> str: 286 return f"The R2 score is {value:.3}. The test threshold is {self.get_condition()}." 287 288 289 @default_renderer(wrap_type=TestValueR2Score) 290 class TestValueR2ScoreRenderer(TestRenderer): 291 def render_html(self, obj: TestValueR2Score) -> TestHtmlInfo: 292 info = super().render_html(obj) 293 result = obj.metric.get_result() 294 295 fig = regression_perf_plot( 296 val_for_plot=result.vals_for_plots.r2_score, 297 hist_for_plot=result.hist_for_plot, 298 name="R2_score", 299 curr_metric=result.current.r2_score, 300 ref_metric=result.reference.r2_score if result.reference is not None else None, 301 color_options=self.color_options, 302 ) 303 info.with_details("R2 Score", plotly_figure(title="", figure=fig)) 304 return info