test_databricks_models_artifact_repo.py
1 import json 2 from unittest import mock 3 from unittest.mock import ANY 4 5 import pytest 6 import requests 7 8 from mlflow.entities import FileInfo 9 from mlflow.entities.model_registry import ModelVersion 10 from mlflow.environment_variables import MLFLOW_MULTIPART_DOWNLOAD_CHUNK_SIZE 11 from mlflow.exceptions import MlflowException 12 from mlflow.store.artifact.databricks_models_artifact_repo import ( 13 DatabricksModelsArtifactRepository, 14 ) 15 from mlflow.tracking._model_registry.client import ModelRegistryClient 16 from mlflow.utils.file_utils import _Chunk 17 18 DATABRICKS_MODEL_ARTIFACT_REPOSITORY_PACKAGE = ( 19 "mlflow.store.artifact.databricks_models_artifact_repo" 20 ) 21 DATABRICKS_MODEL_ARTIFACT_REPOSITORY = ( 22 DATABRICKS_MODEL_ARTIFACT_REPOSITORY_PACKAGE + ".DatabricksModelsArtifactRepository" 23 ) 24 MOCK_MODEL_ROOT_URI_WITH_PROFILE = "models://profile@databricks/MyModel/12" 25 MOCK_MODEL_ROOT_URI_WITHOUT_PROFILE = "models:/MyModel/12" 26 MOCK_PROFILE = "databricks://profile" 27 MOCK_MODEL_NAME = "MyModel" 28 MOCK_MODEL_VERSION = "12" 29 30 REGISTRY_LIST_ARTIFACTS_ENDPOINT = "/api/2.0/mlflow/model-versions/list-artifacts" 31 REGISTRY_ARTIFACT_PRESIGNED_URI_ENDPOINT = "/api/2.0/mlflow/model-versions/get-signed-download-uri" 32 33 34 @pytest.fixture 35 def databricks_model_artifact_repo(): 36 return DatabricksModelsArtifactRepository(MOCK_MODEL_ROOT_URI_WITH_PROFILE) 37 38 39 def test_init_with_version_uri_containing_profile(): 40 repo = DatabricksModelsArtifactRepository(MOCK_MODEL_ROOT_URI_WITH_PROFILE) 41 assert repo.artifact_uri == MOCK_MODEL_ROOT_URI_WITH_PROFILE 42 assert repo.model_name == MOCK_MODEL_NAME 43 assert repo.model_version == MOCK_MODEL_VERSION 44 45 46 @pytest.mark.parametrize( 47 "stage_uri_with_profile", 48 [ 49 "models://profile@databricks/MyModel/Staging", 50 "models://profile@databricks/MyModel/Production", 51 ], 52 ) 53 def test_init_with_stage_uri_containing_profile(stage_uri_with_profile): 54 model_version_detailed = ModelVersion( 55 MOCK_MODEL_NAME, 56 MOCK_MODEL_VERSION, 57 "2345671890", 58 "234567890", 59 "some description", 60 "UserID", 61 "Production", 62 "source", 63 "run12345", 64 ) 65 get_latest_versions_patch = mock.patch.object( 66 ModelRegistryClient, "get_latest_versions", return_value=[model_version_detailed] 67 ) 68 with get_latest_versions_patch: 69 repo = DatabricksModelsArtifactRepository(stage_uri_with_profile) 70 assert repo.artifact_uri == stage_uri_with_profile 71 assert repo.model_name == MOCK_MODEL_NAME 72 assert repo.model_version == MOCK_MODEL_VERSION 73 assert repo.databricks_profile_uri == MOCK_PROFILE 74 75 76 @pytest.mark.parametrize( 77 "invalid_artifact_uri", 78 [ 79 "s3://test", 80 "dbfs:/databricks/mlflow/MV-id/models", 81 "dbfs://scope:key@notdatabricks/databricks/mlflow-regisry/123/models", 82 "models:/MyModel/12", 83 "models://scope:key@notdatabricks/MyModel/12", 84 ], 85 ) 86 def test_init_with_invalid_artifact_uris(invalid_artifact_uri): 87 with pytest.raises( 88 MlflowException, 89 match="A valid databricks profile is required to instantiate this repository", 90 ): 91 DatabricksModelsArtifactRepository(invalid_artifact_uri) 92 93 94 def test_init_with_version_uri_and_profile_is_inferred(): 95 # First mock for `is_using_databricks_registry` to pass 96 # Second mock to set `databricks_profile_uri` during instantiation 97 with ( 98 mock.patch( 99 "mlflow.store.artifact.utils.models.mlflow.get_registry_uri", 100 return_value=MOCK_PROFILE, 101 ), 102 mock.patch("mlflow.tracking.get_registry_uri", return_value=MOCK_PROFILE), 103 ): 104 repo = DatabricksModelsArtifactRepository(MOCK_MODEL_ROOT_URI_WITHOUT_PROFILE) 105 assert repo.artifact_uri == MOCK_MODEL_ROOT_URI_WITHOUT_PROFILE 106 assert repo.model_name == MOCK_MODEL_NAME 107 assert repo.model_version == MOCK_MODEL_VERSION 108 assert repo.databricks_profile_uri == MOCK_PROFILE 109 110 111 @pytest.mark.parametrize( 112 "stage_uri_without_profile", 113 ["models:/MyModel/Staging", "models:/MyModel/Production"], 114 ) 115 def test_init_with_stage_uri_and_profile_is_inferred(stage_uri_without_profile): 116 model_version_detailed = ModelVersion( 117 MOCK_MODEL_NAME, 118 MOCK_MODEL_VERSION, 119 "2345671890", 120 "234567890", 121 "some description", 122 "UserID", 123 "Production", 124 "source", 125 "run12345", 126 ) 127 get_latest_versions_patch = mock.patch.object( 128 ModelRegistryClient, "get_latest_versions", return_value=[model_version_detailed] 129 ) 130 with ( 131 get_latest_versions_patch, 132 mock.patch( 133 "mlflow.store.artifact.utils.models.mlflow.get_registry_uri", 134 return_value=MOCK_PROFILE, 135 ), 136 mock.patch("mlflow.tracking.get_registry_uri", return_value=MOCK_PROFILE), 137 ): 138 repo = DatabricksModelsArtifactRepository(stage_uri_without_profile) 139 assert repo.artifact_uri == stage_uri_without_profile 140 assert repo.model_name == MOCK_MODEL_NAME 141 assert repo.model_version == MOCK_MODEL_VERSION 142 assert repo.databricks_profile_uri == MOCK_PROFILE 143 144 145 @pytest.mark.parametrize( 146 "valid_profileless_artifact_uri", 147 ["models:/MyModel/12", "models:/MyModel/Staging"], 148 ) 149 def test_init_with_valid_uri_but_no_profile(valid_profileless_artifact_uri): 150 # Mock for `is_using_databricks_registry` fail when calling `get_registry_uri` 151 with mock.patch( 152 "mlflow.store.artifact.utils.models.mlflow.get_registry_uri", 153 return_value=None, 154 ): 155 with pytest.raises( 156 MlflowException, 157 match="A valid databricks profile is required to instantiate this repository", 158 ): 159 DatabricksModelsArtifactRepository(valid_profileless_artifact_uri) 160 161 162 def test_list_artifacts(databricks_model_artifact_repo): 163 status_code = 200 164 165 def _raise_for_status(): 166 if status_code == 404: 167 raise Exception( 168 "404 Client Error: Not Found for url: https://shard-uri/api/2.0/mlflow/model-versions/list-artifacts?name=model&version=1" 169 ) 170 171 list_artifact_dir_response_mock = mock.MagicMock() 172 list_artifact_dir_response_mock.status_code = status_code 173 list_artifact_dir_json_mock = { 174 "files": [ 175 {"path": "MLmodel", "is_dir": False, "file_size": 294}, 176 {"path": "data", "is_dir": True, "file_size": None}, 177 ] 178 } 179 list_artifact_dir_response_mock.text = json.dumps(list_artifact_dir_json_mock) 180 list_artifact_dir_response_mock.raise_for_status.side_effect = _raise_for_status 181 with mock.patch( 182 DATABRICKS_MODEL_ARTIFACT_REPOSITORY + "._call_endpoint", 183 return_value=list_artifact_dir_response_mock, 184 ) as call_endpoint_mock: 185 artifacts = databricks_model_artifact_repo.list_artifacts("") 186 assert isinstance(artifacts, list) 187 assert len(artifacts) == 2 188 assert artifacts[0].path == "MLmodel" 189 assert artifacts[0].is_dir is False 190 assert artifacts[0].file_size == 294 191 assert artifacts[1].path == "data" 192 assert artifacts[1].is_dir is True 193 assert artifacts[1].file_size is None 194 call_endpoint_mock.assert_called_once_with(ANY, REGISTRY_LIST_ARTIFACTS_ENDPOINT) 195 196 # errors from API are propagated through to cli response 197 list_artifact_dir_bad_response_mock = mock.MagicMock() 198 status_code = 404 199 list_artifact_dir_bad_response_mock.status_code = status_code 200 list_artifact_dir_bad_response_mock.text = "An error occurred" 201 list_artifact_dir_bad_response_mock.raise_for_status.side_effect = _raise_for_status 202 with mock.patch( 203 DATABRICKS_MODEL_ARTIFACT_REPOSITORY + "._call_endpoint", 204 return_value=list_artifact_dir_bad_response_mock, 205 ) as call_endpoint_mock: 206 with pytest.raises( 207 MlflowException, 208 match=r"API request to list files under path `` failed with status code 404. " 209 "Response body: An error occurred", 210 ): 211 databricks_model_artifact_repo.list_artifacts("") 212 call_endpoint_mock.assert_called_once_with(ANY, REGISTRY_LIST_ARTIFACTS_ENDPOINT) 213 214 215 def test_list_artifacts_for_single_file(databricks_model_artifact_repo): 216 list_artifact_file_response_mock = mock.MagicMock() 217 list_artifact_file_response_mock.status_code = 200 218 list_artifact_file_json_mock = { 219 "files": [{"path": "MLmodel", "is_dir": False, "file_size": 294}] 220 } 221 list_artifact_file_response_mock.text = json.dumps(list_artifact_file_json_mock) 222 with mock.patch( 223 DATABRICKS_MODEL_ARTIFACT_REPOSITORY + "._call_endpoint", 224 return_value=list_artifact_file_response_mock, 225 ): 226 artifacts = databricks_model_artifact_repo.list_artifacts("MLmodel") 227 assert len(artifacts) == 0 228 229 230 @pytest.mark.parametrize( 231 ("remote_file_path", "local_path"), 232 [ 233 ("test_file.txt", ""), 234 ("test_file.txt", None), 235 ("output/test_file", None), 236 ], 237 ) 238 def test_download_file(databricks_model_artifact_repo, remote_file_path, local_path): 239 signed_uri_response_mock = mock.MagicMock() 240 signed_uri_response_mock.status_code = 200 241 signed_uri_mock = { 242 "signed_uri": "https://my-amazing-signed-uri-to-rule-them-all.com/1234-numbers-yay-567", 243 "headers": [{"name": "header_name", "value": "header_value"}], 244 } 245 expected_headers = {"header_name": "header_value"} 246 signed_uri_response_mock.text = json.dumps(signed_uri_mock) 247 with ( 248 mock.patch( 249 DATABRICKS_MODEL_ARTIFACT_REPOSITORY + "._call_endpoint", 250 return_value=signed_uri_response_mock, 251 ) as call_endpoint_mock, 252 mock.patch( 253 DATABRICKS_MODEL_ARTIFACT_REPOSITORY_PACKAGE + ".download_file_using_http_uri", 254 return_value=None, 255 ) as download_mock, 256 ): 257 databricks_model_artifact_repo.download_artifacts(remote_file_path, local_path) 258 call_endpoint_mock.assert_called_with(ANY, REGISTRY_ARTIFACT_PRESIGNED_URI_ENDPOINT) 259 download_mock.assert_called_with( 260 signed_uri_mock["signed_uri"], 261 ANY, 262 ANY, 263 expected_headers, 264 ) 265 266 267 @pytest.mark.parametrize( 268 ("remote_file_path"), 269 [ 270 ("test_file.txt"), 271 ("output/test_file"), 272 ], 273 ) 274 def test_parallelized_download_file_using_http_uri_success( 275 databricks_model_artifact_repo, remote_file_path 276 ): 277 signed_uri_mock = { 278 "signed_uri": "https://my-amazing-signed-uri-to-rule-them-all.com/1234-numbers-yay-567", 279 "headers": [{"name": "header_name", "value": "header_value"}], 280 } 281 282 with ( 283 mock.patch( 284 DATABRICKS_MODEL_ARTIFACT_REPOSITORY + ".list_artifacts", 285 return_value=[ 286 FileInfo(remote_file_path, True, MLFLOW_MULTIPART_DOWNLOAD_CHUNK_SIZE.get() + 1) 287 ], 288 ), 289 mock.patch( 290 DATABRICKS_MODEL_ARTIFACT_REPOSITORY + "._get_signed_download_uri", 291 return_value=(signed_uri_mock["signed_uri"], signed_uri_mock["headers"]), 292 ), 293 mock.patch( 294 "mlflow.utils.databricks_utils.get_databricks_env_vars", 295 return_value={}, 296 ), 297 mock.patch( 298 DATABRICKS_MODEL_ARTIFACT_REPOSITORY_PACKAGE 299 + ".parallelized_download_file_using_http_uri", 300 return_value={}, 301 ) as download_file_mock, 302 ): 303 databricks_model_artifact_repo._download_file(remote_file_path, "") 304 download_file_mock.assert_called() 305 306 307 @pytest.mark.parametrize( 308 ("remote_file_path"), 309 [ 310 ("test_file.txt"), 311 ("output/test_file"), 312 ], 313 ) 314 def test_parallelized_download_file_using_http_uri_with_error_downloads( 315 databricks_model_artifact_repo, remote_file_path 316 ): 317 signed_uri_mock = { 318 "signed_uri": "https://my-amazing-signed-uri-to-rule-them-all.com/1234-numbers-yay-567", 319 "headers": [{"name": "header_name", "value": "header_value"}], 320 } 321 error_downloads = {_Chunk(1, 2, 3, "test"): Exception("Internal Server Error")} 322 323 with ( 324 mock.patch( 325 DATABRICKS_MODEL_ARTIFACT_REPOSITORY + ".list_artifacts", 326 return_value=[ 327 FileInfo(remote_file_path, True, MLFLOW_MULTIPART_DOWNLOAD_CHUNK_SIZE.get() + 1) 328 ], 329 ), 330 mock.patch( 331 DATABRICKS_MODEL_ARTIFACT_REPOSITORY + "._get_signed_download_uri", 332 return_value=(signed_uri_mock["signed_uri"], signed_uri_mock["headers"]), 333 ), 334 mock.patch( 335 "mlflow.utils.databricks_utils.get_databricks_env_vars", 336 return_value={}, 337 ), 338 mock.patch( 339 DATABRICKS_MODEL_ARTIFACT_REPOSITORY_PACKAGE 340 + ".parallelized_download_file_using_http_uri", 341 return_value=error_downloads, 342 ), 343 mock.patch( 344 "mlflow.utils.file_utils.download_chunk", side_effect=Exception("Retry failed") 345 ) as mock_download_chunk, 346 ): 347 with pytest.raises(MlflowException, match="Retry failed"): 348 databricks_model_artifact_repo._download_file(remote_file_path, "") 349 350 mock_download_chunk.assert_called_with( 351 range_start=2, 352 range_end=3, 353 headers={"header_name": "header_value"}, 354 download_path="", 355 http_uri="https://my-amazing-signed-uri-to-rule-them-all.com/1234-numbers-yay-567", 356 ) 357 358 359 @pytest.mark.parametrize( 360 ("remote_file_path"), 361 [ 362 ("test_file.txt"), 363 ("output/test_file"), 364 ], 365 ) 366 def test_parallelized_download_file_using_http_uri_with_failed_downloads( 367 databricks_model_artifact_repo, remote_file_path 368 ): 369 signed_uri_mock = { 370 "signed_uri": "https://my-amazing-signed-uri-to-rule-them-all.com/1234-numbers-yay-567", 371 "headers": [{"name": "header_name", "value": "header_value"}], 372 } 373 failed_downloads = {_Chunk(1, 2, 3, "test"): Exception("Internal Server Error")} 374 375 with ( 376 mock.patch( 377 DATABRICKS_MODEL_ARTIFACT_REPOSITORY + ".list_artifacts", 378 return_value=[ 379 FileInfo(remote_file_path, True, MLFLOW_MULTIPART_DOWNLOAD_CHUNK_SIZE.get() + 1) 380 ], 381 ), 382 mock.patch( 383 DATABRICKS_MODEL_ARTIFACT_REPOSITORY + "._get_signed_download_uri", 384 return_value=(signed_uri_mock["signed_uri"], signed_uri_mock["headers"]), 385 ), 386 mock.patch( 387 "mlflow.utils.databricks_utils.get_databricks_env_vars", 388 return_value={}, 389 ), 390 mock.patch( 391 DATABRICKS_MODEL_ARTIFACT_REPOSITORY_PACKAGE 392 + ".parallelized_download_file_using_http_uri", 393 return_value=failed_downloads, 394 ), 395 mock.patch( 396 "mlflow.utils.file_utils.download_chunk", 397 return_value=None, 398 ) as download_chunk_mock, 399 ): 400 databricks_model_artifact_repo._download_file(remote_file_path, "") 401 download_chunk_mock.assert_called() 402 403 404 def test_download_file_get_request_fail(databricks_model_artifact_repo): 405 with mock.patch(DATABRICKS_MODEL_ARTIFACT_REPOSITORY + "._call_endpoint") as call_endpoint_mock: 406 call_endpoint_mock.side_effect = MlflowException("MOCK ERROR") 407 with pytest.raises(MlflowException, match=r".+"): 408 databricks_model_artifact_repo.download_artifacts("Something") 409 410 411 def test_log_artifact_fail(databricks_model_artifact_repo): 412 with pytest.raises(MlflowException, match="This repository does not support logging artifacts"): 413 databricks_model_artifact_repo.log_artifact("Some file") 414 415 416 def test_log_artifacts_fail(databricks_model_artifact_repo): 417 with pytest.raises(MlflowException, match="This repository does not support logging artifacts"): 418 databricks_model_artifact_repo.log_artifacts("Some dir") 419 420 421 def test_delete_artifacts_fail(databricks_model_artifact_repo): 422 with pytest.raises( 423 NotImplementedError, 424 match="This artifact repository does not support deleting artifacts", 425 ): 426 databricks_model_artifact_repo.delete_artifacts() 427 428 429 def test_empty_headers_with_presigned_url(databricks_model_artifact_repo): 430 url = "https://test.com/1234" 431 encoding = "utf-8" 432 response = requests.Response() 433 response._content = bytes(json.dumps({"signed_uri": url}), encoding) 434 response.encoding = encoding 435 with mock.patch( 436 DATABRICKS_MODEL_ARTIFACT_REPOSITORY + "._call_endpoint", 437 return_value=response, 438 ) as call_endpoint_mock: 439 ret_url, headers = databricks_model_artifact_repo._get_signed_download_uri("test_file.txt") 440 call_endpoint_mock.assert_called_with(ANY, REGISTRY_ARTIFACT_PRESIGNED_URI_ENDPOINT) 441 442 assert ret_url == url 443 assert headers is None 444 445 new_headers = databricks_model_artifact_repo._extract_headers_from_signed_url(headers) 446 447 assert new_headers == {}