test_databricks_command_context.py
1 from unittest import mock 2 3 from mlflow.tracking.context.databricks_command_context import DatabricksCommandRunContext 4 from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_NOTEBOOK_COMMAND_ID 5 6 7 def test_databricks_command_run_context_in_context(): 8 with mock.patch("mlflow.utils.databricks_utils.get_job_group_id", return_value="1"): 9 assert DatabricksCommandRunContext().in_context() 10 11 12 def test_databricks_command_run_context_tags(): 13 with mock.patch("mlflow.utils.databricks_utils.get_job_group_id") as job_group_id_mock: 14 assert DatabricksCommandRunContext().tags() == { 15 MLFLOW_DATABRICKS_NOTEBOOK_COMMAND_ID: job_group_id_mock.return_value 16 } 17 18 19 def test_databricks_command_run_context_tags_nones(): 20 with mock.patch("mlflow.utils.databricks_utils.get_job_group_id", return_value=None): 21 assert DatabricksCommandRunContext().tags() == {}