/ src / evidently / metrics / classification.py
classification.py
   1  import abc
   2  from typing import ClassVar
   3  from typing import Dict
   4  from typing import Generic
   5  from typing import List
   6  from typing import Optional
   7  from typing import Tuple
   8  from typing import Type
   9  from typing import TypeVar
  10  
  11  from evidently.core.base_types import Label
  12  from evidently.core.metric_types import BoundTest
  13  from evidently.core.metric_types import ByLabelCalculation
  14  from evidently.core.metric_types import ByLabelMetric
  15  from evidently.core.metric_types import ByLabelValue
  16  from evidently.core.metric_types import SingleValue
  17  from evidently.core.metric_types import SingleValueCalculation
  18  from evidently.core.metric_types import SingleValueMetric
  19  from evidently.core.metric_types import TMetricResult
  20  from evidently.core.report import Context
  21  from evidently.core.report import _default_input_data_generator
  22  from evidently.legacy.base_metric import InputData
  23  from evidently.legacy.base_metric import Metric
  24  from evidently.legacy.metrics import ClassificationConfusionMatrix
  25  from evidently.legacy.metrics import ClassificationDummyMetric
  26  from evidently.legacy.metrics import ClassificationLiftCurve
  27  from evidently.legacy.metrics import ClassificationLiftTable
  28  from evidently.legacy.metrics import ClassificationPRCurve
  29  from evidently.legacy.metrics import ClassificationProbDistribution
  30  from evidently.legacy.metrics import ClassificationPRTable
  31  from evidently.legacy.metrics import ClassificationQualityByClass as _ClassificationQualityByClass
  32  from evidently.legacy.metrics import ClassificationRocCurve
  33  from evidently.legacy.metrics.classification_performance.classification_dummy_metric import (
  34      ClassificationDummyMetricResults,
  35  )
  36  from evidently.legacy.metrics.classification_performance.classification_quality_metric import (
  37      ClassificationQualityMetric,
  38  )
  39  from evidently.legacy.metrics.classification_performance.classification_quality_metric import (
  40      ClassificationQualityMetricResult,
  41  )
  42  from evidently.legacy.metrics.classification_performance.quality_by_class_metric import (
  43      ClassificationQualityByClassResult,
  44  )
  45  from evidently.legacy.model.widget import BaseWidgetInfo
  46  from evidently.metrics._legacy import LegacyMetricCalculation
  47  from evidently.tests import Reference
  48  from evidently.tests import eq
  49  from evidently.tests import gt
  50  from evidently.tests import lt
  51  
  52  
  53  class ClassificationQualityByLabel(ByLabelMetric):
  54      classification_name: str = "default"
  55      probas_threshold: Optional[float] = None
  56      k: Optional[int] = None
  57  
  58  
  59  class ClassificationQualityBase(SingleValueMetric):
  60      classification_name: str = "default"
  61      probas_threshold: Optional[float] = None
  62      k: Optional[int] = None
  63  
  64  
  65  class ClassificationQuality(ClassificationQualityBase):
  66      def _default_tests_with_reference(self, context: Context) -> List[BoundTest]:
  67          return [eq(Reference(relative=0.2)).bind_single(self.get_fingerprint())]
  68  
  69      def _get_dummy_value(
  70          self, context: Context, dummy_type: Type["DummyClassificationQuality"], **kwargs
  71      ) -> SingleValue:
  72          return context.calculate_metric(
  73              dummy_type(probas_threshold=self.probas_threshold, k=self.k, **kwargs).to_calculation()
  74          )
  75  
  76  
  77  TByLabelMetric = TypeVar("TByLabelMetric", bound=ClassificationQualityByLabel)
  78  TSingleValueMetric = TypeVar("TSingleValueMetric", bound=ClassificationQualityBase)
  79  
  80  
  81  def _gen_classification_input_data(context: "Context", task_name: Optional[str]) -> InputData:
  82      default_input_data = _default_input_data_generator(context, task_name)
  83      return default_input_data
  84  
  85  
  86  class LegacyClassificationQualityByClass(
  87      ByLabelCalculation[TByLabelMetric],
  88      LegacyMetricCalculation[
  89          ByLabelValue,
  90          TByLabelMetric,
  91          ClassificationQualityByClassResult,
  92          _ClassificationQualityByClass,
  93      ],
  94      Generic[TByLabelMetric],
  95      abc.ABC,
  96  ):
  97      _legacy_metric = None
  98  
  99      def task_name(self) -> str:
 100          return self.metric.classification_name
 101  
 102      def legacy_metric(self) -> _ClassificationQualityByClass:
 103          if self._legacy_metric is None:
 104              self._legacy_metric = _ClassificationQualityByClass(self.metric.probas_threshold, self.metric.k)
 105          return self._legacy_metric
 106  
 107      def calculate_value(
 108          self,
 109          context: "Context",
 110          legacy_result: ClassificationQualityByClassResult,
 111          render: List[BaseWidgetInfo],
 112      ):
 113          raise NotImplementedError()
 114  
 115      def _relabel(self, context: "Context", label: Label) -> Label:
 116          classification = context.data_definition.get_classification("default")
 117          if classification is None:
 118              return label
 119          actual_labels = context.get_labels(classification.target, classification.prediction_labels)
 120          _label = None
 121          for actual_label in actual_labels:
 122              if label == actual_label:
 123                  _label = label
 124                  break
 125              if label == str(actual_label):
 126                  _label = actual_label
 127                  break
 128          if _label is None:
 129              raise ValueError(f"Failed to relabel {label}")
 130          labels = classification.labels
 131          if labels is not None:
 132              return labels[_label]
 133          return _label
 134  
 135      def get_additional_widgets(self, context: "Context") -> List[BaseWidgetInfo]:
 136          result = []
 137          for field, metric in ADDITIONAL_WIDGET_MAPPING.items():
 138              if hasattr(self.metric, field) and getattr(self.metric, field):
 139                  _, widgets = context.get_legacy_metric(metric, self._gen_input_data, self.task_name())
 140                  result += widgets
 141          return result
 142  
 143  
 144  class F1ByLabel(ClassificationQualityByLabel):
 145      """Calculate F1 score separately for each class label in multiclass classification.
 146  
 147      Returns a dictionary mapping each label to its F1 score. Useful for understanding
 148      per-class performance in multiclass problems.
 149  
 150      Args:
 151      * `classification_name`: Name of the classification task (default: "default").
 152      * `probas_threshold`: Optional probability threshold for binary classification.
 153      * `k`: Optional top-k value for multiclass classification.
 154      * `tests`: Optional list of test conditions.
 155      """
 156  
 157      pass
 158  
 159  
 160  class F1ByLabelCalculation(LegacyClassificationQualityByClass[F1ByLabel]):
 161      def calculate_value(
 162          self,
 163          context: "Context",
 164          legacy_result: ClassificationQualityByClassResult,
 165          render: List[BaseWidgetInfo],
 166      ) -> Tuple[ByLabelValue, Optional[ByLabelValue]]:
 167          return self.collect_by_label_result(
 168              context,
 169              lambda x: x.f1,
 170              legacy_result.current.metrics,
 171              None if legacy_result.reference is None else legacy_result.reference.metrics,
 172          )
 173  
 174      def display_name(self) -> str:
 175          return "F1 by Label metric"
 176  
 177  
 178  class PrecisionByLabel(ClassificationQualityByLabel):
 179      """Calculate precision separately for each class label in multiclass classification.
 180  
 181      Returns a dictionary mapping each label to its precision score. Useful for
 182      understanding per-class precision in multiclass problems.
 183  
 184      Args:
 185      * `classification_name`: Name of the classification task (default: "default").
 186      * `probas_threshold`: Optional probability threshold for binary classification.
 187      * `k`: Optional top-k value for multiclass classification.
 188      * `tests`: Optional list of test conditions.
 189      """
 190  
 191      pass
 192  
 193  
 194  class PrecisionByLabelCalculation(LegacyClassificationQualityByClass[PrecisionByLabel]):
 195      def calculate_value(
 196          self,
 197          context: "Context",
 198          legacy_result: ClassificationQualityByClassResult,
 199          render: List[BaseWidgetInfo],
 200      ) -> Tuple[ByLabelValue, Optional[ByLabelValue]]:
 201          return self.collect_by_label_result(
 202              context,
 203              lambda x: x.precision,
 204              legacy_result.current.metrics,
 205              None if legacy_result.reference is None else legacy_result.reference.metrics,
 206          )
 207  
 208      def display_name(self) -> str:
 209          return "Precision by Label metric"
 210  
 211  
 212  class RecallByLabel(ClassificationQualityByLabel):
 213      """Calculate recall separately for each class label in multiclass classification.
 214  
 215      Returns a dictionary mapping each label to its recall score. Useful for
 216      understanding per-class recall in multiclass problems.
 217  
 218      Args:
 219      * `classification_name`: Name of the classification task (default: "default").
 220      * `probas_threshold`: Optional probability threshold for binary classification.
 221      * `k`: Optional top-k value for multiclass classification.
 222      * `tests`: Optional list of test conditions.
 223      """
 224  
 225      pass
 226  
 227  
 228  class RecallByLabelCalculation(LegacyClassificationQualityByClass[RecallByLabel]):
 229      def calculate_value(
 230          self,
 231          context: "Context",
 232          legacy_result: ClassificationQualityByClassResult,
 233          render: List[BaseWidgetInfo],
 234      ) -> Tuple[ByLabelValue, Optional[ByLabelValue]]:
 235          return self.collect_by_label_result(
 236              context,
 237              lambda x: x.recall,
 238              legacy_result.current.metrics,
 239              None if legacy_result.reference is None else legacy_result.reference.metrics,
 240          )
 241  
 242      def display_name(self) -> str:
 243          return "Recall by Label metric"
 244  
 245  
 246  class RocAucByLabel(ClassificationQualityByLabel):
 247      """Calculate ROC AUC separately for each class label in multiclass classification.
 248  
 249      Returns a dictionary mapping each label to its ROC AUC score. Useful for
 250      understanding per-class ROC AUC in multiclass problems.
 251  
 252      Args:
 253      * `classification_name`: Name of the classification task (default: "default").
 254      * `probas_threshold`: Optional probability threshold for binary classification.
 255      * `k`: Optional top-k value for multiclass classification.
 256      * `tests`: Optional list of test conditions.
 257      """
 258  
 259      pass
 260  
 261  
 262  class RocAucByLabelCalculation(LegacyClassificationQualityByClass[RocAucByLabel]):
 263      def calculate_value(
 264          self,
 265          context: "Context",
 266          legacy_result: ClassificationQualityByClassResult,
 267          render: List[BaseWidgetInfo],
 268      ) -> Tuple[ByLabelValue, Optional[ByLabelValue]]:
 269          return self.collect_by_label_result(
 270              context,
 271              lambda x: x.roc_auc if x.roc_auc is not None else 0.0,
 272              legacy_result.current.metrics,
 273              None if legacy_result.reference is None else legacy_result.reference.metrics,
 274          )
 275  
 276      def display_name(self) -> str:
 277          return "ROC AUC by Label metric"
 278  
 279  
 280  ADDITIONAL_WIDGET_MAPPING: Dict[str, Metric] = {
 281      "prob_distribution": ClassificationProbDistribution(),
 282      "conf_matrix": ClassificationConfusionMatrix(),
 283      "pr_curve": ClassificationPRCurve(),
 284      "pr_table": ClassificationPRTable(),
 285      "roc_curve": ClassificationRocCurve(),
 286      "lift_curve": ClassificationLiftCurve(),
 287      "lift_table": ClassificationLiftTable(),
 288  }
 289  
 290  
 291  class LegacyClassificationQuality(
 292      SingleValueCalculation[TSingleValueMetric],
 293      LegacyMetricCalculation[
 294          SingleValue,
 295          TSingleValueMetric,
 296          ClassificationQualityMetricResult,
 297          ClassificationQualityMetric,
 298      ],
 299      Generic[TSingleValueMetric],
 300      abc.ABC,
 301  ):
 302      _legacy_metric = None
 303  
 304      def legacy_metric(self) -> ClassificationQualityMetric:
 305          if self._legacy_metric is None:
 306              self._legacy_metric = ClassificationQualityMetric(self.metric.probas_threshold, self.metric.k)
 307          return self._legacy_metric
 308  
 309      @abc.abstractmethod
 310      def calculate_value(
 311          self,
 312          context: "Context",
 313          legacy_result: ClassificationQualityMetricResult,
 314          render: List[BaseWidgetInfo],
 315      ) -> Tuple[SingleValue, Optional[SingleValue]]:
 316          raise NotImplementedError()
 317  
 318      def get_additional_widgets(self, context: "Context") -> List[BaseWidgetInfo]:
 319          result = []
 320          for field, metric in ADDITIONAL_WIDGET_MAPPING.items():
 321              if hasattr(self.metric, field) and getattr(self.metric, field):
 322                  _, widgets = context.get_legacy_metric(metric, self._gen_input_data, self.task_name())
 323                  result += widgets
 324          return result
 325  
 326  
 327  class F1Score(ClassificationQuality):
 328      """Calculate F1 score (harmonic mean of precision and recall).
 329  
 330      F1 score balances precision and recall, providing a single metric for
 331      classification performance. Higher values indicate better performance.
 332      """
 333  
 334      conf_matrix: bool = True
 335      """Whether to show confusion matrix visualization."""
 336  
 337      def _default_tests(self, context: Context) -> List[BoundTest]:
 338          dummy_value = self._get_dummy_value(context, DummyF1Score)
 339          return [gt(dummy_value.value).bind_single(self.get_fingerprint())]
 340  
 341  
 342  class F1ScoreCalculation(LegacyClassificationQuality[F1Score]):
 343      def task_name(self) -> str:
 344          return self.metric.classification_name
 345  
 346      def calculate_value(
 347          self,
 348          context: "Context",
 349          legacy_result: ClassificationQualityMetricResult,
 350          render: List[BaseWidgetInfo],
 351      ) -> Tuple[SingleValue, Optional[SingleValue]]:
 352          return (
 353              self.result(legacy_result.current.f1),
 354              None if legacy_result.reference is None else self.result(legacy_result.reference.f1),
 355          )
 356  
 357      def display_name(self) -> str:
 358          return "F1 score metric"
 359  
 360  
 361  class Accuracy(ClassificationQuality):
 362      """Calculate classification accuracy (proportion of correct predictions).
 363  
 364      Accuracy measures the fraction of predictions that match the true labels.
 365      Simple and intuitive, but can be misleading for imbalanced datasets.
 366      """
 367  
 368      def _default_tests(self, context: Context) -> List[BoundTest]:
 369          dummy_value = self._get_dummy_value(context, DummyAccuracy)
 370          return [gt(dummy_value.value).bind_single(self.get_fingerprint())]
 371  
 372  
 373  class AccuracyCalculation(LegacyClassificationQuality[Accuracy]):
 374      def task_name(self) -> str:
 375          return self.metric.classification_name
 376  
 377      def calculate_value(
 378          self,
 379          context: "Context",
 380          legacy_result: ClassificationQualityMetricResult,
 381          render: List[BaseWidgetInfo],
 382      ) -> Tuple[SingleValue, Optional[SingleValue]]:
 383          return (
 384              self.result(legacy_result.current.accuracy),
 385              None if legacy_result.reference is None else self.result(legacy_result.reference.accuracy),
 386          )
 387  
 388      def display_name(self) -> str:
 389          return "Accuracy metric"
 390  
 391  
 392  class Precision(ClassificationQuality):
 393      """Calculate precision (proportion of positive predictions that are correct).
 394  
 395      Precision measures how many of the predicted positive cases are actually positive.
 396      Useful when false positives are costly.
 397  
 398      Note: At least one visualization (`conf_matrix`, `pr_curve`, or `pr_table`) must be enabled.
 399      """
 400  
 401      conf_matrix: bool = True
 402      """Whether to show confusion matrix visualization."""
 403  
 404      pr_curve: bool = False
 405      """Whether to show precision-recall curve."""
 406      pr_table: bool = False
 407      """Whether to show precision-recall table."""
 408  
 409      def _default_tests(self, context: Context) -> List[BoundTest]:
 410          dummy_value = self._get_dummy_value(context, DummyPrecision)
 411          return [gt(dummy_value.value).bind_single(self.get_fingerprint())]
 412  
 413  
 414  class PrecisionCalculation(LegacyClassificationQuality[Precision]):
 415      def task_name(self) -> str:
 416          return self.metric.classification_name
 417  
 418      def calculate_value(
 419          self,
 420          context: "Context",
 421          legacy_result: ClassificationQualityMetricResult,
 422          render: List[BaseWidgetInfo],
 423      ) -> Tuple[SingleValue, Optional[SingleValue]]:
 424          return (
 425              self.result(legacy_result.current.precision),
 426              None if legacy_result.reference is None else self.result(legacy_result.reference.precision),
 427          )
 428  
 429      def display_name(self) -> str:
 430          return "Precision metric"
 431  
 432  
 433  class Recall(ClassificationQuality):
 434      """Calculate recall (proportion of actual positives that are correctly identified).
 435  
 436      Recall measures how many of the actual positive cases are correctly predicted.
 437      Useful when false negatives are costly.
 438  
 439      Note: At least one visualization (`conf_matrix`, `pr_curve`, or `pr_table`) must be enabled.
 440      """
 441  
 442      conf_matrix: bool = True
 443      """Whether to show confusion matrix visualization."""
 444  
 445      pr_curve: bool = False
 446      """Whether to show precision-recall curve."""
 447      pr_table: bool = False
 448      """Whether to show precision-recall table."""
 449  
 450      def _default_tests(self, context: Context) -> List[BoundTest]:
 451          dummy_value = self._get_dummy_value(context, DummyRecall)
 452          return [gt(dummy_value.value).bind_single(self.get_fingerprint())]
 453  
 454  
 455  class RecallCalculation(LegacyClassificationQuality[Recall]):
 456      def task_name(self) -> str:
 457          return self.metric.classification_name
 458  
 459      def calculate_value(
 460          self,
 461          context: "Context",
 462          legacy_result: ClassificationQualityMetricResult,
 463          render: List[BaseWidgetInfo],
 464      ) -> Tuple[SingleValue, Optional[SingleValue]]:
 465          return (
 466              self.result(legacy_result.current.recall),
 467              None if legacy_result.reference is None else self.result(legacy_result.reference.recall),
 468          )
 469  
 470      def display_name(self) -> str:
 471          return "Recall metric"
 472  
 473  
 474  class TPR(ClassificationQuality):
 475      """Calculate True Positive Rate (TPR), also known as recall or sensitivity.
 476  
 477      TPR measures the proportion of actual positives correctly identified.
 478      Equivalent to recall. Higher values indicate better detection of positive cases.
 479  
 480      Note: `pr_table` visualization must be enabled.
 481      """
 482  
 483      pr_table: bool = False
 484      """Whether to show precision-recall table."""
 485  
 486      def _default_tests(self, context: Context) -> List[BoundTest]:
 487          dummy_value = self._get_dummy_value(context, DummyTPR)
 488          return [gt(dummy_value.value).bind_single(self.get_fingerprint())]
 489  
 490  
 491  class TPRCalculation(LegacyClassificationQuality[TPR]):
 492      def task_name(self) -> str:
 493          return self.metric.classification_name
 494  
 495      def calculate_value(
 496          self,
 497          context: "Context",
 498          legacy_result: ClassificationQualityMetricResult,
 499          render: List[BaseWidgetInfo],
 500      ) -> Tuple[SingleValue, Optional[SingleValue]]:
 501          if legacy_result.current.tpr is None:
 502              raise ValueError(
 503                  "Cannot compute TPR: current TPR value is missing. "
 504                  "Ensure prediction labels and probabilities are available. "
 505              )
 506          return (
 507              self.result(legacy_result.current.tpr),
 508              None
 509              if legacy_result.reference is None or legacy_result.reference.tpr is None
 510              else self.result(legacy_result.reference.tpr),
 511          )
 512  
 513      def display_name(self) -> str:
 514          return "TPR metric"
 515  
 516  
 517  class TNR(ClassificationQuality):
 518      """Calculate True Negative Rate (TNR), also known as specificity.
 519  
 520      TNR measures the proportion of actual negatives correctly identified.
 521      Higher values indicate better detection of negative cases.
 522  
 523      Note: `pr_table` visualization must be enabled.
 524      """
 525  
 526      pr_table: bool = False
 527      """Whether to show precision-recall table."""
 528  
 529      def _default_tests(self, context: Context) -> List[BoundTest]:
 530          dummy_value = self._get_dummy_value(context, DummyTNR)
 531          return [gt(dummy_value.value).bind_single(self.get_fingerprint())]
 532  
 533  
 534  class TNRCalculation(LegacyClassificationQuality[TNR]):
 535      def task_name(self) -> str:
 536          return self.metric.classification_name
 537  
 538      def calculate_value(
 539          self,
 540          context: "Context",
 541          legacy_result: ClassificationQualityMetricResult,
 542          render: List[BaseWidgetInfo],
 543      ) -> Tuple[SingleValue, Optional[SingleValue]]:
 544          if legacy_result.current.tnr is None:
 545              raise ValueError(
 546                  "Cannot compute TNR: current TNR value is missing. "
 547                  "Ensure prediction labels and probabilities are available. "
 548              )
 549          return (
 550              self.result(legacy_result.current.tnr),
 551              None
 552              if legacy_result.reference is None or legacy_result.reference.tnr is None
 553              else self.result(legacy_result.reference.tnr),
 554          )
 555  
 556      def display_name(self) -> str:
 557          return "TNR metric"
 558  
 559  
 560  class FPR(ClassificationQuality):
 561      """Calculate False Positive Rate (FPR).
 562  
 563      FPR measures the proportion of actual negatives incorrectly classified as positive.
 564      Lower values are better. FPR = 1 - TNR.
 565  
 566      Note: `pr_table` visualization must be enabled.
 567      """
 568  
 569      pr_table: bool = False
 570      """Whether to show precision-recall table."""
 571  
 572      def _default_tests(self, context: Context) -> List[BoundTest]:
 573          dummy_value = self._get_dummy_value(context, DummyFPR)
 574          return [lt(dummy_value.value).bind_single(self.get_fingerprint())]
 575  
 576  
 577  class FPRCalculation(LegacyClassificationQuality[FPR]):
 578      def task_name(self) -> str:
 579          return self.metric.classification_name
 580  
 581      def calculate_value(
 582          self,
 583          context: "Context",
 584          legacy_result: ClassificationQualityMetricResult,
 585          render: List[BaseWidgetInfo],
 586      ) -> Tuple[SingleValue, Optional[SingleValue]]:
 587          if legacy_result.current.fpr is None:
 588              raise ValueError(
 589                  "Cannot compute FPR: current FPR value is missing. "
 590                  "Ensure prediction labels and probabilities are available. "
 591              )
 592          return (
 593              self.result(legacy_result.current.fpr),
 594              None
 595              if legacy_result.reference is None or legacy_result.reference.fpr is None
 596              else self.result(legacy_result.reference.fpr),
 597          )
 598  
 599      def display_name(self) -> str:
 600          return "FPR metric"
 601  
 602  
 603  class FNR(ClassificationQuality):
 604      """Calculate False Negative Rate (FNR).
 605  
 606      FNR measures the proportion of actual positives incorrectly classified as negative.
 607      Lower values are better. FNR = 1 - TPR.
 608  
 609      Note: `pr_table` visualization must be enabled.
 610      """
 611  
 612      pr_table: bool = False
 613      """Whether to show precision-recall table."""
 614  
 615      def _default_tests(self, context: Context) -> List[BoundTest]:
 616          dummy_value = self._get_dummy_value(context, DummyFNR)
 617          return [lt(dummy_value.value).bind_single(self.get_fingerprint())]
 618  
 619  
 620  class FNRCalculation(LegacyClassificationQuality[FNR]):
 621      def task_name(self) -> str:
 622          return self.metric.classification_name
 623  
 624      def calculate_value(
 625          self,
 626          context: "Context",
 627          legacy_result: ClassificationQualityMetricResult,
 628          render: List[BaseWidgetInfo],
 629      ) -> Tuple[SingleValue, Optional[SingleValue]]:
 630          if legacy_result.current.fnr is None:
 631              raise ValueError(
 632                  "Cannot compute FNR: current FNR value is missing. "
 633                  "Ensure prediction labels and probabilities are available. "
 634              )
 635          return (
 636              self.result(legacy_result.current.fnr),
 637              None
 638              if legacy_result.reference is None or legacy_result.reference.fnr is None
 639              else self.result(legacy_result.reference.fnr),
 640          )
 641  
 642      def display_name(self) -> str:
 643          return "FNR metric"
 644  
 645  
 646  class RocAuc(ClassificationQuality):
 647      """Calculate ROC AUC (Area Under the Receiver Operating Characteristic Curve).
 648  
 649      ROC AUC measures the model's ability to distinguish between classes across
 650      all possible thresholds. Values range from 0 to 1, with 0.5 being random and 1.0 perfect.
 651  
 652      Note: At least one visualization (`roc_curve` or `pr_table`) must be enabled.
 653      """
 654  
 655      roc_curve: bool = True
 656      """Whether to show ROC curve."""
 657  
 658      pr_table: bool = False
 659      """Whether to show precision-recall table."""
 660  
 661      def _default_tests(self, context: Context) -> List[BoundTest]:
 662          dummy_value = self._get_dummy_value(context, DummyRocAuc)
 663          return [gt(dummy_value.value).bind_single(self.get_fingerprint())]
 664  
 665  
 666  class RocAucCalculation(LegacyClassificationQuality[RocAuc]):
 667      def task_name(self) -> str:
 668          return self.metric.classification_name
 669  
 670      def calculate_value(
 671          self,
 672          context: "Context",
 673          legacy_result: ClassificationQualityMetricResult,
 674          render: List[BaseWidgetInfo],
 675      ) -> Tuple[SingleValue, Optional[SingleValue]]:
 676          if legacy_result.current.roc_auc is None:
 677              raise ValueError(
 678                  "Cannot compute RocAuc: current RocAuc value is missing. "
 679                  "Ensure prediction labels and probabilities are available. "
 680              )
 681          return (
 682              self.result(legacy_result.current.roc_auc),
 683              None
 684              if legacy_result.reference is None or legacy_result.reference.roc_auc is None
 685              else self.result(legacy_result.reference.roc_auc),
 686          )
 687  
 688      def display_name(self) -> str:
 689          return "RocAuc metric"
 690  
 691  
 692  class LogLoss(ClassificationQuality):
 693      """Calculate logarithmic loss (cross-entropy loss).
 694  
 695      Log loss penalizes confident wrong predictions more heavily. Lower values
 696      indicate better calibrated probability predictions. Requires probability predictions.
 697  
 698      Note: `pr_table` visualization must be enabled. Requires probability predictions.
 699      """
 700  
 701      pr_table: bool = False
 702      """Whether to show precision-recall table."""
 703  
 704      def _default_tests(self, context: Context) -> List[BoundTest]:
 705          dummy_value = self._get_dummy_value(context, DummyLogLoss)
 706          return [lt(dummy_value.value).bind_single(self.get_fingerprint())]
 707  
 708  
 709  class LogLossCalculation(LegacyClassificationQuality[LogLoss]):
 710      def task_name(self) -> str:
 711          return self.metric.classification_name
 712  
 713      def calculate_value(
 714          self,
 715          context: "Context",
 716          legacy_result: ClassificationQualityMetricResult,
 717          render: List[BaseWidgetInfo],
 718      ) -> Tuple[SingleValue, Optional[SingleValue]]:
 719          if legacy_result.current.log_loss is None:
 720              raise ValueError(
 721                  "Cannot compute LogLoss: current LogLoss value is missing. "
 722                  "Ensure prediction labels and probabilities are available. "
 723              )
 724          return (
 725              self.result(legacy_result.current.log_loss),
 726              None
 727              if legacy_result.reference is None or legacy_result.reference.log_loss is None
 728              else self.result(legacy_result.reference.log_loss),
 729          )
 730  
 731      def display_name(self) -> str:
 732          return "LogLoss metric"
 733  
 734  
 735  class LegacyClassificationDummy(
 736      LegacyMetricCalculation[
 737          SingleValue,
 738          TSingleValueMetric,
 739          ClassificationDummyMetricResults,
 740          ClassificationDummyMetric,
 741      ],
 742      SingleValueCalculation[TSingleValueMetric],
 743      Generic[TSingleValueMetric],
 744      abc.ABC,
 745  ):
 746      _legacy_metric = None
 747      __legacy_field_name__: ClassVar[str]
 748  
 749      def task_name(self) -> str:
 750          return self.metric.classification_name
 751  
 752      def legacy_metric(self) -> ClassificationDummyMetric:
 753          if self._legacy_metric is None:
 754              self._legacy_metric = ClassificationDummyMetric(self.metric.probas_threshold, self.metric.k)
 755          return self._legacy_metric
 756  
 757      def calculate_value(
 758          self,
 759          context: "Context",
 760          legacy_result: ClassificationDummyMetricResults,
 761          render: List[BaseWidgetInfo],
 762      ) -> TMetricResult:
 763          current_value = getattr(legacy_result.dummy, self.__legacy_field_name__)
 764          if current_value is None:
 765              raise ValueError(f"Failed to calculate {self.display_name()}")
 766          if legacy_result.by_reference_dummy is None:
 767              return self.result(current_value)
 768          reference_value = getattr(legacy_result.by_reference_dummy, self.__legacy_field_name__)
 769          return self.result(current_value), self.result(reference_value)
 770  
 771  
 772  class DummyClassificationQuality(ClassificationQualityBase):
 773      def _default_tests_with_reference(self, context: Context) -> List[BoundTest]:
 774          return []
 775  
 776      def _default_tests(self, context: Context) -> List[BoundTest]:
 777          return []
 778  
 779  
 780  class DummyPrecision(DummyClassificationQuality):
 781      """Calculate precision for a dummy/baseline model.
 782  
 783      Computes precision using a simple heuristic-based model (e.g., always predict
 784      the most common class). Useful as a baseline to compare your model against.
 785  
 786      Args:
 787      * `classification_name`: Name of the classification task (default: "default").
 788      * `probas_threshold`: Optional probability threshold.
 789      * `k`: Optional top-k value for multiclass classification.
 790      """
 791  
 792      pass
 793  
 794  
 795  class DummyPrecisionCalculation(LegacyClassificationDummy[DummyPrecision]):
 796      def task_name(self) -> str:
 797          return self.metric.classification_name
 798  
 799      __legacy_field_name__ = "precision"
 800  
 801      def display_name(self) -> str:
 802          return "Dummy precision metric"
 803  
 804  
 805  class DummyRecall(DummyClassificationQuality):
 806      """Calculate recall for a dummy/baseline model.
 807  
 808      Computes recall using a simple heuristic-based model. Useful as a baseline
 809      to compare your model against.
 810  
 811      Args:
 812      * `classification_name`: Name of the classification task (default: "default").
 813      * `probas_threshold`: Optional probability threshold.
 814      * `k`: Optional top-k value for multiclass classification.
 815      """
 816  
 817      pass
 818  
 819  
 820  class DummyRecallCalculation(LegacyClassificationDummy[DummyRecall]):
 821      __legacy_field_name__ = "recall"
 822  
 823      def display_name(self) -> str:
 824          return "Dummy recall metric"
 825  
 826  
 827  class DummyF1Score(DummyClassificationQuality):
 828      """Calculate F1 score for a dummy/baseline model.
 829  
 830      Computes F1 score using a simple heuristic-based model. Useful as a baseline
 831      to compare your model against.
 832  
 833      Args:
 834      * `classification_name`: Name of the classification task (default: "default").
 835      * `probas_threshold`: Optional probability threshold.
 836      * `k`: Optional top-k value for multiclass classification.
 837      """
 838  
 839      pass
 840  
 841  
 842  class DummyF1ScoreCalculation(LegacyClassificationDummy[DummyF1Score]):
 843      __legacy_field_name__ = "f1"
 844  
 845      def task_name(self) -> str:
 846          return self.metric.classification_name
 847  
 848      def display_name(self) -> str:
 849          return "Dummy F1 score metric"
 850  
 851  
 852  class DummyAccuracy(DummyClassificationQuality):
 853      """Calculate accuracy for a dummy/baseline model.
 854  
 855      Computes accuracy using a simple heuristic-based model (e.g., always predict
 856      the most common class). Useful as a baseline to compare your model against.
 857  
 858      Args:
 859      * `classification_name`: Name of the classification task (default: "default").
 860      * `probas_threshold`: Optional probability threshold.
 861      * `k`: Optional top-k value for multiclass classification.
 862      """
 863  
 864      pass
 865  
 866  
 867  class DummyAccuracyCalculation(LegacyClassificationDummy[DummyAccuracy]):
 868      __legacy_field_name__ = "accuracy"
 869  
 870      def task_name(self) -> str:
 871          return self.metric.classification_name
 872  
 873      def display_name(self) -> str:
 874          return "Dummy accuracy metric"
 875  
 876  
 877  class DummyTPR(DummyClassificationQuality):
 878      """Calculate True Positive Rate for a dummy/baseline model.
 879  
 880      Computes TPR using a simple heuristic-based model. Useful as a baseline
 881      to compare your model against.
 882  
 883      Args:
 884      * `classification_name`: Name of the classification task (default: "default").
 885      * `probas_threshold`: Optional probability threshold.
 886      * `k`: Optional top-k value for multiclass classification.
 887      """
 888  
 889      pass
 890  
 891  
 892  class DummyTPRCalculation(LegacyClassificationDummy[DummyTPR]):
 893      __legacy_field_name__ = "tpr"
 894  
 895      def task_name(self) -> str:
 896          return self.metric.classification_name
 897  
 898      def display_name(self) -> str:
 899          return "Dummy TPR metric"
 900  
 901  
 902  class DummyTNR(DummyClassificationQuality):
 903      """Calculate True Negative Rate for a dummy/baseline model.
 904  
 905      Computes TNR using a simple heuristic-based model. Useful as a baseline
 906      to compare your model against.
 907  
 908      Args:
 909      * `classification_name`: Name of the classification task (default: "default").
 910      * `probas_threshold`: Optional probability threshold.
 911      * `k`: Optional top-k value for multiclass classification.
 912      """
 913  
 914      pass
 915  
 916  
 917  class DummyTNRCalculation(LegacyClassificationDummy[DummyTNR]):
 918      __legacy_field_name__ = "tnr"
 919  
 920      def task_name(self) -> str:
 921          return self.metric.classification_name
 922  
 923      def display_name(self) -> str:
 924          return "Dummy TNR metric"
 925  
 926  
 927  class DummyFPR(DummyClassificationQuality):
 928      """Calculate False Positive Rate for a dummy/baseline model.
 929  
 930      Computes FPR using a simple heuristic-based model. Useful as a baseline
 931      to compare your model against.
 932  
 933      Args:
 934      * `classification_name`: Name of the classification task (default: "default").
 935      * `probas_threshold`: Optional probability threshold.
 936      * `k`: Optional top-k value for multiclass classification.
 937      """
 938  
 939      pass
 940  
 941  
 942  class DummyFPRCalculation(LegacyClassificationDummy[DummyFPR]):
 943      __legacy_field_name__ = "fpr"
 944  
 945      def task_name(self) -> str:
 946          return self.metric.classification_name
 947  
 948      def display_name(self) -> str:
 949          return "Dummy FPR metric"
 950  
 951  
 952  class DummyFNR(DummyClassificationQuality):
 953      """Calculate False Negative Rate for a dummy/baseline model.
 954  
 955      Computes FNR using a simple heuristic-based model. Useful as a baseline
 956      to compare your model against.
 957  
 958      Args:
 959      * `classification_name`: Name of the classification task (default: "default").
 960      * `probas_threshold`: Optional probability threshold.
 961      * `k`: Optional top-k value for multiclass classification.
 962      """
 963  
 964      pass
 965  
 966  
 967  class DummyFNRCalculation(LegacyClassificationDummy[DummyFNR]):
 968      __legacy_field_name__ = "fnr"
 969  
 970      def task_name(self) -> str:
 971          return self.metric.classification_name
 972  
 973      def display_name(self) -> str:
 974          return "Dummy FNR metric"
 975  
 976  
 977  class DummyLogLoss(DummyClassificationQuality):
 978      """Calculate logarithmic loss for a dummy/baseline model.
 979  
 980      Computes log loss using a simple heuristic-based model (equals 0.5 for a
 981      constant model). Useful as a baseline to compare your model against.
 982  
 983      Args:
 984      * `classification_name`: Name of the classification task (default: "default").
 985      * `probas_threshold`: Optional probability threshold.
 986      * `k`: Optional top-k value for multiclass classification.
 987      """
 988  
 989      pass
 990  
 991  
 992  class DummyLogLossCalculation(LegacyClassificationDummy[DummyLogLoss]):
 993      __legacy_field_name__ = "log_loss"
 994  
 995      def task_name(self) -> str:
 996          return self.metric.classification_name
 997  
 998      def display_name(self) -> str:
 999          return "Dummy LogLoss metric"
1000  
1001  
1002  class DummyRocAuc(DummyClassificationQuality):
1003      """Calculate ROC AUC for a dummy/baseline model.
1004  
1005      Computes ROC AUC using a simple heuristic-based model. Useful as a baseline
1006      to compare your model against (typically 0.5 for random).
1007  
1008      Args:
1009      * `classification_name`: Name of the classification task (default: "default").
1010      * `probas_threshold`: Optional probability threshold.
1011      * `k`: Optional top-k value for multiclass classification.
1012      """
1013  
1014      pass
1015  
1016  
1017  class DummyRocAucCalculation(LegacyClassificationDummy[DummyRocAuc]):
1018      __legacy_field_name__ = "roc_auc"
1019  
1020      def task_name(self) -> str:
1021          return self.metric.classification_name
1022  
1023      def display_name(self) -> str:
1024          return "Dummy RocAuc metric"