classification_performance_tests.py
1 import abc 2 from abc import ABC 3 from typing import Any 4 from typing import ClassVar 5 from typing import List 6 from typing import Optional 7 from typing import Union 8 9 from evidently.legacy.metric_results import DatasetClassificationQuality 10 from evidently.legacy.metric_results import Label 11 from evidently.legacy.metric_results import ROCCurve 12 from evidently.legacy.metrics.classification_performance.classification_dummy_metric import ClassificationDummyMetric 13 from evidently.legacy.metrics.classification_performance.classification_quality_metric import ( 14 ClassificationConfusionMatrix, 15 ) 16 from evidently.legacy.metrics.classification_performance.classification_quality_metric import ( 17 ClassificationQualityMetric, 18 ) 19 from evidently.legacy.metrics.classification_performance.classification_quality_metric import ( 20 ClassificationQualityMetricResult, 21 ) 22 from evidently.legacy.metrics.classification_performance.confusion_matrix_metric import ( 23 ClassificationConfusionMatrixParameters, 24 ) 25 from evidently.legacy.metrics.classification_performance.objects import ClassMetric 26 from evidently.legacy.metrics.classification_performance.quality_by_class_metric import ClassificationQualityByClass 27 from evidently.legacy.metrics.classification_performance.roc_curve_metric import ClassificationRocCurve 28 from evidently.legacy.renderers.base_renderer import TestHtmlInfo 29 from evidently.legacy.renderers.base_renderer import TestRenderer 30 from evidently.legacy.renderers.base_renderer import default_renderer 31 from evidently.legacy.renderers.html_widgets import TabData 32 from evidently.legacy.renderers.html_widgets import get_roc_auc_tab_data 33 from evidently.legacy.renderers.html_widgets import plotly_figure 34 from evidently.legacy.renderers.html_widgets import widget_tabs 35 from evidently.legacy.tests.base_test import BaseCheckValueTest 36 from evidently.legacy.tests.base_test import CheckValueParameters 37 from evidently.legacy.tests.base_test import GroupData 38 from evidently.legacy.tests.base_test import GroupingTypes 39 from evidently.legacy.tests.base_test import TestValueCondition 40 from evidently.legacy.tests.utils import approx 41 from evidently.legacy.tests.utils import plot_boxes 42 from evidently.legacy.tests.utils import plot_conf_mtrx 43 from evidently.legacy.tests.utils import plot_rates 44 from evidently.legacy.utils.types import Numeric 45 46 CLASSIFICATION_GROUP = GroupData(id="classification", title="Classification", description="") 47 GroupingTypes.TestGroup.add_value(CLASSIFICATION_GROUP) 48 49 50 class SimpleClassificationTest(BaseCheckValueTest): 51 condition_arg: ClassVar[str] = "gt" 52 53 group: ClassVar = CLASSIFICATION_GROUP.id 54 name: ClassVar[str] 55 _metric: ClassificationQualityMetric 56 _dummy_metric: ClassificationDummyMetric 57 58 def __init__( 59 self, 60 eq: Optional[Numeric] = None, 61 gt: Optional[Numeric] = None, 62 gte: Optional[Numeric] = None, 63 is_in: Optional[List[Union[Numeric, str, bool]]] = None, 64 lt: Optional[Numeric] = None, 65 lte: Optional[Numeric] = None, 66 not_eq: Optional[Numeric] = None, 67 not_in: Optional[List[Union[Numeric, str, bool]]] = None, 68 is_critical: bool = True, 69 ): 70 super().__init__( 71 eq=eq, 72 gt=gt, 73 gte=gte, 74 is_in=is_in, 75 lt=lt, 76 lte=lte, 77 not_eq=not_eq, 78 not_in=not_in, 79 is_critical=is_critical, 80 ) 81 self._metric = ClassificationQualityMetric() 82 self._dummy_metric = ClassificationDummyMetric() 83 84 @property 85 def metric(self): 86 return self._metric 87 88 @property 89 def dummy_metric(self): 90 return self._dummy_metric 91 92 def calculate_value_for_test(self) -> Optional[Any]: 93 return self.get_value(self.metric.get_result().current) 94 95 def get_condition(self) -> TestValueCondition: 96 if self.condition.has_condition(): 97 return self.condition 98 99 ref_metrics = self.metric.get_result().reference 100 101 if ref_metrics is not None: 102 return TestValueCondition(eq=approx(self.get_value(ref_metrics), relative=0.2)) 103 104 if self.get_value(self.dummy_metric.get_result().dummy) is None: 105 raise ValueError("Neither required test parameters nor reference data has been provided.") 106 107 return TestValueCondition(**{self.condition_arg: self.get_value(self.dummy_metric.get_result().dummy)}) 108 109 @abc.abstractmethod 110 def get_value(self, result: DatasetClassificationQuality): 111 raise NotImplementedError() 112 113 114 class SimpleClassificationTestTopK(SimpleClassificationTest, ClassificationConfusionMatrixParameters, ABC): 115 _conf_matrix: ClassificationConfusionMatrix 116 117 def __init__( 118 self, 119 probas_threshold: Optional[float] = None, 120 k: Optional[int] = None, 121 eq: Optional[Numeric] = None, 122 gt: Optional[Numeric] = None, 123 gte: Optional[Numeric] = None, 124 is_in: Optional[List[Union[Numeric, str, bool]]] = None, 125 lt: Optional[Numeric] = None, 126 lte: Optional[Numeric] = None, 127 not_eq: Optional[Numeric] = None, 128 not_in: Optional[List[Union[Numeric, str, bool]]] = None, 129 is_critical: bool = True, 130 ): 131 if k is not None and probas_threshold is not None: 132 raise ValueError("Only one of 'probas_threshold' or 'k' should be given") 133 self.k = k 134 self.probas_threshold = probas_threshold 135 super().__init__( 136 eq=eq, 137 gt=gt, 138 gte=gte, 139 is_in=is_in, 140 lt=lt, 141 lte=lte, 142 not_eq=not_eq, 143 not_in=not_in, 144 is_critical=is_critical, 145 ) 146 self._dummy_metric = ClassificationDummyMetric(probas_threshold=self.probas_threshold, k=self.k) 147 self._metric = ClassificationQualityMetric(probas_threshold=self.probas_threshold, k=self.k) 148 self._conf_matrix = self.confusion_matric_metric() 149 150 def calculate_value_for_test(self) -> Optional[Any]: 151 return self.get_value(self.metric.get_result().current) 152 153 @property 154 def conf_matrix(self): 155 return self._conf_matrix 156 157 158 class TestAccuracyScore(SimpleClassificationTestTopK): 159 class Config: 160 type_alias = "evidently:test:TestAccuracyScore" 161 162 name = "Accuracy Score" 163 164 def get_value(self, result: DatasetClassificationQuality): 165 return result.accuracy 166 167 def get_description(self, value: Numeric) -> str: 168 return f"The Accuracy Score is {value:.3g}. The test threshold is {self.get_condition()}" 169 170 171 @default_renderer(wrap_type=TestAccuracyScore) 172 class TestAccuracyScoreRenderer(TestRenderer): 173 def render_html(self, obj: TestAccuracyScore) -> TestHtmlInfo: 174 info = super().render_html(obj) 175 curr_matrix = obj.conf_matrix.get_result().current_matrix 176 ref_matrix = obj.conf_matrix.get_result().reference_matrix 177 fig = plot_conf_mtrx(curr_matrix, ref_matrix) 178 info.with_details("Accuracy Score", plotly_figure(figure=fig, title="")) 179 return info 180 181 182 class TestPrecisionScore(SimpleClassificationTestTopK): 183 class Config: 184 type_alias = "evidently:test:TestPrecisionScore" 185 186 name = "Precision Score" 187 188 def get_value(self, result: DatasetClassificationQuality): 189 return result.precision 190 191 def get_description(self, value: Numeric) -> str: 192 return f"The Precision Score is {value:.3g}. The test threshold is {self.get_condition()}" 193 194 195 @default_renderer(wrap_type=TestPrecisionScore) 196 class TestPrecisionScoreRenderer(TestRenderer): 197 def render_html(self, obj: TestPrecisionScore) -> TestHtmlInfo: 198 info = super().render_html(obj) 199 curr_matrix = obj.conf_matrix.get_result().current_matrix 200 ref_matrix = obj.conf_matrix.get_result().reference_matrix 201 fig = plot_conf_mtrx(curr_matrix, ref_matrix) 202 info.with_details("Precision Score", plotly_figure(figure=fig, title="")) 203 return info 204 205 206 class TestF1Score(SimpleClassificationTestTopK): 207 class Config: 208 type_alias = "evidently:test:TestF1Score" 209 210 name: ClassVar = "F1 Score" 211 212 def get_value(self, result: DatasetClassificationQuality): 213 return result.f1 214 215 def get_description(self, value: Numeric) -> str: 216 return f"The F1 Score is {value:.3g}. The test threshold is {self.get_condition()}" 217 218 219 @default_renderer(wrap_type=TestF1Score) 220 class TestF1ScoreRenderer(TestRenderer): 221 def render_html(self, obj: TestF1Score) -> TestHtmlInfo: 222 info = super().render_html(obj) 223 curr_matrix = obj.conf_matrix.get_result().current_matrix 224 ref_matrix = obj.conf_matrix.get_result().reference_matrix 225 fig = plot_conf_mtrx(curr_matrix, ref_matrix) 226 info.with_details("F1 Score", plotly_figure(title="", figure=fig)) 227 return info 228 229 230 class TestRecallScore(SimpleClassificationTestTopK): 231 class Config: 232 type_alias = "evidently:test:TestRecallScore" 233 234 name = "Recall Score" 235 236 def get_value(self, result: DatasetClassificationQuality): 237 return result.recall 238 239 def get_description(self, value: Numeric) -> str: 240 return f"The Recall Score is {value:.3g}. The test threshold is {self.get_condition()}" 241 242 243 @default_renderer(wrap_type=TestRecallScore) 244 class TestRecallScoreRenderer(TestRenderer): 245 def render_html(self, obj: TestRecallScore) -> TestHtmlInfo: 246 info = super().render_html(obj) 247 curr_matrix = obj.conf_matrix.get_result().current_matrix 248 ref_matrix = obj.conf_matrix.get_result().reference_matrix 249 fig = plot_conf_mtrx(curr_matrix, ref_matrix) 250 info.with_details("Recall Score", plotly_figure(title="", figure=fig)) 251 return info 252 253 254 class TestRocAuc(SimpleClassificationTest): 255 class Config: 256 type_alias = "evidently:test:TestRocAuc" 257 258 name: ClassVar = "ROC AUC Score" 259 _roc_curve: ClassificationRocCurve 260 261 def __init__( 262 self, 263 eq: Optional[Numeric] = None, 264 gt: Optional[Numeric] = None, 265 gte: Optional[Numeric] = None, 266 is_in: Optional[List[Union[Numeric, str, bool]]] = None, 267 lt: Optional[Numeric] = None, 268 lte: Optional[Numeric] = None, 269 not_eq: Optional[Numeric] = None, 270 not_in: Optional[List[Union[Numeric, str, bool]]] = None, 271 is_critical: bool = True, 272 ): 273 self._roc_curve = ClassificationRocCurve() 274 super().__init__( 275 eq=eq, 276 gt=gt, 277 gte=gte, 278 is_in=is_in, 279 lt=lt, 280 lte=lte, 281 not_eq=not_eq, 282 not_in=not_in, 283 is_critical=is_critical, 284 ) 285 286 def get_value(self, result: DatasetClassificationQuality): 287 return result.roc_auc 288 289 def get_description(self, value: Numeric) -> str: 290 if value is None: 291 return "Not enough data to calculate ROC AUC. Consider providing probabilities instead of labels." 292 293 else: 294 return f"The ROC AUC Score is {value:.3g}. The test threshold is {self.get_condition()}" 295 296 297 @default_renderer(wrap_type=TestRocAuc) 298 class TestRocAucRenderer(TestRenderer): 299 def render_html(self, obj: TestRocAuc) -> TestHtmlInfo: 300 info = super().render_html(obj) 301 curr_roc_curve: Optional[ROCCurve] = obj._roc_curve.get_result().current_roc_curve 302 ref_roc_curve: Optional[ROCCurve] = obj._roc_curve.get_result().reference_roc_curve 303 304 if curr_roc_curve is None: 305 return info 306 307 tab_data = get_roc_auc_tab_data(curr_roc_curve, ref_roc_curve, color_options=self.color_options) 308 309 if len(tab_data) == 1: 310 return info.with_details("ROC Curve", tab_data[0][1]) 311 312 tabs = [TabData(name, widget) for name, widget in tab_data] 313 return info.with_details("", widget_tabs(title="", tabs=tabs)) 314 315 316 class TestLogLoss(SimpleClassificationTest): 317 class Config: 318 type_alias = "evidently:test:TestLogLoss" 319 320 condition_arg = "lt" 321 name = "Logarithmic Loss" 322 323 def get_value(self, result: DatasetClassificationQuality): 324 return result.log_loss 325 326 def get_description(self, value: Numeric) -> str: 327 if value is None: 328 return "Not enough data to calculate Logarithmic Loss. Consider providing probabilities instead of labels." 329 330 else: 331 return f"The Logarithmic Loss is {value:.3g}. The test threshold is {self.get_condition()}" 332 333 334 @default_renderer(wrap_type=TestLogLoss) 335 class TestLogLossRenderer(TestRenderer): 336 def render_html(self, obj: TestLogLoss) -> TestHtmlInfo: 337 info = super().render_html(obj) 338 result: ClassificationQualityMetricResult = obj.metric.get_result() 339 340 curr_metrics = result.current.plot_data 341 ref_metrics = None if result.reference is None else result.reference.plot_data 342 343 if curr_metrics is not None: 344 fig = plot_boxes( 345 curr_for_plots=curr_metrics, 346 ref_for_plots=ref_metrics, 347 color_options=self.color_options, 348 ) 349 info.with_details("Logarithmic Loss", plotly_figure(title="", figure=fig)) 350 351 return info 352 353 354 class TestTPR(SimpleClassificationTestTopK): 355 class Config: 356 type_alias = "evidently:test:TestTPR" 357 358 name = "True Positive Rate" 359 360 def get_value(self, result: DatasetClassificationQuality): 361 return result.tpr 362 363 def get_description(self, value: Numeric) -> str: 364 if value is None: 365 return "This test is applicable only for binary classification" 366 367 return f"The True Positive Rate is {value:.3g}. The test threshold is {self.get_condition()}" 368 369 370 @default_renderer(wrap_type=TestTPR) 371 class TestTPRRenderer(TestRenderer): 372 def render_html(self, obj: TestF1Score) -> TestHtmlInfo: 373 info = super().render_html(obj) 374 curr_metrics = obj.metric.get_result().current 375 ref_metrics = obj.metric.get_result().reference 376 curr_rate_plots_data = curr_metrics.rate_plots_data 377 ref_rate_plots_data = None 378 379 if ref_metrics is not None: 380 ref_rate_plots_data = ref_metrics.rate_plots_data 381 382 if curr_rate_plots_data is not None: 383 fig = plot_rates( 384 curr_rate_plots_data=curr_rate_plots_data, 385 ref_rate_plots_data=ref_rate_plots_data, 386 color_options=self.color_options, 387 ) 388 info.with_details("TPR", plotly_figure(title="", figure=fig)) 389 390 return info 391 392 393 class TestTNR(SimpleClassificationTestTopK): 394 class Config: 395 type_alias = "evidently:test:TestTNR" 396 397 name = "True Negative Rate" 398 399 def get_value(self, result: DatasetClassificationQuality): 400 return result.tnr 401 402 def get_description(self, value: Numeric) -> str: 403 if value is None: 404 return "This test is applicable only for binary classification" 405 406 return f"The True Negative Rate is {value:.3g}. The test threshold is {self.get_condition()}" 407 408 409 @default_renderer(wrap_type=TestTNR) 410 class TestTNRRenderer(TestRenderer): 411 def render_html(self, obj: TestF1Score) -> TestHtmlInfo: 412 info = super().render_html(obj) 413 curr_metrics = obj.metric.get_result().current 414 ref_metrics = obj.metric.get_result().reference 415 curr_rate_plots_data = curr_metrics.rate_plots_data 416 ref_rate_plots_data = None 417 418 if ref_metrics is not None: 419 ref_rate_plots_data = ref_metrics.rate_plots_data 420 421 if curr_rate_plots_data is not None: 422 fig = plot_rates( 423 curr_rate_plots_data=curr_rate_plots_data, 424 ref_rate_plots_data=ref_rate_plots_data, 425 color_options=self.color_options, 426 ) 427 info.with_details("TNR", plotly_figure(title="", figure=fig)) 428 429 return info 430 431 432 class TestFPR(SimpleClassificationTestTopK): 433 class Config: 434 type_alias = "evidently:test:TestFPR" 435 436 condition_arg: ClassVar = "lt" 437 name = "False Positive Rate" 438 439 def get_value(self, result: DatasetClassificationQuality): 440 return result.fpr 441 442 def get_description(self, value: Numeric) -> str: 443 if value is None: 444 return "This test is applicable only for binary classification" 445 446 return f"The False Positive Rate is {value:.3g}. The test threshold is {self.get_condition()}" 447 448 449 @default_renderer(wrap_type=TestFPR) 450 class TestFPRRenderer(TestRenderer): 451 def render_html(self, obj: TestF1Score) -> TestHtmlInfo: 452 info = super().render_html(obj) 453 curr_metrics = obj.metric.get_result().current 454 ref_metrics = obj.metric.get_result().reference 455 curr_rate_plots_data = curr_metrics.rate_plots_data 456 ref_rate_plots_data = None 457 458 if ref_metrics is not None: 459 ref_rate_plots_data = ref_metrics.rate_plots_data 460 461 if curr_rate_plots_data is not None: 462 fig = plot_rates( 463 curr_rate_plots_data=curr_rate_plots_data, 464 ref_rate_plots_data=ref_rate_plots_data, 465 color_options=self.color_options, 466 ) 467 info.with_details("FPR", plotly_figure(title="", figure=fig)) 468 469 return info 470 471 472 class TestFNR(SimpleClassificationTestTopK): 473 class Config: 474 type_alias = "evidently:test:TestFNR" 475 476 condition_arg: ClassVar = "lt" 477 name = "False Negative Rate" 478 479 def get_value(self, result: DatasetClassificationQuality): 480 return result.fnr 481 482 def get_description(self, value: Numeric) -> str: 483 if value is None: 484 return "This test is applicable only for binary classification" 485 486 return f"The False Negative Rate is {value:.3g}. The test threshold is {self.get_condition()}" 487 488 489 @default_renderer(wrap_type=TestFNR) 490 class TestFNRRenderer(TestRenderer): 491 def render_html(self, obj: TestF1Score) -> TestHtmlInfo: 492 info = super().render_html(obj) 493 curr_metrics = obj.metric.get_result().current 494 ref_metrics = obj.metric.get_result().reference 495 curr_rate_plots_data = curr_metrics.rate_plots_data 496 ref_rate_plots_data = None 497 498 if ref_metrics is not None: 499 ref_rate_plots_data = ref_metrics.rate_plots_data 500 501 if curr_rate_plots_data is not None: 502 fig = plot_rates( 503 curr_rate_plots_data=curr_rate_plots_data, 504 ref_rate_plots_data=ref_rate_plots_data, 505 color_options=self.color_options, 506 ) 507 info.with_details("FNR", plotly_figure(title="", figure=fig)) 508 509 return info 510 511 512 class ByClassParameters(CheckValueParameters): 513 class Config: 514 type_alias = "evidently:test_parameters:ByClassParameters" 515 516 label: Label 517 518 519 class ByClassClassificationTest(BaseCheckValueTest, ABC): 520 group: ClassVar = CLASSIFICATION_GROUP.id 521 _metric: ClassificationQualityMetric 522 _by_class_metric: ClassificationQualityByClass 523 _dummy_metric: ClassificationDummyMetric 524 _conf_matrix: ClassificationConfusionMatrix 525 label: Label 526 probas_threshold: Optional[float] = None 527 k: Optional[int] = None 528 529 def __init__( 530 self, 531 label: Label, 532 probas_threshold: Optional[float] = None, 533 k: Optional[int] = None, 534 eq: Optional[Numeric] = None, 535 gt: Optional[Numeric] = None, 536 gte: Optional[Numeric] = None, 537 is_in: Optional[List[Union[Numeric, str, bool]]] = None, 538 lt: Optional[Numeric] = None, 539 lte: Optional[Numeric] = None, 540 not_eq: Optional[Numeric] = None, 541 not_in: Optional[List[Union[Numeric, str, bool]]] = None, 542 is_critical: bool = True, 543 ): 544 if k is not None and probas_threshold is not None: 545 raise ValueError("Only one of 'probas_threshold' or 'k' should be given") 546 547 self.label = label 548 self.probas_threshold = probas_threshold 549 self.k = k 550 super().__init__( 551 eq=eq, 552 gt=gt, 553 gte=gte, 554 is_in=is_in, 555 lt=lt, 556 lte=lte, 557 not_eq=not_eq, 558 not_in=not_in, 559 is_critical=is_critical, 560 ) 561 562 self._metric = ClassificationQualityMetric(probas_threshold=self.probas_threshold, k=self.k) 563 self._dummy_metric = ClassificationDummyMetric(probas_threshold=self.probas_threshold, k=self.k) 564 self._by_class_metric = ClassificationQualityByClass(probas_threshold=self.probas_threshold, k=self.k) 565 self._conf_matrix = ClassificationConfusionMatrix(probas_threshold=self.probas_threshold, k=self.k) 566 567 @property 568 def metric(self): 569 return self._metric 570 571 @property 572 def dummy_metric(self): 573 return self._dummy_metric 574 575 @property 576 def by_class_metric(self): 577 return self._by_class_metric 578 579 @property 580 def conf_matrix(self): 581 return self._conf_matrix 582 583 def calculate_value_for_test(self) -> Optional[Any]: 584 return self.get_value(self.by_class_metric.get_result().current.metrics[str(self.label)]) 585 586 def get_condition(self) -> TestValueCondition: 587 if self.condition.has_condition(): 588 return self.condition 589 590 result = self.by_class_metric.get_result() 591 ref_metrics = result.reference.metrics if result.reference is not None else None 592 593 if ref_metrics is not None: 594 return TestValueCondition(eq=approx(self.get_value(ref_metrics[str(self.label)]), relative=0.2)) 595 596 dummy_result = self.dummy_metric.get_result().metrics_matrix[str(self.label)] 597 598 if self.get_value(dummy_result) is None: 599 raise ValueError("Neither required test parameters nor reference data has been provided.") 600 601 return TestValueCondition(gt=self.get_value(dummy_result)) 602 603 @abc.abstractmethod 604 def get_value(self, result: ClassMetric): 605 raise NotImplementedError() 606 607 def get_parameters(self) -> ByClassParameters: 608 return ByClassParameters(condition=self.get_condition(), value=self._value, label=self.label) 609 610 611 class TestPrecisionByClass(ByClassClassificationTest): 612 class Config: 613 type_alias = "evidently:test:TestPrecisionByClass" 614 615 name: ClassVar[str] = "Precision Score by Class" 616 617 def get_value(self, result: ClassMetric): 618 return result.precision 619 620 def get_description(self, value: Numeric) -> str: 621 return ( 622 f"The precision score of the label **{self.label}** is {value:.3g}. " 623 f"The test threshold is {self.get_condition()}" 624 ) 625 626 627 @default_renderer(wrap_type=TestPrecisionByClass) 628 class TestPrecisionByClassRenderer(TestRenderer): 629 def render_html(self, obj: TestPrecisionByClass) -> TestHtmlInfo: 630 info = super().render_html(obj) 631 curr_matrix = obj.conf_matrix.get_result().current_matrix 632 ref_matrix = obj.conf_matrix.get_result().reference_matrix 633 fig = plot_conf_mtrx(curr_matrix, ref_matrix) 634 info.with_details("Precision by Class", plotly_figure(title="", figure=fig)) 635 return info 636 637 638 class TestRecallByClass(ByClassClassificationTest): 639 class Config: 640 type_alias = "evidently:test:TestRecallByClass" 641 642 name: ClassVar[str] = "Recall Score by Class" 643 644 def get_value(self, result: ClassMetric): 645 return result.recall 646 647 def get_description(self, value: Numeric) -> str: 648 return ( 649 f"The recall score of the label **{self.label}** is {value:.3g}. " 650 f"The test threshold is {self.get_condition()}" 651 ) 652 653 654 @default_renderer(wrap_type=TestRecallByClass) 655 class TestRecallByClassRenderer(TestRenderer): 656 def render_html(self, obj: TestRecallByClass) -> TestHtmlInfo: 657 info = super().render_html(obj) 658 curr_matrix = obj.conf_matrix.get_result().current_matrix 659 ref_matrix = obj.conf_matrix.get_result().reference_matrix 660 fig = plot_conf_mtrx(curr_matrix, ref_matrix) 661 info.with_details("Recall by Class", plotly_figure(title="", figure=fig)) 662 return info 663 664 665 class TestF1ByClass(ByClassClassificationTest): 666 class Config: 667 type_alias = "evidently:test:TestF1ByClass" 668 669 name: ClassVar[str] = "F1 Score by Class" 670 671 def get_value(self, result: ClassMetric): 672 return result.f1 673 674 def get_description(self, value: Numeric) -> str: 675 return ( 676 f"The F1 score of the label **{self.label}** is {value:.3g}. The test threshold is {self.get_condition()}" 677 ) 678 679 680 @default_renderer(wrap_type=TestF1ByClass) 681 class TestF1ByClassRenderer(TestRenderer): 682 def render_html(self, obj: TestF1ByClass) -> TestHtmlInfo: 683 info = super().render_html(obj) 684 curr_matrix = obj.conf_matrix.get_result().current_matrix 685 ref_matrix = obj.conf_matrix.get_result().reference_matrix 686 fig = plot_conf_mtrx(curr_matrix, ref_matrix) 687 info.with_details("F1 by Class", plotly_figure(title="", figure=fig)) 688 return info