/ tests / tracking / context / test_databricks_repo_context.py
test_databricks_repo_context.py
 1  from unittest import mock
 2  
 3  from mlflow.tracking.context.databricks_repo_context import DatabricksRepoRunContext
 4  from mlflow.utils.mlflow_tags import (
 5      MLFLOW_DATABRICKS_GIT_REPO_COMMIT,
 6      MLFLOW_DATABRICKS_GIT_REPO_PROVIDER,
 7      MLFLOW_DATABRICKS_GIT_REPO_REFERENCE,
 8      MLFLOW_DATABRICKS_GIT_REPO_REFERENCE_TYPE,
 9      MLFLOW_DATABRICKS_GIT_REPO_RELATIVE_PATH,
10      MLFLOW_DATABRICKS_GIT_REPO_STATUS,
11      MLFLOW_DATABRICKS_GIT_REPO_URL,
12  )
13  
14  
15  def test_databricks_repo_run_context_in_context():
16      with mock.patch("mlflow.utils.databricks_utils.is_in_databricks_repo") as in_repo_mock:
17          assert DatabricksRepoRunContext().in_context() == in_repo_mock.return_value
18  
19  
20  def test_databricks_repo_run_context_tags():
21      patch_git_repo_url = mock.patch("mlflow.utils.databricks_utils.get_git_repo_url")
22      patch_git_repo_provider = mock.patch("mlflow.utils.databricks_utils.get_git_repo_provider")
23      patch_git_repo_commit = mock.patch("mlflow.utils.databricks_utils.get_git_repo_commit")
24      patch_git_repo_relative_path = mock.patch(
25          "mlflow.utils.databricks_utils.get_git_repo_relative_path"
26      )
27      patch_git_repo_reference = mock.patch("mlflow.utils.databricks_utils.get_git_repo_reference")
28      patch_git_repo_reference_type = mock.patch(
29          "mlflow.utils.databricks_utils.get_git_repo_reference_type"
30      )
31      patch_git_repo_status = mock.patch("mlflow.utils.databricks_utils.get_git_repo_status")
32  
33      with (
34          patch_git_repo_url as git_repo_url_mock,
35          patch_git_repo_provider as git_repo_provider_mock,
36          patch_git_repo_commit as git_repo_commit_mock,
37          patch_git_repo_relative_path as git_repo_relative_path_mock,
38          patch_git_repo_reference as git_repo_reference_mock,
39          patch_git_repo_reference_type as git_repo_reference_type_mock,
40          patch_git_repo_status as git_repo_status_mock,
41      ):
42          assert DatabricksRepoRunContext().tags() == {
43              MLFLOW_DATABRICKS_GIT_REPO_URL: git_repo_url_mock.return_value,
44              MLFLOW_DATABRICKS_GIT_REPO_PROVIDER: git_repo_provider_mock.return_value,
45              MLFLOW_DATABRICKS_GIT_REPO_COMMIT: git_repo_commit_mock.return_value,
46              MLFLOW_DATABRICKS_GIT_REPO_RELATIVE_PATH: git_repo_relative_path_mock.return_value,
47              MLFLOW_DATABRICKS_GIT_REPO_REFERENCE: git_repo_reference_mock.return_value,
48              MLFLOW_DATABRICKS_GIT_REPO_REFERENCE_TYPE: git_repo_reference_type_mock.return_value,
49              MLFLOW_DATABRICKS_GIT_REPO_STATUS: git_repo_status_mock.return_value,
50          }
51  
52  
53  def test_databricks_repo_run_context_tags_nones():
54      patch_git_repo_url = mock.patch(
55          "mlflow.utils.databricks_utils.get_git_repo_url", return_value=None
56      )
57      patch_git_repo_provider = mock.patch(
58          "mlflow.utils.databricks_utils.get_git_repo_provider", return_value=None
59      )
60      patch_git_repo_commit = mock.patch(
61          "mlflow.utils.databricks_utils.get_git_repo_commit", return_value=None
62      )
63      patch_git_repo_relative_path = mock.patch(
64          "mlflow.utils.databricks_utils.get_git_repo_relative_path", return_value=None
65      )
66      patch_git_repo_reference = mock.patch(
67          "mlflow.utils.databricks_utils.get_git_repo_reference", return_value=None
68      )
69      patch_git_repo_reference_type = mock.patch(
70          "mlflow.utils.databricks_utils.get_git_repo_reference_type", return_value=None
71      )
72      patch_git_repo_status = mock.patch(
73          "mlflow.utils.databricks_utils.get_git_repo_status", return_value=None
74      )
75      with (
76          patch_git_repo_url,
77          patch_git_repo_provider,
78          patch_git_repo_commit,
79          patch_git_repo_relative_path,
80          patch_git_repo_reference,
81          patch_git_repo_reference_type,
82          patch_git_repo_status,
83      ):
84          assert DatabricksRepoRunContext().tags() == {}