classification.py
1 import abc 2 from typing import ClassVar 3 from typing import Dict 4 from typing import Generic 5 from typing import List 6 from typing import Optional 7 from typing import Tuple 8 from typing import Type 9 from typing import TypeVar 10 11 from evidently.core.base_types import Label 12 from evidently.core.metric_types import BoundTest 13 from evidently.core.metric_types import ByLabelCalculation 14 from evidently.core.metric_types import ByLabelMetric 15 from evidently.core.metric_types import ByLabelValue 16 from evidently.core.metric_types import SingleValue 17 from evidently.core.metric_types import SingleValueCalculation 18 from evidently.core.metric_types import SingleValueMetric 19 from evidently.core.metric_types import TMetricResult 20 from evidently.core.report import Context 21 from evidently.core.report import _default_input_data_generator 22 from evidently.legacy.base_metric import InputData 23 from evidently.legacy.base_metric import Metric 24 from evidently.legacy.metrics import ClassificationConfusionMatrix 25 from evidently.legacy.metrics import ClassificationDummyMetric 26 from evidently.legacy.metrics import ClassificationLiftCurve 27 from evidently.legacy.metrics import ClassificationLiftTable 28 from evidently.legacy.metrics import ClassificationPRCurve 29 from evidently.legacy.metrics import ClassificationProbDistribution 30 from evidently.legacy.metrics import ClassificationPRTable 31 from evidently.legacy.metrics import ClassificationQualityByClass as _ClassificationQualityByClass 32 from evidently.legacy.metrics import ClassificationRocCurve 33 from evidently.legacy.metrics.classification_performance.classification_dummy_metric import ( 34 ClassificationDummyMetricResults, 35 ) 36 from evidently.legacy.metrics.classification_performance.classification_quality_metric import ( 37 ClassificationQualityMetric, 38 ) 39 from evidently.legacy.metrics.classification_performance.classification_quality_metric import ( 40 ClassificationQualityMetricResult, 41 ) 42 from evidently.legacy.metrics.classification_performance.quality_by_class_metric import ( 43 ClassificationQualityByClassResult, 44 ) 45 from evidently.legacy.model.widget import BaseWidgetInfo 46 from evidently.metrics._legacy import LegacyMetricCalculation 47 from evidently.tests import Reference 48 from evidently.tests import eq 49 from evidently.tests import gt 50 from evidently.tests import lt 51 52 53 class ClassificationQualityByLabel(ByLabelMetric): 54 classification_name: str = "default" 55 probas_threshold: Optional[float] = None 56 k: Optional[int] = None 57 58 59 class ClassificationQualityBase(SingleValueMetric): 60 classification_name: str = "default" 61 probas_threshold: Optional[float] = None 62 k: Optional[int] = None 63 64 65 class ClassificationQuality(ClassificationQualityBase): 66 def _default_tests_with_reference(self, context: Context) -> List[BoundTest]: 67 return [eq(Reference(relative=0.2)).bind_single(self.get_fingerprint())] 68 69 def _get_dummy_value( 70 self, context: Context, dummy_type: Type["DummyClassificationQuality"], **kwargs 71 ) -> SingleValue: 72 return context.calculate_metric( 73 dummy_type(probas_threshold=self.probas_threshold, k=self.k, **kwargs).to_calculation() 74 ) 75 76 77 TByLabelMetric = TypeVar("TByLabelMetric", bound=ClassificationQualityByLabel) 78 TSingleValueMetric = TypeVar("TSingleValueMetric", bound=ClassificationQualityBase) 79 80 81 def _gen_classification_input_data(context: "Context", task_name: Optional[str]) -> InputData: 82 default_input_data = _default_input_data_generator(context, task_name) 83 return default_input_data 84 85 86 class LegacyClassificationQualityByClass( 87 ByLabelCalculation[TByLabelMetric], 88 LegacyMetricCalculation[ 89 ByLabelValue, 90 TByLabelMetric, 91 ClassificationQualityByClassResult, 92 _ClassificationQualityByClass, 93 ], 94 Generic[TByLabelMetric], 95 abc.ABC, 96 ): 97 _legacy_metric = None 98 99 def task_name(self) -> str: 100 return self.metric.classification_name 101 102 def legacy_metric(self) -> _ClassificationQualityByClass: 103 if self._legacy_metric is None: 104 self._legacy_metric = _ClassificationQualityByClass(self.metric.probas_threshold, self.metric.k) 105 return self._legacy_metric 106 107 def calculate_value( 108 self, 109 context: "Context", 110 legacy_result: ClassificationQualityByClassResult, 111 render: List[BaseWidgetInfo], 112 ): 113 raise NotImplementedError() 114 115 def _relabel(self, context: "Context", label: Label) -> Label: 116 classification = context.data_definition.get_classification("default") 117 if classification is None: 118 return label 119 actual_labels = context.get_labels(classification.target, classification.prediction_labels) 120 _label = None 121 for actual_label in actual_labels: 122 if label == actual_label: 123 _label = label 124 break 125 if label == str(actual_label): 126 _label = actual_label 127 break 128 if _label is None: 129 raise ValueError(f"Failed to relabel {label}") 130 labels = classification.labels 131 if labels is not None: 132 return labels[_label] 133 return _label 134 135 def get_additional_widgets(self, context: "Context") -> List[BaseWidgetInfo]: 136 result = [] 137 for field, metric in ADDITIONAL_WIDGET_MAPPING.items(): 138 if hasattr(self.metric, field) and getattr(self.metric, field): 139 _, widgets = context.get_legacy_metric(metric, self._gen_input_data, self.task_name()) 140 result += widgets 141 return result 142 143 144 class F1ByLabel(ClassificationQualityByLabel): 145 """Calculate F1 score separately for each class label in multiclass classification. 146 147 Returns a dictionary mapping each label to its F1 score. Useful for understanding 148 per-class performance in multiclass problems. 149 150 Args: 151 * `classification_name`: Name of the classification task (default: "default"). 152 * `probas_threshold`: Optional probability threshold for binary classification. 153 * `k`: Optional top-k value for multiclass classification. 154 * `tests`: Optional list of test conditions. 155 """ 156 157 pass 158 159 160 class F1ByLabelCalculation(LegacyClassificationQualityByClass[F1ByLabel]): 161 def calculate_value( 162 self, 163 context: "Context", 164 legacy_result: ClassificationQualityByClassResult, 165 render: List[BaseWidgetInfo], 166 ) -> Tuple[ByLabelValue, Optional[ByLabelValue]]: 167 return self.collect_by_label_result( 168 context, 169 lambda x: x.f1, 170 legacy_result.current.metrics, 171 None if legacy_result.reference is None else legacy_result.reference.metrics, 172 ) 173 174 def display_name(self) -> str: 175 return "F1 by Label metric" 176 177 178 class PrecisionByLabel(ClassificationQualityByLabel): 179 """Calculate precision separately for each class label in multiclass classification. 180 181 Returns a dictionary mapping each label to its precision score. Useful for 182 understanding per-class precision in multiclass problems. 183 184 Args: 185 * `classification_name`: Name of the classification task (default: "default"). 186 * `probas_threshold`: Optional probability threshold for binary classification. 187 * `k`: Optional top-k value for multiclass classification. 188 * `tests`: Optional list of test conditions. 189 """ 190 191 pass 192 193 194 class PrecisionByLabelCalculation(LegacyClassificationQualityByClass[PrecisionByLabel]): 195 def calculate_value( 196 self, 197 context: "Context", 198 legacy_result: ClassificationQualityByClassResult, 199 render: List[BaseWidgetInfo], 200 ) -> Tuple[ByLabelValue, Optional[ByLabelValue]]: 201 return self.collect_by_label_result( 202 context, 203 lambda x: x.precision, 204 legacy_result.current.metrics, 205 None if legacy_result.reference is None else legacy_result.reference.metrics, 206 ) 207 208 def display_name(self) -> str: 209 return "Precision by Label metric" 210 211 212 class RecallByLabel(ClassificationQualityByLabel): 213 """Calculate recall separately for each class label in multiclass classification. 214 215 Returns a dictionary mapping each label to its recall score. Useful for 216 understanding per-class recall in multiclass problems. 217 218 Args: 219 * `classification_name`: Name of the classification task (default: "default"). 220 * `probas_threshold`: Optional probability threshold for binary classification. 221 * `k`: Optional top-k value for multiclass classification. 222 * `tests`: Optional list of test conditions. 223 """ 224 225 pass 226 227 228 class RecallByLabelCalculation(LegacyClassificationQualityByClass[RecallByLabel]): 229 def calculate_value( 230 self, 231 context: "Context", 232 legacy_result: ClassificationQualityByClassResult, 233 render: List[BaseWidgetInfo], 234 ) -> Tuple[ByLabelValue, Optional[ByLabelValue]]: 235 return self.collect_by_label_result( 236 context, 237 lambda x: x.recall, 238 legacy_result.current.metrics, 239 None if legacy_result.reference is None else legacy_result.reference.metrics, 240 ) 241 242 def display_name(self) -> str: 243 return "Recall by Label metric" 244 245 246 class RocAucByLabel(ClassificationQualityByLabel): 247 """Calculate ROC AUC separately for each class label in multiclass classification. 248 249 Returns a dictionary mapping each label to its ROC AUC score. Useful for 250 understanding per-class ROC AUC in multiclass problems. 251 252 Args: 253 * `classification_name`: Name of the classification task (default: "default"). 254 * `probas_threshold`: Optional probability threshold for binary classification. 255 * `k`: Optional top-k value for multiclass classification. 256 * `tests`: Optional list of test conditions. 257 """ 258 259 pass 260 261 262 class RocAucByLabelCalculation(LegacyClassificationQualityByClass[RocAucByLabel]): 263 def calculate_value( 264 self, 265 context: "Context", 266 legacy_result: ClassificationQualityByClassResult, 267 render: List[BaseWidgetInfo], 268 ) -> Tuple[ByLabelValue, Optional[ByLabelValue]]: 269 return self.collect_by_label_result( 270 context, 271 lambda x: x.roc_auc if x.roc_auc is not None else 0.0, 272 legacy_result.current.metrics, 273 None if legacy_result.reference is None else legacy_result.reference.metrics, 274 ) 275 276 def display_name(self) -> str: 277 return "ROC AUC by Label metric" 278 279 280 ADDITIONAL_WIDGET_MAPPING: Dict[str, Metric] = { 281 "prob_distribution": ClassificationProbDistribution(), 282 "conf_matrix": ClassificationConfusionMatrix(), 283 "pr_curve": ClassificationPRCurve(), 284 "pr_table": ClassificationPRTable(), 285 "roc_curve": ClassificationRocCurve(), 286 "lift_curve": ClassificationLiftCurve(), 287 "lift_table": ClassificationLiftTable(), 288 } 289 290 291 class LegacyClassificationQuality( 292 SingleValueCalculation[TSingleValueMetric], 293 LegacyMetricCalculation[ 294 SingleValue, 295 TSingleValueMetric, 296 ClassificationQualityMetricResult, 297 ClassificationQualityMetric, 298 ], 299 Generic[TSingleValueMetric], 300 abc.ABC, 301 ): 302 _legacy_metric = None 303 304 def legacy_metric(self) -> ClassificationQualityMetric: 305 if self._legacy_metric is None: 306 self._legacy_metric = ClassificationQualityMetric(self.metric.probas_threshold, self.metric.k) 307 return self._legacy_metric 308 309 @abc.abstractmethod 310 def calculate_value( 311 self, 312 context: "Context", 313 legacy_result: ClassificationQualityMetricResult, 314 render: List[BaseWidgetInfo], 315 ) -> Tuple[SingleValue, Optional[SingleValue]]: 316 raise NotImplementedError() 317 318 def get_additional_widgets(self, context: "Context") -> List[BaseWidgetInfo]: 319 result = [] 320 for field, metric in ADDITIONAL_WIDGET_MAPPING.items(): 321 if hasattr(self.metric, field) and getattr(self.metric, field): 322 _, widgets = context.get_legacy_metric(metric, self._gen_input_data, self.task_name()) 323 result += widgets 324 return result 325 326 327 class F1Score(ClassificationQuality): 328 """Calculate F1 score (harmonic mean of precision and recall). 329 330 F1 score balances precision and recall, providing a single metric for 331 classification performance. Higher values indicate better performance. 332 """ 333 334 conf_matrix: bool = True 335 """Whether to show confusion matrix visualization.""" 336 337 def _default_tests(self, context: Context) -> List[BoundTest]: 338 dummy_value = self._get_dummy_value(context, DummyF1Score) 339 return [gt(dummy_value.value).bind_single(self.get_fingerprint())] 340 341 342 class F1ScoreCalculation(LegacyClassificationQuality[F1Score]): 343 def task_name(self) -> str: 344 return self.metric.classification_name 345 346 def calculate_value( 347 self, 348 context: "Context", 349 legacy_result: ClassificationQualityMetricResult, 350 render: List[BaseWidgetInfo], 351 ) -> Tuple[SingleValue, Optional[SingleValue]]: 352 return ( 353 self.result(legacy_result.current.f1), 354 None if legacy_result.reference is None else self.result(legacy_result.reference.f1), 355 ) 356 357 def display_name(self) -> str: 358 return "F1 score metric" 359 360 361 class Accuracy(ClassificationQuality): 362 """Calculate classification accuracy (proportion of correct predictions). 363 364 Accuracy measures the fraction of predictions that match the true labels. 365 Simple and intuitive, but can be misleading for imbalanced datasets. 366 """ 367 368 def _default_tests(self, context: Context) -> List[BoundTest]: 369 dummy_value = self._get_dummy_value(context, DummyAccuracy) 370 return [gt(dummy_value.value).bind_single(self.get_fingerprint())] 371 372 373 class AccuracyCalculation(LegacyClassificationQuality[Accuracy]): 374 def task_name(self) -> str: 375 return self.metric.classification_name 376 377 def calculate_value( 378 self, 379 context: "Context", 380 legacy_result: ClassificationQualityMetricResult, 381 render: List[BaseWidgetInfo], 382 ) -> Tuple[SingleValue, Optional[SingleValue]]: 383 return ( 384 self.result(legacy_result.current.accuracy), 385 None if legacy_result.reference is None else self.result(legacy_result.reference.accuracy), 386 ) 387 388 def display_name(self) -> str: 389 return "Accuracy metric" 390 391 392 class Precision(ClassificationQuality): 393 """Calculate precision (proportion of positive predictions that are correct). 394 395 Precision measures how many of the predicted positive cases are actually positive. 396 Useful when false positives are costly. 397 398 Note: At least one visualization (`conf_matrix`, `pr_curve`, or `pr_table`) must be enabled. 399 """ 400 401 conf_matrix: bool = True 402 """Whether to show confusion matrix visualization.""" 403 404 pr_curve: bool = False 405 """Whether to show precision-recall curve.""" 406 pr_table: bool = False 407 """Whether to show precision-recall table.""" 408 409 def _default_tests(self, context: Context) -> List[BoundTest]: 410 dummy_value = self._get_dummy_value(context, DummyPrecision) 411 return [gt(dummy_value.value).bind_single(self.get_fingerprint())] 412 413 414 class PrecisionCalculation(LegacyClassificationQuality[Precision]): 415 def task_name(self) -> str: 416 return self.metric.classification_name 417 418 def calculate_value( 419 self, 420 context: "Context", 421 legacy_result: ClassificationQualityMetricResult, 422 render: List[BaseWidgetInfo], 423 ) -> Tuple[SingleValue, Optional[SingleValue]]: 424 return ( 425 self.result(legacy_result.current.precision), 426 None if legacy_result.reference is None else self.result(legacy_result.reference.precision), 427 ) 428 429 def display_name(self) -> str: 430 return "Precision metric" 431 432 433 class Recall(ClassificationQuality): 434 """Calculate recall (proportion of actual positives that are correctly identified). 435 436 Recall measures how many of the actual positive cases are correctly predicted. 437 Useful when false negatives are costly. 438 439 Note: At least one visualization (`conf_matrix`, `pr_curve`, or `pr_table`) must be enabled. 440 """ 441 442 conf_matrix: bool = True 443 """Whether to show confusion matrix visualization.""" 444 445 pr_curve: bool = False 446 """Whether to show precision-recall curve.""" 447 pr_table: bool = False 448 """Whether to show precision-recall table.""" 449 450 def _default_tests(self, context: Context) -> List[BoundTest]: 451 dummy_value = self._get_dummy_value(context, DummyRecall) 452 return [gt(dummy_value.value).bind_single(self.get_fingerprint())] 453 454 455 class RecallCalculation(LegacyClassificationQuality[Recall]): 456 def task_name(self) -> str: 457 return self.metric.classification_name 458 459 def calculate_value( 460 self, 461 context: "Context", 462 legacy_result: ClassificationQualityMetricResult, 463 render: List[BaseWidgetInfo], 464 ) -> Tuple[SingleValue, Optional[SingleValue]]: 465 return ( 466 self.result(legacy_result.current.recall), 467 None if legacy_result.reference is None else self.result(legacy_result.reference.recall), 468 ) 469 470 def display_name(self) -> str: 471 return "Recall metric" 472 473 474 class TPR(ClassificationQuality): 475 """Calculate True Positive Rate (TPR), also known as recall or sensitivity. 476 477 TPR measures the proportion of actual positives correctly identified. 478 Equivalent to recall. Higher values indicate better detection of positive cases. 479 480 Note: `pr_table` visualization must be enabled. 481 """ 482 483 pr_table: bool = False 484 """Whether to show precision-recall table.""" 485 486 def _default_tests(self, context: Context) -> List[BoundTest]: 487 dummy_value = self._get_dummy_value(context, DummyTPR) 488 return [gt(dummy_value.value).bind_single(self.get_fingerprint())] 489 490 491 class TPRCalculation(LegacyClassificationQuality[TPR]): 492 def task_name(self) -> str: 493 return self.metric.classification_name 494 495 def calculate_value( 496 self, 497 context: "Context", 498 legacy_result: ClassificationQualityMetricResult, 499 render: List[BaseWidgetInfo], 500 ) -> Tuple[SingleValue, Optional[SingleValue]]: 501 if legacy_result.current.tpr is None: 502 raise ValueError( 503 "Cannot compute TPR: current TPR value is missing. " 504 "Ensure prediction labels and probabilities are available. " 505 ) 506 return ( 507 self.result(legacy_result.current.tpr), 508 None 509 if legacy_result.reference is None or legacy_result.reference.tpr is None 510 else self.result(legacy_result.reference.tpr), 511 ) 512 513 def display_name(self) -> str: 514 return "TPR metric" 515 516 517 class TNR(ClassificationQuality): 518 """Calculate True Negative Rate (TNR), also known as specificity. 519 520 TNR measures the proportion of actual negatives correctly identified. 521 Higher values indicate better detection of negative cases. 522 523 Note: `pr_table` visualization must be enabled. 524 """ 525 526 pr_table: bool = False 527 """Whether to show precision-recall table.""" 528 529 def _default_tests(self, context: Context) -> List[BoundTest]: 530 dummy_value = self._get_dummy_value(context, DummyTNR) 531 return [gt(dummy_value.value).bind_single(self.get_fingerprint())] 532 533 534 class TNRCalculation(LegacyClassificationQuality[TNR]): 535 def task_name(self) -> str: 536 return self.metric.classification_name 537 538 def calculate_value( 539 self, 540 context: "Context", 541 legacy_result: ClassificationQualityMetricResult, 542 render: List[BaseWidgetInfo], 543 ) -> Tuple[SingleValue, Optional[SingleValue]]: 544 if legacy_result.current.tnr is None: 545 raise ValueError( 546 "Cannot compute TNR: current TNR value is missing. " 547 "Ensure prediction labels and probabilities are available. " 548 ) 549 return ( 550 self.result(legacy_result.current.tnr), 551 None 552 if legacy_result.reference is None or legacy_result.reference.tnr is None 553 else self.result(legacy_result.reference.tnr), 554 ) 555 556 def display_name(self) -> str: 557 return "TNR metric" 558 559 560 class FPR(ClassificationQuality): 561 """Calculate False Positive Rate (FPR). 562 563 FPR measures the proportion of actual negatives incorrectly classified as positive. 564 Lower values are better. FPR = 1 - TNR. 565 566 Note: `pr_table` visualization must be enabled. 567 """ 568 569 pr_table: bool = False 570 """Whether to show precision-recall table.""" 571 572 def _default_tests(self, context: Context) -> List[BoundTest]: 573 dummy_value = self._get_dummy_value(context, DummyFPR) 574 return [lt(dummy_value.value).bind_single(self.get_fingerprint())] 575 576 577 class FPRCalculation(LegacyClassificationQuality[FPR]): 578 def task_name(self) -> str: 579 return self.metric.classification_name 580 581 def calculate_value( 582 self, 583 context: "Context", 584 legacy_result: ClassificationQualityMetricResult, 585 render: List[BaseWidgetInfo], 586 ) -> Tuple[SingleValue, Optional[SingleValue]]: 587 if legacy_result.current.fpr is None: 588 raise ValueError( 589 "Cannot compute FPR: current FPR value is missing. " 590 "Ensure prediction labels and probabilities are available. " 591 ) 592 return ( 593 self.result(legacy_result.current.fpr), 594 None 595 if legacy_result.reference is None or legacy_result.reference.fpr is None 596 else self.result(legacy_result.reference.fpr), 597 ) 598 599 def display_name(self) -> str: 600 return "FPR metric" 601 602 603 class FNR(ClassificationQuality): 604 """Calculate False Negative Rate (FNR). 605 606 FNR measures the proportion of actual positives incorrectly classified as negative. 607 Lower values are better. FNR = 1 - TPR. 608 609 Note: `pr_table` visualization must be enabled. 610 """ 611 612 pr_table: bool = False 613 """Whether to show precision-recall table.""" 614 615 def _default_tests(self, context: Context) -> List[BoundTest]: 616 dummy_value = self._get_dummy_value(context, DummyFNR) 617 return [lt(dummy_value.value).bind_single(self.get_fingerprint())] 618 619 620 class FNRCalculation(LegacyClassificationQuality[FNR]): 621 def task_name(self) -> str: 622 return self.metric.classification_name 623 624 def calculate_value( 625 self, 626 context: "Context", 627 legacy_result: ClassificationQualityMetricResult, 628 render: List[BaseWidgetInfo], 629 ) -> Tuple[SingleValue, Optional[SingleValue]]: 630 if legacy_result.current.fnr is None: 631 raise ValueError( 632 "Cannot compute FNR: current FNR value is missing. " 633 "Ensure prediction labels and probabilities are available. " 634 ) 635 return ( 636 self.result(legacy_result.current.fnr), 637 None 638 if legacy_result.reference is None or legacy_result.reference.fnr is None 639 else self.result(legacy_result.reference.fnr), 640 ) 641 642 def display_name(self) -> str: 643 return "FNR metric" 644 645 646 class RocAuc(ClassificationQuality): 647 """Calculate ROC AUC (Area Under the Receiver Operating Characteristic Curve). 648 649 ROC AUC measures the model's ability to distinguish between classes across 650 all possible thresholds. Values range from 0 to 1, with 0.5 being random and 1.0 perfect. 651 652 Note: At least one visualization (`roc_curve` or `pr_table`) must be enabled. 653 """ 654 655 roc_curve: bool = True 656 """Whether to show ROC curve.""" 657 658 pr_table: bool = False 659 """Whether to show precision-recall table.""" 660 661 def _default_tests(self, context: Context) -> List[BoundTest]: 662 dummy_value = self._get_dummy_value(context, DummyRocAuc) 663 return [gt(dummy_value.value).bind_single(self.get_fingerprint())] 664 665 666 class RocAucCalculation(LegacyClassificationQuality[RocAuc]): 667 def task_name(self) -> str: 668 return self.metric.classification_name 669 670 def calculate_value( 671 self, 672 context: "Context", 673 legacy_result: ClassificationQualityMetricResult, 674 render: List[BaseWidgetInfo], 675 ) -> Tuple[SingleValue, Optional[SingleValue]]: 676 if legacy_result.current.roc_auc is None: 677 raise ValueError( 678 "Cannot compute RocAuc: current RocAuc value is missing. " 679 "Ensure prediction labels and probabilities are available. " 680 ) 681 return ( 682 self.result(legacy_result.current.roc_auc), 683 None 684 if legacy_result.reference is None or legacy_result.reference.roc_auc is None 685 else self.result(legacy_result.reference.roc_auc), 686 ) 687 688 def display_name(self) -> str: 689 return "RocAuc metric" 690 691 692 class LogLoss(ClassificationQuality): 693 """Calculate logarithmic loss (cross-entropy loss). 694 695 Log loss penalizes confident wrong predictions more heavily. Lower values 696 indicate better calibrated probability predictions. Requires probability predictions. 697 698 Note: `pr_table` visualization must be enabled. Requires probability predictions. 699 """ 700 701 pr_table: bool = False 702 """Whether to show precision-recall table.""" 703 704 def _default_tests(self, context: Context) -> List[BoundTest]: 705 dummy_value = self._get_dummy_value(context, DummyLogLoss) 706 return [lt(dummy_value.value).bind_single(self.get_fingerprint())] 707 708 709 class LogLossCalculation(LegacyClassificationQuality[LogLoss]): 710 def task_name(self) -> str: 711 return self.metric.classification_name 712 713 def calculate_value( 714 self, 715 context: "Context", 716 legacy_result: ClassificationQualityMetricResult, 717 render: List[BaseWidgetInfo], 718 ) -> Tuple[SingleValue, Optional[SingleValue]]: 719 if legacy_result.current.log_loss is None: 720 raise ValueError( 721 "Cannot compute LogLoss: current LogLoss value is missing. " 722 "Ensure prediction labels and probabilities are available. " 723 ) 724 return ( 725 self.result(legacy_result.current.log_loss), 726 None 727 if legacy_result.reference is None or legacy_result.reference.log_loss is None 728 else self.result(legacy_result.reference.log_loss), 729 ) 730 731 def display_name(self) -> str: 732 return "LogLoss metric" 733 734 735 class LegacyClassificationDummy( 736 LegacyMetricCalculation[ 737 SingleValue, 738 TSingleValueMetric, 739 ClassificationDummyMetricResults, 740 ClassificationDummyMetric, 741 ], 742 SingleValueCalculation[TSingleValueMetric], 743 Generic[TSingleValueMetric], 744 abc.ABC, 745 ): 746 _legacy_metric = None 747 __legacy_field_name__: ClassVar[str] 748 749 def task_name(self) -> str: 750 return self.metric.classification_name 751 752 def legacy_metric(self) -> ClassificationDummyMetric: 753 if self._legacy_metric is None: 754 self._legacy_metric = ClassificationDummyMetric(self.metric.probas_threshold, self.metric.k) 755 return self._legacy_metric 756 757 def calculate_value( 758 self, 759 context: "Context", 760 legacy_result: ClassificationDummyMetricResults, 761 render: List[BaseWidgetInfo], 762 ) -> TMetricResult: 763 current_value = getattr(legacy_result.dummy, self.__legacy_field_name__) 764 if current_value is None: 765 raise ValueError(f"Failed to calculate {self.display_name()}") 766 if legacy_result.by_reference_dummy is None: 767 return self.result(current_value) 768 reference_value = getattr(legacy_result.by_reference_dummy, self.__legacy_field_name__) 769 return self.result(current_value), self.result(reference_value) 770 771 772 class DummyClassificationQuality(ClassificationQualityBase): 773 def _default_tests_with_reference(self, context: Context) -> List[BoundTest]: 774 return [] 775 776 def _default_tests(self, context: Context) -> List[BoundTest]: 777 return [] 778 779 780 class DummyPrecision(DummyClassificationQuality): 781 """Calculate precision for a dummy/baseline model. 782 783 Computes precision using a simple heuristic-based model (e.g., always predict 784 the most common class). Useful as a baseline to compare your model against. 785 786 Args: 787 * `classification_name`: Name of the classification task (default: "default"). 788 * `probas_threshold`: Optional probability threshold. 789 * `k`: Optional top-k value for multiclass classification. 790 """ 791 792 pass 793 794 795 class DummyPrecisionCalculation(LegacyClassificationDummy[DummyPrecision]): 796 def task_name(self) -> str: 797 return self.metric.classification_name 798 799 __legacy_field_name__ = "precision" 800 801 def display_name(self) -> str: 802 return "Dummy precision metric" 803 804 805 class DummyRecall(DummyClassificationQuality): 806 """Calculate recall for a dummy/baseline model. 807 808 Computes recall using a simple heuristic-based model. Useful as a baseline 809 to compare your model against. 810 811 Args: 812 * `classification_name`: Name of the classification task (default: "default"). 813 * `probas_threshold`: Optional probability threshold. 814 * `k`: Optional top-k value for multiclass classification. 815 """ 816 817 pass 818 819 820 class DummyRecallCalculation(LegacyClassificationDummy[DummyRecall]): 821 __legacy_field_name__ = "recall" 822 823 def display_name(self) -> str: 824 return "Dummy recall metric" 825 826 827 class DummyF1Score(DummyClassificationQuality): 828 """Calculate F1 score for a dummy/baseline model. 829 830 Computes F1 score using a simple heuristic-based model. Useful as a baseline 831 to compare your model against. 832 833 Args: 834 * `classification_name`: Name of the classification task (default: "default"). 835 * `probas_threshold`: Optional probability threshold. 836 * `k`: Optional top-k value for multiclass classification. 837 """ 838 839 pass 840 841 842 class DummyF1ScoreCalculation(LegacyClassificationDummy[DummyF1Score]): 843 __legacy_field_name__ = "f1" 844 845 def task_name(self) -> str: 846 return self.metric.classification_name 847 848 def display_name(self) -> str: 849 return "Dummy F1 score metric" 850 851 852 class DummyAccuracy(DummyClassificationQuality): 853 """Calculate accuracy for a dummy/baseline model. 854 855 Computes accuracy using a simple heuristic-based model (e.g., always predict 856 the most common class). Useful as a baseline to compare your model against. 857 858 Args: 859 * `classification_name`: Name of the classification task (default: "default"). 860 * `probas_threshold`: Optional probability threshold. 861 * `k`: Optional top-k value for multiclass classification. 862 """ 863 864 pass 865 866 867 class DummyAccuracyCalculation(LegacyClassificationDummy[DummyAccuracy]): 868 __legacy_field_name__ = "accuracy" 869 870 def task_name(self) -> str: 871 return self.metric.classification_name 872 873 def display_name(self) -> str: 874 return "Dummy accuracy metric" 875 876 877 class DummyTPR(DummyClassificationQuality): 878 """Calculate True Positive Rate for a dummy/baseline model. 879 880 Computes TPR using a simple heuristic-based model. Useful as a baseline 881 to compare your model against. 882 883 Args: 884 * `classification_name`: Name of the classification task (default: "default"). 885 * `probas_threshold`: Optional probability threshold. 886 * `k`: Optional top-k value for multiclass classification. 887 """ 888 889 pass 890 891 892 class DummyTPRCalculation(LegacyClassificationDummy[DummyTPR]): 893 __legacy_field_name__ = "tpr" 894 895 def task_name(self) -> str: 896 return self.metric.classification_name 897 898 def display_name(self) -> str: 899 return "Dummy TPR metric" 900 901 902 class DummyTNR(DummyClassificationQuality): 903 """Calculate True Negative Rate for a dummy/baseline model. 904 905 Computes TNR using a simple heuristic-based model. Useful as a baseline 906 to compare your model against. 907 908 Args: 909 * `classification_name`: Name of the classification task (default: "default"). 910 * `probas_threshold`: Optional probability threshold. 911 * `k`: Optional top-k value for multiclass classification. 912 """ 913 914 pass 915 916 917 class DummyTNRCalculation(LegacyClassificationDummy[DummyTNR]): 918 __legacy_field_name__ = "tnr" 919 920 def task_name(self) -> str: 921 return self.metric.classification_name 922 923 def display_name(self) -> str: 924 return "Dummy TNR metric" 925 926 927 class DummyFPR(DummyClassificationQuality): 928 """Calculate False Positive Rate for a dummy/baseline model. 929 930 Computes FPR using a simple heuristic-based model. Useful as a baseline 931 to compare your model against. 932 933 Args: 934 * `classification_name`: Name of the classification task (default: "default"). 935 * `probas_threshold`: Optional probability threshold. 936 * `k`: Optional top-k value for multiclass classification. 937 """ 938 939 pass 940 941 942 class DummyFPRCalculation(LegacyClassificationDummy[DummyFPR]): 943 __legacy_field_name__ = "fpr" 944 945 def task_name(self) -> str: 946 return self.metric.classification_name 947 948 def display_name(self) -> str: 949 return "Dummy FPR metric" 950 951 952 class DummyFNR(DummyClassificationQuality): 953 """Calculate False Negative Rate for a dummy/baseline model. 954 955 Computes FNR using a simple heuristic-based model. Useful as a baseline 956 to compare your model against. 957 958 Args: 959 * `classification_name`: Name of the classification task (default: "default"). 960 * `probas_threshold`: Optional probability threshold. 961 * `k`: Optional top-k value for multiclass classification. 962 """ 963 964 pass 965 966 967 class DummyFNRCalculation(LegacyClassificationDummy[DummyFNR]): 968 __legacy_field_name__ = "fnr" 969 970 def task_name(self) -> str: 971 return self.metric.classification_name 972 973 def display_name(self) -> str: 974 return "Dummy FNR metric" 975 976 977 class DummyLogLoss(DummyClassificationQuality): 978 """Calculate logarithmic loss for a dummy/baseline model. 979 980 Computes log loss using a simple heuristic-based model (equals 0.5 for a 981 constant model). Useful as a baseline to compare your model against. 982 983 Args: 984 * `classification_name`: Name of the classification task (default: "default"). 985 * `probas_threshold`: Optional probability threshold. 986 * `k`: Optional top-k value for multiclass classification. 987 """ 988 989 pass 990 991 992 class DummyLogLossCalculation(LegacyClassificationDummy[DummyLogLoss]): 993 __legacy_field_name__ = "log_loss" 994 995 def task_name(self) -> str: 996 return self.metric.classification_name 997 998 def display_name(self) -> str: 999 return "Dummy LogLoss metric" 1000 1001 1002 class DummyRocAuc(DummyClassificationQuality): 1003 """Calculate ROC AUC for a dummy/baseline model. 1004 1005 Computes ROC AUC using a simple heuristic-based model. Useful as a baseline 1006 to compare your model against (typically 0.5 for random). 1007 1008 Args: 1009 * `classification_name`: Name of the classification task (default: "default"). 1010 * `probas_threshold`: Optional probability threshold. 1011 * `k`: Optional top-k value for multiclass classification. 1012 """ 1013 1014 pass 1015 1016 1017 class DummyRocAucCalculation(LegacyClassificationDummy[DummyRocAuc]): 1018 __legacy_field_name__ = "roc_auc" 1019 1020 def task_name(self) -> str: 1021 return self.metric.classification_name 1022 1023 def display_name(self) -> str: 1024 return "Dummy RocAuc metric"