test_sklearn_autolog_without_matplotlib.py
1 from unittest import mock 2 3 import pytest 4 from sklearn.datasets import load_breast_cancer 5 from sklearn.ensemble import RandomForestClassifier 6 7 import mlflow 8 from mlflow import MlflowClient 9 10 from tests.helper_functions import AnyStringWith 11 12 13 def is_matplotlib_installed(): 14 try: 15 import matplotlib # noqa: F401 16 17 return True 18 except ImportError: 19 return False 20 21 22 @pytest.mark.skipif( 23 is_matplotlib_installed(), reason="matplotlib must be uninstalled to run this test" 24 ) 25 def test_sklearn_autolog_works_without_matplotlib(): 26 mlflow.sklearn.autolog() 27 model = RandomForestClassifier(max_depth=2, random_state=0, n_estimators=10) 28 X, y = load_breast_cancer(return_X_y=True) 29 with ( 30 mlflow.start_run() as run, 31 mock.patch("mlflow.sklearn.utils._logger.warning") as mock_warning, 32 ): 33 model.fit(X, y) 34 mock_warning.assert_called_once_with(AnyStringWith("Failed to import matplotlib")) 35 36 run = MlflowClient().get_run(run.info.run_id) 37 expected_metric_keys = { 38 "training_score", 39 "training_accuracy_score", 40 "training_precision_score", 41 "training_recall_score", 42 "training_f1_score", 43 "training_log_loss", 44 } 45 assert set(run.data.metrics).issuperset(expected_metric_keys)