classification.py
1 from typing import List 2 from typing import Optional 3 from typing import Sequence 4 from typing import Tuple 5 6 from evidently._pydantic_compat import PrivateAttr 7 from evidently.core.container import MetricContainer 8 from evidently.core.container import MetricOrContainer 9 from evidently.core.datasets import BinaryClassification 10 from evidently.core.metric_types import ByLabelMetricTests 11 from evidently.core.metric_types import GenericByLabelMetricTests 12 from evidently.core.metric_types import GenericSingleValueMetricTests 13 from evidently.core.metric_types import Metric 14 from evidently.core.metric_types import MetricId 15 from evidently.core.metric_types import SingleValueMetricTests 16 from evidently.core.metric_types import convert_tests 17 from evidently.core.report import Context 18 from evidently.legacy.metrics import ClassificationConfusionMatrix 19 from evidently.legacy.metrics import ClassificationDummyMetric 20 from evidently.legacy.metrics import ClassificationPRCurve 21 from evidently.legacy.metrics import ClassificationPRTable 22 from evidently.legacy.metrics import ClassificationQualityByClass 23 from evidently.legacy.metrics import ClassificationQualityMetric 24 from evidently.legacy.model.widget import BaseWidgetInfo 25 from evidently.legacy.model.widget import link_metric 26 from evidently.metrics import FNR 27 from evidently.metrics import FPR 28 from evidently.metrics import TNR 29 from evidently.metrics import TPR 30 from evidently.metrics import Accuracy 31 from evidently.metrics import F1ByLabel 32 from evidently.metrics import F1Score 33 from evidently.metrics import LogLoss 34 from evidently.metrics import Precision 35 from evidently.metrics import PrecisionByLabel 36 from evidently.metrics import Recall 37 from evidently.metrics import RecallByLabel 38 from evidently.metrics import RocAuc 39 from evidently.metrics import RocAucByLabel 40 from evidently.metrics.classification import DummyF1Score 41 from evidently.metrics.classification import DummyPrecision 42 from evidently.metrics.classification import DummyRecall 43 from evidently.metrics.classification import _gen_classification_input_data 44 45 46 class ClassificationQuality(MetricContainer): 47 """Small preset summarizing classification quality metrics. 48 49 Generates aggregated classification metrics including accuracy, precision, recall, 50 F1 score, ROC AUC, log loss, and binary classification rates (TPR, TNR, FPR, FNR). 51 Optionally includes visualizations like confusion matrix, PR curve, and PR table. 52 53 """ 54 55 classification_name: str = "default" 56 """Name of the classification task.""" 57 probas_threshold: Optional[float] = None 58 """Optional probability threshold for binary classification.""" 59 conf_matrix: bool = False 60 """Whether to show confusion matrix visualization.""" 61 pr_curve: bool = False 62 """Whether to show precision-recall curve.""" 63 pr_table: bool = False 64 """Whether to show precision-recall table.""" 65 accuracy_tests: SingleValueMetricTests = None 66 """Optional test conditions for accuracy.""" 67 precision_tests: SingleValueMetricTests = None 68 """Optional test conditions for precision.""" 69 recall_tests: SingleValueMetricTests = None 70 """Optional test conditions for recall.""" 71 f1score_tests: SingleValueMetricTests = None 72 """Optional test conditions for F1 score.""" 73 rocauc_tests: SingleValueMetricTests = None 74 """Optional test conditions for ROC AUC.""" 75 logloss_tests: SingleValueMetricTests = None 76 """Optional test conditions for log loss.""" 77 tpr_tests: SingleValueMetricTests = None 78 """Optional test conditions for TPR.""" 79 tnr_tests: SingleValueMetricTests = None 80 """Optional test conditions for TNR.""" 81 fpr_tests: SingleValueMetricTests = None 82 """Optional test conditions for FPR.""" 83 fnr_tests: SingleValueMetricTests = None 84 """Optional test conditions for FNR.""" 85 86 def __init__( 87 self, 88 classification_name: str = "default", 89 probas_threshold: Optional[float] = None, 90 conf_matrix: bool = False, 91 pr_curve: bool = False, 92 pr_table: bool = False, 93 accuracy_tests: GenericSingleValueMetricTests = None, 94 precision_tests: GenericSingleValueMetricTests = None, 95 recall_tests: GenericSingleValueMetricTests = None, 96 f1score_tests: GenericSingleValueMetricTests = None, 97 rocauc_tests: GenericSingleValueMetricTests = None, 98 logloss_tests: GenericSingleValueMetricTests = None, 99 tpr_tests: GenericSingleValueMetricTests = None, 100 tnr_tests: GenericSingleValueMetricTests = None, 101 fpr_tests: GenericSingleValueMetricTests = None, 102 fnr_tests: GenericSingleValueMetricTests = None, 103 include_tests: bool = True, 104 ): 105 self.classification_name = classification_name 106 self.accuracy_tests = convert_tests(accuracy_tests) 107 self.precision_tests = convert_tests(precision_tests) 108 self.recall_tests = convert_tests(recall_tests) 109 self.f1score_tests = convert_tests(f1score_tests) 110 self.rocauc_tests = convert_tests(rocauc_tests) 111 self.logloss_tests = convert_tests(logloss_tests) 112 self.tpr_tests = convert_tests(tpr_tests) 113 self.tnr_tests = convert_tests(tnr_tests) 114 self.fpr_tests = convert_tests(fpr_tests) 115 self.fnr_tests = convert_tests(fnr_tests) 116 self.probas_threshold = probas_threshold 117 self.conf_matrix = conf_matrix 118 self.pr_curve = pr_curve 119 self.pr_table = pr_table 120 super().__init__(include_tests=include_tests) 121 122 def generate_metrics(self, context: "Context") -> Sequence[MetricOrContainer]: 123 classification = context.data_definition.get_classification(self.classification_name) 124 if classification is None: 125 raise ValueError("Classification with name '{}' not found".format(self.classification_name)) 126 127 metrics: List[Metric] 128 129 metrics = [ 130 Accuracy( 131 probas_threshold=self.probas_threshold, 132 classification_name=self.classification_name, 133 tests=self._get_tests(self.accuracy_tests), 134 ), 135 Precision( 136 probas_threshold=self.probas_threshold, 137 classification_name=self.classification_name, 138 tests=self._get_tests(self.precision_tests), 139 ), 140 Recall( 141 probas_threshold=self.probas_threshold, 142 classification_name=self.classification_name, 143 tests=self._get_tests(self.recall_tests), 144 ), 145 F1Score( 146 probas_threshold=self.probas_threshold, 147 classification_name=self.classification_name, 148 tests=self._get_tests(self.f1score_tests), 149 ), 150 ] 151 if classification.prediction_probas is not None: 152 metrics.extend( 153 [ 154 RocAuc( 155 probas_threshold=self.probas_threshold, 156 classification_name=self.classification_name, 157 tests=self._get_tests(self.rocauc_tests), 158 ), 159 LogLoss( 160 probas_threshold=self.probas_threshold, 161 classification_name=self.classification_name, 162 tests=self._get_tests(self.logloss_tests), 163 ), 164 ] 165 ) 166 if isinstance(classification, BinaryClassification): 167 metrics.extend( 168 [ 169 TPR( 170 probas_threshold=self.probas_threshold, 171 classification_name=self.classification_name, 172 tests=self._get_tests(self.tpr_tests), 173 ), 174 TNR( 175 probas_threshold=self.probas_threshold, 176 classification_name=self.classification_name, 177 tests=self._get_tests(self.tnr_tests), 178 ), 179 FPR( 180 probas_threshold=self.probas_threshold, 181 classification_name=self.classification_name, 182 tests=self._get_tests(self.fpr_tests), 183 ), 184 FNR( 185 probas_threshold=self.probas_threshold, 186 classification_name=self.classification_name, 187 tests=self._get_tests(self.fnr_tests), 188 ), 189 ] 190 ) 191 return metrics 192 193 def render( 194 self, 195 context: "Context", 196 child_widgets: Optional[List[Tuple[Optional[MetricId], List[BaseWidgetInfo]]]] = None, 197 ) -> List[BaseWidgetInfo]: 198 _, render = context.get_legacy_metric( 199 ClassificationQualityMetric(probas_threshold=self.probas_threshold), 200 _gen_classification_input_data, 201 self.classification_name, 202 ) 203 if self.conf_matrix: 204 render += context.get_legacy_metric( 205 ClassificationConfusionMatrix(probas_threshold=self.probas_threshold), 206 _gen_classification_input_data, 207 self.classification_name, 208 )[1] 209 classification = context.data_definition.get_classification(self.classification_name) 210 if classification is None: 211 raise ValueError("Cannot use ClassificationQuality without a classification data") 212 if self.pr_curve and classification.prediction_probas is not None: 213 render += context.get_legacy_metric( 214 ClassificationPRCurve(probas_threshold=self.probas_threshold), 215 _gen_classification_input_data, 216 self.classification_name, 217 )[1] 218 if self.pr_table and classification.prediction_probas is not None: 219 render += context.get_legacy_metric( 220 ClassificationPRTable(probas_threshold=self.probas_threshold), 221 _gen_classification_input_data, 222 self.classification_name, 223 )[1] 224 for metric in self.list_metrics(context): 225 link_metric(render, metric) 226 return render 227 228 229 class ClassificationQualityByLabel(MetricContainer): 230 """Small preset summarizing classification quality metrics by label. 231 232 Generates per-class metrics for multiclass classification including F1, precision, 233 recall, and ROC AUC for each label. Useful for understanding per-class performance. 234 235 """ 236 237 probas_threshold: Optional[float] = None 238 """Optional probability threshold for binary classification.""" 239 k: Optional[int] = None 240 """Optional top-k value for multiclass classification.""" 241 f1score_tests: ByLabelMetricTests = None 242 """Optional test conditions for F1 score by label.""" 243 precision_tests: ByLabelMetricTests = None 244 """Optional test conditions for precision by label.""" 245 recall_tests: ByLabelMetricTests = None 246 """Optional test conditions for recall by label.""" 247 rocauc_tests: ByLabelMetricTests = None 248 """Optional test conditions for ROC AUC by label.""" 249 classification_name: str = "default" 250 """Name of the classification task.""" 251 252 def __init__( 253 self, 254 probas_threshold: Optional[float] = None, 255 k: Optional[int] = None, 256 f1score_tests: GenericByLabelMetricTests = None, 257 precision_tests: GenericByLabelMetricTests = None, 258 recall_tests: GenericByLabelMetricTests = None, 259 rocauc_tests: GenericByLabelMetricTests = None, 260 classification_name: str = "default", 261 include_tests: bool = True, 262 ): 263 self.probas_threshold = probas_threshold 264 self.k = k 265 self.f1score_tests = convert_tests(f1score_tests) 266 self.precision_tests = convert_tests(precision_tests) 267 self.recall_tests = convert_tests(recall_tests) 268 self.rocauc_tests = convert_tests(rocauc_tests) 269 self.classification_name = classification_name 270 super().__init__(include_tests=include_tests) 271 272 def generate_metrics(self, context: "Context") -> Sequence[MetricOrContainer]: 273 classification = context.data_definition.get_classification(self.classification_name) 274 if classification is None: 275 raise ValueError("Cannot use ClassificationPreset without a classification configration") 276 return [ 277 F1ByLabel( 278 classification_name=self.classification_name, 279 probas_threshold=self.probas_threshold, 280 k=self.k, 281 tests=self._get_tests(self.f1score_tests), 282 ), 283 PrecisionByLabel( 284 classification_name=self.classification_name, 285 probas_threshold=self.probas_threshold, 286 k=self.k, 287 tests=self._get_tests(self.precision_tests), 288 ), 289 RecallByLabel( 290 classification_name=self.classification_name, 291 probas_threshold=self.probas_threshold, 292 k=self.k, 293 tests=self._get_tests(self.recall_tests), 294 ), 295 ] + ( 296 [] 297 if classification.prediction_probas is None 298 else [ 299 RocAucByLabel( 300 classification_name=self.classification_name, 301 probas_threshold=self.probas_threshold, 302 k=self.k, 303 tests=self._get_tests(self.rocauc_tests), 304 ), 305 ] 306 ) 307 308 def render( 309 self, 310 context: "Context", 311 child_widgets: Optional[List[Tuple[Optional[MetricId], List[BaseWidgetInfo]]]] = None, 312 ) -> List[BaseWidgetInfo]: 313 render = context.get_legacy_metric( 314 ClassificationQualityByClass(self.probas_threshold, self.k), 315 _gen_classification_input_data, 316 self.classification_name, 317 )[1] 318 widget = render 319 widget[0].params["counters"][0]["label"] = "Classification Quality by Label" 320 for metric in self.list_metrics(context): 321 link_metric(widget, metric) 322 return widget 323 324 325 class ClassificationDummyQuality(MetricContainer): 326 """Small preset summarizing quality of a dummy/baseline classification model. 327 328 Generates metrics for a simple heuristic-based baseline model (e.g., always predict 329 the most common class). Useful as a baseline to compare your model against. 330 331 """ 332 333 probas_threshold: Optional[float] = None 334 """Optional probability threshold.""" 335 k: Optional[int] = None 336 """Optional top-k value for multiclass classification.""" 337 classification_name: str = "default" 338 """Name of the classification task.""" 339 340 def __init__( 341 self, 342 probas_threshold: Optional[float] = None, 343 k: Optional[int] = None, 344 include_tests: bool = True, 345 classification_name: str = "default", 346 ): 347 self.probas_threshold = probas_threshold 348 self.k = k 349 self.classification_name = classification_name 350 super().__init__(include_tests=include_tests) 351 352 def generate_metrics(self, context: "Context") -> Sequence[MetricOrContainer]: 353 return [ 354 DummyPrecision(), 355 DummyRecall(), 356 DummyF1Score(), 357 ] 358 359 def render( 360 self, 361 context: "Context", 362 child_widgets: Optional[List[Tuple[Optional[MetricId], List[BaseWidgetInfo]]]] = None, 363 ) -> List[BaseWidgetInfo]: 364 _, widgets = context.get_legacy_metric( 365 ClassificationDummyMetric(self.probas_threshold, self.k), 366 _gen_classification_input_data, 367 self.classification_name, 368 ) 369 for metric in self.list_metrics(context): 370 link_metric(widgets, metric) 371 return widgets 372 373 374 class ClassificationPreset(MetricContainer): 375 """Large preset with comprehensive classification quality metrics and visualizations. 376 377 Combines `ClassificationQuality` and `ClassificationQualityByLabel` to provide 378 a complete classification evaluation including aggregated metrics, per-label metrics, 379 and optional visualizations (confusion matrix, PR curves, ROC curves). 380 381 """ 382 383 probas_threshold: Optional[float] = None 384 """Optional probability threshold for binary classification.""" 385 accuracy_tests: SingleValueMetricTests = None 386 """Optional test conditions for accuracy.""" 387 precision_tests: SingleValueMetricTests = None 388 """Optional test conditions for precision.""" 389 recall_tests: SingleValueMetricTests = None 390 """Optional test conditions for recall.""" 391 f1score_tests: SingleValueMetricTests = None 392 """Optional test conditions for F1 score.""" 393 rocauc_tests: SingleValueMetricTests = None 394 """Optional test conditions for ROC AUC.""" 395 logloss_tests: SingleValueMetricTests = None 396 """Optional test conditions for log loss.""" 397 tpr_tests: SingleValueMetricTests = None 398 """Optional test conditions for TPR.""" 399 tnr_tests: SingleValueMetricTests = None 400 """Optional test conditions for TNR.""" 401 fpr_tests: SingleValueMetricTests = None 402 """Optional test conditions for FPR.""" 403 fnr_tests: SingleValueMetricTests = None 404 """Optional test conditions for FNR.""" 405 f1score_by_label_tests: ByLabelMetricTests = None 406 """Optional test conditions for F1 score by label.""" 407 precision_by_label_tests: ByLabelMetricTests = None 408 """Optional test conditions for precision by label.""" 409 recall_by_label_tests: ByLabelMetricTests = None 410 """Optional test conditions for recall by label.""" 411 rocauc_by_label_tests: ByLabelMetricTests = None 412 """Optional test conditions for ROC AUC by label.""" 413 classification_name: str = "default" 414 """Name of the classification task.""" 415 416 _quality: ClassificationQuality = PrivateAttr() 417 """Internal classification quality preset.""" 418 _quality_by_label: ClassificationQualityByLabel = PrivateAttr() 419 """Internal classification quality by label preset.""" 420 _roc_auc: Optional[RocAuc] = PrivateAttr() 421 """Internal ROC AUC metric.""" 422 423 def __init__( 424 self, 425 probas_threshold: Optional[float] = None, 426 accuracy_tests: GenericSingleValueMetricTests = None, 427 precision_tests: GenericSingleValueMetricTests = None, 428 recall_tests: GenericSingleValueMetricTests = None, 429 f1score_tests: GenericSingleValueMetricTests = None, 430 rocauc_tests: GenericSingleValueMetricTests = None, 431 logloss_tests: GenericSingleValueMetricTests = None, 432 tpr_tests: GenericSingleValueMetricTests = None, 433 tnr_tests: GenericSingleValueMetricTests = None, 434 fpr_tests: GenericSingleValueMetricTests = None, 435 fnr_tests: GenericSingleValueMetricTests = None, 436 f1score_by_label_tests: GenericByLabelMetricTests = None, 437 precision_by_label_tests: GenericByLabelMetricTests = None, 438 recall_by_label_tests: GenericByLabelMetricTests = None, 439 rocauc_by_label_tests: GenericByLabelMetricTests = None, 440 include_tests: bool = True, 441 classification_name: str = "default", 442 ): 443 super().__init__( 444 include_tests=include_tests, 445 probas_threshold=probas_threshold, 446 accuracy_tests=convert_tests(accuracy_tests), 447 precision_tests=convert_tests(precision_tests), 448 recall_tests=convert_tests(recall_tests), 449 f1score_tests=convert_tests(f1score_tests), 450 rocauc_tests=convert_tests(rocauc_tests), 451 logloss_tests=convert_tests(logloss_tests), 452 tpr_tests=convert_tests(tpr_tests), 453 tnr_tests=convert_tests(tnr_tests), 454 fpr_tests=convert_tests(fpr_tests), 455 fnr_tests=convert_tests(fnr_tests), 456 f1score_by_label_tests=convert_tests(f1score_by_label_tests), 457 precision_by_label_tests=convert_tests(precision_by_label_tests), 458 recall_by_label_tests=convert_tests(recall_by_label_tests), 459 rocauc_by_label_tests=convert_tests(rocauc_by_label_tests), 460 classification_name=classification_name, 461 ) 462 self._quality = ClassificationQuality( 463 probas_threshold=probas_threshold, 464 conf_matrix=True, 465 pr_curve=True, 466 pr_table=True, 467 accuracy_tests=accuracy_tests, 468 precision_tests=precision_tests, 469 recall_tests=recall_tests, 470 f1score_tests=f1score_tests, 471 rocauc_tests=rocauc_tests, 472 logloss_tests=logloss_tests, 473 tpr_tests=tpr_tests, 474 tnr_tests=tnr_tests, 475 fpr_tests=fpr_tests, 476 fnr_tests=fnr_tests, 477 include_tests=include_tests, 478 classification_name=classification_name, 479 ) 480 self._quality_by_label = ClassificationQualityByLabel( 481 probas_threshold=probas_threshold, 482 f1score_tests=f1score_by_label_tests, 483 precision_tests=precision_by_label_tests, 484 recall_tests=recall_by_label_tests, 485 rocauc_tests=rocauc_by_label_tests, 486 include_tests=include_tests, 487 classification_name=classification_name, 488 ) 489 self._roc_auc = None 490 491 def generate_metrics(self, context: "Context") -> Sequence[MetricOrContainer]: 492 classification = context.data_definition.get_classification(self.classification_name) 493 if classification is None: 494 raise ValueError("Cannot use ClassificationPreset without a classification configration") 495 quality_metrics = self._quality.metrics(context) 496 self._roc_auc = next((m for m in quality_metrics if isinstance(m, RocAuc)), None) 497 return quality_metrics + self._quality_by_label.metrics(context) 498 499 def render( 500 self, 501 context: "Context", 502 child_widgets: Optional[List[Tuple[Optional[MetricId], List[BaseWidgetInfo]]]] = None, 503 ) -> List[BaseWidgetInfo]: 504 return ( 505 self._quality.render(context) 506 + self._quality_by_label.render(context) 507 + ([] if self._roc_auc is None else context.get_metric_result(self._roc_auc).get_widgets()) 508 )