test_rest_store.py
1 import json 2 import uuid 3 from unittest import mock 4 5 import pytest 6 7 from mlflow.entities.model_registry import ModelVersion, ModelVersionTag, RegisteredModelTag 8 from mlflow.entities.model_registry.model_version_status import ModelVersionStatus 9 from mlflow.exceptions import MlflowException 10 from mlflow.prompt.registry_utils import IS_PROMPT_TAG_KEY 11 from mlflow.protos.model_registry_pb2 import ( 12 CreateModelVersion, 13 CreateRegisteredModel, 14 DeleteModelVersion, 15 DeleteModelVersionTag, 16 DeleteRegisteredModel, 17 DeleteRegisteredModelAlias, 18 DeleteRegisteredModelTag, 19 GetLatestVersions, 20 GetModelVersion, 21 GetModelVersionByAlias, 22 GetModelVersionDownloadUri, 23 GetRegisteredModel, 24 RenameRegisteredModel, 25 SearchModelVersions, 26 SearchRegisteredModels, 27 SetModelVersionTag, 28 SetRegisteredModelAlias, 29 SetRegisteredModelTag, 30 TransitionModelVersionStage, 31 UpdateModelVersion, 32 UpdateRegisteredModel, 33 ) 34 from mlflow.store.model_registry.rest_store import RestStore 35 from mlflow.utils.proto_json_utils import message_to_json 36 from mlflow.utils.rest_utils import MlflowHostCreds 37 from mlflow.utils.workspace_context import WorkspaceContext 38 from mlflow.utils.workspace_utils import DEFAULT_WORKSPACE_NAME 39 40 from tests.helper_functions import mock_http_request_200, mock_http_request_403_200 41 42 43 @pytest.fixture(autouse=True, params=[False, True], ids=["workspace-disabled", "workspace-enabled"]) 44 def workspaces_enabled(request): 45 """ 46 Run every test in this module with workspaces disabled and enabled to cover both code paths. 47 """ 48 49 enabled = request.param 50 if enabled: 51 with ( 52 WorkspaceContext(DEFAULT_WORKSPACE_NAME), 53 mock.patch( 54 "mlflow.store.workspace_rest_store_mixin.WorkspaceRestStoreMixin.supports_workspaces", 55 new_callable=mock.PropertyMock, 56 return_value=True, 57 ), 58 ): 59 yield enabled 60 else: 61 yield enabled 62 63 64 @pytest.fixture 65 def creds(): 66 return MlflowHostCreds("https://hello") 67 68 69 @pytest.fixture 70 def store(creds): 71 return RestStore(lambda: creds) 72 73 74 def _args(host_creds, endpoint, method, json_body): 75 res = {"host_creds": host_creds, "endpoint": f"/api/2.0/mlflow/{endpoint}", "method": method} 76 if method == "GET": 77 res["params"] = json.loads(json_body) 78 else: 79 res["json"] = json.loads(json_body) 80 return res 81 82 83 def _verify_requests(http_request, creds, endpoint, method, proto_message): 84 json_body = message_to_json(proto_message) 85 http_request.assert_any_call(**(_args(creds, endpoint, method, json_body))) 86 87 88 def _verify_all_requests(http_request, creds, endpoints, proto_message): 89 json_body = message_to_json(proto_message) 90 http_request.assert_has_calls([ 91 mock.call(**(_args(creds, endpoint, method, json_body))) for endpoint, method in endpoints 92 ]) 93 94 95 def test_create_registered_model(store, creds): 96 tags = [ 97 RegisteredModelTag(key="key", value="value"), 98 RegisteredModelTag(key="anotherKey", value="some other value"), 99 ] 100 description = "best model ever" 101 with mock_http_request_200() as mock_http: 102 store.create_registered_model("model_1", tags, description) 103 _verify_requests( 104 mock_http, 105 creds, 106 "registered-models/create", 107 "POST", 108 CreateRegisteredModel( 109 name="model_1", tags=[tag.to_proto() for tag in tags], description=description 110 ), 111 ) 112 113 114 def test_update_registered_model_name(store, creds): 115 name = "model_1" 116 new_name = "model_2" 117 with mock_http_request_200() as mock_http: 118 store.rename_registered_model(name=name, new_name=new_name) 119 _verify_requests( 120 mock_http, 121 creds, 122 "registered-models/rename", 123 "POST", 124 RenameRegisteredModel(name=name, new_name=new_name), 125 ) 126 127 128 def test_update_registered_model_description(store, creds): 129 name = "model_1" 130 description = "test model" 131 with mock_http_request_200() as mock_http: 132 store.update_registered_model(name=name, description=description) 133 _verify_requests( 134 mock_http, 135 creds, 136 "registered-models/update", 137 "PATCH", 138 UpdateRegisteredModel(name=name, description=description), 139 ) 140 141 142 def test_delete_registered_model(store, creds): 143 name = "model_1" 144 with mock_http_request_200() as mock_http: 145 store.delete_registered_model(name=name) 146 _verify_requests( 147 mock_http, creds, "registered-models/delete", "DELETE", DeleteRegisteredModel(name=name) 148 ) 149 150 151 def test_search_registered_models(store, creds): 152 with mock_http_request_200() as mock_http: 153 store.search_registered_models() 154 _verify_requests(mock_http, creds, "registered-models/search", "GET", SearchRegisteredModels()) 155 156 157 @pytest.mark.parametrize("filter_string", [None, "model = 'yo'"]) 158 @pytest.mark.parametrize("max_results", [None, 400]) 159 @pytest.mark.parametrize("page_token", [None, "blah"]) 160 @pytest.mark.parametrize("order_by", [None, ["x", "Y"]]) 161 def test_search_registered_models_params( 162 store, creds, filter_string, max_results, page_token, order_by 163 ): 164 params = { 165 "filter_string": filter_string, 166 "max_results": max_results, 167 "page_token": page_token, 168 "order_by": order_by, 169 } 170 params = {k: v for k, v in params.items() if v is not None} 171 with mock_http_request_200() as mock_http: 172 store.search_registered_models(**params) 173 if "filter_string" in params: 174 params["filter"] = params.pop("filter_string") 175 _verify_requests( 176 mock_http, 177 creds, 178 "registered-models/search", 179 "GET", 180 SearchRegisteredModels(**params), 181 ) 182 183 184 def test_get_registered_model(store, creds): 185 name = "model_1" 186 with mock_http_request_200() as mock_http: 187 store.get_registered_model(name=name) 188 _verify_requests( 189 mock_http, creds, "registered-models/get", "GET", GetRegisteredModel(name=name) 190 ) 191 192 193 def test_get_latest_versions(store, creds): 194 name = "model_1" 195 with mock_http_request_403_200() as mock_http: 196 store.get_latest_versions(name=name) 197 endpoint = "registered-models/get-latest-versions" 198 endpoints = [(endpoint, "POST"), (endpoint, "GET")] 199 _verify_all_requests(mock_http, creds, endpoints, GetLatestVersions(name=name)) 200 201 202 def test_get_latest_versions_with_stages(store, creds): 203 name = "model_1" 204 with mock_http_request_403_200() as mock_http: 205 store.get_latest_versions(name=name, stages=["blaah"]) 206 endpoint = "registered-models/get-latest-versions" 207 endpoints = [(endpoint, "POST"), (endpoint, "GET")] 208 _verify_all_requests( 209 mock_http, creds, endpoints, GetLatestVersions(name=name, stages=["blaah"]) 210 ) 211 212 213 def test_set_registered_model_tag(store, creds): 214 name = "model_1" 215 tag = RegisteredModelTag(key="key", value="value") 216 with mock_http_request_200() as mock_http: 217 store.set_registered_model_tag(name=name, tag=tag) 218 _verify_requests( 219 mock_http, 220 creds, 221 "registered-models/set-tag", 222 "POST", 223 SetRegisteredModelTag(name=name, key=tag.key, value=tag.value), 224 ) 225 226 227 def test_delete_registered_model_tag(store, creds): 228 name = "model_1" 229 with mock_http_request_200() as mock_http: 230 store.delete_registered_model_tag(name=name, key="key") 231 _verify_requests( 232 mock_http, 233 creds, 234 "registered-models/delete-tag", 235 "DELETE", 236 DeleteRegisteredModelTag(name=name, key="key"), 237 ) 238 239 240 def test_create_model_version(store, creds): 241 with mock_http_request_200() as mock_http: 242 store.create_model_version("model_1", "path/to/source") 243 _verify_requests( 244 mock_http, 245 creds, 246 "model-versions/create", 247 "POST", 248 CreateModelVersion(name="model_1", source="path/to/source"), 249 ) 250 # test optional fields 251 run_id = uuid.uuid4().hex 252 tags = [ 253 ModelVersionTag(key="key", value="value"), 254 ModelVersionTag(key="anotherKey", value="some other value"), 255 ] 256 run_link = "localhost:5000/path/to/run" 257 description = "version description" 258 with mock_http_request_200() as mock_http: 259 store.create_model_version( 260 "model_1", 261 "path/to/source", 262 run_id, 263 tags, 264 run_link=run_link, 265 description=description, 266 ) 267 _verify_requests( 268 mock_http, 269 creds, 270 "model-versions/create", 271 "POST", 272 CreateModelVersion( 273 name="model_1", 274 source="path/to/source", 275 run_id=run_id, 276 run_link=run_link, 277 tags=[tag.to_proto() for tag in tags], 278 description=description, 279 ), 280 ) 281 282 283 def test_transition_model_version_stage(store, creds): 284 name = "model_1" 285 version = "5" 286 with mock_http_request_200() as mock_http: 287 store.transition_model_version_stage( 288 name=name, version=version, stage="prod", archive_existing_versions=True 289 ) 290 _verify_requests( 291 mock_http, 292 creds, 293 "model-versions/transition-stage", 294 "POST", 295 TransitionModelVersionStage( 296 name=name, version=version, stage="prod", archive_existing_versions=True 297 ), 298 ) 299 300 301 def test_update_model_version_description(store, creds): 302 name = "model_1" 303 version = "5" 304 description = "test model version" 305 with mock_http_request_200() as mock_http: 306 store.update_model_version(name=name, version=version, description=description) 307 _verify_requests( 308 mock_http, 309 creds, 310 "model-versions/update", 311 "PATCH", 312 UpdateModelVersion(name=name, version=version, description="test model version"), 313 ) 314 315 316 def test_delete_model_version(store, creds): 317 name = "model_1" 318 version = "12" 319 with mock_http_request_200() as mock_http: 320 store.delete_model_version(name=name, version=version) 321 _verify_requests( 322 mock_http, 323 creds, 324 "model-versions/delete", 325 "DELETE", 326 DeleteModelVersion(name=name, version=version), 327 ) 328 329 330 def test_get_model_version_details(store, creds): 331 name = "model_11" 332 version = "8" 333 with mock_http_request_200() as mock_http: 334 store.get_model_version(name=name, version=version) 335 _verify_requests( 336 mock_http, creds, "model-versions/get", "GET", GetModelVersion(name=name, version=version) 337 ) 338 339 340 def test_get_model_version_download_uri(store, creds): 341 name = "model_11" 342 version = "8" 343 with mock_http_request_200() as mock_http: 344 store.get_model_version_download_uri(name=name, version=version) 345 _verify_requests( 346 mock_http, 347 creds, 348 "model-versions/get-download-uri", 349 "GET", 350 GetModelVersionDownloadUri(name=name, version=version), 351 ) 352 353 354 def test_search_model_versions(store, creds): 355 with mock_http_request_200() as mock_http: 356 store.search_model_versions() 357 _verify_requests(mock_http, creds, "model-versions/search", "GET", SearchModelVersions()) 358 359 360 @pytest.mark.parametrize("filter_string", [None, "name = 'model_12'"]) 361 @pytest.mark.parametrize("max_results", [None, 400]) 362 @pytest.mark.parametrize("page_token", [None, "blah"]) 363 @pytest.mark.parametrize("order_by", ["version DESC", "creation_time DESC"]) 364 def test_search_model_versions_params( 365 store, creds, filter_string, max_results, page_token, order_by 366 ): 367 params = { 368 "filter_string": filter_string, 369 "max_results": max_results, 370 "page_token": page_token, 371 "order_by": order_by, 372 } 373 params = {k: v for k, v in params.items() if v is not None} 374 with mock_http_request_200() as mock_http: 375 store.search_model_versions(**params) 376 if "filter_string" in params: 377 params["filter"] = params.pop("filter_string") 378 _verify_requests( 379 mock_http, 380 creds, 381 "model-versions/search", 382 "GET", 383 SearchModelVersions(**params), 384 ) 385 386 387 def test_set_model_version_tag(store, creds): 388 name = "model_1" 389 tag = ModelVersionTag(key="key", value="value") 390 with mock_http_request_200() as mock_http: 391 store.set_model_version_tag(name=name, version="1", tag=tag) 392 _verify_requests( 393 mock_http, 394 creds, 395 "model-versions/set-tag", 396 "POST", 397 SetModelVersionTag(name=name, version="1", key=tag.key, value=tag.value), 398 ) 399 400 401 def test_delete_model_version_tag(store, creds): 402 name = "model_1" 403 with mock_http_request_200() as mock_http: 404 store.delete_model_version_tag(name=name, version="1", key="key") 405 _verify_requests( 406 mock_http, 407 creds, 408 "model-versions/delete-tag", 409 "DELETE", 410 DeleteModelVersionTag(name=name, version="1", key="key"), 411 ) 412 413 414 def test_set_registered_model_alias(store, creds): 415 name = "model_1" 416 with mock_http_request_200() as mock_http: 417 store.set_registered_model_alias(name=name, alias="test_alias", version="1") 418 _verify_requests( 419 mock_http, 420 creds, 421 "registered-models/alias", 422 "POST", 423 SetRegisteredModelAlias(name=name, alias="test_alias", version="1"), 424 ) 425 426 427 def test_delete_registered_model_alias(store, creds): 428 name = "model_1" 429 with mock_http_request_200() as mock_http: 430 store.delete_registered_model_alias(name=name, alias="test_alias") 431 _verify_requests( 432 mock_http, 433 creds, 434 "registered-models/alias", 435 "DELETE", 436 DeleteRegisteredModelAlias(name=name, alias="test_alias"), 437 ) 438 439 440 def test_get_model_version_by_alias(store, creds): 441 name = "model_1" 442 with mock_http_request_200() as mock_http: 443 store.get_model_version_by_alias(name=name, alias="test_alias") 444 _verify_requests( 445 mock_http, 446 creds, 447 "registered-models/alias", 448 "GET", 449 GetModelVersionByAlias(name=name, alias="test_alias"), 450 ) 451 452 453 def test_await_model_version_creation_pending(store): 454 pending_mv = ModelVersion( 455 name="Model 1", 456 version="1", 457 creation_timestamp=123, 458 status=ModelVersionStatus.to_string(ModelVersionStatus.PENDING_REGISTRATION), 459 ) 460 with ( 461 mock.patch( 462 "mlflow.store.model_registry.abstract_store.AWAIT_MODEL_VERSION_CREATE_SLEEP_INTERVAL_SECONDS", 463 1, 464 ), 465 mock.patch.object(store, "get_model_version", return_value=pending_mv), 466 pytest.raises(MlflowException, match="Exceeded max wait time"), 467 ): 468 store._await_model_version_creation(pending_mv, 0.5) 469 470 471 def test_await_model_version_creation_failed(store): 472 pending_mv = ModelVersion( 473 name="Model 1", 474 version="1", 475 creation_timestamp=123, 476 status=ModelVersionStatus.to_string(ModelVersionStatus.FAILED_REGISTRATION), 477 ) 478 with ( 479 mock.patch.object(store, "get_model_version", return_value=pending_mv), 480 pytest.raises(MlflowException, match="Model version creation failed for model name"), 481 ): 482 store._await_model_version_creation(pending_mv, 0.5) 483 484 485 @pytest.mark.parametrize("is_prompt", [True, False], ids=["prompt", "model"]) 486 def test_await_model_version_creation_show_correct_message_for_prompt(store, is_prompt): 487 tags = [ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true")] if is_prompt else [] 488 pending = ModelVersion( 489 name="test", 490 version="1", 491 creation_timestamp=123, 492 tags=tags, 493 status=ModelVersionStatus.to_string(ModelVersionStatus.PENDING_REGISTRATION), 494 ) 495 completed = ModelVersion( 496 name="test", 497 version="1", 498 creation_timestamp=123, 499 tags=tags, 500 status=ModelVersionStatus.to_string(ModelVersionStatus.READY), 501 ) 502 503 with ( 504 mock.patch("mlflow.store.model_registry.abstract_store._logger") as mock_logger, 505 mock.patch.object(store, "get_model_version", return_value=completed), 506 ): 507 store._await_model_version_creation(pending, 10) 508 509 mock_logger.info.assert_called_once() 510 info_message = mock_logger.mock_calls[0][1][0] 511 if is_prompt: 512 assert "prompt" in info_message 513 assert "model" not in info_message 514 else: 515 assert "prompt" not in info_message 516 assert "model" in info_message