/ tests / evaluate / test_validation.py
test_validation.py
  1  import random
  2  from unittest import mock
  3  
  4  import pytest
  5  
  6  import mlflow
  7  from mlflow.exceptions import MlflowException
  8  from mlflow.models.evaluation import (
  9      EvaluationResult,
 10      MetricThreshold,
 11      ModelEvaluator,
 12      evaluate,
 13  )
 14  from mlflow.models.evaluation.evaluator_registry import _model_evaluation_registry
 15  from mlflow.models.evaluation.validation import (
 16      MetricThresholdClassException,
 17      ModelValidationFailedException,
 18      _MetricValidationResult,
 19  )
 20  
 21  from tests.evaluate.test_evaluation import (
 22      iris_dataset,  # noqa: F401
 23      multiclass_logistic_regressor_model_uri,  # noqa: F401
 24  )
 25  
 26  message_separator = "\n"
 27  
 28  
 29  @pytest.fixture
 30  def metric_threshold_class_test_spec(request):
 31      """
 32      Test specification for MetricThreshold class:
 33  
 34      Returns:
 35          A tuple containing the following elements:
 36  
 37          - class_params: A dictionary mapping MetricThreshold class parameter names to values.
 38          - expected_failure_message: Expected failure message.
 39      """
 40      class_params = {
 41          "threshold": 1,
 42          "min_absolute_change": 1,
 43          "min_relative_change": 0.1,
 44          "greater_is_better": True,
 45      }
 46  
 47      if request.param == "threshold_is_not_number":
 48          class_params["threshold"] = "string"
 49          expected_failure_message = "`threshold` parameter must be a number."
 50      if request.param == "min_absolute_change_is_not_number":
 51          class_params["min_absolute_change"] = "string"
 52          expected_failure_message = "`min_absolute_change` parameter must be a positive number."
 53      elif request.param == "min_absolute_change_is_not_positive":
 54          class_params["min_absolute_change"] = -1
 55          expected_failure_message = "`min_absolute_change` parameter must be a positive number."
 56      elif request.param == "min_relative_change_is_not_float":
 57          class_params["min_relative_change"] = 2
 58          expected_failure_message = (
 59              "`min_relative_change` parameter must be a floating point number."
 60          )
 61      elif request.param == "min_relative_change_is_not_between_0_and_1":
 62          class_params["min_relative_change"] = -0.1
 63          expected_failure_message = "`min_relative_change` parameter must be between 0 and 1."
 64      elif request.param == "greater_is_better_is_not_defined":
 65          class_params["greater_is_better"] = None
 66          expected_failure_message = "`greater_is_better` parameter must be defined."
 67      elif request.param == "greater_is_better_is_not_bool":
 68          class_params["greater_is_better"] = 1
 69          expected_failure_message = "`greater_is_better` parameter must be a boolean."
 70      elif request.param == "no_threshold":
 71          class_params["threshold"] = None
 72          class_params["min_absolute_change"] = None
 73          class_params["min_relative_change"] = None
 74          expected_failure_message = "no threshold was specified."
 75  
 76      return (class_params, expected_failure_message)
 77  
 78  
 79  @pytest.mark.parametrize(
 80      "metric_threshold_class_test_spec",
 81      [
 82          ("threshold_is_not_number"),
 83          ("min_absolute_change_is_not_number"),
 84          ("min_absolute_change_is_not_positive"),
 85          ("min_relative_change_is_not_float"),
 86          ("min_relative_change_is_not_between_0_and_1"),
 87          ("greater_is_better_is_not_defined"),
 88          ("greater_is_better_is_not_bool"),
 89          ("no_threshold"),
 90      ],
 91      indirect=["metric_threshold_class_test_spec"],
 92  )
 93  def test_metric_threshold_class_should_fail(metric_threshold_class_test_spec):
 94      class_params, expected_failure_message = metric_threshold_class_test_spec
 95      with pytest.raises(
 96          MetricThresholdClassException,
 97          match=expected_failure_message,
 98      ):
 99          MetricThreshold(
100              threshold=class_params["threshold"],
101              min_absolute_change=class_params["min_absolute_change"],
102              min_relative_change=class_params["min_relative_change"],
103              greater_is_better=class_params["greater_is_better"],
104          )
105  
106  
107  @pytest.fixture
108  def faulty_baseline_model_param_test_spec(request):
109      """
110      Test specification for faulty `baseline_model` parameter tests:
111  
112      Returns:
113          A dict containing the following elements:
114  
115          - validation_thresholds: A dictionary mapping scalar metric names to
116            MetricThreshold(threshold=0.2, greater_is_better=True).
117          - baseline_model: Value for the `baseline_model` param passed into mlflow.evaluate().
118          - expected_failure_message: Expected failure message.
119      """
120      if request.param == "min_relative_change_present":
121          return (
122              {"accuracy": MetricThreshold(min_absolute_change=0.1, greater_is_better=True)},
123              None,
124              "The baseline model must be specified",
125          )
126      if request.param == "min_absolute_change_present":
127          return (
128              {"accuracy": MetricThreshold(min_relative_change=0.1, greater_is_better=True)},
129              None,
130              "The baseline model must be specified",
131          )
132      if request.param == "both_relative_absolute_change_present":
133          return (
134              {
135                  "accuracy": MetricThreshold(
136                      min_absolute_change=0.05, min_relative_change=0.1, greater_is_better=True
137                  )
138              },
139              None,
140              "The baseline model must be specified",
141          )
142      if request.param == "baseline_model_is_not_string":
143          return (
144              {
145                  "accuracy": MetricThreshold(
146                      min_absolute_change=0.05, min_relative_change=0.1, greater_is_better=True
147                  )
148              },
149              1.0,
150              "The baseline model argument must be a string URI",
151          )
152  
153  
154  @pytest.mark.parametrize(
155      "validation_thresholds",
156      [
157          pytest.param(1, id="param_not_dict"),
158          pytest.param(
159              {1: MetricThreshold(min_absolute_change=0.1, greater_is_better=True)}, id="key_not_str"
160          ),
161          pytest.param({"accuracy": 1}, id="value_not_metric_threshold"),
162      ],
163  )
164  def test_validation_faulty_validation_thresholds(validation_thresholds):
165      with pytest.raises(MlflowException, match="The validation thresholds argument"):
166          mlflow.validate_evaluation_results(
167              candidate_result={},
168              baseline_result={},
169              validation_thresholds=validation_thresholds,
170          )
171  
172  
173  @pytest.fixture
174  def value_threshold_test_spec(request):
175      """
176      Test specification for value threshold tests:
177  
178      Returns:
179          A dict containing the following elements:
180  
181          - metrics: A dictionary mapping scalar metric names to scalar metric values.
182          - validation_thresholds: A dictionary mapping scalar metric names to
183            MetricThreshold(threshold=0.2, greater_is_better=True).
184          - expected_validation_results: A dictionary mapping scalar metric names
185            to _MetricValidationResult.
186      """
187      acc_threshold = MetricThreshold(threshold=0.9, greater_is_better=True)
188      acc_validation_result = _MetricValidationResult("accuracy", 0.8, acc_threshold, None)
189      acc_validation_result.threshold_failed = True
190  
191      f1score_threshold = MetricThreshold(threshold=0.8, greater_is_better=True)
192      f1score_validation_result = _MetricValidationResult("f1_score", 0.7, f1score_threshold, None)
193      f1score_validation_result.threshold_failed = True
194  
195      log_loss_threshold = MetricThreshold(threshold=0.5, greater_is_better=False)
196      log_loss_validation_result = _MetricValidationResult("log_loss", 0.3, log_loss_threshold, None)
197  
198      l1_loss_threshold = MetricThreshold(threshold=0.3, greater_is_better=False)
199      l1_loss_validation_result = _MetricValidationResult(
200          "custom_l1_loss", 0.5, l1_loss_threshold, None
201      )
202      l1_loss_validation_result.threshold_failed = True
203  
204      if request.param == "single_metric_not_satisfied_higher_better":
205          return ({"accuracy": 0.8}, {"accuracy": acc_threshold}, {"accuracy": acc_validation_result})
206  
207      if request.param == "multiple_metrics_not_satisfied_higher_better":
208          return (
209              {"accuracy": 0.8, "f1_score": 0.7},
210              {"accuracy": acc_threshold, "f1_score": f1score_threshold},
211              {"accuracy": acc_validation_result, "f1_score": f1score_validation_result},
212          )
213  
214      if request.param == "single_metric_not_satisfied_lower_better":
215          return (
216              {"custom_l1_loss": 0.5},
217              {"custom_l1_loss": l1_loss_threshold},
218              {"custom_l1_loss": l1_loss_validation_result},
219          )
220  
221      if request.param == "multiple_metrics_not_satisfied_lower_better":
222          log_loss_validation_result.candidate_metric_value = 0.8
223          log_loss_validation_result.threshold_failed = True
224          return (
225              {"custom_l1_loss": 0.5, "log_loss": 0.8},
226              {"custom_l1_loss": l1_loss_threshold, "log_loss": log_loss_threshold},
227              {"custom_l1_loss": l1_loss_validation_result, "log_loss": log_loss_validation_result},
228          )
229  
230      if request.param == "missing_candidate_metric":
231          acc_validation_result.missing_candidate = True
232          return ({}, {"accuracy": acc_threshold}, {"accuracy": acc_validation_result})
233  
234      if request.param == "multiple_metrics_not_all_satisfied":
235          return (
236              {"accuracy": 0.8, "f1_score": 0.7, "log_loss": 0.3},
237              {
238                  "accuracy": acc_threshold,
239                  "f1_score": f1score_threshold,
240                  "log_loss": log_loss_threshold,
241              },
242              {"accuracy": acc_validation_result, "f1_score": f1score_validation_result},
243          )
244  
245      if request.param == "equality_boundary":
246          return (
247              {"accuracy": 0.9, "log_loss": 0.5},
248              {"accuracy": acc_threshold, "log_loss": log_loss_threshold},
249              {},
250          )
251  
252      if request.param == "single_metric_satisfied_higher_better":
253          return ({"accuracy": 0.91}, {"accuracy": acc_threshold}, {})
254  
255      if request.param == "single_metric_satisfied_lower_better":
256          return ({"log_loss": 0.3}, {"log_loss": log_loss_threshold}, {})
257  
258      if request.param == "multiple_metrics_all_satisfied":
259          return (
260              {"accuracy": 0.9, "f1_score": 0.8, "log_loss": 0.3},
261              {
262                  "accuracy": acc_threshold,
263                  "f1_score": f1score_threshold,
264                  "log_loss": log_loss_threshold,
265              },
266              {},
267          )
268  
269  
270  @pytest.mark.parametrize(
271      "value_threshold_test_spec",
272      [
273          ("single_metric_not_satisfied_higher_better"),
274          ("multiple_metrics_not_satisfied_higher_better"),
275          ("single_metric_not_satisfied_lower_better"),
276          ("missing_candidate_metric"),
277          ("multiple_metrics_not_satisfied_lower_better"),
278          ("multiple_metrics_not_all_satisfied"),
279      ],
280      indirect=["value_threshold_test_spec"],
281  )
282  def test_validation_value_threshold_should_fail(
283      multiclass_logistic_regressor_model_uri,
284      iris_dataset,
285      value_threshold_test_spec,
286  ):
287      metrics, validation_thresholds, expected_validation_results = value_threshold_test_spec
288  
289      MockEvaluator = mock.MagicMock(spec=ModelEvaluator)
290      MockEvaluator().can_evaluate.return_value = True
291      MockEvaluator().evaluate.return_value = EvaluationResult(metrics=metrics, artifacts={})
292  
293      with mock.patch.object(
294          _model_evaluation_registry, "_registry", {"test_evaluator1": MockEvaluator}
295      ):
296          candidate_result = evaluate(
297              multiclass_logistic_regressor_model_uri,
298              data=iris_dataset._constructor_args["data"],
299              model_type="classifier",
300              targets=iris_dataset._constructor_args["targets"],
301              evaluators="test_evaluator1",
302          )
303  
304      with pytest.raises(
305          ModelValidationFailedException,
306          match=message_separator.join(map(str, list(expected_validation_results.values()))),
307      ):
308          mlflow.validate_evaluation_results(
309              candidate_result=candidate_result,
310              baseline_result=None,
311              validation_thresholds=validation_thresholds,
312          )
313  
314  
315  @pytest.mark.parametrize(
316      "value_threshold_test_spec",
317      [
318          ("single_metric_satisfied_higher_better"),
319          ("single_metric_satisfied_lower_better"),
320          ("equality_boundary"),
321          ("multiple_metrics_all_satisfied"),
322      ],
323      indirect=["value_threshold_test_spec"],
324  )
325  def test_validation_value_threshold_should_pass(
326      multiclass_logistic_regressor_model_uri,
327      iris_dataset,
328      value_threshold_test_spec,
329  ):
330      metrics, validation_thresholds, _ = value_threshold_test_spec
331  
332      MockEvaluator = mock.MagicMock(spec=ModelEvaluator)
333      MockEvaluator().can_evaluate.return_value = True
334      MockEvaluator().evaluate.return_value = EvaluationResult(metrics=metrics, artifacts={})
335  
336      with mock.patch.object(
337          _model_evaluation_registry, "_registry", {"test_evaluator1": MockEvaluator}
338      ):
339          candidate_result = evaluate(
340              multiclass_logistic_regressor_model_uri,
341              data=iris_dataset._constructor_args["data"],
342              model_type="classifier",
343              targets=iris_dataset._constructor_args["targets"],
344              evaluators="test_evaluator1",
345          )
346  
347      mlflow.validate_evaluation_results(
348          candidate_result=candidate_result,
349          baseline_result=None,
350          validation_thresholds=validation_thresholds,
351      )
352  
353  
354  @pytest.fixture
355  def min_absolute_change_threshold_test_spec(request):
356      """
357      Test specification for min_absolute_change threshold tests:
358  
359      Returns:
360          A dict containing the following elements:
361  
362          - metrics: A dictionary mapping scalar metric names to scalar metric values.
363          - baseline_model_metrics: A dictionary mapping scalar metric names
364              to scalar metric values of baseline_model.
365          - validation_thresholds: A dictionary mapping scalar metric names
366              to MetricThreshold(threshold=0.2, greater_is_better=True).
367          - expected_validation_results: A dictionary mapping scalar metric names
368              to _MetricValidationResult.
369      """
370      acc_threshold = MetricThreshold(min_absolute_change=0.1, greater_is_better=True)
371      f1score_threshold = MetricThreshold(min_absolute_change=0.15, greater_is_better=True)
372      log_loss_threshold = MetricThreshold(min_absolute_change=0.1, greater_is_better=False)
373      l1_loss_threshold = MetricThreshold(min_absolute_change=0.15, greater_is_better=False)
374  
375      if request.param == "single_metric_not_satisfied_higher_better":
376          acc_validation_result = _MetricValidationResult("accuracy", 0.79, acc_threshold, 0.7)
377          acc_validation_result.min_absolute_change_failed = True
378          return (
379              {"accuracy": 0.79},
380              {"accuracy": 0.7},
381              {"accuracy": acc_threshold},
382              {"accuracy": acc_validation_result},
383          )
384  
385      if request.param == "multiple_metrics_not_satisfied_higher_better":
386          acc_validation_result = _MetricValidationResult("accuracy", 0.79, acc_threshold, 0.7)
387          acc_validation_result.min_absolute_change_failed = True
388          f1score_validation_result = _MetricValidationResult("f1_score", 0.8, f1score_threshold, 0.7)
389          f1score_validation_result.min_absolute_change_failed = True
390          return (
391              {"accuracy": 0.79, "f1_score": 0.8},
392              {"accuracy": 0.7, "f1_score": 0.7},
393              {"accuracy": acc_threshold, "f1_score": f1score_threshold},
394              {"accuracy": acc_validation_result, "f1_score": f1score_validation_result},
395          )
396  
397      if request.param == "single_metric_not_satisfied_lower_better":
398          l1_loss_validation_result = _MetricValidationResult(
399              "custom_l1_loss", 0.5, l1_loss_threshold, 0.6
400          )
401          l1_loss_validation_result.min_absolute_change_failed = True
402          return (
403              {"custom_l1_loss": 0.5},
404              {"custom_l1_loss": 0.6},
405              {"custom_l1_loss": l1_loss_threshold},
406              {"custom_l1_loss": l1_loss_validation_result},
407          )
408  
409      if request.param == "multiple_metrics_not_satisfied_lower_better":
410          l1_loss_validation_result = _MetricValidationResult(
411              "custom_l1_loss", 0.5, l1_loss_threshold, 0.6
412          )
413          l1_loss_validation_result.min_absolute_change_failed = True
414          log_loss_validation_result = _MetricValidationResult(
415              "log_loss", 0.45, log_loss_threshold, 0.3
416          )
417          log_loss_validation_result.min_absolute_change_failed = True
418          return (
419              {"custom_l1_loss": 0.5, "log_loss": 0.45},
420              {"custom_l1_loss": 0.6, "log_loss": 0.3},
421              {"custom_l1_loss": l1_loss_threshold, "log_loss": log_loss_threshold},
422              {
423                  "custom_l1_loss": l1_loss_validation_result,
424                  "log_loss": log_loss_validation_result,
425              },
426          )
427  
428      if request.param == "equality_boundary":
429          acc_validation_result = _MetricValidationResult("accuracy", 0.8, acc_threshold, 0.7)
430          log_loss_validation_result = _MetricValidationResult(
431              "custom_log_loss", 0.2, log_loss_threshold, 0.3
432          )
433          return (
434              {"accuracy": 0.8 + 1e-10, "log_loss": 0.2 - 1e-10},
435              {"accuracy": 0.7, "log_loss": 0.3},
436              {"accuracy": acc_threshold, "log_loss": log_loss_threshold},
437              {},
438          )
439  
440      if request.param == "single_metric_satisfied_higher_better":
441          return ({"accuracy": 0.9 + 1e-2}, {"accuracy": 0.8}, {"accuracy": acc_threshold}, {})
442  
443      if request.param == "single_metric_satisfied_lower_better":
444          return ({"log_loss": 0.3}, {"log_loss": 0.4 + 1e-3}, {"log_loss": log_loss_threshold}, {})
445  
446      if request.param == "multiple_metrics_all_satisfied":
447          return (
448              {"accuracy": 0.9, "f1_score": 0.8, "log_loss": 0.3},
449              {"accuracy": 0.7, "f1_score": 0.6, "log_loss": 0.5},
450              {
451                  "accuracy": acc_threshold,
452                  "f1_score": f1score_threshold,
453                  "log_loss": log_loss_threshold,
454              },
455              {},
456          )
457  
458      if request.param == "missing_baseline_metric":
459          l1_loss_validation_result = _MetricValidationResult(
460              "custom_l1_loss", 0.72, l1_loss_threshold, None
461          )
462          l1_loss_validation_result.missing_baseline = True
463          return (
464              {"custom_l1_loss": 0.72},
465              None,
466              {"custom_l1_loss": l1_loss_threshold},
467              {"custom_l1_loss": l1_loss_validation_result},
468          )
469  
470  
471  @pytest.mark.parametrize(
472      "min_absolute_change_threshold_test_spec",
473      [
474          ("single_metric_not_satisfied_higher_better"),
475          ("multiple_metrics_not_satisfied_higher_better"),
476          ("single_metric_not_satisfied_lower_better"),
477          ("multiple_metrics_not_satisfied_lower_better"),
478          ("missing_baseline_metric"),
479      ],
480      indirect=["min_absolute_change_threshold_test_spec"],
481  )
482  def test_validation_model_comparison_absolute_threshold_should_fail(
483      multiclass_logistic_regressor_model_uri,
484      iris_dataset,
485      min_absolute_change_threshold_test_spec,
486  ):
487      (
488          metrics,
489          baseline_model_metrics,
490          validation_thresholds,
491          expected_validation_results,
492      ) = min_absolute_change_threshold_test_spec
493  
494      MockEvaluator = mock.MagicMock(spec=ModelEvaluator)
495      MockEvaluator().can_evaluate.return_value = True
496      mock_evaluate = MockEvaluator().evaluate
497  
498      with mock.patch.object(
499          _model_evaluation_registry, "_registry", {"test_evaluator1": MockEvaluator}
500      ):
501          common_kwargs = {
502              "data": iris_dataset._constructor_args["data"],
503              "model_type": "classifier",
504              "targets": iris_dataset._constructor_args["targets"],
505              "evaluators": "test_evaluator1",
506          }
507  
508          mock_evaluate.return_value = EvaluationResult(metrics=metrics, artifacts={})
509          candidate_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs)
510  
511          if baseline_model_metrics is None:
512              baseline_result = None
513          else:
514              mock_evaluate.return_value = EvaluationResult(
515                  metrics=baseline_model_metrics, artifacts={}
516              )
517              baseline_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs)
518  
519      with pytest.raises(
520          ModelValidationFailedException,
521          match=message_separator.join(map(str, list(expected_validation_results.values()))),
522      ):
523          mlflow.validate_evaluation_results(
524              candidate_result=candidate_result,
525              baseline_result=baseline_result,
526              validation_thresholds=validation_thresholds,
527          )
528  
529  
530  @pytest.mark.parametrize(
531      "min_absolute_change_threshold_test_spec",
532      [
533          ("single_metric_satisfied_higher_better"),
534          ("single_metric_satisfied_lower_better"),
535          ("equality_boundary"),
536          ("multiple_metrics_all_satisfied"),
537      ],
538      indirect=["min_absolute_change_threshold_test_spec"],
539  )
540  def test_validation_model_comparison_absolute_threshold_should_pass(
541      multiclass_logistic_regressor_model_uri,
542      iris_dataset,
543      min_absolute_change_threshold_test_spec,
544  ):
545      (
546          metrics,
547          baseline_model_metrics,
548          validation_thresholds,
549          _,
550      ) = min_absolute_change_threshold_test_spec
551  
552      MockEvaluator = mock.MagicMock(spec=ModelEvaluator)
553      MockEvaluator().can_evaluate.return_value = True
554      mock_evaluate = MockEvaluator().evaluate
555  
556      with mock.patch.object(
557          _model_evaluation_registry, "_registry", {"test_evaluator1": MockEvaluator}
558      ):
559          common_kwargs = {
560              "data": iris_dataset._constructor_args["data"],
561              "model_type": "classifier",
562              "targets": iris_dataset._constructor_args["targets"],
563              "evaluators": "test_evaluator1",
564          }
565  
566          mock_evaluate.return_value = EvaluationResult(metrics=metrics, artifacts={})
567          candidate_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs)
568  
569          mock_evaluate.return_value = EvaluationResult(metrics=baseline_model_metrics, artifacts={})
570          baseline_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs)
571  
572      mlflow.validate_evaluation_results(
573          candidate_result=candidate_result,
574          baseline_result=baseline_result,
575          validation_thresholds=validation_thresholds,
576      )
577  
578  
579  @pytest.fixture
580  def min_relative_change_threshold_test_spec(request):
581      """
582      Test specification for min_relative_change threshold tests:
583  
584      Returns:
585          A dict with the following elements:
586  
587          - metrics: A dictionary mapping scalar metric names to scalar metric values.
588          - baseline_model_metrics: A dictionary mapping scalar metric names
589              to scalar metric values of baseline_model.
590          - validation_thresholds: A dictionary mapping scalar metric names
591              to MetricThreshold(threshold=0.2, greater_is_better=True).
592          - expected_validation_results: A dictionary mapping scalar metric names
593              to _MetricValidationResult.
594      """
595      acc_threshold = MetricThreshold(min_relative_change=0.1, greater_is_better=True)
596      f1score_threshold = MetricThreshold(min_relative_change=0.15, greater_is_better=True)
597      log_loss_threshold = MetricThreshold(min_relative_change=0.15, greater_is_better=False)
598      l1_loss_threshold = MetricThreshold(min_relative_change=0.1, greater_is_better=False)
599  
600      if request.param == "single_metric_not_satisfied_higher_better":
601          acc_validation_result = _MetricValidationResult("accuracy", 0.75, acc_threshold, 0.7)
602          acc_validation_result.min_relative_change_failed = True
603          return (
604              {"accuracy": 0.75},
605              {"accuracy": 0.7},
606              {"accuracy": acc_threshold},
607              {"accuracy": acc_validation_result},
608          )
609  
610      if request.param == "multiple_metrics_not_satisfied_higher_better":
611          acc_validation_result = _MetricValidationResult("accuracy", 0.53, acc_threshold, 0.5)
612          acc_validation_result.min_relative_change_failed = True
613          f1score_validation_result = _MetricValidationResult("f1_score", 0.8, f1score_threshold, 0.7)
614          f1score_validation_result.min_relative_change_failed = True
615          return (
616              {"accuracy": 0.53, "f1_score": 0.8},
617              {"accuracy": 0.5, "f1_score": 0.7},
618              {"accuracy": acc_threshold, "f1_score": f1score_threshold},
619              {"accuracy": acc_validation_result, "f1_score": f1score_validation_result},
620          )
621  
622      if request.param == "single_metric_not_satisfied_lower_better":
623          l1_loss_validation_result = _MetricValidationResult(
624              "custom_l1_loss", 0.55, l1_loss_threshold, 0.6
625          )
626          l1_loss_validation_result.min_relative_change_failed = True
627          return (
628              {"custom_l1_loss": 0.55},
629              {"custom_l1_loss": 0.6},
630              {"custom_l1_loss": l1_loss_threshold},
631              {"custom_l1_loss": l1_loss_validation_result},
632          )
633  
634      if request.param == "missing_baseline_metric":
635          l1_loss_validation_result = _MetricValidationResult(
636              "custom_l1_loss", 0.72, l1_loss_threshold, None
637          )
638          l1_loss_validation_result.missing_baseline = True
639          return (
640              {"custom_l1_loss": 0.72},
641              None,
642              {"custom_l1_loss": l1_loss_threshold},
643              {"custom_l1_loss": l1_loss_validation_result},
644          )
645  
646      if request.param == "multiple_metrics_not_satisfied_lower_better":
647          l1_loss_validation_result = _MetricValidationResult(
648              "custom_l1_loss", 0.72 + 1e-3, l1_loss_threshold, 0.8
649          )
650          l1_loss_validation_result.min_relative_change_failed = True
651          log_loss_validation_result = _MetricValidationResult(
652              "log_loss", 0.27 + 1e-5, log_loss_threshold, 0.3
653          )
654          log_loss_validation_result.min_relative_change_failed = True
655          return (
656              {"custom_l1_loss": 0.72 + 1e-3, "log_loss": 0.27 + 1e-5},
657              {"custom_l1_loss": 0.8, "log_loss": 0.3},
658              {"custom_l1_loss": l1_loss_threshold, "log_loss": log_loss_threshold},
659              {
660                  "custom_l1_loss": l1_loss_validation_result,
661                  "log_loss": log_loss_validation_result,
662              },
663          )
664  
665      if request.param == "equality_boundary":
666          acc_validation_result = _MetricValidationResult("accuracy", 0.77, acc_threshold, 0.7)
667          log_loss_validation_result = _MetricValidationResult(
668              "custom_log_loss", 0.3 * 0.85 - 1e-10, log_loss_threshold, 0.3
669          )
670          return (
671              {"accuracy": 0.77, "log_loss": 0.3 * 0.85 - 1e-10},
672              {"accuracy": 0.7, "log_loss": 0.3},
673              {"accuracy": acc_threshold, "log_loss": log_loss_threshold},
674              {},
675          )
676  
677      if request.param == "single_metric_satisfied_higher_better":
678          return ({"accuracy": 0.99 + 1e-10}, {"accuracy": 0.9}, {"accuracy": acc_threshold}, {})
679  
680      if request.param == "single_metric_satisfied_lower_better":
681          return ({"log_loss": 0.3}, {"log_loss": 0.4}, {"log_loss": log_loss_threshold}, {})
682  
683      if request.param == "multiple_metrics_all_satisfied":
684          return (
685              {"accuracy": 0.9, "f1_score": 0.9, "log_loss": 0.3},
686              {"accuracy": 0.7, "f1_score": 0.6, "log_loss": 0.5},
687              {
688                  "accuracy": acc_threshold,
689                  "f1_score": f1score_threshold,
690                  "log_loss": log_loss_threshold,
691              },
692              {},
693          )
694  
695      if request.param == "baseline_metric_value_equals_0_succeeds":
696          threshold = MetricThreshold(min_relative_change=0.1, greater_is_better=True)
697          return (
698              {"metric_1": 1e-10},
699              {"metric_1": 0},
700              {"metric_1": threshold},
701              {"metric_1": _MetricValidationResult("metric_1", 0.8, threshold, 0.7)},
702          )
703  
704      if request.param == "baseline_metric_value_equals_0_fails":
705          metric_1_threshold = MetricThreshold(min_relative_change=0.1, greater_is_better=True)
706          metric_1_result = _MetricValidationResult("metric_1", 0, metric_1_threshold, 0)
707          metric_1_result.min_relative_change_failed = True
708          return (
709              {"metric_1": 0},
710              {"metric_1": 0},
711              {"metric_1": metric_1_threshold},
712              {"metric_1": metric_1_result},
713          )
714  
715  
716  @pytest.mark.parametrize(
717      "min_relative_change_threshold_test_spec",
718      [
719          ("single_metric_not_satisfied_higher_better"),
720          ("multiple_metrics_not_satisfied_higher_better"),
721          ("single_metric_not_satisfied_lower_better"),
722          ("multiple_metrics_not_satisfied_lower_better"),
723          ("missing_baseline_metric"),
724          ("baseline_metric_value_equals_0_fails"),
725      ],
726      indirect=["min_relative_change_threshold_test_spec"],
727  )
728  def test_validation_model_comparison_relative_threshold_should_fail(
729      multiclass_logistic_regressor_model_uri,
730      iris_dataset,
731      min_relative_change_threshold_test_spec,
732  ):
733      (
734          metrics,
735          baseline_model_metrics,
736          validation_thresholds,
737          expected_validation_results,
738      ) = min_relative_change_threshold_test_spec
739  
740      MockEvaluator = mock.MagicMock(spec=ModelEvaluator)
741      MockEvaluator().can_evaluate.return_value = True
742      mock_evaluate = MockEvaluator().evaluate
743  
744      with mock.patch.object(
745          _model_evaluation_registry, "_registry", {"test_evaluator1": MockEvaluator}
746      ):
747          common_kwargs = {
748              "data": iris_dataset._constructor_args["data"],
749              "model_type": "classifier",
750              "targets": iris_dataset._constructor_args["targets"],
751              "evaluators": "test_evaluator1",
752          }
753  
754          mock_evaluate.return_value = EvaluationResult(metrics=metrics, artifacts={})
755          candidate_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs)
756  
757          if baseline_model_metrics is None:
758              baseline_result = None
759          else:
760              mock_evaluate.return_value = EvaluationResult(
761                  metrics=baseline_model_metrics, artifacts={}
762              )
763              baseline_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs)
764  
765          with pytest.raises(
766              ModelValidationFailedException,
767              match=message_separator.join(map(str, list(expected_validation_results.values()))),
768          ):
769              mlflow.validate_evaluation_results(
770                  candidate_result=candidate_result,
771                  baseline_result=baseline_result,
772                  validation_thresholds=validation_thresholds,
773              )
774  
775  
776  @pytest.mark.parametrize(
777      "min_relative_change_threshold_test_spec",
778      [
779          ("single_metric_satisfied_higher_better"),
780          ("single_metric_satisfied_lower_better"),
781          ("equality_boundary"),
782          ("multiple_metrics_all_satisfied"),
783          ("baseline_metric_value_equals_0_succeeds"),
784      ],
785      indirect=["min_relative_change_threshold_test_spec"],
786  )
787  def test_validation_model_comparison_relative_threshold_should_pass(
788      multiclass_logistic_regressor_model_uri,
789      iris_dataset,
790      min_relative_change_threshold_test_spec,
791  ):
792      (
793          metrics,
794          baseline_model_metrics,
795          validation_thresholds,
796          _,
797      ) = min_relative_change_threshold_test_spec
798  
799      MockEvaluator = mock.MagicMock(spec=ModelEvaluator)
800      MockEvaluator().can_evaluate.return_value = True
801      mock_evaluate = MockEvaluator().evaluate
802  
803      with mock.patch.object(
804          _model_evaluation_registry, "_registry", {"test_evaluator1": MockEvaluator}
805      ):
806          common_kwargs = {
807              "data": iris_dataset._constructor_args["data"],
808              "model_type": "classifier",
809              "targets": iris_dataset._constructor_args["targets"],
810              "evaluators": "test_evaluator1",
811          }
812  
813          mock_evaluate.return_value = EvaluationResult(metrics=metrics, artifacts={})
814          candidate_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs)
815  
816          mock_evaluate.return_value = EvaluationResult(metrics=baseline_model_metrics, artifacts={})
817          baseline_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs)
818  
819      mlflow.validate_evaluation_results(
820          candidate_result=candidate_result,
821          baseline_result=baseline_result,
822          validation_thresholds=validation_thresholds,
823      )
824  
825  
826  @pytest.fixture
827  def multi_thresholds_test_spec(request):
828      """
829      Test specification for multi-thresholds tests:
830  
831      Returns:
832          A dict with the following elements:
833  
834          - metrics: A dictionary mapping scalar metric names to scalar metric values.
835          - baseline_model_metrics: A dictionary mapping scalar metric names
836              to scalar metric values of baseline_model.
837          - validation_thresholds: A dictionary mapping scalar metric names
838              to MetricThreshold(threshold=0.2, greater_is_better=True).
839          - expected_validation_results: A dictionary mapping scalar metric names
840              to _MetricValidationResult.
841      """
842      acc_threshold = MetricThreshold(
843          threshold=0.8, min_absolute_change=0.1, min_relative_change=0.1, greater_is_better=True
844      )
845  
846      if request.param == "single_metric_all_thresholds_failed":
847          acc_validation_result = _MetricValidationResult("accuracy", 0.75, acc_threshold, 0.7)
848          acc_validation_result.threshold_failed = True
849          acc_validation_result.min_relative_change_failed = True
850          acc_validation_result.min_absolute_change_failed = True
851          return (
852              {"accuracy": 0.75},
853              {"accuracy": 0.7},
854              {"accuracy": acc_threshold},
855              {"accuracy": acc_validation_result},
856          )
857  
858  
859  @pytest.mark.parametrize(
860      "multi_thresholds_test_spec",
861      [
862          ("single_metric_all_thresholds_failed"),
863      ],
864      indirect=["multi_thresholds_test_spec"],
865  )
866  def test_validation_multi_thresholds_should_fail(
867      multiclass_logistic_regressor_model_uri,
868      iris_dataset,
869      multi_thresholds_test_spec,
870  ):
871      (
872          metrics,
873          baseline_model_metrics,
874          validation_thresholds,
875          expected_validation_results,
876      ) = multi_thresholds_test_spec
877  
878      MockEvaluator = mock.MagicMock(spec=ModelEvaluator)
879      MockEvaluator().can_evaluate.return_value = True
880      mock_evaluate = MockEvaluator().evaluate
881  
882      with mock.patch.object(
883          _model_evaluation_registry, "_registry", {"test_evaluator1": MockEvaluator}
884      ):
885          common_kwargs = {
886              "data": iris_dataset._constructor_args["data"],
887              "model_type": "classifier",
888              "targets": iris_dataset._constructor_args["targets"],
889              "evaluators": "test_evaluator1",
890          }
891  
892          mock_evaluate.return_value = EvaluationResult(metrics=metrics, artifacts={})
893          candidate_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs)
894  
895          mock_evaluate.return_value = EvaluationResult(metrics=baseline_model_metrics, artifacts={})
896          baseline_result = evaluate(multiclass_logistic_regressor_model_uri, **common_kwargs)
897  
898      with pytest.raises(
899          ModelValidationFailedException,
900          match=message_separator.join(map(str, list(expected_validation_results.values()))),
901      ):
902          mlflow.validate_evaluation_results(
903              candidate_result=candidate_result,
904              baseline_result=baseline_result,
905              validation_thresholds=validation_thresholds,
906          )
907  
908  
909  def test_validation_thresholds_no_mock():
910      targets = [0, 1, 1, 1]
911      data = [[random.random()] for _ in targets]
912  
913      class BaseModel(mlflow.pyfunc.PythonModel):
914          def predict(self, context, model_input):
915              return len(model_input) * [0]
916  
917      class CandidateModel(mlflow.pyfunc.PythonModel):
918          def predict(self, context, model_input):
919              return len(model_input) * [1]
920  
921      with mlflow.start_run():
922          base = mlflow.pyfunc.log_model(name="base", python_model=BaseModel())
923          candidate = mlflow.pyfunc.log_model(name="candidate", python_model=CandidateModel())
924  
925          candidate_result = evaluate(
926              candidate.model_uri,
927              data=data,
928              model_type="classifier",
929              targets=targets,
930          )
931  
932          baseline_result = evaluate(
933              base.model_uri,
934              data=data,
935              model_type="classifier",
936              targets=targets,
937          )
938  
939      mlflow.validate_evaluation_results(
940          candidate_result=candidate_result,
941          baseline_result=baseline_result,
942          validation_thresholds={
943              "recall_score": MetricThreshold(
944                  threshold=0.9,
945                  min_absolute_change=0.1,
946                  greater_is_better=True,
947              ),
948          },
949      )
950  
951      with pytest.raises(
952          ModelValidationFailedException,
953          match="recall_score value threshold check failed",
954      ):
955          mlflow.validate_evaluation_results(
956              candidate_result=baseline_result,
957              baseline_result=candidate_result,
958              validation_thresholds={
959                  "recall_score": MetricThreshold(
960                      threshold=0.9,
961                      min_absolute_change=0.1,
962                      greater_is_better=True,
963                  ),
964              },
965          )