/ tests / tracking / context / test_databricks_command_context.py
test_databricks_command_context.py
 1  from unittest import mock
 2  
 3  from mlflow.tracking.context.databricks_command_context import DatabricksCommandRunContext
 4  from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_NOTEBOOK_COMMAND_ID
 5  
 6  
 7  def test_databricks_command_run_context_in_context():
 8      with mock.patch("mlflow.utils.databricks_utils.get_job_group_id", return_value="1"):
 9          assert DatabricksCommandRunContext().in_context()
10  
11  
12  def test_databricks_command_run_context_tags():
13      with mock.patch("mlflow.utils.databricks_utils.get_job_group_id") as job_group_id_mock:
14          assert DatabricksCommandRunContext().tags() == {
15              MLFLOW_DATABRICKS_NOTEBOOK_COMMAND_ID: job_group_id_mock.return_value
16          }
17  
18  
19  def test_databricks_command_run_context_tags_nones():
20      with mock.patch("mlflow.utils.databricks_utils.get_job_group_id", return_value=None):
21          assert DatabricksCommandRunContext().tags() == {}