/ src / evidently / presets / classification.py
classification.py
  1  from typing import List
  2  from typing import Optional
  3  from typing import Sequence
  4  from typing import Tuple
  5  
  6  from evidently._pydantic_compat import PrivateAttr
  7  from evidently.core.container import MetricContainer
  8  from evidently.core.container import MetricOrContainer
  9  from evidently.core.datasets import BinaryClassification
 10  from evidently.core.metric_types import ByLabelMetricTests
 11  from evidently.core.metric_types import GenericByLabelMetricTests
 12  from evidently.core.metric_types import GenericSingleValueMetricTests
 13  from evidently.core.metric_types import Metric
 14  from evidently.core.metric_types import MetricId
 15  from evidently.core.metric_types import SingleValueMetricTests
 16  from evidently.core.metric_types import convert_tests
 17  from evidently.core.report import Context
 18  from evidently.legacy.metrics import ClassificationConfusionMatrix
 19  from evidently.legacy.metrics import ClassificationDummyMetric
 20  from evidently.legacy.metrics import ClassificationPRCurve
 21  from evidently.legacy.metrics import ClassificationPRTable
 22  from evidently.legacy.metrics import ClassificationQualityByClass
 23  from evidently.legacy.metrics import ClassificationQualityMetric
 24  from evidently.legacy.model.widget import BaseWidgetInfo
 25  from evidently.legacy.model.widget import link_metric
 26  from evidently.metrics import FNR
 27  from evidently.metrics import FPR
 28  from evidently.metrics import TNR
 29  from evidently.metrics import TPR
 30  from evidently.metrics import Accuracy
 31  from evidently.metrics import F1ByLabel
 32  from evidently.metrics import F1Score
 33  from evidently.metrics import LogLoss
 34  from evidently.metrics import Precision
 35  from evidently.metrics import PrecisionByLabel
 36  from evidently.metrics import Recall
 37  from evidently.metrics import RecallByLabel
 38  from evidently.metrics import RocAuc
 39  from evidently.metrics import RocAucByLabel
 40  from evidently.metrics.classification import DummyF1Score
 41  from evidently.metrics.classification import DummyPrecision
 42  from evidently.metrics.classification import DummyRecall
 43  from evidently.metrics.classification import _gen_classification_input_data
 44  
 45  
 46  class ClassificationQuality(MetricContainer):
 47      """Small preset summarizing classification quality metrics.
 48  
 49      Generates aggregated classification metrics including accuracy, precision, recall,
 50      F1 score, ROC AUC, log loss, and binary classification rates (TPR, TNR, FPR, FNR).
 51      Optionally includes visualizations like confusion matrix, PR curve, and PR table.
 52  
 53      """
 54  
 55      classification_name: str = "default"
 56      """Name of the classification task."""
 57      probas_threshold: Optional[float] = None
 58      """Optional probability threshold for binary classification."""
 59      conf_matrix: bool = False
 60      """Whether to show confusion matrix visualization."""
 61      pr_curve: bool = False
 62      """Whether to show precision-recall curve."""
 63      pr_table: bool = False
 64      """Whether to show precision-recall table."""
 65      accuracy_tests: SingleValueMetricTests = None
 66      """Optional test conditions for accuracy."""
 67      precision_tests: SingleValueMetricTests = None
 68      """Optional test conditions for precision."""
 69      recall_tests: SingleValueMetricTests = None
 70      """Optional test conditions for recall."""
 71      f1score_tests: SingleValueMetricTests = None
 72      """Optional test conditions for F1 score."""
 73      rocauc_tests: SingleValueMetricTests = None
 74      """Optional test conditions for ROC AUC."""
 75      logloss_tests: SingleValueMetricTests = None
 76      """Optional test conditions for log loss."""
 77      tpr_tests: SingleValueMetricTests = None
 78      """Optional test conditions for TPR."""
 79      tnr_tests: SingleValueMetricTests = None
 80      """Optional test conditions for TNR."""
 81      fpr_tests: SingleValueMetricTests = None
 82      """Optional test conditions for FPR."""
 83      fnr_tests: SingleValueMetricTests = None
 84      """Optional test conditions for FNR."""
 85  
 86      def __init__(
 87          self,
 88          classification_name: str = "default",
 89          probas_threshold: Optional[float] = None,
 90          conf_matrix: bool = False,
 91          pr_curve: bool = False,
 92          pr_table: bool = False,
 93          accuracy_tests: GenericSingleValueMetricTests = None,
 94          precision_tests: GenericSingleValueMetricTests = None,
 95          recall_tests: GenericSingleValueMetricTests = None,
 96          f1score_tests: GenericSingleValueMetricTests = None,
 97          rocauc_tests: GenericSingleValueMetricTests = None,
 98          logloss_tests: GenericSingleValueMetricTests = None,
 99          tpr_tests: GenericSingleValueMetricTests = None,
100          tnr_tests: GenericSingleValueMetricTests = None,
101          fpr_tests: GenericSingleValueMetricTests = None,
102          fnr_tests: GenericSingleValueMetricTests = None,
103          include_tests: bool = True,
104      ):
105          self.classification_name = classification_name
106          self.accuracy_tests = convert_tests(accuracy_tests)
107          self.precision_tests = convert_tests(precision_tests)
108          self.recall_tests = convert_tests(recall_tests)
109          self.f1score_tests = convert_tests(f1score_tests)
110          self.rocauc_tests = convert_tests(rocauc_tests)
111          self.logloss_tests = convert_tests(logloss_tests)
112          self.tpr_tests = convert_tests(tpr_tests)
113          self.tnr_tests = convert_tests(tnr_tests)
114          self.fpr_tests = convert_tests(fpr_tests)
115          self.fnr_tests = convert_tests(fnr_tests)
116          self.probas_threshold = probas_threshold
117          self.conf_matrix = conf_matrix
118          self.pr_curve = pr_curve
119          self.pr_table = pr_table
120          super().__init__(include_tests=include_tests)
121  
122      def generate_metrics(self, context: "Context") -> Sequence[MetricOrContainer]:
123          classification = context.data_definition.get_classification(self.classification_name)
124          if classification is None:
125              raise ValueError("Classification with name '{}' not found".format(self.classification_name))
126  
127          metrics: List[Metric]
128  
129          metrics = [
130              Accuracy(
131                  probas_threshold=self.probas_threshold,
132                  classification_name=self.classification_name,
133                  tests=self._get_tests(self.accuracy_tests),
134              ),
135              Precision(
136                  probas_threshold=self.probas_threshold,
137                  classification_name=self.classification_name,
138                  tests=self._get_tests(self.precision_tests),
139              ),
140              Recall(
141                  probas_threshold=self.probas_threshold,
142                  classification_name=self.classification_name,
143                  tests=self._get_tests(self.recall_tests),
144              ),
145              F1Score(
146                  probas_threshold=self.probas_threshold,
147                  classification_name=self.classification_name,
148                  tests=self._get_tests(self.f1score_tests),
149              ),
150          ]
151          if classification.prediction_probas is not None:
152              metrics.extend(
153                  [
154                      RocAuc(
155                          probas_threshold=self.probas_threshold,
156                          classification_name=self.classification_name,
157                          tests=self._get_tests(self.rocauc_tests),
158                      ),
159                      LogLoss(
160                          probas_threshold=self.probas_threshold,
161                          classification_name=self.classification_name,
162                          tests=self._get_tests(self.logloss_tests),
163                      ),
164                  ]
165              )
166          if isinstance(classification, BinaryClassification):
167              metrics.extend(
168                  [
169                      TPR(
170                          probas_threshold=self.probas_threshold,
171                          classification_name=self.classification_name,
172                          tests=self._get_tests(self.tpr_tests),
173                      ),
174                      TNR(
175                          probas_threshold=self.probas_threshold,
176                          classification_name=self.classification_name,
177                          tests=self._get_tests(self.tnr_tests),
178                      ),
179                      FPR(
180                          probas_threshold=self.probas_threshold,
181                          classification_name=self.classification_name,
182                          tests=self._get_tests(self.fpr_tests),
183                      ),
184                      FNR(
185                          probas_threshold=self.probas_threshold,
186                          classification_name=self.classification_name,
187                          tests=self._get_tests(self.fnr_tests),
188                      ),
189                  ]
190              )
191          return metrics
192  
193      def render(
194          self,
195          context: "Context",
196          child_widgets: Optional[List[Tuple[Optional[MetricId], List[BaseWidgetInfo]]]] = None,
197      ) -> List[BaseWidgetInfo]:
198          _, render = context.get_legacy_metric(
199              ClassificationQualityMetric(probas_threshold=self.probas_threshold),
200              _gen_classification_input_data,
201              self.classification_name,
202          )
203          if self.conf_matrix:
204              render += context.get_legacy_metric(
205                  ClassificationConfusionMatrix(probas_threshold=self.probas_threshold),
206                  _gen_classification_input_data,
207                  self.classification_name,
208              )[1]
209          classification = context.data_definition.get_classification(self.classification_name)
210          if classification is None:
211              raise ValueError("Cannot use ClassificationQuality without a classification data")
212          if self.pr_curve and classification.prediction_probas is not None:
213              render += context.get_legacy_metric(
214                  ClassificationPRCurve(probas_threshold=self.probas_threshold),
215                  _gen_classification_input_data,
216                  self.classification_name,
217              )[1]
218          if self.pr_table and classification.prediction_probas is not None:
219              render += context.get_legacy_metric(
220                  ClassificationPRTable(probas_threshold=self.probas_threshold),
221                  _gen_classification_input_data,
222                  self.classification_name,
223              )[1]
224          for metric in self.list_metrics(context):
225              link_metric(render, metric)
226          return render
227  
228  
229  class ClassificationQualityByLabel(MetricContainer):
230      """Small preset summarizing classification quality metrics by label.
231  
232      Generates per-class metrics for multiclass classification including F1, precision,
233      recall, and ROC AUC for each label. Useful for understanding per-class performance.
234  
235      """
236  
237      probas_threshold: Optional[float] = None
238      """Optional probability threshold for binary classification."""
239      k: Optional[int] = None
240      """Optional top-k value for multiclass classification."""
241      f1score_tests: ByLabelMetricTests = None
242      """Optional test conditions for F1 score by label."""
243      precision_tests: ByLabelMetricTests = None
244      """Optional test conditions for precision by label."""
245      recall_tests: ByLabelMetricTests = None
246      """Optional test conditions for recall by label."""
247      rocauc_tests: ByLabelMetricTests = None
248      """Optional test conditions for ROC AUC by label."""
249      classification_name: str = "default"
250      """Name of the classification task."""
251  
252      def __init__(
253          self,
254          probas_threshold: Optional[float] = None,
255          k: Optional[int] = None,
256          f1score_tests: GenericByLabelMetricTests = None,
257          precision_tests: GenericByLabelMetricTests = None,
258          recall_tests: GenericByLabelMetricTests = None,
259          rocauc_tests: GenericByLabelMetricTests = None,
260          classification_name: str = "default",
261          include_tests: bool = True,
262      ):
263          self.probas_threshold = probas_threshold
264          self.k = k
265          self.f1score_tests = convert_tests(f1score_tests)
266          self.precision_tests = convert_tests(precision_tests)
267          self.recall_tests = convert_tests(recall_tests)
268          self.rocauc_tests = convert_tests(rocauc_tests)
269          self.classification_name = classification_name
270          super().__init__(include_tests=include_tests)
271  
272      def generate_metrics(self, context: "Context") -> Sequence[MetricOrContainer]:
273          classification = context.data_definition.get_classification(self.classification_name)
274          if classification is None:
275              raise ValueError("Cannot use ClassificationPreset without a classification configration")
276          return [
277              F1ByLabel(
278                  classification_name=self.classification_name,
279                  probas_threshold=self.probas_threshold,
280                  k=self.k,
281                  tests=self._get_tests(self.f1score_tests),
282              ),
283              PrecisionByLabel(
284                  classification_name=self.classification_name,
285                  probas_threshold=self.probas_threshold,
286                  k=self.k,
287                  tests=self._get_tests(self.precision_tests),
288              ),
289              RecallByLabel(
290                  classification_name=self.classification_name,
291                  probas_threshold=self.probas_threshold,
292                  k=self.k,
293                  tests=self._get_tests(self.recall_tests),
294              ),
295          ] + (
296              []
297              if classification.prediction_probas is None
298              else [
299                  RocAucByLabel(
300                      classification_name=self.classification_name,
301                      probas_threshold=self.probas_threshold,
302                      k=self.k,
303                      tests=self._get_tests(self.rocauc_tests),
304                  ),
305              ]
306          )
307  
308      def render(
309          self,
310          context: "Context",
311          child_widgets: Optional[List[Tuple[Optional[MetricId], List[BaseWidgetInfo]]]] = None,
312      ) -> List[BaseWidgetInfo]:
313          render = context.get_legacy_metric(
314              ClassificationQualityByClass(self.probas_threshold, self.k),
315              _gen_classification_input_data,
316              self.classification_name,
317          )[1]
318          widget = render
319          widget[0].params["counters"][0]["label"] = "Classification Quality by Label"
320          for metric in self.list_metrics(context):
321              link_metric(widget, metric)
322          return widget
323  
324  
325  class ClassificationDummyQuality(MetricContainer):
326      """Small preset summarizing quality of a dummy/baseline classification model.
327  
328      Generates metrics for a simple heuristic-based baseline model (e.g., always predict
329      the most common class). Useful as a baseline to compare your model against.
330  
331      """
332  
333      probas_threshold: Optional[float] = None
334      """Optional probability threshold."""
335      k: Optional[int] = None
336      """Optional top-k value for multiclass classification."""
337      classification_name: str = "default"
338      """Name of the classification task."""
339  
340      def __init__(
341          self,
342          probas_threshold: Optional[float] = None,
343          k: Optional[int] = None,
344          include_tests: bool = True,
345          classification_name: str = "default",
346      ):
347          self.probas_threshold = probas_threshold
348          self.k = k
349          self.classification_name = classification_name
350          super().__init__(include_tests=include_tests)
351  
352      def generate_metrics(self, context: "Context") -> Sequence[MetricOrContainer]:
353          return [
354              DummyPrecision(),
355              DummyRecall(),
356              DummyF1Score(),
357          ]
358  
359      def render(
360          self,
361          context: "Context",
362          child_widgets: Optional[List[Tuple[Optional[MetricId], List[BaseWidgetInfo]]]] = None,
363      ) -> List[BaseWidgetInfo]:
364          _, widgets = context.get_legacy_metric(
365              ClassificationDummyMetric(self.probas_threshold, self.k),
366              _gen_classification_input_data,
367              self.classification_name,
368          )
369          for metric in self.list_metrics(context):
370              link_metric(widgets, metric)
371          return widgets
372  
373  
374  class ClassificationPreset(MetricContainer):
375      """Large preset with comprehensive classification quality metrics and visualizations.
376  
377      Combines `ClassificationQuality` and `ClassificationQualityByLabel` to provide
378      a complete classification evaluation including aggregated metrics, per-label metrics,
379      and optional visualizations (confusion matrix, PR curves, ROC curves).
380  
381      """
382  
383      probas_threshold: Optional[float] = None
384      """Optional probability threshold for binary classification."""
385      accuracy_tests: SingleValueMetricTests = None
386      """Optional test conditions for accuracy."""
387      precision_tests: SingleValueMetricTests = None
388      """Optional test conditions for precision."""
389      recall_tests: SingleValueMetricTests = None
390      """Optional test conditions for recall."""
391      f1score_tests: SingleValueMetricTests = None
392      """Optional test conditions for F1 score."""
393      rocauc_tests: SingleValueMetricTests = None
394      """Optional test conditions for ROC AUC."""
395      logloss_tests: SingleValueMetricTests = None
396      """Optional test conditions for log loss."""
397      tpr_tests: SingleValueMetricTests = None
398      """Optional test conditions for TPR."""
399      tnr_tests: SingleValueMetricTests = None
400      """Optional test conditions for TNR."""
401      fpr_tests: SingleValueMetricTests = None
402      """Optional test conditions for FPR."""
403      fnr_tests: SingleValueMetricTests = None
404      """Optional test conditions for FNR."""
405      f1score_by_label_tests: ByLabelMetricTests = None
406      """Optional test conditions for F1 score by label."""
407      precision_by_label_tests: ByLabelMetricTests = None
408      """Optional test conditions for precision by label."""
409      recall_by_label_tests: ByLabelMetricTests = None
410      """Optional test conditions for recall by label."""
411      rocauc_by_label_tests: ByLabelMetricTests = None
412      """Optional test conditions for ROC AUC by label."""
413      classification_name: str = "default"
414      """Name of the classification task."""
415  
416      _quality: ClassificationQuality = PrivateAttr()
417      """Internal classification quality preset."""
418      _quality_by_label: ClassificationQualityByLabel = PrivateAttr()
419      """Internal classification quality by label preset."""
420      _roc_auc: Optional[RocAuc] = PrivateAttr()
421      """Internal ROC AUC metric."""
422  
423      def __init__(
424          self,
425          probas_threshold: Optional[float] = None,
426          accuracy_tests: GenericSingleValueMetricTests = None,
427          precision_tests: GenericSingleValueMetricTests = None,
428          recall_tests: GenericSingleValueMetricTests = None,
429          f1score_tests: GenericSingleValueMetricTests = None,
430          rocauc_tests: GenericSingleValueMetricTests = None,
431          logloss_tests: GenericSingleValueMetricTests = None,
432          tpr_tests: GenericSingleValueMetricTests = None,
433          tnr_tests: GenericSingleValueMetricTests = None,
434          fpr_tests: GenericSingleValueMetricTests = None,
435          fnr_tests: GenericSingleValueMetricTests = None,
436          f1score_by_label_tests: GenericByLabelMetricTests = None,
437          precision_by_label_tests: GenericByLabelMetricTests = None,
438          recall_by_label_tests: GenericByLabelMetricTests = None,
439          rocauc_by_label_tests: GenericByLabelMetricTests = None,
440          include_tests: bool = True,
441          classification_name: str = "default",
442      ):
443          super().__init__(
444              include_tests=include_tests,
445              probas_threshold=probas_threshold,
446              accuracy_tests=convert_tests(accuracy_tests),
447              precision_tests=convert_tests(precision_tests),
448              recall_tests=convert_tests(recall_tests),
449              f1score_tests=convert_tests(f1score_tests),
450              rocauc_tests=convert_tests(rocauc_tests),
451              logloss_tests=convert_tests(logloss_tests),
452              tpr_tests=convert_tests(tpr_tests),
453              tnr_tests=convert_tests(tnr_tests),
454              fpr_tests=convert_tests(fpr_tests),
455              fnr_tests=convert_tests(fnr_tests),
456              f1score_by_label_tests=convert_tests(f1score_by_label_tests),
457              precision_by_label_tests=convert_tests(precision_by_label_tests),
458              recall_by_label_tests=convert_tests(recall_by_label_tests),
459              rocauc_by_label_tests=convert_tests(rocauc_by_label_tests),
460              classification_name=classification_name,
461          )
462          self._quality = ClassificationQuality(
463              probas_threshold=probas_threshold,
464              conf_matrix=True,
465              pr_curve=True,
466              pr_table=True,
467              accuracy_tests=accuracy_tests,
468              precision_tests=precision_tests,
469              recall_tests=recall_tests,
470              f1score_tests=f1score_tests,
471              rocauc_tests=rocauc_tests,
472              logloss_tests=logloss_tests,
473              tpr_tests=tpr_tests,
474              tnr_tests=tnr_tests,
475              fpr_tests=fpr_tests,
476              fnr_tests=fnr_tests,
477              include_tests=include_tests,
478              classification_name=classification_name,
479          )
480          self._quality_by_label = ClassificationQualityByLabel(
481              probas_threshold=probas_threshold,
482              f1score_tests=f1score_by_label_tests,
483              precision_tests=precision_by_label_tests,
484              recall_tests=recall_by_label_tests,
485              rocauc_tests=rocauc_by_label_tests,
486              include_tests=include_tests,
487              classification_name=classification_name,
488          )
489          self._roc_auc = None
490  
491      def generate_metrics(self, context: "Context") -> Sequence[MetricOrContainer]:
492          classification = context.data_definition.get_classification(self.classification_name)
493          if classification is None:
494              raise ValueError("Cannot use ClassificationPreset without a classification configration")
495          quality_metrics = self._quality.metrics(context)
496          self._roc_auc = next((m for m in quality_metrics if isinstance(m, RocAuc)), None)
497          return quality_metrics + self._quality_by_label.metrics(context)
498  
499      def render(
500          self,
501          context: "Context",
502          child_widgets: Optional[List[Tuple[Optional[MetricId], List[BaseWidgetInfo]]]] = None,
503      ) -> List[BaseWidgetInfo]:
504          return (
505              self._quality.render(context)
506              + self._quality_by_label.render(context)
507              + ([] if self._roc_auc is None else context.get_metric_result(self._roc_auc).get_widgets())
508          )