/ tests / tracking / context / test_databricks_job_context.py
test_databricks_job_context.py
  1  from unittest import mock
  2  
  3  from mlflow.entities import SourceType
  4  from mlflow.tracking.context.databricks_job_context import DatabricksJobRunContext
  5  from mlflow.utils.mlflow_tags import (
  6      MLFLOW_DATABRICKS_JOB_ID,
  7      MLFLOW_DATABRICKS_JOB_RUN_ID,
  8      MLFLOW_DATABRICKS_JOB_TYPE,
  9      MLFLOW_DATABRICKS_WEBAPP_URL,
 10      MLFLOW_DATABRICKS_WORKSPACE_ID,
 11      MLFLOW_DATABRICKS_WORKSPACE_URL,
 12      MLFLOW_SOURCE_NAME,
 13      MLFLOW_SOURCE_TYPE,
 14  )
 15  
 16  
 17  def test_databricks_job_run_context_in_context():
 18      with mock.patch("mlflow.utils.databricks_utils.is_in_databricks_job") as in_job_mock:
 19          assert DatabricksJobRunContext().in_context() == in_job_mock.return_value
 20  
 21  
 22  def test_databricks_job_run_context_tags():
 23      patch_job_id = mock.patch("mlflow.utils.databricks_utils.get_job_id")
 24      patch_job_run_id = mock.patch("mlflow.utils.databricks_utils.get_job_run_id")
 25      patch_job_type = mock.patch("mlflow.utils.databricks_utils.get_job_type")
 26      patch_webapp_url = mock.patch("mlflow.utils.databricks_utils.get_webapp_url")
 27      patch_workspace_url = mock.patch(
 28          "mlflow.utils.databricks_utils.get_workspace_url",
 29          return_value="https://dev.databricks.com",
 30      )
 31      patch_workspace_id = mock.patch(
 32          "mlflow.utils.databricks_utils.get_workspace_id", return_value="123456"
 33      )
 34      patch_workspace_url_none = mock.patch(
 35          "mlflow.utils.databricks_utils.get_workspace_url", return_value=None
 36      )
 37      patch_workspace_info = mock.patch(
 38          "mlflow.utils.databricks_utils.get_workspace_info_from_dbutils",
 39          return_value=("https://databricks.com", "123456"),
 40      )
 41  
 42      with (
 43          patch_job_id as job_id_mock,
 44          patch_job_run_id as job_run_id_mock,
 45          patch_job_type as job_type_mock,
 46          patch_webapp_url as webapp_url_mock,
 47          patch_workspace_url as workspace_url_mock,
 48          patch_workspace_info as workspace_info_mock,
 49          patch_workspace_id as workspace_id_mock,
 50      ):
 51          assert DatabricksJobRunContext().tags() == {
 52              MLFLOW_SOURCE_NAME: (
 53                  f"jobs/{job_id_mock.return_value}/run/{job_run_id_mock.return_value}"
 54              ),
 55              MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.JOB),
 56              MLFLOW_DATABRICKS_JOB_ID: job_id_mock.return_value,
 57              MLFLOW_DATABRICKS_JOB_RUN_ID: job_run_id_mock.return_value,
 58              MLFLOW_DATABRICKS_JOB_TYPE: job_type_mock.return_value,
 59              MLFLOW_DATABRICKS_WEBAPP_URL: webapp_url_mock.return_value,
 60              MLFLOW_DATABRICKS_WORKSPACE_URL: workspace_url_mock.return_value,
 61              MLFLOW_DATABRICKS_WORKSPACE_ID: workspace_id_mock.return_value,
 62          }
 63  
 64      with (
 65          patch_job_id as job_id_mock,
 66          patch_job_run_id as job_run_id_mock,
 67          patch_job_type as job_type_mock,
 68          patch_webapp_url as webapp_url_mock,
 69          patch_workspace_url_none as workspace_url_mock,
 70          patch_workspace_info as workspace_info_mock,
 71          patch_workspace_id as workspace_id_mock,
 72      ):
 73          assert DatabricksJobRunContext().tags() == {
 74              MLFLOW_SOURCE_NAME: (
 75                  f"jobs/{job_id_mock.return_value}/run/{job_run_id_mock.return_value}"
 76              ),
 77              MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.JOB),
 78              MLFLOW_DATABRICKS_JOB_ID: job_id_mock.return_value,
 79              MLFLOW_DATABRICKS_JOB_RUN_ID: job_run_id_mock.return_value,
 80              MLFLOW_DATABRICKS_JOB_TYPE: job_type_mock.return_value,
 81              MLFLOW_DATABRICKS_WEBAPP_URL: webapp_url_mock.return_value,
 82              MLFLOW_DATABRICKS_WORKSPACE_URL: workspace_info_mock.return_value[0],  # fallback value
 83              MLFLOW_DATABRICKS_WORKSPACE_ID: workspace_id_mock.return_value,
 84          }
 85  
 86  
 87  def test_databricks_job_run_context_tags_nones():
 88      patch_job_id = mock.patch("mlflow.utils.databricks_utils.get_job_id", return_value=None)
 89      patch_job_run_id = mock.patch("mlflow.utils.databricks_utils.get_job_run_id", return_value=None)
 90      patch_job_type = mock.patch("mlflow.utils.databricks_utils.get_job_type", return_value=None)
 91      patch_webapp_url = mock.patch("mlflow.utils.databricks_utils.get_webapp_url", return_value=None)
 92      patch_workspace_info = mock.patch(
 93          "mlflow.utils.databricks_utils.get_workspace_info_from_dbutils", return_value=(None, None)
 94      )
 95  
 96      with patch_job_id, patch_job_run_id, patch_job_type, patch_webapp_url, patch_workspace_info:
 97          assert DatabricksJobRunContext().tags() == {
 98              MLFLOW_SOURCE_NAME: None,
 99              MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.JOB),
100          }