/ tests / sklearn / test_sklearn_autolog_without_matplotlib.py
test_sklearn_autolog_without_matplotlib.py
 1  from unittest import mock
 2  
 3  import pytest
 4  from sklearn.datasets import load_breast_cancer
 5  from sklearn.ensemble import RandomForestClassifier
 6  
 7  import mlflow
 8  from mlflow import MlflowClient
 9  
10  from tests.helper_functions import AnyStringWith
11  
12  
13  def is_matplotlib_installed():
14      try:
15          import matplotlib  # noqa: F401
16  
17          return True
18      except ImportError:
19          return False
20  
21  
22  @pytest.mark.skipif(
23      is_matplotlib_installed(), reason="matplotlib must be uninstalled to run this test"
24  )
25  def test_sklearn_autolog_works_without_matplotlib():
26      mlflow.sklearn.autolog()
27      model = RandomForestClassifier(max_depth=2, random_state=0, n_estimators=10)
28      X, y = load_breast_cancer(return_X_y=True)
29      with (
30          mlflow.start_run() as run,
31          mock.patch("mlflow.sklearn.utils._logger.warning") as mock_warning,
32      ):
33          model.fit(X, y)
34          mock_warning.assert_called_once_with(AnyStringWith("Failed to import matplotlib"))
35  
36      run = MlflowClient().get_run(run.info.run_id)
37      expected_metric_keys = {
38          "training_score",
39          "training_accuracy_score",
40          "training_precision_score",
41          "training_recall_score",
42          "training_f1_score",
43          "training_log_loss",
44      }
45      assert set(run.data.metrics).issuperset(expected_metric_keys)