test_databricks_job_context.py
1 from unittest import mock 2 3 from mlflow.entities import SourceType 4 from mlflow.tracking.context.databricks_job_context import DatabricksJobRunContext 5 from mlflow.utils.mlflow_tags import ( 6 MLFLOW_DATABRICKS_JOB_ID, 7 MLFLOW_DATABRICKS_JOB_RUN_ID, 8 MLFLOW_DATABRICKS_JOB_TYPE, 9 MLFLOW_DATABRICKS_WEBAPP_URL, 10 MLFLOW_DATABRICKS_WORKSPACE_ID, 11 MLFLOW_DATABRICKS_WORKSPACE_URL, 12 MLFLOW_SOURCE_NAME, 13 MLFLOW_SOURCE_TYPE, 14 ) 15 16 17 def test_databricks_job_run_context_in_context(): 18 with mock.patch("mlflow.utils.databricks_utils.is_in_databricks_job") as in_job_mock: 19 assert DatabricksJobRunContext().in_context() == in_job_mock.return_value 20 21 22 def test_databricks_job_run_context_tags(): 23 patch_job_id = mock.patch("mlflow.utils.databricks_utils.get_job_id") 24 patch_job_run_id = mock.patch("mlflow.utils.databricks_utils.get_job_run_id") 25 patch_job_type = mock.patch("mlflow.utils.databricks_utils.get_job_type") 26 patch_webapp_url = mock.patch("mlflow.utils.databricks_utils.get_webapp_url") 27 patch_workspace_url = mock.patch( 28 "mlflow.utils.databricks_utils.get_workspace_url", 29 return_value="https://dev.databricks.com", 30 ) 31 patch_workspace_id = mock.patch( 32 "mlflow.utils.databricks_utils.get_workspace_id", return_value="123456" 33 ) 34 patch_workspace_url_none = mock.patch( 35 "mlflow.utils.databricks_utils.get_workspace_url", return_value=None 36 ) 37 patch_workspace_info = mock.patch( 38 "mlflow.utils.databricks_utils.get_workspace_info_from_dbutils", 39 return_value=("https://databricks.com", "123456"), 40 ) 41 42 with ( 43 patch_job_id as job_id_mock, 44 patch_job_run_id as job_run_id_mock, 45 patch_job_type as job_type_mock, 46 patch_webapp_url as webapp_url_mock, 47 patch_workspace_url as workspace_url_mock, 48 patch_workspace_info as workspace_info_mock, 49 patch_workspace_id as workspace_id_mock, 50 ): 51 assert DatabricksJobRunContext().tags() == { 52 MLFLOW_SOURCE_NAME: ( 53 f"jobs/{job_id_mock.return_value}/run/{job_run_id_mock.return_value}" 54 ), 55 MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.JOB), 56 MLFLOW_DATABRICKS_JOB_ID: job_id_mock.return_value, 57 MLFLOW_DATABRICKS_JOB_RUN_ID: job_run_id_mock.return_value, 58 MLFLOW_DATABRICKS_JOB_TYPE: job_type_mock.return_value, 59 MLFLOW_DATABRICKS_WEBAPP_URL: webapp_url_mock.return_value, 60 MLFLOW_DATABRICKS_WORKSPACE_URL: workspace_url_mock.return_value, 61 MLFLOW_DATABRICKS_WORKSPACE_ID: workspace_id_mock.return_value, 62 } 63 64 with ( 65 patch_job_id as job_id_mock, 66 patch_job_run_id as job_run_id_mock, 67 patch_job_type as job_type_mock, 68 patch_webapp_url as webapp_url_mock, 69 patch_workspace_url_none as workspace_url_mock, 70 patch_workspace_info as workspace_info_mock, 71 patch_workspace_id as workspace_id_mock, 72 ): 73 assert DatabricksJobRunContext().tags() == { 74 MLFLOW_SOURCE_NAME: ( 75 f"jobs/{job_id_mock.return_value}/run/{job_run_id_mock.return_value}" 76 ), 77 MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.JOB), 78 MLFLOW_DATABRICKS_JOB_ID: job_id_mock.return_value, 79 MLFLOW_DATABRICKS_JOB_RUN_ID: job_run_id_mock.return_value, 80 MLFLOW_DATABRICKS_JOB_TYPE: job_type_mock.return_value, 81 MLFLOW_DATABRICKS_WEBAPP_URL: webapp_url_mock.return_value, 82 MLFLOW_DATABRICKS_WORKSPACE_URL: workspace_info_mock.return_value[0], # fallback value 83 MLFLOW_DATABRICKS_WORKSPACE_ID: workspace_id_mock.return_value, 84 } 85 86 87 def test_databricks_job_run_context_tags_nones(): 88 patch_job_id = mock.patch("mlflow.utils.databricks_utils.get_job_id", return_value=None) 89 patch_job_run_id = mock.patch("mlflow.utils.databricks_utils.get_job_run_id", return_value=None) 90 patch_job_type = mock.patch("mlflow.utils.databricks_utils.get_job_type", return_value=None) 91 patch_webapp_url = mock.patch("mlflow.utils.databricks_utils.get_webapp_url", return_value=None) 92 patch_workspace_info = mock.patch( 93 "mlflow.utils.databricks_utils.get_workspace_info_from_dbutils", return_value=(None, None) 94 ) 95 96 with patch_job_id, patch_job_run_id, patch_job_type, patch_webapp_url, patch_workspace_info: 97 assert DatabricksJobRunContext().tags() == { 98 MLFLOW_SOURCE_NAME: None, 99 MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.JOB), 100 }