test_databricks_repo_context.py
1 from unittest import mock 2 3 from mlflow.tracking.context.databricks_repo_context import DatabricksRepoRunContext 4 from mlflow.utils.mlflow_tags import ( 5 MLFLOW_DATABRICKS_GIT_REPO_COMMIT, 6 MLFLOW_DATABRICKS_GIT_REPO_PROVIDER, 7 MLFLOW_DATABRICKS_GIT_REPO_REFERENCE, 8 MLFLOW_DATABRICKS_GIT_REPO_REFERENCE_TYPE, 9 MLFLOW_DATABRICKS_GIT_REPO_RELATIVE_PATH, 10 MLFLOW_DATABRICKS_GIT_REPO_STATUS, 11 MLFLOW_DATABRICKS_GIT_REPO_URL, 12 ) 13 14 15 def test_databricks_repo_run_context_in_context(): 16 with mock.patch("mlflow.utils.databricks_utils.is_in_databricks_repo") as in_repo_mock: 17 assert DatabricksRepoRunContext().in_context() == in_repo_mock.return_value 18 19 20 def test_databricks_repo_run_context_tags(): 21 patch_git_repo_url = mock.patch("mlflow.utils.databricks_utils.get_git_repo_url") 22 patch_git_repo_provider = mock.patch("mlflow.utils.databricks_utils.get_git_repo_provider") 23 patch_git_repo_commit = mock.patch("mlflow.utils.databricks_utils.get_git_repo_commit") 24 patch_git_repo_relative_path = mock.patch( 25 "mlflow.utils.databricks_utils.get_git_repo_relative_path" 26 ) 27 patch_git_repo_reference = mock.patch("mlflow.utils.databricks_utils.get_git_repo_reference") 28 patch_git_repo_reference_type = mock.patch( 29 "mlflow.utils.databricks_utils.get_git_repo_reference_type" 30 ) 31 patch_git_repo_status = mock.patch("mlflow.utils.databricks_utils.get_git_repo_status") 32 33 with ( 34 patch_git_repo_url as git_repo_url_mock, 35 patch_git_repo_provider as git_repo_provider_mock, 36 patch_git_repo_commit as git_repo_commit_mock, 37 patch_git_repo_relative_path as git_repo_relative_path_mock, 38 patch_git_repo_reference as git_repo_reference_mock, 39 patch_git_repo_reference_type as git_repo_reference_type_mock, 40 patch_git_repo_status as git_repo_status_mock, 41 ): 42 assert DatabricksRepoRunContext().tags() == { 43 MLFLOW_DATABRICKS_GIT_REPO_URL: git_repo_url_mock.return_value, 44 MLFLOW_DATABRICKS_GIT_REPO_PROVIDER: git_repo_provider_mock.return_value, 45 MLFLOW_DATABRICKS_GIT_REPO_COMMIT: git_repo_commit_mock.return_value, 46 MLFLOW_DATABRICKS_GIT_REPO_RELATIVE_PATH: git_repo_relative_path_mock.return_value, 47 MLFLOW_DATABRICKS_GIT_REPO_REFERENCE: git_repo_reference_mock.return_value, 48 MLFLOW_DATABRICKS_GIT_REPO_REFERENCE_TYPE: git_repo_reference_type_mock.return_value, 49 MLFLOW_DATABRICKS_GIT_REPO_STATUS: git_repo_status_mock.return_value, 50 } 51 52 53 def test_databricks_repo_run_context_tags_nones(): 54 patch_git_repo_url = mock.patch( 55 "mlflow.utils.databricks_utils.get_git_repo_url", return_value=None 56 ) 57 patch_git_repo_provider = mock.patch( 58 "mlflow.utils.databricks_utils.get_git_repo_provider", return_value=None 59 ) 60 patch_git_repo_commit = mock.patch( 61 "mlflow.utils.databricks_utils.get_git_repo_commit", return_value=None 62 ) 63 patch_git_repo_relative_path = mock.patch( 64 "mlflow.utils.databricks_utils.get_git_repo_relative_path", return_value=None 65 ) 66 patch_git_repo_reference = mock.patch( 67 "mlflow.utils.databricks_utils.get_git_repo_reference", return_value=None 68 ) 69 patch_git_repo_reference_type = mock.patch( 70 "mlflow.utils.databricks_utils.get_git_repo_reference_type", return_value=None 71 ) 72 patch_git_repo_status = mock.patch( 73 "mlflow.utils.databricks_utils.get_git_repo_status", return_value=None 74 ) 75 with ( 76 patch_git_repo_url, 77 patch_git_repo_provider, 78 patch_git_repo_commit, 79 patch_git_repo_relative_path, 80 patch_git_repo_reference, 81 patch_git_repo_reference_type, 82 patch_git_repo_status, 83 ): 84 assert DatabricksRepoRunContext().tags() == {}