test_databricks_cluster_context.py
1 from unittest import mock 2 3 from mlflow.tracking.context.databricks_cluster_context import DatabricksClusterRunContext 4 from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_CLUSTER_ID 5 6 7 def test_databricks_cluster_run_context_in_context(): 8 with mock.patch("mlflow.utils.databricks_utils.is_in_cluster") as in_cluster_mock: 9 assert DatabricksClusterRunContext().in_context() == in_cluster_mock.return_value 10 11 12 def test_databricks_cluster_run_context_tags(): 13 patch_cluster_id = mock.patch("mlflow.utils.databricks_utils.get_cluster_id") 14 with patch_cluster_id as cluster_id_mock: 15 assert DatabricksClusterRunContext().tags() == { 16 MLFLOW_DATABRICKS_CLUSTER_ID: cluster_id_mock.return_value 17 } 18 19 20 def test_databricks_notebook_run_context_tags_nones(): 21 assert DatabricksClusterRunContext().tags() == {}