test_validation.py
1 import random 2 from unittest import mock 3 4 import pytest 5 6 import mlflow 7 from mlflow.exceptions import MlflowException 8 from mlflow.models.evaluation import ( 9 EvaluationResult, 10 MetricThreshold, 11 ModelEvaluator, 12 evaluate, 13 ) 14 from mlflow.models.evaluation.evaluator_registry import _model_evaluation_registry 15 from mlflow.models.evaluation.validation import ( 16 MetricThresholdClassException, 17 ModelValidationFailedException, 18 _MetricValidationResult, 19 ) 20 21 from tests.evaluate.test_evaluation import ( 22 iris_dataset, # noqa: F401 23 multiclass_logistic_regressor_model_uri, # noqa: F401 24 ) 25 26 message_separator = "\n" 27 28 29 @pytest.fixture 30 def metric_threshold_class_test_spec(request): 31 """ 32 Test specification for MetricThreshold class: 33 34 Returns: 35 A tuple containing the following elements: 36 37 - class_params: A dictionary mapping MetricThreshold class parameter names to values. 38 - expected_failure_message: Expected failure message. 39 """ 40 class_params = { 41 "threshold": 1, 42 "min_absolute_change": 1, 43 "min_relative_change": 0.1, 44 "greater_is_better": True, 45 } 46 47 if request.param == "threshold_is_not_number": 48 class_params["threshold"] = "string" 49 expected_failure_message = "`threshold` parameter must be a number." 50 if request.param == "min_absolute_change_is_not_number": 51 class_params["min_absolute_change"] = "string" 52 expected_failure_message = "`min_absolute_change` parameter must be a positive number." 53 elif request.param == "min_absolute_change_is_not_positive": 54 class_params["min_absolute_change"] = -1 55 expected_failure_message = "`min_absolute_change` parameter must be a positive number." 56 elif request.param == "min_relative_change_is_not_float": 57 class_params["min_relative_change"] = 2 58 expected_failure_message = ( 59 "`min_relative_change` parameter must be a floating point number." 60 ) 61 elif request.param == "min_relative_change_is_not_between_0_and_1": 62 class_params["min_relative_change"] = -0.1 63 expected_failure_message = "`min_relative_change` parameter must be between 0 and 1." 64 elif request.param == "greater_is_better_is_not_defined": 65 class_params["greater_is_better"] = None 66 expected_failure_message = "`greater_is_better` parameter must be defined." 67 elif request.param == "greater_is_better_is_not_bool": 68 class_params["greater_is_better"] = 1 69 expected_failure_message = "`greater_is_better` parameter must be a boolean." 70 elif request.param == "no_threshold": 71 class_params["threshold"] = None 72 class_params["min_absolute_change"] = None 73 class_params["min_relative_change"] = None 74 expected_failure_message = "no threshold was specified." 75 76 return (class_params, expected_failure_message) 77 78 79 @pytest.mark.parametrize( 80 "metric_threshold_class_test_spec", 81 [ 82 ("threshold_is_not_number"), 83 ("min_absolute_change_is_not_number"), 84 ("min_absolute_change_is_not_positive"), 85 ("min_relative_change_is_not_float"), 86 ("min_relative_change_is_not_between_0_and_1"), 87 ("greater_is_better_is_not_defined"), 88 ("greater_is_better_is_not_bool"), 89 ("no_threshold"), 90 ], 91 indirect=["metric_threshold_class_test_spec"], 92 ) 93 def test_metric_threshold_class_should_fail(metric_threshold_class_test_spec): 94 class_params, expected_failure_message = metric_threshold_class_test_spec 95 with pytest.raises( 96 MetricThresholdClassException, 97 match=expected_failure_message, 98 ): 99 MetricThreshold( 100 threshold=class_params["threshold"], 101 min_absolute_change=class_params["min_absolute_change"], 102 min_relative_change=class_params["min_relative_change"], 103 greater_is_better=class_params["greater_is_better"], 104 ) 105 106 107 @pytest.fixture 108 def faulty_baseline_model_param_test_spec(request): 109 """ 110 Test specification for faulty `baseline_model` parameter tests: 111 112 Returns: 113 A dict containing the following elements: 114 115 - validation_thresholds: A dictionary mapping scalar metric names to 116 MetricThreshold(threshold=0.2, greater_is_better=True). 117 - baseline_model: Value for the `baseline_model` param passed into mlflow.evaluate(). 118 - expected_failure_message: Expected failure message. 119 """ 120 if request.param == "min_relative_change_present": 121 return ( 122 {"accuracy": MetricThreshold(min_absolute_change=0.1, greater_is_better=True)}, 123 None, 124 "The baseline model must be specified", 125 ) 126 if request.param == "min_absolute_change_present": 127 return ( 128 {"accuracy": MetricThreshold(min_relative_change=0.1, greater_is_better=True)}, 129 None, 130 "The baseline model must be specified", 131 ) 132 if request.param == "both_relative_absolute_change_present": 133 return ( 134 { 135 "accuracy": MetricThreshold( 136 min_absolute_change=0.05, min_relative_change=0.1, greater_is_better=True 137 ) 138 }, 139 None, 140 "The baseline model must be specified", 141 ) 142 if request.param == "baseline_model_is_not_string": 143 return ( 144 { 145 "accuracy": MetricThreshold( 146 min_absolute_change=0.05, min_relative_change=0.1, greater_is_better=True 147 ) 148 }, 149 1.0, 150 "The baseline model argument must be a string URI", 151 ) 152 153 154 @pytest.mark.parametrize( 155 "validation_thresholds", 156 [ 157 pytest.param(1, id="param_not_dict"), 158 pytest.param( 159 {1: MetricThreshold(min_absolute_change=0.1, greater_is_better=True)}, id="key_not_str" 160 ), 161 pytest.param({"accuracy": 1}, id="value_not_metric_threshold"), 162 ], 163 ) 164 def test_validation_faulty_validation_thresholds(validation_thresholds): 165 with pytest.raises(MlflowException, match="The validation thresholds argument"): 166 mlflow.validate_evaluation_results( 167 candidate_result={}, 168 baseline_result={}, 169 validation_thresholds=validation_thresholds, 170 ) 171 172 173 @pytest.fixture 174 def value_threshold_test_spec(request): 175 """ 176 Test specification for value threshold tests: 177 178 Returns: 179 A dict containing the following elements: 180 181 - metrics: A dictionary mapping scalar metric names to scalar metric values. 182 - validation_thresholds: A dictionary mapping scalar metric names to 183 MetricThreshold(threshold=0.2, greater_is_better=True). 184 - expected_validation_results: A dictionary mapping scalar metric names 185 to _MetricValidationResult. 186 """ 187 acc_threshold = MetricThreshold(threshold=0.9, greater_is_better=True) 188 acc_validation_result = _MetricValidationResult("accuracy", 0.8, acc_threshold, None) 189 acc_validation_result.threshold_failed = True 190 191 f1score_threshold = MetricThreshold(threshold=0.8, greater_is_better=True) 192 f1score_validation_result = _MetricValidationResult("f1_score", 0.7, f1score_threshold, None) 193 f1score_validation_result.threshold_failed = True 194 195 log_loss_threshold = MetricThreshold(threshold=0.5, greater_is_better=False) 196 log_loss_validation_result = _MetricValidationResult("log_loss", 0.3, log_loss_threshold, None) 197 198 l1_loss_threshold = MetricThreshold(threshold=0.3, greater_is_better=False) 199 l1_loss_validation_result = _MetricValidationResult( 200 "custom_l1_loss", 0.5, l1_loss_threshold, None 201 ) 202 l1_loss_validation_result.threshold_failed = True 203 204 if request.param == "single_metric_not_satisfied_higher_better": 205 return ({"accuracy": 0.8}, {"accuracy": acc_threshold}, {"accuracy": acc_validation_result}) 206 207 if request.param == "multiple_metrics_not_satisfied_higher_better": 208 return ( 209 {"accuracy": 0.8, "f1_score": 0.7}, 210 {"accuracy": acc_threshold, "f1_score": f1score_threshold}, 211 {"accuracy": acc_validation_result, "f1_score": f1score_validation_result}, 212 ) 213 214 if request.param == "single_metric_not_satisfied_lower_better": 215 return ( 216 {"custom_l1_loss": 0.5}, 217 {"custom_l1_loss": l1_loss_threshold}, 218 {"custom_l1_loss": l1_loss_validation_result}, 219 ) 220 221 if request.param == "multiple_metrics_not_satisfied_lower_better": 222 log_loss_validation_result.candidate_metric_value = 0.8 223 log_loss_validation_result.threshold_failed = True 224 return ( 225 {"custom_l1_loss": 0.5, "log_loss": 0.8}, 226 {"custom_l1_loss": l1_loss_threshold, "log_loss": log_loss_threshold}, 227 {"custom_l1_loss": l1_loss_validation_result, "log_loss": log_loss_validation_result}, 228 ) 229 230 if request.param == "missing_candidate_metric": 231 acc_validation_result.missing_candidate = True 232 return ({}, {"accuracy": acc_threshold}, {"accuracy": acc_validation_result}) 233 234 if request.param == "multiple_metrics_not_all_satisfied": 235 return ( 236 {"accuracy": 0.8, "f1_score": 0.7, "log_loss": 0.3}, 237 { 238 "accuracy": acc_threshold, 239 "f1_score": f1score_threshold, 240 "log_loss": log_loss_threshold, 241 }, 242 {"accuracy": acc_validation_result, "f1_score": f1score_validation_result}, 243 ) 244 245 if request.param == "equality_boundary": 246 return ( 247 {"accuracy": 0.9, "log_loss": 0.5}, 248 {"accuracy": acc_threshold, "log_loss": log_loss_threshold}, 249 {}, 250 ) 251 252 if request.param == "single_metric_satisfied_higher_better": 253 return ({"accuracy": 0.91}, {"accuracy": acc_threshold}, {}) 254 255 if request.param == "single_metric_satisfied_lower_better": 256 return ({"log_loss": 0.3}, {"log_loss": log_loss_threshold}, {}) 257 258 if request.param == "multiple_metrics_all_satisfied": 259 return ( 260 {"accuracy": 0.9, "f1_score": 0.8, "log_loss": 0.3}, 261 { 262 "accuracy": acc_threshold, 263 "f1_score": f1score_threshold, 264 "log_loss": log_loss_threshold, 265 }, 266 {}, 267 ) 268 269 270 @pytest.mark.parametrize( 271 "value_threshold_test_spec", 272 [ 273 ("single_metric_not_satisfied_higher_better"), 274 ("multiple_metrics_not_satisfied_higher_better"), 275 ("single_metric_not_satisfied_lower_better"), 276 ("missing_candidate_metric"), 277 ("multiple_metrics_not_satisfied_lower_better"), 278 ("multiple_metrics_not_all_satisfied"), 279 ], 280 indirect=["value_threshold_test_spec"], 281 ) 282 def test_validation_value_threshold_should_fail( 283 multiclass_logistic_regressor_model_uri, 284 iris_dataset, 285 value_threshold_test_spec, 286 ): 287 metrics, validation_thresholds, expected_validation_results = value_threshold_test_spec 288 289 MockEvaluator = mock.MagicMock(spec=ModelEvaluator) 290 MockEvaluator().can_evaluate.return_value = True 291 MockEvaluator().evaluate.return_value = EvaluationResult(metrics=metrics, artifacts={}) 292 293 with mock.patch.object( 294 _model_evaluation_registry, "_registry", {"test_evaluator1": MockEvaluator} 295 ): 296 candidate_result = evaluate( 297 multiclass_logistic_regressor_model_uri, 298 data=iris_dataset._constructor_args["data"], 299 model_type="classifier", 300 targets=iris_dataset._constructor_args["targets"], 301 evaluators="test_evaluator1", 302 ) 303 304 with pytest.raises( 305 ModelValidationFailedException, 306 match=message_separator.join(map(str, list(expected_validation_results.values()))), 307 ): 308 mlflow.validate_evaluation_results( 309 candidate_result=candidate_result, 310 baseline_result=None, 311 validation_thresholds=validation_thresholds, 312 ) 313 314 315 @pytest.mark.parametrize( 316 "value_threshold_test_spec", 317 [ 318 ("single_metric_satisfied_higher_better"), 319 ("single_metric_satisfied_lower_better"), 320 ("equality_boundary"), 321 ("multiple_metrics_all_satisfied"), 322 ], 323 indirect=["value_threshold_test_spec"], 324 ) 325 def test_validation_value_threshold_should_pass( 326 multiclass_logistic_regressor_model_uri, 327 iris_dataset, 328 value_threshold_test_spec, 329 ): 330 metrics, validation_thresholds, _ = value_threshold_test_spec 331 332 MockEvaluator = mock.MagicMock(spec=ModelEvaluator) 333 MockEvaluator().can_evaluate.return_value = True 334 MockEvaluator().evaluate.return_value = EvaluationResult(metrics=metrics, artifacts={}) 335 336 with mock.patch.object( 337 _model_evaluation_registry, "_registry", {"test_evaluator1": MockEvaluator} 338 ): 339 candidate_result = evaluate( 340 multiclass_logistic_regressor_model_uri, 341 data=iris_dataset._constructor_args["data"], 342 model_type="classifier", 343 targets=iris_dataset._constructor_args["targets"], 344 evaluators="test_evaluator1", 345 ) 346 347 mlflow.validate_evaluation_results( 348 candidate_result=candidate_result, 349 baseline_result=None, 350 validation_thresholds=validation_thresholds, 351 ) 352 353 354 @pytest.fixture 355 def min_absolute_change_threshold_test_spec(request): 356 """ 357 Test specification for min_absolute_change threshold tests: 358 359 Returns: 360 A dict containing the following elements: 361 362 - metrics: A dictionary mapping scalar metric names to scalar metric values. 363 - baseline_model_metrics: A dictionary mapping scalar metric names 364 to scalar metric values of baseline_model. 365 - validation_thresholds: A dictionary mapping scalar metric names 366 to MetricThreshold(threshold=0.2, greater_is_better=True). 367 - expected_validation_results: A dictionary mapping scalar metric names 368 to _MetricValidationResult. 369 """ 370 acc_threshold = MetricThreshold(min_absolute_change=0.1, greater_is_better=True) 371 f1score_threshold = MetricThreshold(min_absolute_change=0.15, greater_is_better=True) 372 log_loss_threshold = MetricThreshold(min_absolute_change=0.1, greater_is_better=False) 373 l1_loss_threshold = MetricThreshold(min_absolute_change=0.15, greater_is_better=False) 374 375 if request.param == "single_metric_not_satisfied_higher_better": 376 acc_validation_result = _MetricValidationResult("accuracy", 0.79, acc_threshold, 0.7) 377 acc_validation_result.min_absolute_change_failed = True 378 return ( 379 {"accuracy": 0.79}, 380 {"accuracy": 0.7}, 381 {"accuracy": acc_threshold}, 382 {"accuracy": acc_validation_result}, 383 ) 384 385 if request.param == "multiple_metrics_not_satisfied_higher_better": 386 acc_validation_result = _MetricValidationResult("accuracy", 0.79, acc_threshold, 0.7) 387 acc_validation_result.min_absolute_change_failed = True 388 f1score_validation_result = _MetricValidationResult("f1_score", 0.8, f1score_threshold, 0.7) 389 f1score_validation_result.min_absolute_change_failed = True 390 return ( 391 {"accuracy": 0.79, "f1_score": 0.8}, 392 {"accuracy": 0.7, "f1_score": 0.7}, 393 {"accuracy": acc_threshold, "f1_score": f1score_threshold}, 394 {"accuracy": acc_validation_result, "f1_score": f1score_validation_result}, 395 ) 396 397 if request.param == "single_metric_not_satisfied_lower_better": 398 l1_loss_validation_result = _MetricValidationResult( 399 "custom_l1_loss", 0.5, l1_loss_threshold, 0.6 400 ) 401 l1_loss_validation_result.min_absolute_change_failed = True 402 return ( 403 {"custom_l1_loss": 0.5}, 404 {"custom_l1_loss": 0.6}, 405 {"custom_l1_loss": l1_loss_threshold}, 406 {"custom_l1_loss": l1_loss_validation_result}, 407 ) 408 409 if request.param == "multiple_metrics_not_satisfied_lower_better": 410 l1_loss_validation_result = _MetricValidationResult( 411 "custom_l1_loss", 0.5, l1_loss_threshold, 0.6 412 ) 413 l1_loss_validation_result.min_absolute_change_failed = True 414 log_loss_validation_result = _MetricValidationResult( 415 "log_loss", 0.45, log_loss_threshold, 0.3 416 ) 417 log_loss_validation_result.min_absolute_change_failed = True 418 return ( 419 {"custom_l1_loss": 0.5, "log_loss": 0.45}, 420 {"custom_l1_loss": 0.6, "log_loss": 0.3}, 421 {"custom_l1_loss": l1_loss_threshold, "log_loss": log_loss_threshold}, 422 { 423 "custom_l1_loss": l1_loss_validation_result, 424 "log_loss": log_loss_validation_result, 425 }, 426 ) 427 428 if request.param == "equality_boundary": 429 acc_validation_result = _MetricValidationResult("accuracy", 0.8, acc_threshold, 0.7) 430 log_loss_validation_result = _MetricValidationResult( 431 "custom_log_loss", 0.2, log_loss_threshold, 0.3 432 ) 433 return ( 434 {"accuracy": 0.8 + 1e-10, "log_loss": 0.2 - 1e-10}, 435 {"accuracy": 0.7, "log_loss": 0.3}, 436 {"accuracy": acc_threshold, "log_loss": log_loss_threshold}, 437 {}, 438 ) 439 440 if request.param == "single_metric_satisfied_higher_better": 441 return ({"accuracy": 0.9 + 1e-2}, {"accuracy": 0.8}, {"accuracy": acc_threshold}, {}) 442 443 if request.param == "single_metric_satisfied_lower_better": 444 return ({"log_loss": 0.3}, {"log_loss": 0.4 + 1e-3}, {"log_loss": log_loss_threshold}, {}) 445 446 if request.param == "multiple_metrics_all_satisfied": 447 return ( 448 {"accuracy": 0.9, "f1_score": 0.8, "log_loss": 0.3}, 449 {"accuracy": 0.7, "f1_score": 0.6, "log_loss": 0.5}, 450 { 451 "accuracy": acc_threshold, 452 "f1_score": f1score_threshold, 453 "log_loss": log_loss_threshold, 454 }, 455 {}, 456 ) 457 458 if request.param == "missing_baseline_metric": 459 l1_loss_validation_result = _MetricValidationResult( 460 "custom_l1_loss", 0.72, l1_loss_threshold, None 461 ) 462 l1_loss_validation_result.missing_baseline = True 463 return ( 464 {"custom_l1_loss": 0.72}, 465 None, 466 {"custom_l1_loss": l1_loss_threshold}, 467 {"custom_l1_loss": l1_loss_validation_result}, 468 ) 469 470 471 @pytest.mark.parametrize( 472 "min_absolute_change_threshold_test_spec", 473 [ 474 ("single_metric_not_satisfied_higher_better"), 475 ("multiple_metrics_not_satisfied_higher_better"), 476 ("single_metric_not_satisfied_lower_better"), 477 ("multiple_metrics_not_satisfied_lower_better"), 478 ("missing_baseline_metric"), 479 ], 480 indirect=["min_absolute_change_threshold_test_spec"], 481 ) 482 def test_validation_model_comparison_absolute_threshold_should_fail( 483 multiclass_logistic_regressor_model_uri, 484 iris_dataset, 485 min_absolute_change_threshold_test_spec, 486 ): 487 ( 488 metrics, 489 baseline_model_metrics, 490 validation_thresholds, 491 expected_validation_results, 492 ) = min_absolute_change_threshold_test_spec 493 494 MockEvaluator = mock.MagicMock(spec=ModelEvaluator) 495 MockEvaluator().can_evaluate.return_value = True 496 mock_evaluate = MockEvaluator().evaluate 497 498 with mock.patch.object( 499 _model_evaluation_registry, "_registry", {"test_evaluator1": MockEvaluator} 500 ): 501 common_kwargs = { 502 "data": iris_dataset._constructor_args["data"], 503 "model_type": "classifier", 504 "targets": iris_dataset._constructor_args["targets"], 505 "evaluators": "test_evaluator1", 506 } 507 508 mock_evaluate.return_value = EvaluationResult(metrics=metrics, artifacts={}) 509 candidate_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs) 510 511 if baseline_model_metrics is None: 512 baseline_result = None 513 else: 514 mock_evaluate.return_value = EvaluationResult( 515 metrics=baseline_model_metrics, artifacts={} 516 ) 517 baseline_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs) 518 519 with pytest.raises( 520 ModelValidationFailedException, 521 match=message_separator.join(map(str, list(expected_validation_results.values()))), 522 ): 523 mlflow.validate_evaluation_results( 524 candidate_result=candidate_result, 525 baseline_result=baseline_result, 526 validation_thresholds=validation_thresholds, 527 ) 528 529 530 @pytest.mark.parametrize( 531 "min_absolute_change_threshold_test_spec", 532 [ 533 ("single_metric_satisfied_higher_better"), 534 ("single_metric_satisfied_lower_better"), 535 ("equality_boundary"), 536 ("multiple_metrics_all_satisfied"), 537 ], 538 indirect=["min_absolute_change_threshold_test_spec"], 539 ) 540 def test_validation_model_comparison_absolute_threshold_should_pass( 541 multiclass_logistic_regressor_model_uri, 542 iris_dataset, 543 min_absolute_change_threshold_test_spec, 544 ): 545 ( 546 metrics, 547 baseline_model_metrics, 548 validation_thresholds, 549 _, 550 ) = min_absolute_change_threshold_test_spec 551 552 MockEvaluator = mock.MagicMock(spec=ModelEvaluator) 553 MockEvaluator().can_evaluate.return_value = True 554 mock_evaluate = MockEvaluator().evaluate 555 556 with mock.patch.object( 557 _model_evaluation_registry, "_registry", {"test_evaluator1": MockEvaluator} 558 ): 559 common_kwargs = { 560 "data": iris_dataset._constructor_args["data"], 561 "model_type": "classifier", 562 "targets": iris_dataset._constructor_args["targets"], 563 "evaluators": "test_evaluator1", 564 } 565 566 mock_evaluate.return_value = EvaluationResult(metrics=metrics, artifacts={}) 567 candidate_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs) 568 569 mock_evaluate.return_value = EvaluationResult(metrics=baseline_model_metrics, artifacts={}) 570 baseline_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs) 571 572 mlflow.validate_evaluation_results( 573 candidate_result=candidate_result, 574 baseline_result=baseline_result, 575 validation_thresholds=validation_thresholds, 576 ) 577 578 579 @pytest.fixture 580 def min_relative_change_threshold_test_spec(request): 581 """ 582 Test specification for min_relative_change threshold tests: 583 584 Returns: 585 A dict with the following elements: 586 587 - metrics: A dictionary mapping scalar metric names to scalar metric values. 588 - baseline_model_metrics: A dictionary mapping scalar metric names 589 to scalar metric values of baseline_model. 590 - validation_thresholds: A dictionary mapping scalar metric names 591 to MetricThreshold(threshold=0.2, greater_is_better=True). 592 - expected_validation_results: A dictionary mapping scalar metric names 593 to _MetricValidationResult. 594 """ 595 acc_threshold = MetricThreshold(min_relative_change=0.1, greater_is_better=True) 596 f1score_threshold = MetricThreshold(min_relative_change=0.15, greater_is_better=True) 597 log_loss_threshold = MetricThreshold(min_relative_change=0.15, greater_is_better=False) 598 l1_loss_threshold = MetricThreshold(min_relative_change=0.1, greater_is_better=False) 599 600 if request.param == "single_metric_not_satisfied_higher_better": 601 acc_validation_result = _MetricValidationResult("accuracy", 0.75, acc_threshold, 0.7) 602 acc_validation_result.min_relative_change_failed = True 603 return ( 604 {"accuracy": 0.75}, 605 {"accuracy": 0.7}, 606 {"accuracy": acc_threshold}, 607 {"accuracy": acc_validation_result}, 608 ) 609 610 if request.param == "multiple_metrics_not_satisfied_higher_better": 611 acc_validation_result = _MetricValidationResult("accuracy", 0.53, acc_threshold, 0.5) 612 acc_validation_result.min_relative_change_failed = True 613 f1score_validation_result = _MetricValidationResult("f1_score", 0.8, f1score_threshold, 0.7) 614 f1score_validation_result.min_relative_change_failed = True 615 return ( 616 {"accuracy": 0.53, "f1_score": 0.8}, 617 {"accuracy": 0.5, "f1_score": 0.7}, 618 {"accuracy": acc_threshold, "f1_score": f1score_threshold}, 619 {"accuracy": acc_validation_result, "f1_score": f1score_validation_result}, 620 ) 621 622 if request.param == "single_metric_not_satisfied_lower_better": 623 l1_loss_validation_result = _MetricValidationResult( 624 "custom_l1_loss", 0.55, l1_loss_threshold, 0.6 625 ) 626 l1_loss_validation_result.min_relative_change_failed = True 627 return ( 628 {"custom_l1_loss": 0.55}, 629 {"custom_l1_loss": 0.6}, 630 {"custom_l1_loss": l1_loss_threshold}, 631 {"custom_l1_loss": l1_loss_validation_result}, 632 ) 633 634 if request.param == "missing_baseline_metric": 635 l1_loss_validation_result = _MetricValidationResult( 636 "custom_l1_loss", 0.72, l1_loss_threshold, None 637 ) 638 l1_loss_validation_result.missing_baseline = True 639 return ( 640 {"custom_l1_loss": 0.72}, 641 None, 642 {"custom_l1_loss": l1_loss_threshold}, 643 {"custom_l1_loss": l1_loss_validation_result}, 644 ) 645 646 if request.param == "multiple_metrics_not_satisfied_lower_better": 647 l1_loss_validation_result = _MetricValidationResult( 648 "custom_l1_loss", 0.72 + 1e-3, l1_loss_threshold, 0.8 649 ) 650 l1_loss_validation_result.min_relative_change_failed = True 651 log_loss_validation_result = _MetricValidationResult( 652 "log_loss", 0.27 + 1e-5, log_loss_threshold, 0.3 653 ) 654 log_loss_validation_result.min_relative_change_failed = True 655 return ( 656 {"custom_l1_loss": 0.72 + 1e-3, "log_loss": 0.27 + 1e-5}, 657 {"custom_l1_loss": 0.8, "log_loss": 0.3}, 658 {"custom_l1_loss": l1_loss_threshold, "log_loss": log_loss_threshold}, 659 { 660 "custom_l1_loss": l1_loss_validation_result, 661 "log_loss": log_loss_validation_result, 662 }, 663 ) 664 665 if request.param == "equality_boundary": 666 acc_validation_result = _MetricValidationResult("accuracy", 0.77, acc_threshold, 0.7) 667 log_loss_validation_result = _MetricValidationResult( 668 "custom_log_loss", 0.3 * 0.85 - 1e-10, log_loss_threshold, 0.3 669 ) 670 return ( 671 {"accuracy": 0.77, "log_loss": 0.3 * 0.85 - 1e-10}, 672 {"accuracy": 0.7, "log_loss": 0.3}, 673 {"accuracy": acc_threshold, "log_loss": log_loss_threshold}, 674 {}, 675 ) 676 677 if request.param == "single_metric_satisfied_higher_better": 678 return ({"accuracy": 0.99 + 1e-10}, {"accuracy": 0.9}, {"accuracy": acc_threshold}, {}) 679 680 if request.param == "single_metric_satisfied_lower_better": 681 return ({"log_loss": 0.3}, {"log_loss": 0.4}, {"log_loss": log_loss_threshold}, {}) 682 683 if request.param == "multiple_metrics_all_satisfied": 684 return ( 685 {"accuracy": 0.9, "f1_score": 0.9, "log_loss": 0.3}, 686 {"accuracy": 0.7, "f1_score": 0.6, "log_loss": 0.5}, 687 { 688 "accuracy": acc_threshold, 689 "f1_score": f1score_threshold, 690 "log_loss": log_loss_threshold, 691 }, 692 {}, 693 ) 694 695 if request.param == "baseline_metric_value_equals_0_succeeds": 696 threshold = MetricThreshold(min_relative_change=0.1, greater_is_better=True) 697 return ( 698 {"metric_1": 1e-10}, 699 {"metric_1": 0}, 700 {"metric_1": threshold}, 701 {"metric_1": _MetricValidationResult("metric_1", 0.8, threshold, 0.7)}, 702 ) 703 704 if request.param == "baseline_metric_value_equals_0_fails": 705 metric_1_threshold = MetricThreshold(min_relative_change=0.1, greater_is_better=True) 706 metric_1_result = _MetricValidationResult("metric_1", 0, metric_1_threshold, 0) 707 metric_1_result.min_relative_change_failed = True 708 return ( 709 {"metric_1": 0}, 710 {"metric_1": 0}, 711 {"metric_1": metric_1_threshold}, 712 {"metric_1": metric_1_result}, 713 ) 714 715 716 @pytest.mark.parametrize( 717 "min_relative_change_threshold_test_spec", 718 [ 719 ("single_metric_not_satisfied_higher_better"), 720 ("multiple_metrics_not_satisfied_higher_better"), 721 ("single_metric_not_satisfied_lower_better"), 722 ("multiple_metrics_not_satisfied_lower_better"), 723 ("missing_baseline_metric"), 724 ("baseline_metric_value_equals_0_fails"), 725 ], 726 indirect=["min_relative_change_threshold_test_spec"], 727 ) 728 def test_validation_model_comparison_relative_threshold_should_fail( 729 multiclass_logistic_regressor_model_uri, 730 iris_dataset, 731 min_relative_change_threshold_test_spec, 732 ): 733 ( 734 metrics, 735 baseline_model_metrics, 736 validation_thresholds, 737 expected_validation_results, 738 ) = min_relative_change_threshold_test_spec 739 740 MockEvaluator = mock.MagicMock(spec=ModelEvaluator) 741 MockEvaluator().can_evaluate.return_value = True 742 mock_evaluate = MockEvaluator().evaluate 743 744 with mock.patch.object( 745 _model_evaluation_registry, "_registry", {"test_evaluator1": MockEvaluator} 746 ): 747 common_kwargs = { 748 "data": iris_dataset._constructor_args["data"], 749 "model_type": "classifier", 750 "targets": iris_dataset._constructor_args["targets"], 751 "evaluators": "test_evaluator1", 752 } 753 754 mock_evaluate.return_value = EvaluationResult(metrics=metrics, artifacts={}) 755 candidate_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs) 756 757 if baseline_model_metrics is None: 758 baseline_result = None 759 else: 760 mock_evaluate.return_value = EvaluationResult( 761 metrics=baseline_model_metrics, artifacts={} 762 ) 763 baseline_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs) 764 765 with pytest.raises( 766 ModelValidationFailedException, 767 match=message_separator.join(map(str, list(expected_validation_results.values()))), 768 ): 769 mlflow.validate_evaluation_results( 770 candidate_result=candidate_result, 771 baseline_result=baseline_result, 772 validation_thresholds=validation_thresholds, 773 ) 774 775 776 @pytest.mark.parametrize( 777 "min_relative_change_threshold_test_spec", 778 [ 779 ("single_metric_satisfied_higher_better"), 780 ("single_metric_satisfied_lower_better"), 781 ("equality_boundary"), 782 ("multiple_metrics_all_satisfied"), 783 ("baseline_metric_value_equals_0_succeeds"), 784 ], 785 indirect=["min_relative_change_threshold_test_spec"], 786 ) 787 def test_validation_model_comparison_relative_threshold_should_pass( 788 multiclass_logistic_regressor_model_uri, 789 iris_dataset, 790 min_relative_change_threshold_test_spec, 791 ): 792 ( 793 metrics, 794 baseline_model_metrics, 795 validation_thresholds, 796 _, 797 ) = min_relative_change_threshold_test_spec 798 799 MockEvaluator = mock.MagicMock(spec=ModelEvaluator) 800 MockEvaluator().can_evaluate.return_value = True 801 mock_evaluate = MockEvaluator().evaluate 802 803 with mock.patch.object( 804 _model_evaluation_registry, "_registry", {"test_evaluator1": MockEvaluator} 805 ): 806 common_kwargs = { 807 "data": iris_dataset._constructor_args["data"], 808 "model_type": "classifier", 809 "targets": iris_dataset._constructor_args["targets"], 810 "evaluators": "test_evaluator1", 811 } 812 813 mock_evaluate.return_value = EvaluationResult(metrics=metrics, artifacts={}) 814 candidate_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs) 815 816 mock_evaluate.return_value = EvaluationResult(metrics=baseline_model_metrics, artifacts={}) 817 baseline_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs) 818 819 mlflow.validate_evaluation_results( 820 candidate_result=candidate_result, 821 baseline_result=baseline_result, 822 validation_thresholds=validation_thresholds, 823 ) 824 825 826 @pytest.fixture 827 def multi_thresholds_test_spec(request): 828 """ 829 Test specification for multi-thresholds tests: 830 831 Returns: 832 A dict with the following elements: 833 834 - metrics: A dictionary mapping scalar metric names to scalar metric values. 835 - baseline_model_metrics: A dictionary mapping scalar metric names 836 to scalar metric values of baseline_model. 837 - validation_thresholds: A dictionary mapping scalar metric names 838 to MetricThreshold(threshold=0.2, greater_is_better=True). 839 - expected_validation_results: A dictionary mapping scalar metric names 840 to _MetricValidationResult. 841 """ 842 acc_threshold = MetricThreshold( 843 threshold=0.8, min_absolute_change=0.1, min_relative_change=0.1, greater_is_better=True 844 ) 845 846 if request.param == "single_metric_all_thresholds_failed": 847 acc_validation_result = _MetricValidationResult("accuracy", 0.75, acc_threshold, 0.7) 848 acc_validation_result.threshold_failed = True 849 acc_validation_result.min_relative_change_failed = True 850 acc_validation_result.min_absolute_change_failed = True 851 return ( 852 {"accuracy": 0.75}, 853 {"accuracy": 0.7}, 854 {"accuracy": acc_threshold}, 855 {"accuracy": acc_validation_result}, 856 ) 857 858 859 @pytest.mark.parametrize( 860 "multi_thresholds_test_spec", 861 [ 862 ("single_metric_all_thresholds_failed"), 863 ], 864 indirect=["multi_thresholds_test_spec"], 865 ) 866 def test_validation_multi_thresholds_should_fail( 867 multiclass_logistic_regressor_model_uri, 868 iris_dataset, 869 multi_thresholds_test_spec, 870 ): 871 ( 872 metrics, 873 baseline_model_metrics, 874 validation_thresholds, 875 expected_validation_results, 876 ) = multi_thresholds_test_spec 877 878 MockEvaluator = mock.MagicMock(spec=ModelEvaluator) 879 MockEvaluator().can_evaluate.return_value = True 880 mock_evaluate = MockEvaluator().evaluate 881 882 with mock.patch.object( 883 _model_evaluation_registry, "_registry", {"test_evaluator1": MockEvaluator} 884 ): 885 common_kwargs = { 886 "data": iris_dataset._constructor_args["data"], 887 "model_type": "classifier", 888 "targets": iris_dataset._constructor_args["targets"], 889 "evaluators": "test_evaluator1", 890 } 891 892 mock_evaluate.return_value = EvaluationResult(metrics=metrics, artifacts={}) 893 candidate_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs) 894 895 mock_evaluate.return_value = EvaluationResult(metrics=baseline_model_metrics, artifacts={}) 896 baseline_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs) 897 898 with pytest.raises( 899 ModelValidationFailedException, 900 match=message_separator.join(map(str, list(expected_validation_results.values()))), 901 ): 902 mlflow.validate_evaluation_results( 903 candidate_result=candidate_result, 904 baseline_result=baseline_result, 905 validation_thresholds=validation_thresholds, 906 ) 907 908 909 def test_validation_thresholds_no_mock(): 910 targets = [0, 1, 1, 1] 911 data = [[random.random()] for _ in targets] 912 913 class BaseModel(mlflow.pyfunc.PythonModel): 914 def predict(self, context, model_input): 915 return len(model_input) * [0] 916 917 class CandidateModel(mlflow.pyfunc.PythonModel): 918 def predict(self, context, model_input): 919 return len(model_input) * [1] 920 921 with mlflow.start_run(): 922 base = mlflow.pyfunc.log_model(name="base", python_model=BaseModel()) 923 candidate = mlflow.pyfunc.log_model(name="candidate", python_model=CandidateModel()) 924 925 candidate_result = evaluate( 926 candidate.model_uri, 927 data=data, 928 model_type="classifier", 929 targets=targets, 930 ) 931 932 baseline_result = evaluate( 933 base.model_uri, 934 data=data, 935 model_type="classifier", 936 targets=targets, 937 ) 938 939 mlflow.validate_evaluation_results( 940 candidate_result=candidate_result, 941 baseline_result=baseline_result, 942 validation_thresholds={ 943 "recall_score": MetricThreshold( 944 threshold=0.9, 945 min_absolute_change=0.1, 946 greater_is_better=True, 947 ), 948 }, 949 ) 950 951 with pytest.raises( 952 ModelValidationFailedException, 953 match="recall_score value threshold check failed", 954 ): 955 mlflow.validate_evaluation_results( 956 candidate_result=baseline_result, 957 baseline_result=candidate_result, 958 validation_thresholds={ 959 "recall_score": MetricThreshold( 960 threshold=0.9, 961 min_absolute_change=0.1, 962 greater_is_better=True, 963 ), 964 }, 965 )