/ tests / tracing / test_databricks.py
test_databricks.py
 1  from unittest import mock
 2  
 3  import pytest
 4  
 5  from mlflow.exceptions import MlflowException
 6  from mlflow.tracing.databricks import set_databricks_monitoring_sql_warehouse_id
 7  
 8  
 9  def test_set_databricks_monitoring_sql_warehouse_id_requires_databricks_tracking_uri():
10      with mock.patch("mlflow.get_tracking_uri", return_value="file:///tmp"):
11          with pytest.raises(MlflowException, match="only supported when the tracking URI"):
12              set_databricks_monitoring_sql_warehouse_id(
13                  sql_warehouse_id="warehouse123", experiment_id="exp456"
14              )
15  
16  
17  def test_set_databricks_monitoring_sql_warehouse_id_with_explicit_experiment_id():
18      mock_store = mock.MagicMock()
19      with (
20          mock.patch("mlflow.tracking.get_tracking_uri", return_value="databricks"),
21          mock.patch(
22              "mlflow.tracking._tracking_service.utils._get_store",
23              return_value=mock_store,
24          ),
25      ):
26          set_databricks_monitoring_sql_warehouse_id(
27              sql_warehouse_id="warehouse123", experiment_id="exp456"
28          )
29          mock_store.set_experiment_tag.assert_called_once()
30          call_args = mock_store.set_experiment_tag.call_args
31          assert call_args[0][0] == "exp456"
32          assert call_args[0][1].key == "mlflow.monitoring.sqlWarehouseId"
33          assert call_args[0][1].value == "warehouse123"
34  
35  
36  def test_set_databricks_monitoring_sql_warehouse_id_with_default_experiment_id():
37      mock_store = mock.MagicMock()
38      with (
39          mock.patch("mlflow.tracking.get_tracking_uri", return_value="databricks"),
40          mock.patch(
41              "mlflow.tracking._tracking_service.utils._get_store",
42              return_value=mock_store,
43          ),
44          mock.patch("mlflow.tracking.fluent._get_experiment_id", return_value="default_exp"),
45      ):
46          set_databricks_monitoring_sql_warehouse_id(sql_warehouse_id="warehouse789")
47          mock_store.set_experiment_tag.assert_called_once()
48          call_args = mock_store.set_experiment_tag.call_args
49          assert call_args[0][0] == "default_exp"
50          assert call_args[0][1].key == "mlflow.monitoring.sqlWarehouseId"
51          assert call_args[0][1].value == "warehouse789"