/ tests / tracking / context / test_databricks_cluster_context.py
test_databricks_cluster_context.py
 1  from unittest import mock
 2  
 3  from mlflow.tracking.context.databricks_cluster_context import DatabricksClusterRunContext
 4  from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_CLUSTER_ID
 5  
 6  
 7  def test_databricks_cluster_run_context_in_context():
 8      with mock.patch("mlflow.utils.databricks_utils.is_in_cluster") as in_cluster_mock:
 9          assert DatabricksClusterRunContext().in_context() == in_cluster_mock.return_value
10  
11  
12  def test_databricks_cluster_run_context_tags():
13      patch_cluster_id = mock.patch("mlflow.utils.databricks_utils.get_cluster_id")
14      with patch_cluster_id as cluster_id_mock:
15          assert DatabricksClusterRunContext().tags() == {
16              MLFLOW_DATABRICKS_CLUSTER_ID: cluster_id_mock.return_value
17          }
18  
19  
20  def test_databricks_notebook_run_context_tags_nones():
21      assert DatabricksClusterRunContext().tags() == {}