/ tests / store / artifact / test_databricks_models_artifact_repo.py
test_databricks_models_artifact_repo.py
  1  import json
  2  from unittest import mock
  3  from unittest.mock import ANY
  4  
  5  import pytest
  6  import requests
  7  
  8  from mlflow.entities import FileInfo
  9  from mlflow.entities.model_registry import ModelVersion
 10  from mlflow.environment_variables import MLFLOW_MULTIPART_DOWNLOAD_CHUNK_SIZE
 11  from mlflow.exceptions import MlflowException
 12  from mlflow.store.artifact.databricks_models_artifact_repo import (
 13      DatabricksModelsArtifactRepository,
 14  )
 15  from mlflow.tracking._model_registry.client import ModelRegistryClient
 16  from mlflow.utils.file_utils import _Chunk
 17  
 18  DATABRICKS_MODEL_ARTIFACT_REPOSITORY_PACKAGE = (
 19      "mlflow.store.artifact.databricks_models_artifact_repo"
 20  )
 21  DATABRICKS_MODEL_ARTIFACT_REPOSITORY = (
 22      DATABRICKS_MODEL_ARTIFACT_REPOSITORY_PACKAGE + ".DatabricksModelsArtifactRepository"
 23  )
 24  MOCK_MODEL_ROOT_URI_WITH_PROFILE = "models://profile@databricks/MyModel/12"
 25  MOCK_MODEL_ROOT_URI_WITHOUT_PROFILE = "models:/MyModel/12"
 26  MOCK_PROFILE = "databricks://profile"
 27  MOCK_MODEL_NAME = "MyModel"
 28  MOCK_MODEL_VERSION = "12"
 29  
 30  REGISTRY_LIST_ARTIFACTS_ENDPOINT = "/api/2.0/mlflow/model-versions/list-artifacts"
 31  REGISTRY_ARTIFACT_PRESIGNED_URI_ENDPOINT = "/api/2.0/mlflow/model-versions/get-signed-download-uri"
 32  
 33  
 34  @pytest.fixture
 35  def databricks_model_artifact_repo():
 36      return DatabricksModelsArtifactRepository(MOCK_MODEL_ROOT_URI_WITH_PROFILE)
 37  
 38  
 39  def test_init_with_version_uri_containing_profile():
 40      repo = DatabricksModelsArtifactRepository(MOCK_MODEL_ROOT_URI_WITH_PROFILE)
 41      assert repo.artifact_uri == MOCK_MODEL_ROOT_URI_WITH_PROFILE
 42      assert repo.model_name == MOCK_MODEL_NAME
 43      assert repo.model_version == MOCK_MODEL_VERSION
 44  
 45  
 46  @pytest.mark.parametrize(
 47      "stage_uri_with_profile",
 48      [
 49          "models://profile@databricks/MyModel/Staging",
 50          "models://profile@databricks/MyModel/Production",
 51      ],
 52  )
 53  def test_init_with_stage_uri_containing_profile(stage_uri_with_profile):
 54      model_version_detailed = ModelVersion(
 55          MOCK_MODEL_NAME,
 56          MOCK_MODEL_VERSION,
 57          "2345671890",
 58          "234567890",
 59          "some description",
 60          "UserID",
 61          "Production",
 62          "source",
 63          "run12345",
 64      )
 65      get_latest_versions_patch = mock.patch.object(
 66          ModelRegistryClient, "get_latest_versions", return_value=[model_version_detailed]
 67      )
 68      with get_latest_versions_patch:
 69          repo = DatabricksModelsArtifactRepository(stage_uri_with_profile)
 70          assert repo.artifact_uri == stage_uri_with_profile
 71          assert repo.model_name == MOCK_MODEL_NAME
 72          assert repo.model_version == MOCK_MODEL_VERSION
 73          assert repo.databricks_profile_uri == MOCK_PROFILE
 74  
 75  
 76  @pytest.mark.parametrize(
 77      "invalid_artifact_uri",
 78      [
 79          "s3://test",
 80          "dbfs:/databricks/mlflow/MV-id/models",
 81          "dbfs://scope:key@notdatabricks/databricks/mlflow-regisry/123/models",
 82          "models:/MyModel/12",
 83          "models://scope:key@notdatabricks/MyModel/12",
 84      ],
 85  )
 86  def test_init_with_invalid_artifact_uris(invalid_artifact_uri):
 87      with pytest.raises(
 88          MlflowException,
 89          match="A valid databricks profile is required to instantiate this repository",
 90      ):
 91          DatabricksModelsArtifactRepository(invalid_artifact_uri)
 92  
 93  
 94  def test_init_with_version_uri_and_profile_is_inferred():
 95      # First mock for `is_using_databricks_registry` to pass
 96      # Second mock to set `databricks_profile_uri` during instantiation
 97      with (
 98          mock.patch(
 99              "mlflow.store.artifact.utils.models.mlflow.get_registry_uri",
100              return_value=MOCK_PROFILE,
101          ),
102          mock.patch("mlflow.tracking.get_registry_uri", return_value=MOCK_PROFILE),
103      ):
104          repo = DatabricksModelsArtifactRepository(MOCK_MODEL_ROOT_URI_WITHOUT_PROFILE)
105          assert repo.artifact_uri == MOCK_MODEL_ROOT_URI_WITHOUT_PROFILE
106          assert repo.model_name == MOCK_MODEL_NAME
107          assert repo.model_version == MOCK_MODEL_VERSION
108          assert repo.databricks_profile_uri == MOCK_PROFILE
109  
110  
111  @pytest.mark.parametrize(
112      "stage_uri_without_profile",
113      ["models:/MyModel/Staging", "models:/MyModel/Production"],
114  )
115  def test_init_with_stage_uri_and_profile_is_inferred(stage_uri_without_profile):
116      model_version_detailed = ModelVersion(
117          MOCK_MODEL_NAME,
118          MOCK_MODEL_VERSION,
119          "2345671890",
120          "234567890",
121          "some description",
122          "UserID",
123          "Production",
124          "source",
125          "run12345",
126      )
127      get_latest_versions_patch = mock.patch.object(
128          ModelRegistryClient, "get_latest_versions", return_value=[model_version_detailed]
129      )
130      with (
131          get_latest_versions_patch,
132          mock.patch(
133              "mlflow.store.artifact.utils.models.mlflow.get_registry_uri",
134              return_value=MOCK_PROFILE,
135          ),
136          mock.patch("mlflow.tracking.get_registry_uri", return_value=MOCK_PROFILE),
137      ):
138          repo = DatabricksModelsArtifactRepository(stage_uri_without_profile)
139          assert repo.artifact_uri == stage_uri_without_profile
140          assert repo.model_name == MOCK_MODEL_NAME
141          assert repo.model_version == MOCK_MODEL_VERSION
142          assert repo.databricks_profile_uri == MOCK_PROFILE
143  
144  
145  @pytest.mark.parametrize(
146      "valid_profileless_artifact_uri",
147      ["models:/MyModel/12", "models:/MyModel/Staging"],
148  )
149  def test_init_with_valid_uri_but_no_profile(valid_profileless_artifact_uri):
150      # Mock for `is_using_databricks_registry` fail when calling `get_registry_uri`
151      with mock.patch(
152          "mlflow.store.artifact.utils.models.mlflow.get_registry_uri",
153          return_value=None,
154      ):
155          with pytest.raises(
156              MlflowException,
157              match="A valid databricks profile is required to instantiate this repository",
158          ):
159              DatabricksModelsArtifactRepository(valid_profileless_artifact_uri)
160  
161  
162  def test_list_artifacts(databricks_model_artifact_repo):
163      status_code = 200
164  
165      def _raise_for_status():
166          if status_code == 404:
167              raise Exception(
168                  "404 Client Error: Not Found for url: https://shard-uri/api/2.0/mlflow/model-versions/list-artifacts?name=model&version=1"
169              )
170  
171      list_artifact_dir_response_mock = mock.MagicMock()
172      list_artifact_dir_response_mock.status_code = status_code
173      list_artifact_dir_json_mock = {
174          "files": [
175              {"path": "MLmodel", "is_dir": False, "file_size": 294},
176              {"path": "data", "is_dir": True, "file_size": None},
177          ]
178      }
179      list_artifact_dir_response_mock.text = json.dumps(list_artifact_dir_json_mock)
180      list_artifact_dir_response_mock.raise_for_status.side_effect = _raise_for_status
181      with mock.patch(
182          DATABRICKS_MODEL_ARTIFACT_REPOSITORY + "._call_endpoint",
183          return_value=list_artifact_dir_response_mock,
184      ) as call_endpoint_mock:
185          artifacts = databricks_model_artifact_repo.list_artifacts("")
186          assert isinstance(artifacts, list)
187          assert len(artifacts) == 2
188          assert artifacts[0].path == "MLmodel"
189          assert artifacts[0].is_dir is False
190          assert artifacts[0].file_size == 294
191          assert artifacts[1].path == "data"
192          assert artifacts[1].is_dir is True
193          assert artifacts[1].file_size is None
194          call_endpoint_mock.assert_called_once_with(ANY, REGISTRY_LIST_ARTIFACTS_ENDPOINT)
195  
196      # errors from API are propagated through to cli response
197      list_artifact_dir_bad_response_mock = mock.MagicMock()
198      status_code = 404
199      list_artifact_dir_bad_response_mock.status_code = status_code
200      list_artifact_dir_bad_response_mock.text = "An error occurred"
201      list_artifact_dir_bad_response_mock.raise_for_status.side_effect = _raise_for_status
202      with mock.patch(
203          DATABRICKS_MODEL_ARTIFACT_REPOSITORY + "._call_endpoint",
204          return_value=list_artifact_dir_bad_response_mock,
205      ) as call_endpoint_mock:
206          with pytest.raises(
207              MlflowException,
208              match=r"API request to list files under path `` failed with status code 404. "
209              "Response body: An error occurred",
210          ):
211              databricks_model_artifact_repo.list_artifacts("")
212          call_endpoint_mock.assert_called_once_with(ANY, REGISTRY_LIST_ARTIFACTS_ENDPOINT)
213  
214  
215  def test_list_artifacts_for_single_file(databricks_model_artifact_repo):
216      list_artifact_file_response_mock = mock.MagicMock()
217      list_artifact_file_response_mock.status_code = 200
218      list_artifact_file_json_mock = {
219          "files": [{"path": "MLmodel", "is_dir": False, "file_size": 294}]
220      }
221      list_artifact_file_response_mock.text = json.dumps(list_artifact_file_json_mock)
222      with mock.patch(
223          DATABRICKS_MODEL_ARTIFACT_REPOSITORY + "._call_endpoint",
224          return_value=list_artifact_file_response_mock,
225      ):
226          artifacts = databricks_model_artifact_repo.list_artifacts("MLmodel")
227          assert len(artifacts) == 0
228  
229  
230  @pytest.mark.parametrize(
231      ("remote_file_path", "local_path"),
232      [
233          ("test_file.txt", ""),
234          ("test_file.txt", None),
235          ("output/test_file", None),
236      ],
237  )
238  def test_download_file(databricks_model_artifact_repo, remote_file_path, local_path):
239      signed_uri_response_mock = mock.MagicMock()
240      signed_uri_response_mock.status_code = 200
241      signed_uri_mock = {
242          "signed_uri": "https://my-amazing-signed-uri-to-rule-them-all.com/1234-numbers-yay-567",
243          "headers": [{"name": "header_name", "value": "header_value"}],
244      }
245      expected_headers = {"header_name": "header_value"}
246      signed_uri_response_mock.text = json.dumps(signed_uri_mock)
247      with (
248          mock.patch(
249              DATABRICKS_MODEL_ARTIFACT_REPOSITORY + "._call_endpoint",
250              return_value=signed_uri_response_mock,
251          ) as call_endpoint_mock,
252          mock.patch(
253              DATABRICKS_MODEL_ARTIFACT_REPOSITORY_PACKAGE + ".download_file_using_http_uri",
254              return_value=None,
255          ) as download_mock,
256      ):
257          databricks_model_artifact_repo.download_artifacts(remote_file_path, local_path)
258          call_endpoint_mock.assert_called_with(ANY, REGISTRY_ARTIFACT_PRESIGNED_URI_ENDPOINT)
259          download_mock.assert_called_with(
260              signed_uri_mock["signed_uri"],
261              ANY,
262              ANY,
263              expected_headers,
264          )
265  
266  
267  @pytest.mark.parametrize(
268      ("remote_file_path"),
269      [
270          ("test_file.txt"),
271          ("output/test_file"),
272      ],
273  )
274  def test_parallelized_download_file_using_http_uri_success(
275      databricks_model_artifact_repo, remote_file_path
276  ):
277      signed_uri_mock = {
278          "signed_uri": "https://my-amazing-signed-uri-to-rule-them-all.com/1234-numbers-yay-567",
279          "headers": [{"name": "header_name", "value": "header_value"}],
280      }
281  
282      with (
283          mock.patch(
284              DATABRICKS_MODEL_ARTIFACT_REPOSITORY + ".list_artifacts",
285              return_value=[
286                  FileInfo(remote_file_path, True, MLFLOW_MULTIPART_DOWNLOAD_CHUNK_SIZE.get() + 1)
287              ],
288          ),
289          mock.patch(
290              DATABRICKS_MODEL_ARTIFACT_REPOSITORY + "._get_signed_download_uri",
291              return_value=(signed_uri_mock["signed_uri"], signed_uri_mock["headers"]),
292          ),
293          mock.patch(
294              "mlflow.utils.databricks_utils.get_databricks_env_vars",
295              return_value={},
296          ),
297          mock.patch(
298              DATABRICKS_MODEL_ARTIFACT_REPOSITORY_PACKAGE
299              + ".parallelized_download_file_using_http_uri",
300              return_value={},
301          ) as download_file_mock,
302      ):
303          databricks_model_artifact_repo._download_file(remote_file_path, "")
304          download_file_mock.assert_called()
305  
306  
307  @pytest.mark.parametrize(
308      ("remote_file_path"),
309      [
310          ("test_file.txt"),
311          ("output/test_file"),
312      ],
313  )
314  def test_parallelized_download_file_using_http_uri_with_error_downloads(
315      databricks_model_artifact_repo, remote_file_path
316  ):
317      signed_uri_mock = {
318          "signed_uri": "https://my-amazing-signed-uri-to-rule-them-all.com/1234-numbers-yay-567",
319          "headers": [{"name": "header_name", "value": "header_value"}],
320      }
321      error_downloads = {_Chunk(1, 2, 3, "test"): Exception("Internal Server Error")}
322  
323      with (
324          mock.patch(
325              DATABRICKS_MODEL_ARTIFACT_REPOSITORY + ".list_artifacts",
326              return_value=[
327                  FileInfo(remote_file_path, True, MLFLOW_MULTIPART_DOWNLOAD_CHUNK_SIZE.get() + 1)
328              ],
329          ),
330          mock.patch(
331              DATABRICKS_MODEL_ARTIFACT_REPOSITORY + "._get_signed_download_uri",
332              return_value=(signed_uri_mock["signed_uri"], signed_uri_mock["headers"]),
333          ),
334          mock.patch(
335              "mlflow.utils.databricks_utils.get_databricks_env_vars",
336              return_value={},
337          ),
338          mock.patch(
339              DATABRICKS_MODEL_ARTIFACT_REPOSITORY_PACKAGE
340              + ".parallelized_download_file_using_http_uri",
341              return_value=error_downloads,
342          ),
343          mock.patch(
344              "mlflow.utils.file_utils.download_chunk", side_effect=Exception("Retry failed")
345          ) as mock_download_chunk,
346      ):
347          with pytest.raises(MlflowException, match="Retry failed"):
348              databricks_model_artifact_repo._download_file(remote_file_path, "")
349  
350          mock_download_chunk.assert_called_with(
351              range_start=2,
352              range_end=3,
353              headers={"header_name": "header_value"},
354              download_path="",
355              http_uri="https://my-amazing-signed-uri-to-rule-them-all.com/1234-numbers-yay-567",
356          )
357  
358  
359  @pytest.mark.parametrize(
360      ("remote_file_path"),
361      [
362          ("test_file.txt"),
363          ("output/test_file"),
364      ],
365  )
366  def test_parallelized_download_file_using_http_uri_with_failed_downloads(
367      databricks_model_artifact_repo, remote_file_path
368  ):
369      signed_uri_mock = {
370          "signed_uri": "https://my-amazing-signed-uri-to-rule-them-all.com/1234-numbers-yay-567",
371          "headers": [{"name": "header_name", "value": "header_value"}],
372      }
373      failed_downloads = {_Chunk(1, 2, 3, "test"): Exception("Internal Server Error")}
374  
375      with (
376          mock.patch(
377              DATABRICKS_MODEL_ARTIFACT_REPOSITORY + ".list_artifacts",
378              return_value=[
379                  FileInfo(remote_file_path, True, MLFLOW_MULTIPART_DOWNLOAD_CHUNK_SIZE.get() + 1)
380              ],
381          ),
382          mock.patch(
383              DATABRICKS_MODEL_ARTIFACT_REPOSITORY + "._get_signed_download_uri",
384              return_value=(signed_uri_mock["signed_uri"], signed_uri_mock["headers"]),
385          ),
386          mock.patch(
387              "mlflow.utils.databricks_utils.get_databricks_env_vars",
388              return_value={},
389          ),
390          mock.patch(
391              DATABRICKS_MODEL_ARTIFACT_REPOSITORY_PACKAGE
392              + ".parallelized_download_file_using_http_uri",
393              return_value=failed_downloads,
394          ),
395          mock.patch(
396              "mlflow.utils.file_utils.download_chunk",
397              return_value=None,
398          ) as download_chunk_mock,
399      ):
400          databricks_model_artifact_repo._download_file(remote_file_path, "")
401          download_chunk_mock.assert_called()
402  
403  
404  def test_download_file_get_request_fail(databricks_model_artifact_repo):
405      with mock.patch(DATABRICKS_MODEL_ARTIFACT_REPOSITORY + "._call_endpoint") as call_endpoint_mock:
406          call_endpoint_mock.side_effect = MlflowException("MOCK ERROR")
407          with pytest.raises(MlflowException, match=r".+"):
408              databricks_model_artifact_repo.download_artifacts("Something")
409  
410  
411  def test_log_artifact_fail(databricks_model_artifact_repo):
412      with pytest.raises(MlflowException, match="This repository does not support logging artifacts"):
413          databricks_model_artifact_repo.log_artifact("Some file")
414  
415  
416  def test_log_artifacts_fail(databricks_model_artifact_repo):
417      with pytest.raises(MlflowException, match="This repository does not support logging artifacts"):
418          databricks_model_artifact_repo.log_artifacts("Some dir")
419  
420  
421  def test_delete_artifacts_fail(databricks_model_artifact_repo):
422      with pytest.raises(
423          NotImplementedError,
424          match="This artifact repository does not support deleting artifacts",
425      ):
426          databricks_model_artifact_repo.delete_artifacts()
427  
428  
429  def test_empty_headers_with_presigned_url(databricks_model_artifact_repo):
430      url = "https://test.com/1234"
431      encoding = "utf-8"
432      response = requests.Response()
433      response._content = bytes(json.dumps({"signed_uri": url}), encoding)
434      response.encoding = encoding
435      with mock.patch(
436          DATABRICKS_MODEL_ARTIFACT_REPOSITORY + "._call_endpoint",
437          return_value=response,
438      ) as call_endpoint_mock:
439          ret_url, headers = databricks_model_artifact_repo._get_signed_download_uri("test_file.txt")
440          call_endpoint_mock.assert_called_with(ANY, REGISTRY_ARTIFACT_PRESIGNED_URI_ENDPOINT)
441  
442          assert ret_url == url
443          assert headers is None
444  
445          new_headers = databricks_model_artifact_repo._extract_headers_from_signed_url(headers)
446  
447          assert new_headers == {}