test_databricks.py
1 from unittest import mock 2 3 import pytest 4 5 from mlflow.exceptions import MlflowException 6 from mlflow.tracing.databricks import set_databricks_monitoring_sql_warehouse_id 7 8 9 def test_set_databricks_monitoring_sql_warehouse_id_requires_databricks_tracking_uri(): 10 with mock.patch("mlflow.get_tracking_uri", return_value="file:///tmp"): 11 with pytest.raises(MlflowException, match="only supported when the tracking URI"): 12 set_databricks_monitoring_sql_warehouse_id( 13 sql_warehouse_id="warehouse123", experiment_id="exp456" 14 ) 15 16 17 def test_set_databricks_monitoring_sql_warehouse_id_with_explicit_experiment_id(): 18 mock_store = mock.MagicMock() 19 with ( 20 mock.patch("mlflow.tracking.get_tracking_uri", return_value="databricks"), 21 mock.patch( 22 "mlflow.tracking._tracking_service.utils._get_store", 23 return_value=mock_store, 24 ), 25 ): 26 set_databricks_monitoring_sql_warehouse_id( 27 sql_warehouse_id="warehouse123", experiment_id="exp456" 28 ) 29 mock_store.set_experiment_tag.assert_called_once() 30 call_args = mock_store.set_experiment_tag.call_args 31 assert call_args[0][0] == "exp456" 32 assert call_args[0][1].key == "mlflow.monitoring.sqlWarehouseId" 33 assert call_args[0][1].value == "warehouse123" 34 35 36 def test_set_databricks_monitoring_sql_warehouse_id_with_default_experiment_id(): 37 mock_store = mock.MagicMock() 38 with ( 39 mock.patch("mlflow.tracking.get_tracking_uri", return_value="databricks"), 40 mock.patch( 41 "mlflow.tracking._tracking_service.utils._get_store", 42 return_value=mock_store, 43 ), 44 mock.patch("mlflow.tracking.fluent._get_experiment_id", return_value="default_exp"), 45 ): 46 set_databricks_monitoring_sql_warehouse_id(sql_warehouse_id="warehouse789") 47 mock_store.set_experiment_tag.assert_called_once() 48 call_args = mock_store.set_experiment_tag.call_args 49 assert call_args[0][0] == "default_exp" 50 assert call_args[0][1].key == "mlflow.monitoring.sqlWarehouseId" 51 assert call_args[0][1].value == "warehouse789"