test_sqlalchemy_store.py
1 import concurrent.futures 2 import shutil 3 import time 4 import uuid 5 from pathlib import Path 6 from unittest import mock 7 8 import pytest 9 from sqlalchemy import create_engine, text 10 11 from mlflow.entities.model_registry import ( 12 ModelVersion, 13 ModelVersionTag, 14 RegisteredModelTag, 15 ) 16 from mlflow.entities.model_registry.prompt_version import IS_PROMPT_TAG_KEY 17 from mlflow.entities.webhook import WebhookAction, WebhookEntity, WebhookEvent, WebhookStatus 18 from mlflow.environment_variables import ( 19 _MLFLOW_GO_STORE_TESTING, 20 MLFLOW_ENABLE_WORKSPACES, 21 MLFLOW_TRACKING_URI, 22 ) 23 from mlflow.exceptions import MlflowException 24 from mlflow.prompt.constants import PROMPT_TEXT_TAG_KEY 25 from mlflow.protos.databricks_pb2 import ( 26 INVALID_PARAMETER_VALUE, 27 RESOURCE_ALREADY_EXISTS, 28 RESOURCE_DOES_NOT_EXIST, 29 ErrorCode, 30 ) 31 from mlflow.store.model_registry.dbmodels.models import ( 32 SqlModelVersion, 33 SqlModelVersionTag, 34 SqlRegisteredModel, 35 SqlRegisteredModelTag, 36 SqlWebhook, 37 ) 38 from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore 39 from mlflow.store.model_registry.sqlalchemy_workspace_store import ( 40 WorkspaceAwareSqlAlchemyStore, 41 ) 42 from mlflow.utils.workspace_context import WorkspaceContext 43 from mlflow.utils.workspace_utils import DEFAULT_WORKSPACE_NAME 44 45 from tests.helper_functions import random_str 46 47 pytestmark = pytest.mark.notrackingurimock 48 49 GO_MOCK_TIME_TAG = "mock.time.go.testing.tag" 50 51 52 @pytest.fixture(autouse=True, params=[False, True], ids=["workspace-disabled", "workspace-enabled"]) 53 def workspaces_enabled(request, monkeypatch, disable_workspace_mode_by_default): 54 """ 55 Run every test in this module with workspaces disabled and enabled to cover both code paths. 56 """ 57 enabled = request.param 58 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true" if enabled else "false") 59 if enabled: 60 with WorkspaceContext(DEFAULT_WORKSPACE_NAME): 61 yield enabled 62 else: 63 yield enabled 64 65 66 @pytest.fixture 67 def store(tmp_path: Path, cached_db: Path, workspaces_enabled): 68 store_cls = WorkspaceAwareSqlAlchemyStore if workspaces_enabled else SqlAlchemyStore 69 if db_uri_env := MLFLOW_TRACKING_URI.get(): 70 s = store_cls(db_uri_env) 71 yield s 72 _cleanup_database(s) 73 else: 74 db_path = tmp_path / "mlflow.db" 75 shutil.copy(cached_db, db_path) 76 db_uri = f"sqlite:///{db_path}" 77 s = store_cls(db_uri) 78 yield s 79 80 # Dispose the engine to close all pooled connections 81 s.engine.dispose() 82 83 84 def _cleanup_database(store: SqlAlchemyStore): 85 with store.ManagedSessionMaker() as session: 86 # Delete all rows in all tables 87 for model in ( 88 SqlModelVersionTag, 89 SqlRegisteredModelTag, 90 SqlModelVersion, 91 SqlRegisteredModel, 92 SqlWebhook, 93 ): 94 session.query(model).delete() 95 96 97 def _rm_maker(store, name, tags=None, description=None): 98 return store.create_registered_model(name, tags, description) 99 100 101 def _add_go_test_tags(tags, val): 102 if _MLFLOW_GO_STORE_TESTING.get(): 103 return tags + [RegisteredModelTag(GO_MOCK_TIME_TAG, val)] 104 return tags 105 106 107 def _mv_maker( 108 store, 109 name, 110 source="path/to/source", 111 run_id=uuid.uuid4().hex, 112 tags=None, 113 run_link=None, 114 description=None, 115 ): 116 return store.create_model_version( 117 name, source, run_id, tags, run_link=run_link, description=description 118 ) 119 120 121 def _extract_latest_by_stage(latest_versions): 122 return {mvd.current_stage: str(mvd.version) for mvd in latest_versions} 123 124 125 def test_create_registered_model(store): 126 name = random_str() + "abCD" 127 rm1 = _rm_maker(store, name) 128 assert rm1.name == name 129 assert rm1.description is None 130 131 # error on duplicate 132 with pytest.raises( 133 MlflowException, match=rf"Registered Model \(name={name}\) already exists" 134 ) as exception_context: 135 _rm_maker(store, name) 136 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_ALREADY_EXISTS) 137 138 # slightly different name is ok 139 for name2 in [name + "extra", name + name]: 140 rm2 = _rm_maker(store, name2) 141 assert rm2.name == name2 142 143 # test create model with tags 144 name2 = random_str() + "tags" 145 tags = [ 146 RegisteredModelTag("key", "value"), 147 RegisteredModelTag("anotherKey", "some other value"), 148 ] 149 rm2 = _rm_maker(store, name2, tags) 150 rmd2 = store.get_registered_model(name2) 151 assert rm2.name == name2 152 assert rm2.tags == {tag.key: tag.value for tag in tags} 153 assert rmd2.name == name2 154 assert rmd2.tags == {tag.key: tag.value for tag in tags} 155 156 # create with description 157 name3 = random_str() + "-description" 158 description = "the best model ever" 159 rm3 = _rm_maker(store, name3, description=description) 160 rmd3 = store.get_registered_model(name3) 161 assert rm3.name == name3 162 assert rm3.description == description 163 assert rmd3.name == name3 164 assert rmd3.description == description 165 166 # invalid model name will fail 167 with pytest.raises( 168 MlflowException, match=r"Missing value for required parameter 'name'" 169 ) as exception_context: 170 _rm_maker(store, None) 171 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 172 with pytest.raises( 173 MlflowException, match=r"Missing value for required parameter 'name'" 174 ) as exception_context: 175 _rm_maker(store, "") 176 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 177 178 179 def test_get_registered_model(store): 180 name = "model_1" 181 tags = [ 182 RegisteredModelTag("key", "value"), 183 RegisteredModelTag("anotherKey", "some other value"), 184 ] 185 # use fake clock 186 with mock.patch("time.time", return_value=1234): 187 rm = _rm_maker(store, name, _add_go_test_tags(tags, "1234000")) 188 assert rm.name == name 189 190 rmd = store.get_registered_model(name=name) 191 assert rmd.name == name 192 assert rmd.creation_timestamp == 1234000 193 assert rmd.last_updated_timestamp == 1234000 194 assert rmd.description is None 195 assert rmd.latest_versions == [] 196 assert rmd.tags == {tag.key: tag.value for tag in tags} 197 198 199 def test_update_registered_model(store): 200 name = "model_for_update_RM" 201 rm1 = _rm_maker(store, name) 202 rmd1 = store.get_registered_model(name=name) 203 assert rm1.name == name 204 assert rmd1.description is None 205 206 # update description 207 rm2 = store.update_registered_model(name=name, description="test model") 208 rmd2 = store.get_registered_model(name=name) 209 assert rm2.name == "model_for_update_RM" 210 assert rmd2.name == "model_for_update_RM" 211 assert rmd2.description == "test model" 212 213 214 def test_rename_registered_model(store): 215 original_name = "original name" 216 new_name = "new name" 217 _rm_maker(store, original_name) 218 _mv_maker(store, original_name) 219 _mv_maker(store, original_name) 220 rm = store.get_registered_model(original_name) 221 mv1 = store.get_model_version(original_name, 1) 222 mv2 = store.get_model_version(original_name, 2) 223 assert rm.name == original_name 224 assert mv1.name == original_name 225 assert mv2.name == original_name 226 227 # test renaming registered model also updates its model versions 228 store.rename_registered_model(original_name, new_name) 229 rm = store.get_registered_model(new_name) 230 mv1 = store.get_model_version(new_name, 1) 231 mv2 = store.get_model_version(new_name, 2) 232 assert rm.name == new_name 233 assert mv1.name == new_name 234 assert mv2.name == new_name 235 236 # test accessing the model with the old name will fail 237 with pytest.raises( 238 MlflowException, match=rf"Registered Model with name={original_name} not found" 239 ) as exception_context: 240 store.get_registered_model(original_name) 241 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 242 243 # test name another model with the replaced name is ok 244 _rm_maker(store, original_name) 245 # cannot rename model to conflict with an existing model 246 with pytest.raises( 247 MlflowException, 248 match=rf"Registered Model \(name={original_name}\) already exists", 249 ) as exception_context: 250 store.rename_registered_model(new_name, original_name) 251 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_ALREADY_EXISTS) 252 # invalid model name will fail 253 with pytest.raises( 254 MlflowException, match=r"Missing value for required parameter 'new_name'" 255 ) as exception_context: 256 store.rename_registered_model(original_name, None) 257 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 258 with pytest.raises( 259 MlflowException, match=r"Missing value for required parameter 'new_name'" 260 ) as exception_context: 261 store.rename_registered_model(original_name, "") 262 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 263 264 265 def test_delete_registered_model(store): 266 name = "model_for_delete_RM" 267 _rm_maker(store, name) 268 _mv_maker(store, name) 269 rm1 = store.get_registered_model(name=name) 270 mv1 = store.get_model_version(name, 1) 271 assert rm1.name == name 272 assert mv1.name == name 273 274 # delete model 275 store.delete_registered_model(name=name) 276 277 # cannot get model 278 with pytest.raises( 279 MlflowException, match=rf"Registered Model with name={name} not found" 280 ) as exception_context: 281 store.get_registered_model(name=name) 282 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 283 284 # cannot update a delete model 285 with pytest.raises( 286 MlflowException, match=rf"Registered Model with name={name} not found" 287 ) as exception_context: 288 store.update_registered_model(name=name, description="deleted") 289 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 290 291 # cannot delete it again 292 with pytest.raises( 293 MlflowException, match=rf"Registered Model with name={name} not found" 294 ) as exception_context: 295 store.delete_registered_model(name=name) 296 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 297 298 # model versions are cascade deleted with the registered model 299 with pytest.raises( 300 MlflowException, match=rf"Model Version \(name={name}, version=1\) not found" 301 ) as exception_context: 302 store.get_model_version(name, 1) 303 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 304 305 306 def test_get_latest_versions(store): 307 name = "test_for_latest_versions" 308 _rm_maker(store, name) 309 rmd1 = store.get_registered_model(name=name) 310 assert rmd1.latest_versions == [] 311 312 mv1 = _mv_maker(store, name) 313 assert mv1.version == 1 314 rmd2 = store.get_registered_model(name=name) 315 assert _extract_latest_by_stage(rmd2.latest_versions) == {"None": "1"} 316 317 # add a bunch more 318 mv2 = _mv_maker(store, name) 319 assert mv2.version == 2 320 store.transition_model_version_stage( 321 name=mv2.name, 322 version=mv2.version, 323 stage="Production", 324 archive_existing_versions=False, 325 ) 326 327 mv3 = _mv_maker(store, name) 328 assert mv3.version == 3 329 store.transition_model_version_stage( 330 name=mv3.name, 331 version=mv3.version, 332 stage="Production", 333 archive_existing_versions=False, 334 ) 335 mv4 = _mv_maker(store, name) 336 assert mv4.version == 4 337 store.transition_model_version_stage( 338 name=mv4.name, 339 version=mv4.version, 340 stage="Staging", 341 archive_existing_versions=False, 342 ) 343 344 # test that correct latest versions are returned for each stage 345 rmd4 = store.get_registered_model(name=name) 346 assert _extract_latest_by_stage(rmd4.latest_versions) == { 347 "None": "1", 348 "Production": "3", 349 "Staging": "4", 350 } 351 assert _extract_latest_by_stage(store.get_latest_versions(name=name, stages=None)) == { 352 "None": "1", 353 "Production": "3", 354 "Staging": "4", 355 } 356 assert _extract_latest_by_stage(store.get_latest_versions(name=name, stages=[])) == { 357 "None": "1", 358 "Production": "3", 359 "Staging": "4", 360 } 361 assert _extract_latest_by_stage( 362 store.get_latest_versions(name=name, stages=["Production"]) 363 ) == {"Production": "3"} 364 assert _extract_latest_by_stage( 365 store.get_latest_versions(name=name, stages=["production"]) 366 ) == {"Production": "3"} # The stages are case insensitive. 367 assert _extract_latest_by_stage( 368 store.get_latest_versions(name=name, stages=["pROduction"]) 369 ) == {"Production": "3"} # The stages are case insensitive. 370 assert _extract_latest_by_stage( 371 store.get_latest_versions(name=name, stages=["None", "Production"]) 372 ) == {"None": "1", "Production": "3"} 373 374 # delete latest Production, and should point to previous one 375 store.delete_model_version(name=mv3.name, version=mv3.version) 376 rmd5 = store.get_registered_model(name=name) 377 assert _extract_latest_by_stage(rmd5.latest_versions) == { 378 "None": "1", 379 "Production": "2", 380 "Staging": "4", 381 } 382 assert _extract_latest_by_stage(store.get_latest_versions(name=name, stages=None)) == { 383 "None": "1", 384 "Production": "2", 385 "Staging": "4", 386 } 387 assert _extract_latest_by_stage( 388 store.get_latest_versions(name=name, stages=["Production"]) 389 ) == {"Production": "2"} 390 391 392 def test_set_registered_model_tag(store): 393 name1 = "SetRegisteredModelTag_TestMod" 394 name2 = "SetRegisteredModelTag_TestMod 2" 395 initial_tags = [ 396 RegisteredModelTag("key", "value"), 397 RegisteredModelTag("anotherKey", "some other value"), 398 ] 399 _rm_maker(store, name1, initial_tags) 400 _rm_maker(store, name2, initial_tags) 401 new_tag = RegisteredModelTag("randomTag", "not a random value") 402 store.set_registered_model_tag(name1, new_tag) 403 rm1 = store.get_registered_model(name=name1) 404 all_tags = initial_tags + [new_tag] 405 assert rm1.tags == {tag.key: tag.value for tag in all_tags} 406 407 # test overriding a tag with the same key 408 overriding_tag = RegisteredModelTag("key", "overriding") 409 store.set_registered_model_tag(name1, overriding_tag) 410 all_tags = [tag for tag in all_tags if tag.key != "key"] + [overriding_tag] 411 rm1 = store.get_registered_model(name=name1) 412 assert rm1.tags == {tag.key: tag.value for tag in all_tags} 413 # does not affect other models with the same key 414 rm2 = store.get_registered_model(name=name2) 415 assert rm2.tags == {tag.key: tag.value for tag in initial_tags} 416 417 # can not set tag on deleted (non-existed) registered model 418 store.delete_registered_model(name1) 419 with pytest.raises( 420 MlflowException, match=rf"Registered Model with name={name1} not found" 421 ) as exception_context: 422 store.set_registered_model_tag(name1, overriding_tag) 423 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 424 # test cannot set tags that are too long 425 long_tag = RegisteredModelTag("longTagKey", "a" * 100_001) 426 with pytest.raises( 427 MlflowException, 428 match=r"'value' exceeds the maximum length of \d+ characters", 429 ) as exception_context: 430 store.set_registered_model_tag(name2, long_tag) 431 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 432 # test can set tags that are somewhat long 433 long_tag = RegisteredModelTag("longTagKey", "a" * 4999) 434 store.set_registered_model_tag(name2, long_tag) 435 # can not set invalid tag 436 with pytest.raises( 437 MlflowException, match=r"Missing value for required parameter 'key'" 438 ) as exception_context: 439 store.set_registered_model_tag(name2, RegisteredModelTag(key=None, value="")) 440 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 441 # can not use invalid model name 442 with pytest.raises( 443 MlflowException, match=r"Missing value for required parameter 'name'" 444 ) as exception_context: 445 store.set_registered_model_tag(None, RegisteredModelTag(key="key", value="value")) 446 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 447 448 449 def test_delete_registered_model_tag(store): 450 name1 = "DeleteRegisteredModelTag_TestMod" 451 name2 = "DeleteRegisteredModelTag_TestMod 2" 452 initial_tags = [ 453 RegisteredModelTag("key", "value"), 454 RegisteredModelTag("anotherKey", "some other value"), 455 ] 456 _rm_maker(store, name1, initial_tags) 457 _rm_maker(store, name2, initial_tags) 458 new_tag = RegisteredModelTag("randomTag", "not a random value") 459 store.set_registered_model_tag(name1, new_tag) 460 store.delete_registered_model_tag(name1, "randomTag") 461 rm1 = store.get_registered_model(name=name1) 462 assert rm1.tags == {tag.key: tag.value for tag in initial_tags} 463 464 # testing deleting a key does not affect other models with the same key 465 store.delete_registered_model_tag(name1, "key") 466 rm1 = store.get_registered_model(name=name1) 467 rm2 = store.get_registered_model(name=name2) 468 assert rm1.tags == {"anotherKey": "some other value"} 469 assert rm2.tags == {tag.key: tag.value for tag in initial_tags} 470 471 # delete tag that is already deleted does nothing 472 store.delete_registered_model_tag(name1, "key") 473 rm1 = store.get_registered_model(name=name1) 474 assert rm1.tags == {"anotherKey": "some other value"} 475 476 # can not delete tag on deleted (non-existed) registered model 477 store.delete_registered_model(name1) 478 with pytest.raises( 479 MlflowException, match=rf"Registered Model with name={name1} not found" 480 ) as exception_context: 481 store.delete_registered_model_tag(name1, "anotherKey") 482 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 483 # can not delete tag with invalid key 484 with pytest.raises( 485 MlflowException, match=r"Missing value for required parameter 'key'" 486 ) as exception_context: 487 store.delete_registered_model_tag(name2, None) 488 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 489 # can not use invalid model name 490 with pytest.raises( 491 MlflowException, match=r"Missing value for required parameter 'name'" 492 ) as exception_context: 493 store.delete_registered_model_tag(None, "key") 494 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 495 496 497 def test_create_model_version(store): 498 name = "test_for_update_MV" 499 _rm_maker(store, name) 500 run_id = uuid.uuid4().hex 501 with mock.patch("time.time", return_value=456778): 502 mv1 = _mv_maker(store, name, "a/b/CD", run_id) 503 assert mv1.name == name 504 assert mv1.version == 1 505 506 mvd1 = store.get_model_version(mv1.name, mv1.version) 507 assert mvd1.name == name 508 assert int(mvd1.version) == 1 509 assert mvd1.current_stage == "None" 510 assert mvd1.creation_timestamp == 456778000 511 assert mvd1.last_updated_timestamp == 456778000 512 assert mvd1.description is None 513 assert mvd1.source == "a/b/CD" 514 assert mvd1.run_id == run_id 515 assert mvd1.status == "READY" 516 assert mvd1.status_message is None 517 assert mvd1.tags == {} 518 519 # new model versions for same name autoincrement versions 520 mv2 = _mv_maker(store, name) 521 mvd2 = store.get_model_version(name=mv2.name, version=mv2.version) 522 assert mv2.version == 2 523 assert int(mvd2.version) == 2 524 525 # create model version with tags return model version entity with tags 526 tags = [ 527 ModelVersionTag("key", "value"), 528 ModelVersionTag("anotherKey", "some other value"), 529 ] 530 mv3 = _mv_maker(store, name, tags=tags) 531 mvd3 = store.get_model_version(name=mv3.name, version=mv3.version) 532 assert mv3.version == 3 533 assert mv3.tags == {tag.key: tag.value for tag in tags} 534 assert int(mvd3.version) == 3 535 assert mvd3.tags == {tag.key: tag.value for tag in tags} 536 537 # create model versions with runLink 538 run_link = "http://localhost:3000/path/to/run/" 539 mv4 = _mv_maker(store, name, run_link=run_link) 540 mvd4 = store.get_model_version(name, mv4.version) 541 assert mv4.version == 4 542 assert mv4.run_link == run_link 543 assert int(mvd4.version) == 4 544 assert mvd4.run_link == run_link 545 546 # create model version with description 547 description = "the best model ever" 548 mv5 = _mv_maker(store, name, description=description) 549 mvd5 = store.get_model_version(name, mv5.version) 550 assert mv5.version == 5 551 assert mv5.description == description 552 assert int(mvd5.version) == 5 553 assert mvd5.description == description 554 555 # create model version without runId 556 mv6 = _mv_maker(store, name, run_id=None) 557 mvd6 = store.get_model_version(name, mv6.version) 558 assert mv6.version == 6 559 assert mv6.run_id is None 560 assert int(mvd6.version) == 6 561 assert mvd6.run_id is None 562 563 564 def test_update_model_version(store): 565 name = "test_for_update_MV" 566 _rm_maker(store, name) 567 mv1 = _mv_maker(store, name) 568 mvd1 = store.get_model_version(name=mv1.name, version=mv1.version) 569 assert mvd1.name == name 570 assert int(mvd1.version) == 1 571 assert mvd1.current_stage == "None" 572 573 # update stage 574 store.transition_model_version_stage( 575 name=mv1.name, 576 version=mv1.version, 577 stage="Production", 578 archive_existing_versions=False, 579 ) 580 mvd2 = store.get_model_version(name=mv1.name, version=mv1.version) 581 assert mvd2.name == name 582 assert int(mvd2.version) == 1 583 assert mvd2.current_stage == "Production" 584 assert mvd2.description is None 585 586 # update description 587 store.update_model_version(name=mv1.name, version=mv1.version, description="test model version") 588 mvd3 = store.get_model_version(name=mv1.name, version=mv1.version) 589 assert mvd3.name == name 590 assert int(mvd3.version) == 1 591 assert mvd3.current_stage == "Production" 592 assert mvd3.description == "test model version" 593 594 # only valid stages can be set 595 with pytest.raises( 596 MlflowException, 597 match=( 598 "Invalid Model Version stage: unknown. " 599 "Value must be one of None, Staging, Production, Archived." 600 ), 601 ) as exception_context: 602 store.transition_model_version_stage( 603 mv1.name, mv1.version, stage="unknown", archive_existing_versions=False 604 ) 605 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 606 607 # stages are case-insensitive and auto-corrected to system stage names 608 for stage_name in ["STAGING", "staging", "StAgInG"]: 609 store.transition_model_version_stage( 610 name=mv1.name, 611 version=mv1.version, 612 stage=stage_name, 613 archive_existing_versions=False, 614 ) 615 mvd5 = store.get_model_version(name=mv1.name, version=mv1.version) 616 assert mvd5.current_stage == "Staging" 617 618 619 def test_transition_model_version_stage_when_archive_existing_versions_is_false(store): 620 name = "model" 621 _rm_maker(store, name) 622 mv1 = _mv_maker(store, name) 623 mv2 = _mv_maker(store, name) 624 mv3 = _mv_maker(store, name) 625 626 # test that when `archive_existing_versions` is False, transitioning a model version 627 # to the inactive stages ("Archived" and "None") does not throw. 628 for stage in ["Archived", "None"]: 629 store.transition_model_version_stage(name, mv1.version, stage, False) 630 631 store.transition_model_version_stage(name, mv1.version, "Staging", False) 632 store.transition_model_version_stage(name, mv2.version, "Production", False) 633 store.transition_model_version_stage(name, mv3.version, "Staging", False) 634 635 mvd1 = store.get_model_version(name=name, version=mv1.version) 636 mvd2 = store.get_model_version(name=name, version=mv2.version) 637 mvd3 = store.get_model_version(name=name, version=mv3.version) 638 639 assert mvd1.current_stage == "Staging" 640 assert mvd2.current_stage == "Production" 641 assert mvd3.current_stage == "Staging" 642 643 store.transition_model_version_stage(name, mv3.version, "Production", False) 644 645 mvd1 = store.get_model_version(name=name, version=mv1.version) 646 mvd2 = store.get_model_version(name=name, version=mv2.version) 647 mvd3 = store.get_model_version(name=name, version=mv3.version) 648 649 assert mvd1.current_stage == "Staging" 650 assert mvd2.current_stage == "Production" 651 assert mvd3.current_stage == "Production" 652 653 654 def test_transition_model_version_stage_when_archive_existing_versions_is_true(store): 655 name = "model" 656 _rm_maker(store, name) 657 mv1 = _mv_maker(store, name) 658 mv2 = _mv_maker(store, name) 659 mv3 = _mv_maker(store, name) 660 661 msg = ( 662 r"Model version transition cannot archive existing model versions " 663 r"because .+ is not an Active stage" 664 ) 665 666 # test that when `archive_existing_versions` is True, transitioning a model version 667 # to the inactive stages ("Archived" and "None") throws. 668 for stage in ["Archived", "None"]: 669 with pytest.raises(MlflowException, match=msg): 670 store.transition_model_version_stage(name, mv1.version, stage, True) 671 672 store.transition_model_version_stage(name, mv1.version, "Staging", False) 673 store.transition_model_version_stage(name, mv2.version, "Production", False) 674 store.transition_model_version_stage(name, mv3.version, "Staging", True) 675 676 mvd1 = store.get_model_version(name=name, version=mv1.version) 677 mvd2 = store.get_model_version(name=name, version=mv2.version) 678 mvd3 = store.get_model_version(name=name, version=mv3.version) 679 680 assert mvd1.current_stage == "Archived" 681 assert mvd2.current_stage == "Production" 682 assert mvd3.current_stage == "Staging" 683 assert mvd1.last_updated_timestamp == mvd3.last_updated_timestamp 684 685 store.transition_model_version_stage(name, mv3.version, "Production", True) 686 687 mvd1 = store.get_model_version(name=name, version=mv1.version) 688 mvd2 = store.get_model_version(name=name, version=mv2.version) 689 mvd3 = store.get_model_version(name=name, version=mv3.version) 690 691 assert mvd1.current_stage == "Archived" 692 assert mvd2.current_stage == "Archived" 693 assert mvd3.current_stage == "Production" 694 assert mvd2.last_updated_timestamp == mvd3.last_updated_timestamp 695 696 for uncanonical_stage_name in ["STAGING", "staging", "StAgInG"]: 697 store.transition_model_version_stage(mv1.name, mv1.version, "Staging", False) 698 store.transition_model_version_stage(mv2.name, mv2.version, "None", False) 699 700 # stage names are case-insensitive and auto-corrected to system stage names 701 store.transition_model_version_stage(mv2.name, mv2.version, uncanonical_stage_name, True) 702 703 mvd1 = store.get_model_version(name=mv1.name, version=mv1.version) 704 mvd2 = store.get_model_version(name=mv2.name, version=mv2.version) 705 assert mvd1.current_stage == "Archived" 706 assert mvd2.current_stage == "Staging" 707 708 709 def test_delete_model_version(store): 710 name = "test_for_delete_MV" 711 initial_tags = [ 712 ModelVersionTag("key", "value"), 713 ModelVersionTag("anotherKey", "some other value"), 714 ] 715 _rm_maker(store, name) 716 mv = _mv_maker(store, name, tags=initial_tags) 717 mvd = store.get_model_version(name=mv.name, version=mv.version) 718 assert mvd.name == name 719 720 store.delete_model_version(name=mv.name, version=mv.version) 721 722 # cannot get a deleted model version 723 with pytest.raises( 724 MlflowException, 725 match=rf"Model Version \(name={mv.name}, version={mv.version}\) not found", 726 ) as exception_context: 727 store.get_model_version(name=mv.name, version=mv.version) 728 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 729 730 # cannot update a delete 731 with pytest.raises( 732 MlflowException, 733 match=rf"Model Version \(name={mv.name}, version={mv.version}\) not found", 734 ) as exception_context: 735 store.update_model_version(mv.name, mv.version, description="deleted!") 736 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 737 738 # cannot delete it again 739 with pytest.raises( 740 MlflowException, 741 match=rf"Model Version \(name={mv.name}, version={mv.version}\) not found", 742 ) as exception_context: 743 store.delete_model_version(name=mv.name, version=mv.version) 744 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 745 746 747 def test_delete_model_version_redaction(store): 748 name = "test_for_delete_MV_redaction" 749 run_link = "http://localhost:5000/path/to/run" 750 run_id = "12345" 751 source = "path/to/source" 752 _rm_maker(store, name) 753 mv = _mv_maker(store, name, source=source, run_id=run_id, run_link=run_link) 754 mvd = store.get_model_version(name=name, version=mv.version) 755 assert mvd.run_link == run_link 756 assert mvd.run_id == run_id 757 assert mvd.source == source 758 # delete the MV now 759 store.delete_model_version(name, mv.version) 760 # verify that the relevant fields are redacted 761 mvd_deleted = store._get_sql_model_version_including_deleted(name=name, version=mv.version) 762 assert "REDACTED" in mvd_deleted.run_link 763 assert "REDACTED" in mvd_deleted.source 764 assert "REDACTED" in mvd_deleted.run_id 765 766 767 def test_get_model_version_download_uri(store): 768 name = "test_for_update_MV" 769 _rm_maker(store, name) 770 source_path = "path/to/source" 771 mv = _mv_maker(store, name, source=source_path, run_id=uuid.uuid4().hex) 772 mvd1 = store.get_model_version(name=mv.name, version=mv.version) 773 assert mvd1.name == name 774 assert mvd1.source == source_path 775 776 # download location points to source 777 assert store.get_model_version_download_uri(name=mv.name, version=mv.version) == source_path 778 779 # download URI does not change even if model version is updated 780 store.transition_model_version_stage( 781 name=mv.name, 782 version=mv.version, 783 stage="Production", 784 archive_existing_versions=False, 785 ) 786 store.update_model_version(name=mv.name, version=mv.version, description="Test for Path") 787 mvd2 = store.get_model_version(name=mv.name, version=mv.version) 788 assert mvd2.source == source_path 789 assert store.get_model_version_download_uri(name=mv.name, version=mv.version) == source_path 790 791 # cannot retrieve download URI for deleted model versions 792 store.delete_model_version(name=mv.name, version=mv.version) 793 with pytest.raises( 794 MlflowException, 795 match=rf"Model Version \(name={mv.name}, version={mv.version}\) not found", 796 ) as exception_context: 797 store.get_model_version_download_uri(name=mv.name, version=mv.version) 798 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 799 800 801 def test_search_model_versions(store): 802 # create some model versions 803 name = "test_for_search_MV" 804 _rm_maker(store, name) 805 run_id_1 = uuid.uuid4().hex 806 run_id_2 = uuid.uuid4().hex 807 run_id_3 = uuid.uuid4().hex 808 mv1 = _mv_maker(store, name=name, source="A/B", run_id=run_id_1) 809 assert mv1.version == 1 810 mv2 = _mv_maker(store, name=name, source="A/C", run_id=run_id_2) 811 assert mv2.version == 2 812 mv3 = _mv_maker(store, name=name, source="A/D", run_id=run_id_2) 813 assert mv3.version == 3 814 mv4 = _mv_maker(store, name=name, source="A/D", run_id=run_id_3) 815 assert mv4.version == 4 816 817 def search_versions(filter_string, max_results=10, order_by=None, page_token=None): 818 return [ 819 mvd.version 820 for mvd in store.search_model_versions(filter_string, max_results, order_by, page_token) 821 ] 822 823 # search using name should return all 4 versions 824 assert set(search_versions(f"name='{name}'")) == {1, 2, 3, 4} 825 826 # search using version 827 assert set(search_versions("version_number=2")) == {2} 828 assert set(search_versions("version_number<=3")) == {1, 2, 3} 829 830 # search using run_id_1 should return version 1 831 assert set(search_versions(f"run_id='{run_id_1}'")) == {1} 832 833 # search using run_id_2 should return versions 2 and 3 834 assert set(search_versions(f"run_id='{run_id_2}'")) == {2, 3} 835 836 # search using the IN operator should return all versions 837 assert set(search_versions(f"run_id IN ('{run_id_1}','{run_id_2}')")) == {1, 2, 3} 838 839 # search IN operator is case sensitive 840 assert set(search_versions(f"run_id IN ('{run_id_1.upper()}','{run_id_2}')")) == { 841 2, 842 3, 843 } 844 845 # search IN operator with other conditions 846 assert set( 847 search_versions(f"version_number=2 AND run_id IN ('{run_id_1.upper()}','{run_id_2}')") 848 ) == {2} 849 850 # search IN operator with right-hand side value containing whitespaces 851 assert set(search_versions(f"run_id IN ('{run_id_1}', '{run_id_2}')")) == {1, 2, 3} 852 853 # search using the IN operator with bad lists should return exceptions 854 with pytest.raises( 855 MlflowException, 856 match=( 857 r"While parsing a list in the query, " 858 r"expected string value, punctuation, or whitespace, " 859 r"but got different type in list" 860 ), 861 ) as exception_context: 862 search_versions("run_id IN (1,2,3)") 863 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 864 865 assert set(search_versions(f"run_id LIKE '{run_id_2[:30]}%'")) == {2, 3} 866 867 assert set(search_versions(f"run_id ILIKE '{run_id_2[:30].upper()}%'")) == {2, 3} 868 869 # search using the IN operator with empty lists should return exceptions 870 with pytest.raises( 871 MlflowException, 872 match=( 873 r"While parsing a list in the query, " 874 r"expected a non-empty list of string values, " 875 r"but got empty list" 876 ), 877 ) as exception_context: 878 search_versions("run_id IN ()") 879 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 880 881 # search using an ill-formed IN operator correctly throws exception 882 with pytest.raises( 883 MlflowException, match=r"Invalid clause\(s\) in filter string" 884 ) as exception_context: 885 search_versions("run_id IN (") 886 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 887 888 with pytest.raises( 889 MlflowException, match=r"Invalid clause\(s\) in filter string" 890 ) as exception_context: 891 search_versions("run_id IN") 892 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 893 894 with pytest.raises( 895 MlflowException, match=r"Invalid clause\(s\) in filter string" 896 ) as exception_context: 897 search_versions("name LIKE") 898 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 899 900 with pytest.raises( 901 MlflowException, 902 match=( 903 r"While parsing a list in the query, " 904 r"expected a non-empty list of string values, " 905 r"but got ill-formed list" 906 ), 907 ) as exception_context: 908 search_versions("run_id IN (,)") 909 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 910 911 with pytest.raises( 912 MlflowException, 913 match=( 914 r"While parsing a list in the query, " 915 r"expected a non-empty list of string values, " 916 r"but got ill-formed list" 917 ), 918 ) as exception_context: 919 search_versions("run_id IN ('runid1',,'runid2')") 920 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 921 922 # search using source_path "A/D" should return version 3 and 4 923 assert set(search_versions("source_path = 'A/D'")) == {3, 4} 924 925 # search using source_path "A" should not return anything 926 assert len(search_versions("source_path = 'A'")) == 0 927 assert len(search_versions("source_path = 'A/'")) == 0 928 assert len(search_versions("source_path = ''")) == 0 929 930 # delete mv4. search should not return version 4 931 store.delete_model_version(name=mv4.name, version=mv4.version) 932 assert set(search_versions("")) == {1, 2, 3} 933 934 assert set(search_versions(None)) == {1, 2, 3} 935 936 assert set(search_versions(f"name='{name}'")) == {1, 2, 3} 937 938 assert set(search_versions("source_path = 'A/D'")) == {3} 939 940 store.transition_model_version_stage( 941 name=mv1.name, 942 version=mv1.version, 943 stage="production", 944 archive_existing_versions=False, 945 ) 946 947 store.update_model_version( 948 name=mv1.name, version=mv1.version, description="Online prediction model!" 949 ) 950 951 mvds = store.search_model_versions(f"run_id = '{run_id_1}'", max_results=10) 952 assert len(mvds) == 1 953 assert isinstance(mvds[0], ModelVersion) 954 assert mvds[0].current_stage == "Production" 955 assert mvds[0].run_id == run_id_1 956 assert mvds[0].source == "A/B" 957 assert mvds[0].description == "Online prediction model!" 958 959 960 def test_search_model_versions_order_by_simple(store): 961 # create some model versions 962 names = ["RM1", "RM2", "RM3", "RM4", "RM1", "RM4"] 963 sources = ["A"] * 3 + ["B"] * 3 964 run_ids = [uuid.uuid4().hex for _ in range(6)] 965 for name in set(names): 966 _rm_maker(store, name) 967 for i in range(6): 968 time.sleep(0.001) # sleep to ensure each model version has a different creation_time 969 _mv_maker(store, name=names[i], source=sources[i], run_id=run_ids[i]) 970 971 # by default order by last_updated_timestamp DESC 972 mvs = store.search_model_versions(filter_string=None) 973 assert [mv.name for mv in mvs] == names[::-1] 974 assert [mv.version for mv in mvs] == [2, 2, 1, 1, 1, 1] 975 976 # order by name DESC 977 mvs = store.search_model_versions(filter_string=None, order_by=["name DESC"]) 978 assert [mv.name for mv in mvs] == sorted(names)[::-1] 979 assert [mv.version for mv in mvs] == [2, 1, 1, 1, 2, 1] 980 981 # order by version DESC 982 mvs = store.search_model_versions(filter_string=None, order_by=["version_number DESC"]) 983 assert [mv.name for mv in mvs] == ["RM1", "RM4", "RM1", "RM2", "RM3", "RM4"] 984 assert [mv.version for mv in mvs] == [2, 2, 1, 1, 1, 1] 985 986 # order by creation_timestamp DESC 987 mvs = store.search_model_versions(filter_string=None, order_by=["creation_timestamp DESC"]) 988 assert [mv.name for mv in mvs] == names[::-1] 989 assert [mv.version for mv in mvs] == [2, 2, 1, 1, 1, 1] 990 991 # order by last_updated_timestamp ASC 992 store.update_model_version(names[0], 1, "latest updated") 993 mvs = store.search_model_versions(filter_string=None, order_by=["last_updated_timestamp ASC"]) 994 assert mvs[-1].name == names[0] 995 assert mvs[-1].version == 1 996 997 998 def test_search_model_versions_order_by_errors(store): 999 # create some model versions 1000 name = "RM1" 1001 _rm_maker(store, name) 1002 for _ in range(6): 1003 _mv_maker(store, name=name) 1004 query = "name LIKE 'RM%'" 1005 # test that invalid columns throw even if they come after valid columns 1006 with pytest.raises( 1007 MlflowException, match=r"Invalid attribute key '.+' specified" 1008 ) as exception_context: 1009 store.search_model_versions( 1010 query, 1011 page_token=None, 1012 order_by=["name ASC", "run_id DESC"], 1013 max_results=5, 1014 ) 1015 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1016 # test that invalid columns with random text throw even if they come after valid columns 1017 with pytest.raises(MlflowException, match=r"Invalid order_by clause '.+'") as exception_context: 1018 store.search_model_versions( 1019 query, 1020 page_token=None, 1021 order_by=["name ASC", "last_updated_timestamp DESC blah"], 1022 max_results=5, 1023 ) 1024 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1025 1026 1027 def test_search_model_versions_pagination(store): 1028 def search_versions(filter_string, page_token=None, max_results=10): 1029 result = store.search_model_versions( 1030 filter_string=filter_string, page_token=page_token, max_results=max_results 1031 ) 1032 return result.to_list(), result.token 1033 1034 name = "test_for_search_MV_pagination" 1035 _rm_maker(store, name) 1036 mvs = [_mv_maker(store, name) for _ in range(50)][::-1] 1037 1038 # test flow with fixed max_results 1039 returned_mvs = [] 1040 query = "name LIKE 'test_for_search_MV_pagination%'" 1041 result, token = search_versions(query, page_token=None, max_results=5) 1042 returned_mvs.extend(result) 1043 while token: 1044 result, token = search_versions(query, page_token=token, max_results=5) 1045 returned_mvs.extend(result) 1046 assert mvs == returned_mvs 1047 1048 # test that pagination will return all valid results in sorted order 1049 # by name ascending 1050 result, token1 = search_versions(query, max_results=5) 1051 assert token1 is not None 1052 assert result == mvs[0:5] 1053 1054 result, token2 = search_versions(query, page_token=token1, max_results=10) 1055 assert token2 is not None 1056 assert result == mvs[5:15] 1057 1058 result, token3 = search_versions(query, page_token=token2, max_results=20) 1059 assert token3 is not None 1060 assert result == mvs[15:35] 1061 1062 result, token4 = search_versions(query, page_token=token3, max_results=100) 1063 # assert that page token is None 1064 assert token4 is None 1065 assert result == mvs[35:] 1066 1067 # test that providing a completely invalid page token throws 1068 with pytest.raises( 1069 MlflowException, match=r"Invalid page token, could not base64-decode" 1070 ) as exception_context: 1071 search_versions(query, page_token="evilhax", max_results=20) 1072 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1073 1074 # test that providing too large of a max_results throws 1075 with pytest.raises( 1076 MlflowException, match=r"Invalid value for max_results\." 1077 ) as exception_context: 1078 search_versions(query, page_token="evilhax", max_results=1e15) 1079 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1080 1081 1082 def test_search_model_versions_by_tag(store): 1083 # create some model versions 1084 name = "test_for_search_MV_by_tag" 1085 _rm_maker(store, name) 1086 run_id_1 = uuid.uuid4().hex 1087 run_id_2 = uuid.uuid4().hex 1088 1089 mv1 = _mv_maker( 1090 store, 1091 name=name, 1092 source="A/B", 1093 run_id=run_id_1, 1094 tags=[ModelVersionTag("t1", "abc"), ModelVersionTag("t2", "xyz")], 1095 ) 1096 assert mv1.version == 1 1097 mv2 = _mv_maker( 1098 store, 1099 name=name, 1100 source="A/C", 1101 run_id=run_id_2, 1102 tags=[ModelVersionTag("t1", "abc"), ModelVersionTag("t2", "x123")], 1103 ) 1104 assert mv2.version == 2 1105 1106 def search_versions(filter_string): 1107 return [mvd.version for mvd in store.search_model_versions(filter_string)] 1108 1109 assert search_versions(f"name = '{name}' and tag.t2 = 'xyz'") == [1] 1110 assert search_versions("name = 'wrong_name' and tag.t2 = 'xyz'") == [] 1111 assert search_versions("tag.`t2` = 'xyz'") == [1] 1112 assert search_versions("tag.t3 = 'xyz'") == [] 1113 assert search_versions("tag.t2 != 'xy'") == [2, 1] 1114 assert search_versions("tag.t2 LIKE 'xy%'") == [1] 1115 assert search_versions("tag.t2 LIKE 'xY%'") == [] 1116 assert search_versions("tag.t2 ILIKE 'xY%'") == [1] 1117 assert search_versions("tag.t2 LIKE 'x%'") == [2, 1] 1118 assert search_versions("tag.T2 = 'xyz'") == [] 1119 assert search_versions("tag.t1 = 'abc' and tag.t2 = 'xyz'") == [1] 1120 assert search_versions("tag.t1 = 'abc' and tag.t2 LIKE 'x%'") == [2, 1] 1121 assert search_versions("tag.t1 = 'abc' and tag.t2 LIKE 'y%'") == [] 1122 # test filter with duplicated keys 1123 assert search_versions("tag.t2 like 'x%' and tag.t2 != 'xyz'") == [2] 1124 1125 1126 def _assert_workspace_in_main_queries(captured_sql, table_name, context_label): 1127 """Assert that every main search query (ORDER BY + LIMIT) includes workspace in WHERE.""" 1128 main_queries = [s for s in captured_sql if "ORDER BY" in s and table_name in s and "LIMIT" in s] 1129 assert main_queries, f"No main search query found for {context_label}" 1130 for sql in main_queries: 1131 where_clause = sql.split("WHERE", 1)[-1] if "WHERE" in sql else "" 1132 assert "workspace" in where_clause.lower(), ( 1133 f"{context_label} missing workspace predicate in WHERE — " 1134 f"causes full table scan on (workspace, ...) PK:\n{sql}" 1135 ) 1136 1137 1138 def test_search_model_versions_includes_workspace_predicate(store): 1139 from sqlalchemy import event 1140 1141 name = "test_ws_predicate_mv" 1142 _rm_maker(store, name) 1143 _mv_maker(store, name=name, source="A/B", tags=[ModelVersionTag("t1", "abc")]) 1144 1145 for filter_string in [ 1146 f"name = '{name}'", 1147 "tag.t1 = 'abc'", 1148 f"name = '{name}' AND tag.t1 = 'abc'", 1149 ]: 1150 captured_sql: list[str] = [] 1151 1152 def _capture(conn, cursor, statement, parameters, context, executemany): 1153 captured_sql.append(statement) 1154 1155 event.listen(store.engine, "before_cursor_execute", _capture) 1156 try: 1157 store.search_model_versions(filter_string) 1158 finally: 1159 event.remove(store.engine, "before_cursor_execute", _capture) 1160 1161 _assert_workspace_in_main_queries( 1162 captured_sql, "model_versions", f"search_model_versions('{filter_string}')" 1163 ) 1164 1165 1166 def test_search_registered_models_includes_workspace_predicate(store): 1167 from sqlalchemy import event 1168 1169 name = "test_ws_predicate_rm" 1170 _rm_maker(store, name, tags=[RegisteredModelTag("t1", "abc")]) 1171 1172 for filter_string in [ 1173 f"name = '{name}'", 1174 "tag.t1 = 'abc'", 1175 f"name = '{name}' AND tag.t1 = 'abc'", 1176 ]: 1177 captured_sql: list[str] = [] 1178 1179 def _capture(conn, cursor, statement, parameters, context, executemany): 1180 captured_sql.append(statement) 1181 1182 event.listen(store.engine, "before_cursor_execute", _capture) 1183 try: 1184 store.search_registered_models(filter_string) 1185 finally: 1186 event.remove(store.engine, "before_cursor_execute", _capture) 1187 1188 _assert_workspace_in_main_queries( 1189 captured_sql, 1190 "registered_models", 1191 f"search_registered_models('{filter_string}')", 1192 ) 1193 1194 1195 def _search_registered_models(store, filter_string, max_results=10, order_by=None, page_token=None): 1196 result = store.search_registered_models( 1197 filter_string=filter_string, 1198 max_results=max_results, 1199 order_by=order_by, 1200 page_token=page_token, 1201 ) 1202 return [registered_model.name for registered_model in result], result.token 1203 1204 1205 def test_search_registered_models(store): 1206 # create some registered models 1207 prefix = "test_for_search_" 1208 names = [prefix + name for name in ["RM1", "RM2", "RM3", "RM4", "RM4A", "RM4ab"]] 1209 for name in names: 1210 _rm_maker(store, name) 1211 1212 # search with no filter should return all registered models 1213 rms, _ = _search_registered_models(store, None) 1214 assert rms == names 1215 1216 # equality search using name should return exactly the 1 name 1217 rms, _ = _search_registered_models(store, f"name='{names[0]}'") 1218 assert rms == [names[0]] 1219 1220 # equality search using name that is not valid should return nothing 1221 rms, _ = _search_registered_models(store, "name='{}'".format(names[0] + "cats")) 1222 assert rms == [] 1223 1224 # case-sensitive prefix search using LIKE should return all the RMs 1225 rms, _ = _search_registered_models(store, f"name LIKE '{prefix}%'") 1226 assert rms == names 1227 1228 # case-sensitive prefix search using LIKE with surrounding % should return all the RMs 1229 rms, _ = _search_registered_models(store, "name LIKE '%RM%'") 1230 assert rms == names 1231 1232 # case-sensitive prefix search using LIKE with surrounding % should return all the RMs 1233 # _e% matches test_for_search_ , so all RMs should match 1234 rms, _ = _search_registered_models(store, "name LIKE '_e%'") 1235 assert rms == names 1236 1237 # case-sensitive prefix search using LIKE should return just rm4 1238 rms, _ = _search_registered_models(store, "name LIKE '{}%'".format(prefix + "RM4A")) 1239 assert rms == [names[4]] 1240 1241 # case-sensitive prefix search using LIKE should return no models if no match 1242 rms, _ = _search_registered_models(store, "name LIKE '{}%'".format(prefix + "cats")) 1243 assert rms == [] 1244 1245 # confirm that LIKE is not case-sensitive 1246 rms, _ = _search_registered_models(store, "name lIkE '%blah%'") 1247 assert rms == [] 1248 1249 rms, _ = _search_registered_models(store, "name like '{}%'".format(prefix + "RM4A")) 1250 assert rms == [names[4]] 1251 1252 # case-insensitive prefix search using ILIKE should return both rm5 and rm6 1253 rms, _ = _search_registered_models(store, "name ILIKE '{}%'".format(prefix + "RM4A")) 1254 assert rms == names[4:] 1255 1256 # case-insensitive postfix search with ILIKE 1257 rms, _ = _search_registered_models(store, "name ILIKE '%RM4a%'") 1258 assert rms == names[4:] 1259 1260 # case-insensitive prefix search using ILIKE should return both rm5 and rm6 1261 rms, _ = _search_registered_models(store, "name ILIKE '{}%'".format(prefix + "cats")) 1262 assert rms == [] 1263 1264 # confirm that ILIKE is not case-sensitive 1265 rms, _ = _search_registered_models(store, "name iLike '%blah%'") 1266 assert rms == [] 1267 1268 # confirm that ILIKE works for empty query 1269 rms, _ = _search_registered_models(store, "name iLike '%%'") 1270 assert rms == names 1271 1272 rms, _ = _search_registered_models(store, "name ilike '%RM4a%'") 1273 assert rms == names[4:] 1274 1275 # cannot search by invalid comparator types 1276 with pytest.raises( 1277 MlflowException, 1278 match="Parameter value is either not quoted or unidentified quote types used for " 1279 "string value something", 1280 ) as exception_context: 1281 _search_registered_models(store, "name!=something") 1282 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1283 1284 # cannot search by run_id 1285 with pytest.raises( 1286 MlflowException, match=r"Invalid attribute key 'run_id' specified." 1287 ) as exception_context: 1288 _search_registered_models(store, "run_id='somerunID'") 1289 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1290 1291 # cannot search by source_path 1292 with pytest.raises( 1293 MlflowException, match=r"Invalid attribute key 'source_path' specified." 1294 ) as exception_context: 1295 _search_registered_models(store, "source_path = 'A/D'") 1296 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1297 1298 # cannot search by other params 1299 with pytest.raises( 1300 MlflowException, match=r"Invalid clause\(s\) in filter string" 1301 ) as exception_context: 1302 _search_registered_models(store, "evilhax = true") 1303 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1304 1305 # delete last registered model. search should not return the first 5 1306 store.delete_registered_model(name=names[-1]) 1307 assert _search_registered_models(store, None, max_results=1000) == ( 1308 names[:-1], 1309 None, 1310 ) 1311 1312 # equality search using name should return no names 1313 assert _search_registered_models(store, f"name='{names[-1]}'") == ([], None) 1314 1315 # case-sensitive prefix search using LIKE should return all the RMs 1316 assert _search_registered_models(store, f"name LIKE '{prefix}%'") == ( 1317 names[0:5], 1318 None, 1319 ) 1320 1321 # case-insensitive prefix search using ILIKE should return both rm5 and rm6 1322 assert _search_registered_models(store, "name ILIKE '{}%'".format(prefix + "RM4A")) == ( 1323 [names[4]], 1324 None, 1325 ) 1326 1327 1328 def test_search_registered_models_by_tag(store): 1329 name1 = "test_for_search_RM_by_tag1" 1330 name2 = "test_for_search_RM_by_tag2" 1331 tags1 = [ 1332 RegisteredModelTag("t1", "abc"), 1333 RegisteredModelTag("t2", "xyz"), 1334 ] 1335 tags2 = [ 1336 RegisteredModelTag("t1", "abcd"), 1337 RegisteredModelTag("t2", "xyz123"), 1338 RegisteredModelTag("t3", "XYZ"), 1339 ] 1340 _rm_maker(store, name1, tags1) 1341 _rm_maker(store, name2, tags2) 1342 1343 rms, _ = _search_registered_models(store, "tag.t3 = 'XYZ'") 1344 assert rms == [name2] 1345 1346 rms, _ = _search_registered_models(store, f"name = '{name1}' and tag.t1 = 'abc'") 1347 assert rms == [name1] 1348 1349 rms, _ = _search_registered_models(store, "tag.t1 LIKE 'ab%'") 1350 assert rms == [name1, name2] 1351 1352 rms, _ = _search_registered_models(store, "tag.t1 ILIKE 'aB%'") 1353 assert rms == [name1, name2] 1354 1355 rms, _ = _search_registered_models(store, "tag.t1 LIKE 'ab%' AND tag.t2 LIKE 'xy%'") 1356 assert rms == [name1, name2] 1357 1358 rms, _ = _search_registered_models(store, "tag.t3 = 'XYz'") 1359 assert rms == [] 1360 1361 rms, _ = _search_registered_models(store, "tag.T3 = 'XYZ'") 1362 assert rms == [] 1363 1364 rms, _ = _search_registered_models(store, "tag.t1 != 'abc'") 1365 assert rms == [name2] 1366 1367 # test filter with duplicated keys 1368 rms, _ = _search_registered_models(store, "tag.t1 != 'abcd' and tag.t1 LIKE 'ab%'") 1369 assert rms == [name1] 1370 1371 1372 def test_parse_search_registered_models_order_by(): 1373 # test that "registered_models.name ASC" is returned by default 1374 parsed = SqlAlchemyStore._parse_search_registered_models_order_by([]) 1375 assert [str(x) for x in parsed] == ["registered_models.name ASC"] 1376 1377 # test that the given 'name' replaces the default one ('registered_models.name ASC') 1378 parsed = SqlAlchemyStore._parse_search_registered_models_order_by(["name DESC"]) 1379 assert [str(x) for x in parsed] == ["registered_models.name DESC"] 1380 1381 # test that an exception is raised when order_by contains duplicate fields 1382 msg = "`order_by` contains duplicate fields:" 1383 with pytest.raises(MlflowException, match=msg): 1384 SqlAlchemyStore._parse_search_registered_models_order_by([ 1385 "last_updated_timestamp", 1386 "last_updated_timestamp", 1387 ]) 1388 1389 with pytest.raises(MlflowException, match=msg): 1390 SqlAlchemyStore._parse_search_registered_models_order_by(["timestamp", "timestamp"]) 1391 1392 with pytest.raises(MlflowException, match=msg): 1393 SqlAlchemyStore._parse_search_registered_models_order_by( 1394 ["timestamp", "last_updated_timestamp"], 1395 ) 1396 1397 with pytest.raises(MlflowException, match=msg): 1398 SqlAlchemyStore._parse_search_registered_models_order_by( 1399 ["last_updated_timestamp ASC", "last_updated_timestamp DESC"], 1400 ) 1401 1402 with pytest.raises(MlflowException, match=msg): 1403 SqlAlchemyStore._parse_search_registered_models_order_by( 1404 ["last_updated_timestamp", "last_updated_timestamp DESC"], 1405 ) 1406 1407 1408 def test_search_registered_model_pagination(store): 1409 rms = [_rm_maker(store, f"RM{i:03}").name for i in range(50)] 1410 1411 # test flow with fixed max_results 1412 returned_rms = [] 1413 query = "name LIKE 'RM%'" 1414 result, token = _search_registered_models(store, query, page_token=None, max_results=5) 1415 returned_rms.extend(result) 1416 while token: 1417 result, token = _search_registered_models(store, query, page_token=token, max_results=5) 1418 returned_rms.extend(result) 1419 assert rms == returned_rms 1420 1421 # test that pagination will return all valid results in sorted order 1422 # by name ascending 1423 result, token1 = _search_registered_models(store, query, max_results=5) 1424 assert token1 is not None 1425 assert result == rms[0:5] 1426 1427 result, token2 = _search_registered_models(store, query, page_token=token1, max_results=10) 1428 assert token2 is not None 1429 assert result == rms[5:15] 1430 1431 result, token3 = _search_registered_models(store, query, page_token=token2, max_results=20) 1432 assert token3 is not None 1433 assert result == rms[15:35] 1434 1435 result, token4 = _search_registered_models(store, query, page_token=token3, max_results=100) 1436 # assert that page token is None 1437 assert token4 is None 1438 assert result == rms[35:] 1439 1440 # test that providing a completely invalid page token throws 1441 with pytest.raises( 1442 MlflowException, match=r"Invalid page token, could not base64-decode" 1443 ) as exception_context: 1444 _search_registered_models(store, query, page_token="evilhax", max_results=20) 1445 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1446 1447 # test that providing too large of a max_results throws 1448 with pytest.raises( 1449 MlflowException, match=r"Invalid value for request parameter max_results" 1450 ) as exception_context: 1451 _search_registered_models(store, query, page_token="evilhax", max_results=1e15) 1452 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1453 1454 1455 def test_search_registered_model_order_by(store): 1456 rms = [] 1457 # explicitly mock the creation_timestamps because timestamps seem to be unstable in Windows 1458 for i in range(50): 1459 with mock.patch( 1460 "mlflow.store.model_registry.sqlalchemy_store.get_current_time_millis", 1461 return_value=i, 1462 ): 1463 rms.append(_rm_maker(store, f"RM{i:03}", _add_go_test_tags([], f"{i}")).name) 1464 1465 # test flow with fixed max_results and order_by (test stable order across pages) 1466 returned_rms = [] 1467 query = "name LIKE 'RM%'" 1468 result, token = _search_registered_models( 1469 store, query, page_token=None, order_by=["name DESC"], max_results=5 1470 ) 1471 returned_rms.extend(result) 1472 while token: 1473 result, token = _search_registered_models( 1474 store, query, page_token=token, order_by=["name DESC"], max_results=5 1475 ) 1476 returned_rms.extend(result) 1477 # name descending should be the opposite order of the current order 1478 assert rms[::-1] == returned_rms 1479 # last_updated_timestamp descending should have the newest RMs first 1480 result, _ = _search_registered_models( 1481 store, 1482 query, 1483 page_token=None, 1484 order_by=["last_updated_timestamp DESC"], 1485 max_results=100, 1486 ) 1487 assert rms[::-1] == result 1488 # timestamp returns same result as last_updated_timestamp 1489 result, _ = _search_registered_models( 1490 store, query, page_token=None, order_by=["timestamp DESC"], max_results=100 1491 ) 1492 assert rms[::-1] == result 1493 # last_updated_timestamp ascending should have the oldest RMs first 1494 result, _ = _search_registered_models( 1495 store, 1496 query, 1497 page_token=None, 1498 order_by=["last_updated_timestamp ASC"], 1499 max_results=100, 1500 ) 1501 assert rms == result 1502 # timestamp returns same result as last_updated_timestamp 1503 result, _ = _search_registered_models( 1504 store, query, page_token=None, order_by=["timestamp ASC"], max_results=100 1505 ) 1506 assert rms == result 1507 # timestamp returns same result as last_updated_timestamp 1508 result, _ = _search_registered_models( 1509 store, query, page_token=None, order_by=["timestamp"], max_results=100 1510 ) 1511 assert rms == result 1512 # name ascending should have the original order 1513 result, _ = _search_registered_models( 1514 store, query, page_token=None, order_by=["name ASC"], max_results=100 1515 ) 1516 assert rms == result 1517 # test that no ASC/DESC defaults to ASC 1518 result, _ = _search_registered_models( 1519 store, 1520 query, 1521 page_token=None, 1522 order_by=["last_updated_timestamp"], 1523 max_results=100, 1524 ) 1525 assert rms == result 1526 with mock.patch( 1527 "mlflow.store.model_registry.sqlalchemy_store.get_current_time_millis", 1528 return_value=1, 1529 ): 1530 rm1 = _rm_maker(store, "MR1", _add_go_test_tags([], "1")).name 1531 rm2 = _rm_maker(store, "MR2", _add_go_test_tags([], "1")).name 1532 with mock.patch( 1533 "mlflow.store.model_registry.sqlalchemy_store.get_current_time_millis", 1534 return_value=2, 1535 ): 1536 rm3 = _rm_maker(store, "MR3", _add_go_test_tags([], "2")).name 1537 rm4 = _rm_maker(store, "MR4", _add_go_test_tags([], "2")).name 1538 query = "name LIKE 'MR%'" 1539 # test with multiple clauses 1540 result, _ = _search_registered_models( 1541 store, 1542 query, 1543 page_token=None, 1544 order_by=["last_updated_timestamp ASC", "name DESC"], 1545 max_results=100, 1546 ) 1547 assert result == [rm2, rm1, rm4, rm3] 1548 result, _ = _search_registered_models( 1549 store, 1550 query, 1551 page_token=None, 1552 order_by=["timestamp ASC", "name DESC"], 1553 max_results=100, 1554 ) 1555 assert result == [rm2, rm1, rm4, rm3] 1556 # confirm that name ascending is the default, even if ties exist on other fields 1557 result, _ = _search_registered_models( 1558 store, query, page_token=None, order_by=[], max_results=100 1559 ) 1560 assert result == [rm1, rm2, rm3, rm4] 1561 # test default tiebreak with descending timestamps 1562 result, _ = _search_registered_models( 1563 store, 1564 query, 1565 page_token=None, 1566 order_by=["last_updated_timestamp DESC"], 1567 max_results=100, 1568 ) 1569 assert result == [rm3, rm4, rm1, rm2] 1570 # test timestamp parsing 1571 result, _ = _search_registered_models( 1572 store, query, page_token=None, order_by=["timestamp\tASC"], max_results=100 1573 ) 1574 assert result == [rm1, rm2, rm3, rm4] 1575 result, _ = _search_registered_models( 1576 store, query, page_token=None, order_by=["timestamp\r\rASC"], max_results=100 1577 ) 1578 assert result == [rm1, rm2, rm3, rm4] 1579 result, _ = _search_registered_models( 1580 store, query, page_token=None, order_by=["timestamp\nASC"], max_results=100 1581 ) 1582 assert result == [rm1, rm2, rm3, rm4] 1583 result, _ = _search_registered_models( 1584 store, query, page_token=None, order_by=["timestamp ASC"], max_results=100 1585 ) 1586 assert result == [rm1, rm2, rm3, rm4] 1587 # validate order by key is case-insensitive 1588 result, _ = _search_registered_models( 1589 store, query, page_token=None, order_by=["timestamp asc"], max_results=100 1590 ) 1591 assert result == [rm1, rm2, rm3, rm4] 1592 result, _ = _search_registered_models( 1593 store, query, page_token=None, order_by=["timestamp aSC"], max_results=100 1594 ) 1595 assert result == [rm1, rm2, rm3, rm4] 1596 result, _ = _search_registered_models( 1597 store, 1598 query, 1599 page_token=None, 1600 order_by=["timestamp desc", "name desc"], 1601 max_results=100, 1602 ) 1603 assert result == [rm4, rm3, rm2, rm1] 1604 result, _ = _search_registered_models( 1605 store, 1606 query, 1607 page_token=None, 1608 order_by=["timestamp deSc", "name deSc"], 1609 max_results=100, 1610 ) 1611 assert result == [rm4, rm3, rm2, rm1] 1612 1613 1614 def test_search_registered_model_order_by_errors(store): 1615 query = "name LIKE 'RM%'" 1616 # test that invalid columns throw even if they come after valid columns 1617 with pytest.raises( 1618 MlflowException, match=r"Invalid order by key '.+' specified" 1619 ) as exception_context: 1620 _search_registered_models( 1621 store, 1622 query, 1623 page_token=None, 1624 order_by=["name ASC", "creation_timestamp DESC"], 1625 max_results=5, 1626 ) 1627 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1628 # test that invalid columns with random text throw even if they come after valid columns 1629 with pytest.raises(MlflowException, match=r"Invalid order_by clause '.+'") as exception_context: 1630 _search_registered_models( 1631 store, 1632 query, 1633 page_token=None, 1634 order_by=["name ASC", "last_updated_timestamp DESC blah"], 1635 max_results=5, 1636 ) 1637 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1638 1639 1640 def test_set_model_version_tag(store): 1641 name1 = "SetModelVersionTag_TestMod" 1642 name2 = "SetModelVersionTag_TestMod 2" 1643 initial_tags = [ 1644 ModelVersionTag("key", "value"), 1645 ModelVersionTag("anotherKey", "some other value"), 1646 ] 1647 _rm_maker(store, name1) 1648 _rm_maker(store, name2) 1649 run_id_1 = uuid.uuid4().hex 1650 run_id_2 = uuid.uuid4().hex 1651 run_id_3 = uuid.uuid4().hex 1652 _mv_maker(store, name1, "A/B", run_id_1, initial_tags) 1653 _mv_maker(store, name1, "A/C", run_id_2, initial_tags) 1654 _mv_maker(store, name2, "A/D", run_id_3, initial_tags) 1655 new_tag = ModelVersionTag("randomTag", "not a random value") 1656 store.set_model_version_tag(name1, 1, new_tag) 1657 all_tags = initial_tags + [new_tag] 1658 rm1mv1 = store.get_model_version(name1, 1) 1659 assert rm1mv1.tags == {tag.key: tag.value for tag in all_tags} 1660 1661 # test overriding a tag with the same key 1662 overriding_tag = ModelVersionTag("key", "overriding") 1663 store.set_model_version_tag(name1, 1, overriding_tag) 1664 all_tags = [tag for tag in all_tags if tag.key != "key"] + [overriding_tag] 1665 rm1mv1 = store.get_model_version(name1, 1) 1666 assert rm1mv1.tags == {tag.key: tag.value for tag in all_tags} 1667 # does not affect other model versions with the same key 1668 rm1mv2 = store.get_model_version(name1, 2) 1669 rm2mv1 = store.get_model_version(name2, 1) 1670 assert rm1mv2.tags == {tag.key: tag.value for tag in initial_tags} 1671 assert rm2mv1.tags == {tag.key: tag.value for tag in initial_tags} 1672 1673 # can not set tag on deleted (non-existed) model version 1674 store.delete_model_version(name1, 2) 1675 with pytest.raises( 1676 MlflowException, match=rf"Model Version \(name={name1}, version=2\) not found" 1677 ) as exception_context: 1678 store.set_model_version_tag(name1, 2, overriding_tag) 1679 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 1680 # test cannot set tags that are too long 1681 long_tag = ModelVersionTag("longTagKey", "a" * 100_001) 1682 with pytest.raises( 1683 MlflowException, 1684 match=r"'value' exceeds the maximum length of \d+ characters", 1685 ) as exception_context: 1686 store.set_model_version_tag(name1, 1, long_tag) 1687 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1688 # test can set tags that are somewhat long 1689 long_tag = ModelVersionTag("longTagKey", "a" * 4999) 1690 store.set_model_version_tag(name1, 1, long_tag) 1691 # can not set invalid tag 1692 with pytest.raises( 1693 MlflowException, match=r"Missing value for required parameter 'key'" 1694 ) as exception_context: 1695 store.set_model_version_tag(name2, 1, ModelVersionTag(key=None, value="")) 1696 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1697 # can not use invalid model name or version 1698 with pytest.raises( 1699 MlflowException, match=r"Missing value for required parameter 'name'" 1700 ) as exception_context: 1701 store.set_model_version_tag(None, 1, ModelVersionTag(key="key", value="value")) 1702 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1703 with pytest.raises( 1704 MlflowException, match=r"Parameter 'version' must be an integer, got 'I am not a version'" 1705 ) as exception_context: 1706 store.set_model_version_tag( 1707 name2, "I am not a version", ModelVersionTag(key="key", value="value") 1708 ) 1709 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1710 1711 1712 def test_delete_model_version_tag(store): 1713 name1 = "DeleteModelVersionTag_TestMod" 1714 name2 = "DeleteModelVersionTag_TestMod 2" 1715 initial_tags = [ 1716 ModelVersionTag("key", "value"), 1717 ModelVersionTag("anotherKey", "some other value"), 1718 ] 1719 _rm_maker(store, name1) 1720 _rm_maker(store, name2) 1721 run_id_1 = uuid.uuid4().hex 1722 run_id_2 = uuid.uuid4().hex 1723 run_id_3 = uuid.uuid4().hex 1724 _mv_maker(store, name1, "A/B", run_id_1, initial_tags) 1725 _mv_maker(store, name1, "A/C", run_id_2, initial_tags) 1726 _mv_maker(store, name2, "A/D", run_id_3, initial_tags) 1727 new_tag = ModelVersionTag("randomTag", "not a random value") 1728 store.set_model_version_tag(name1, 1, new_tag) 1729 store.delete_model_version_tag(name1, 1, "randomTag") 1730 rm1mv1 = store.get_model_version(name1, 1) 1731 assert rm1mv1.tags == {tag.key: tag.value for tag in initial_tags} 1732 1733 # testing deleting a key does not affect other model versions with the same key 1734 store.delete_model_version_tag(name1, 1, "key") 1735 rm1mv1 = store.get_model_version(name1, 1) 1736 rm1mv2 = store.get_model_version(name1, 2) 1737 rm2mv1 = store.get_model_version(name2, 1) 1738 assert rm1mv1.tags == {"anotherKey": "some other value"} 1739 assert rm1mv2.tags == {tag.key: tag.value for tag in initial_tags} 1740 assert rm2mv1.tags == {tag.key: tag.value for tag in initial_tags} 1741 1742 # delete tag that is already deleted does nothing 1743 store.delete_model_version_tag(name1, 1, "key") 1744 rm1mv1 = store.get_model_version(name1, 1) 1745 assert rm1mv1.tags == {"anotherKey": "some other value"} 1746 1747 # can not delete tag on deleted (non-existed) model version 1748 store.delete_model_version(name2, 1) 1749 with pytest.raises( 1750 MlflowException, match=rf"Model Version \(name={name2}, version=1\) not found" 1751 ) as exception_context: 1752 store.delete_model_version_tag(name2, 1, "key") 1753 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 1754 # can not delete tag with invalid key 1755 with pytest.raises( 1756 MlflowException, match=r"Missing value for required parameter 'key'" 1757 ) as exception_context: 1758 store.delete_model_version_tag(name1, 2, None) 1759 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1760 # can not use invalid model name or version 1761 with pytest.raises( 1762 MlflowException, match=r"Missing value for required parameter 'name'." 1763 ) as exception_context: 1764 store.delete_model_version_tag(None, 2, "key") 1765 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1766 with pytest.raises( 1767 MlflowException, match=r"Parameter 'version' must be an integer, got 'I am not a version'" 1768 ) as exception_context: 1769 store.delete_model_version_tag(name1, "I am not a version", "key") 1770 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1771 1772 1773 def _setup_and_test_aliases(store, model_name): 1774 store.create_registered_model(model_name) 1775 run_id_1 = uuid.uuid4().hex 1776 run_id_2 = uuid.uuid4().hex 1777 store.create_model_version(model_name, "v1", run_id_1) 1778 store.create_model_version(model_name, "v2", run_id_2) 1779 store.set_registered_model_alias(model_name, "test_alias", "2") 1780 model = store.get_registered_model(model_name) 1781 assert model.aliases == {"test_alias": 2} 1782 mv1 = store.get_model_version(model_name, 1) 1783 mv2 = store.get_model_version(model_name, 2) 1784 assert mv1.aliases == [] 1785 assert mv2.aliases == ["test_alias"] 1786 1787 1788 def test_set_registered_model_alias(store): 1789 _setup_and_test_aliases(store, "SetRegisteredModelAlias_TestMod") 1790 1791 1792 def test_delete_registered_model_alias(store): 1793 model_name = "DeleteRegisteredModelAlias_TestMod" 1794 _setup_and_test_aliases(store, model_name) 1795 store.delete_registered_model_alias(model_name, "test_alias") 1796 model = store.get_registered_model(model_name) 1797 assert model.aliases == {} 1798 mv2 = store.get_model_version(model_name, 2) 1799 assert mv2.aliases == [] 1800 1801 1802 def test_get_model_version_by_alias(store): 1803 model_name = "GetModelVersionByAlias_TestMod" 1804 _setup_and_test_aliases(store, model_name) 1805 mv = store.get_model_version_by_alias(model_name, "test_alias") 1806 assert mv.aliases == ["test_alias"] 1807 1808 1809 def test_delete_model_version_deletes_alias(store): 1810 model_name = "DeleteModelVersionDeletesAlias_TestMod" 1811 _setup_and_test_aliases(store, model_name) 1812 store.delete_model_version(model_name, 2) 1813 model = store.get_registered_model(model_name) 1814 assert model.aliases == {} 1815 with pytest.raises( 1816 MlflowException, 1817 match=r"Registered model alias test_alias not found.", 1818 ): 1819 store.get_model_version_by_alias(model_name, "test_alias") 1820 1821 1822 def test_delete_model_deletes_alias(store): 1823 model_name = "DeleteModelDeletesAlias_TestMod" 1824 _setup_and_test_aliases(store, model_name) 1825 store.delete_registered_model(model_name) 1826 with pytest.raises( 1827 MlflowException, 1828 match=rf"Registered Model with name={model_name} not found", 1829 ): 1830 store.get_model_version_by_alias(model_name, "test_alias") 1831 1832 1833 @pytest.mark.parametrize("copy_to_same_model", [False, True]) 1834 def test_copy_model_version(store, copy_to_same_model): 1835 name1 = "test_for_copy_MV1" 1836 store.create_registered_model(name1) 1837 src_tags = [ 1838 ModelVersionTag("key", "value"), 1839 ModelVersionTag("anotherKey", "some other value"), 1840 ] 1841 src_mv = _mv_maker( 1842 store, 1843 name1, 1844 tags=src_tags, 1845 run_link="dummylink", 1846 description="test description", 1847 ) 1848 1849 # Make some changes to the src MV that won't be copied over 1850 store.transition_model_version_stage( 1851 name1, src_mv.version, "Production", archive_existing_versions=False 1852 ) 1853 1854 copy_rm_name = name1 if copy_to_same_model else "test_for_copy_MV2" 1855 copy_mv_version = 2 if copy_to_same_model else 1 1856 timestamp = time.time() 1857 dst_mv = store.copy_model_version(src_mv, copy_rm_name) 1858 assert dst_mv.name == copy_rm_name 1859 assert dst_mv.version == copy_mv_version 1860 1861 copied_mv = store.get_model_version(dst_mv.name, dst_mv.version) 1862 assert copied_mv.name == copy_rm_name 1863 assert int(copied_mv.version) == copy_mv_version 1864 assert copied_mv.current_stage == "None" 1865 assert copied_mv.creation_timestamp >= timestamp 1866 assert copied_mv.last_updated_timestamp >= timestamp 1867 assert copied_mv.description == "test description" 1868 assert copied_mv.source == f"models:/{src_mv.name}/{src_mv.version}" 1869 assert store.get_model_version_download_uri(dst_mv.name, dst_mv.version) == src_mv.source 1870 assert copied_mv.run_link == "dummylink" 1871 assert copied_mv.run_id == src_mv.run_id 1872 assert copied_mv.status == "READY" 1873 assert copied_mv.status_message is None 1874 assert copied_mv.tags == {"key": "value", "anotherKey": "some other value"} 1875 1876 # Copy a model version copy 1877 double_copy_mv = store.copy_model_version(copied_mv, "test_for_copy_MV3") 1878 assert double_copy_mv.source == f"models:/{copied_mv.name}/{copied_mv.version}" 1879 assert store.get_model_version_download_uri(dst_mv.name, dst_mv.version) == src_mv.source 1880 1881 1882 def test_search_prompts(store): 1883 store.create_registered_model("model", tags=[RegisteredModelTag(key="fruit", value="apple")]) 1884 1885 store.create_registered_model( 1886 "prompt_1", tags=[RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")] 1887 ) 1888 store.create_registered_model( 1889 "prompt_2", 1890 tags=[ 1891 RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true"), 1892 RegisteredModelTag(key="fruit", value="apple"), 1893 ], 1894 ) 1895 1896 # By default, should not return prompts 1897 rms = store.search_registered_models(max_results=10) 1898 assert len(rms) == 1 1899 assert rms[0].name == "model" 1900 1901 rms = store.search_registered_models(filter_string="tags.fruit = 'apple'", max_results=10) 1902 assert len(rms) == 1 1903 assert rms[0].name == "model" 1904 1905 rms = store.search_registered_models(filter_string="name = 'prompt_1'", max_results=10) 1906 assert len(rms) == 0 1907 1908 rms = store.search_registered_models( 1909 filter_string="tags.`mlflow.prompt.is_prompt` = 'false'", max_results=10 1910 ) 1911 assert len(rms) == 1 1912 assert rms[0].name == "model" 1913 1914 rms = store.search_registered_models( 1915 filter_string="tags.`mlflow.prompt.is_prompt` != 'true'", max_results=10 1916 ) 1917 assert len(rms) == 1 1918 assert rms[0].name == "model" 1919 1920 # Search for prompts 1921 rms = store.search_registered_models( 1922 filter_string="tags.`mlflow.prompt.is_prompt` = 'true'", max_results=10 1923 ) 1924 assert len(rms) == 2 1925 assert {rm.name for rm in rms} == {"prompt_1", "prompt_2"} 1926 1927 rms = store.search_registered_models( 1928 filter_string="name = 'prompt_1' and tags.`mlflow.prompt.is_prompt` = 'true'", 1929 max_results=10, 1930 ) 1931 assert len(rms) == 1 1932 assert rms[0].name == "prompt_1" 1933 1934 rms = store.search_registered_models( 1935 filter_string="tags.`mlflow.prompt.is_prompt` = 'true' and tags.fruit = 'apple'", 1936 max_results=10, 1937 ) 1938 assert len(rms) == 1 1939 assert rms[0].name == "prompt_2" 1940 1941 1942 def test_search_prompts_versions(store): 1943 # A Model 1944 store.create_registered_model("model") 1945 store.create_model_version( 1946 "model", "1", "dummy_source", tags=[ModelVersionTag(key="fruit", value="apple")] 1947 ) 1948 1949 # A Prompt with 1 version 1950 store.create_registered_model( 1951 "prompt_1", tags=[RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")] 1952 ) 1953 store.create_model_version( 1954 "prompt_1", "1", "dummy_source", tags=[ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true")] 1955 ) 1956 1957 # A Prompt with 2 versions 1958 store.create_registered_model( 1959 "prompt_2", 1960 tags=[RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")], 1961 ) 1962 store.create_model_version( 1963 "prompt_2", 1964 "1", 1965 "dummy_source", 1966 tags=[ 1967 ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true"), 1968 ModelVersionTag(key="fruit", value="apple"), 1969 ], 1970 ) 1971 store.create_model_version( 1972 "prompt_2", 1973 "2", 1974 "dummy_source", 1975 tags=[ 1976 ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true"), 1977 ModelVersionTag(key="fruit", value="orange"), 1978 ], 1979 ) 1980 1981 # Searching model versions should not return prompts by default either 1982 mvs = store.search_model_versions(max_results=10) 1983 assert len(mvs) == 1 1984 assert mvs[0].name == "model" 1985 1986 mvs = store.search_model_versions(filter_string="tags.fruit = 'apple'", max_results=10) 1987 assert len(mvs) == 1 1988 assert mvs[0].name == "model" 1989 1990 mvs = store.search_model_versions( 1991 filter_string="tags.`mlflow.prompt.is_prompt` = 'false'", max_results=10 1992 ) 1993 assert len(mvs) == 1 1994 assert mvs[0].name == "model" 1995 1996 mvs = store.search_model_versions( 1997 filter_string="tags.`mlflow.prompt.is_prompt` != 'true'", max_results=10 1998 ) 1999 assert len(mvs) == 1 2000 assert mvs[0].name == "model" 2001 2002 # Search for prompts via search_model_versions 2003 mvs = store.search_model_versions( 2004 filter_string="tags.`mlflow.prompt.is_prompt` = 'true'", max_results=10 2005 ) 2006 assert len(mvs) == 3 2007 2008 mvs = store.search_model_versions( 2009 filter_string="tags.`mlflow.prompt.is_prompt` = 'true' and name = 'prompt_2'", 2010 max_results=10, 2011 ) 2012 assert len(mvs) == 2 2013 2014 mvs = store.search_model_versions( 2015 filter_string="tags.`mlflow.prompt.is_prompt` = 'true' and tags.fruit = 'apple'", 2016 max_results=10, 2017 ) 2018 assert len(mvs) == 1 2019 assert mvs[0].name == "prompt_2" 2020 2021 2022 def test_search_prompt_versions(store): 2023 # Create a prompt with 3 versions 2024 store.create_registered_model( 2025 "my_prompt", tags=[RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")] 2026 ) 2027 for i in range(1, 4): 2028 store.create_model_version( 2029 "my_prompt", 2030 str(i), 2031 "dummy_source", 2032 tags=[ 2033 ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true"), 2034 ModelVersionTag(key=PROMPT_TEXT_TAG_KEY, value=f"Hello {{{{name}}}} v{i}"), 2035 ], 2036 ) 2037 2038 # Create a different prompt to verify filtering by name 2039 store.create_registered_model( 2040 "other_prompt", tags=[RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")] 2041 ) 2042 store.create_model_version( 2043 "other_prompt", 2044 "1", 2045 "dummy_source", 2046 tags=[ 2047 ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true"), 2048 ModelVersionTag(key=PROMPT_TEXT_TAG_KEY, value="Other prompt text"), 2049 ], 2050 ) 2051 2052 # Search all versions of my_prompt 2053 results = store.search_prompt_versions("my_prompt") 2054 assert len(results) == 3 2055 # Should be ordered by version descending 2056 assert [pv.version for pv in results] == [3, 2, 1] 2057 assert all(pv.name == "my_prompt" for pv in results) 2058 2059 # Pagination with max_results 2060 page1 = store.search_prompt_versions("my_prompt", max_results=2) 2061 assert len(page1) == 2 2062 assert [pv.version for pv in page1] == [3, 2] 2063 assert page1.token is not None 2064 2065 # Fetch next page 2066 page2 = store.search_prompt_versions("my_prompt", max_results=2, page_token=page1.token) 2067 assert len(page2) == 1 2068 assert page2[0].version == 1 2069 assert page2.token is None 2070 2071 # Search other_prompt returns only its versions 2072 results = store.search_prompt_versions("other_prompt") 2073 assert len(results) == 1 2074 assert results[0].name == "other_prompt" 2075 2076 # Searching a non-existent prompt raises 2077 with pytest.raises(MlflowException, match="not found"): 2078 store.search_prompt_versions("nonexistent_prompt") 2079 2080 # Searching a model (not a prompt) raises 2081 store.create_registered_model("a_model") 2082 with pytest.raises(MlflowException, match="registered as a model, not a prompt"): 2083 store.search_prompt_versions("a_model") 2084 2085 2086 def test_create_registered_model_handle_prompt_properly(store): 2087 prompt_tags = [RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")] 2088 2089 store.create_registered_model("model") 2090 2091 store.create_registered_model("prompt", tags=prompt_tags) 2092 2093 with pytest.raises(MlflowException, match=r"Registered Model \(name=model\) already exists"): 2094 store.create_registered_model("model") 2095 2096 with pytest.raises(MlflowException, match=r"Prompt \(name=prompt\) already exists"): 2097 store.create_registered_model("prompt", tags=prompt_tags) 2098 2099 with pytest.raises( 2100 MlflowException, 2101 match=r"Tried to create a prompt with name 'model', " 2102 r"but the name is already taken by a registered model.", 2103 ): 2104 store.create_registered_model("model", tags=prompt_tags) 2105 2106 with pytest.raises( 2107 MlflowException, 2108 match=r"Tried to create a registered model with name 'prompt', " 2109 r"but the name is already taken by a prompt.", 2110 ): 2111 store.create_registered_model("prompt") 2112 2113 2114 def test_create_webhook(store): 2115 events = [ 2116 WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED), 2117 WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED), 2118 ] 2119 webhook = store.create_webhook( 2120 name="test_webhook", 2121 url="https://example.com/webhook", 2122 events=events, 2123 description="Test webhook", 2124 secret="secret123", 2125 status=WebhookStatus.ACTIVE, 2126 ) 2127 2128 assert webhook.name == "test_webhook" 2129 assert webhook.url == "https://example.com/webhook" 2130 assert webhook.events == events 2131 assert webhook.description == "Test webhook" 2132 assert webhook.status == WebhookStatus.ACTIVE 2133 assert webhook.webhook_id is not None 2134 assert webhook.creation_timestamp is not None 2135 assert webhook.last_updated_timestamp is not None 2136 assert webhook.secret == "secret123" 2137 2138 2139 # Shared test data for invalid webhook names 2140 INVALID_WEBHOOK_NAMES = [ 2141 ("", r"is invalid"), 2142 (" ", r"is invalid"), 2143 ("webhook<script>", r"is invalid"), 2144 ("webhook@test", r"is invalid"), 2145 ("webhook#hash", r"is invalid"), 2146 ("webhook/slash", r"is invalid"), 2147 ("webhook\\backslash", r"is invalid"), 2148 ("-webhook", r"is invalid"), # Must start with letter or digit 2149 ("webhook-", r"is invalid"), # Must end with letter or digit 2150 ("_webhook", r"is invalid"), # Must start with letter or digit 2151 ("webhook_", r"is invalid"), # Must end with letter or digit 2152 (".webhook", r"is invalid"), # Must start with letter or digit 2153 ("webhook.", r"is invalid"), # Must end with letter or digit 2154 ("a" * 64, r"is invalid"), # Too long (max 63 chars) 2155 ] 2156 2157 2158 # Shared test data for valid webhook names 2159 VALID_WEBHOOK_NAMES = [ 2160 "a", # Single character letter 2161 "1", # Single character digit 2162 "a1", # Two characters 2163 "1a", # Start with digit, end with letter 2164 "webhook123", # Alphanumeric 2165 "web_hook", # With underscore 2166 "web-hook", # With hyphen 2167 "web.hook", # With dot 2168 "web_hook-123.test", # Mixed special chars 2169 "A" * 63, # Maximum length 2170 "1" + "a" * 61 + "1", # Maximum length with digit start/end 2171 "WebHook123", # Mixed case 2172 ] 2173 2174 2175 @pytest.mark.parametrize(("invalid_name", "expected_match"), INVALID_WEBHOOK_NAMES) 2176 def test_create_webhook_invalid_names(store, invalid_name, expected_match): 2177 with pytest.raises(MlflowException, match=expected_match): 2178 store.create_webhook( 2179 name=invalid_name, 2180 url="https://example.com", 2181 events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)], 2182 ) 2183 2184 2185 @pytest.mark.parametrize("valid_name", VALID_WEBHOOK_NAMES) 2186 def test_create_webhook_valid_names(store, valid_name): 2187 webhook = store.create_webhook( 2188 name=valid_name, 2189 url="https://example.com", 2190 events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)], 2191 ) 2192 assert webhook.name == valid_name 2193 2194 2195 @pytest.mark.parametrize( 2196 ("invalid_url", "expected_match"), 2197 [ 2198 ("", r"Webhook URL cannot be empty or just whitespace"), 2199 (" ", r"Webhook URL cannot be empty or just whitespace"), 2200 ("example.com/webhook", r"Invalid webhook URL"), 2201 ("ftp://example.com/webhook", r"Invalid webhook URL scheme"), 2202 ("http://[invalid-url", r"Invalid webhook URL"), 2203 ("invalid_url", r"Invalid webhook URL"), 2204 ], 2205 ) 2206 def test_create_webhook_invalid_urls(store, invalid_url, expected_match): 2207 with pytest.raises(MlflowException, match=expected_match): 2208 store.create_webhook( 2209 name="test", 2210 url=invalid_url, 2211 events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)], 2212 ) 2213 2214 2215 def test_create_webhook_invalid_events(store): 2216 # Test empty events 2217 with pytest.raises(MlflowException, match="Webhook events must be a non-empty list"): 2218 store.create_webhook(name="test", url="https://example.com", events=[]) 2219 2220 # Test non-list events 2221 with pytest.raises(MlflowException, match="Webhook events must be a non-empty list"): 2222 store.create_webhook(name="test", url="https://example.com", events=()) 2223 2224 # Test list with non-WebhookEvent items 2225 with pytest.raises(MlflowException, match="Webhook events must be a non-empty list"): 2226 store.create_webhook(name="test", url="https://example.com", events=[1, 2, 3]) 2227 2228 2229 def test_get_webhook(store): 2230 events = [WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)] 2231 created_webhook = store.create_webhook( 2232 name="test_webhook", url="https://example.com/webhook", events=events 2233 ) 2234 2235 retrieved_webhook = store.get_webhook(created_webhook.webhook_id) 2236 2237 assert retrieved_webhook.webhook_id == created_webhook.webhook_id 2238 assert retrieved_webhook.name == "test_webhook" 2239 assert retrieved_webhook.url == "https://example.com/webhook" 2240 assert retrieved_webhook.events == events 2241 2242 2243 def test_get_webhook_not_found(store): 2244 with pytest.raises(MlflowException, match="Webhook with ID nonexistent not found"): 2245 store.get_webhook("nonexistent") 2246 2247 2248 def test_list_webhooks(store): 2249 # Create multiple webhooks 2250 webhook1 = store.create_webhook( 2251 name="webhook1", 2252 url="https://example.com/1", 2253 events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)], 2254 ) 2255 webhook2 = store.create_webhook( 2256 name="webhook2", 2257 url="https://example.com/2", 2258 events=[WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED)], 2259 ) 2260 2261 webhooks_page = store.list_webhooks() 2262 2263 assert len(webhooks_page) == 2 2264 assert webhooks_page.token is None 2265 webhook_ids = {w.webhook_id for w in webhooks_page} 2266 assert webhook1.webhook_id in webhook_ids 2267 assert webhook2.webhook_id in webhook_ids 2268 2269 2270 def test_list_webhooks_pagination(store): 2271 # Create more webhooks than max_results 2272 for i in range(5): 2273 store.create_webhook( 2274 name=f"webhook{i}", 2275 url=f"https://example.com/{i}", 2276 events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)], 2277 ) 2278 2279 # Test pagination with max_results=2 2280 webhooks_page = store.list_webhooks(max_results=2) 2281 assert len(webhooks_page) == 2 2282 assert webhooks_page.token is not None 2283 2284 # Get next page 2285 next_webhooks_page = store.list_webhooks(max_results=2, page_token=webhooks_page.token) 2286 assert len(next_webhooks_page) == 2 2287 assert next_webhooks_page.token is not None 2288 2289 # Verify we don't get duplicates 2290 first_page_ids = {w.webhook_id for w in webhooks_page} 2291 second_page_ids = {w.webhook_id for w in next_webhooks_page} 2292 assert first_page_ids.isdisjoint(second_page_ids) 2293 2294 2295 def test_list_webhooks_invalid_max_results(store): 2296 with pytest.raises(MlflowException, match="max_results must be between 1 and 1000"): 2297 store.list_webhooks(max_results=1001) 2298 2299 2300 def test_update_webhook(store): 2301 events = [WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)] 2302 webhook = store.create_webhook( 2303 name="original_name", url="https://example.com/original", events=events 2304 ) 2305 2306 # Update webhook 2307 new_events = [ 2308 WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED), 2309 WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED), 2310 ] 2311 updated_webhook = store.update_webhook( 2312 webhook_id=webhook.webhook_id, 2313 name="updated_name", 2314 url="https://example.com/updated", 2315 events=new_events, 2316 description="Updated description", 2317 secret="new_secret", 2318 status=WebhookStatus.DISABLED, 2319 ) 2320 2321 assert updated_webhook.webhook_id == webhook.webhook_id 2322 assert updated_webhook.name == "updated_name" 2323 assert updated_webhook.url == "https://example.com/updated" 2324 assert updated_webhook.events == new_events 2325 assert updated_webhook.description == "Updated description" 2326 assert updated_webhook.status == WebhookStatus.DISABLED 2327 assert updated_webhook.last_updated_timestamp > webhook.last_updated_timestamp 2328 2329 2330 def test_update_webhook_partial(store): 2331 events = [WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)] 2332 webhook = store.create_webhook( 2333 name="original_name", url="https://example.com/original", events=events 2334 ) 2335 2336 # Update only name 2337 updated_webhook = store.update_webhook(webhook_id=webhook.webhook_id, name="new_name") 2338 2339 assert updated_webhook.name == "new_name" 2340 assert updated_webhook.url == "https://example.com/original" # Should remain unchanged 2341 assert updated_webhook.events == events # Should remain unchanged 2342 2343 2344 def test_update_webhook_not_found(store): 2345 with pytest.raises(MlflowException, match="Webhook with ID nonexistent not found"): 2346 store.update_webhook( 2347 webhook_id="nonexistent", name="new_name", url="https://example.com/new" 2348 ) 2349 2350 2351 def test_update_webhook_invalid_events(store): 2352 # Create a valid webhook first 2353 webhook = store.create_webhook( 2354 name="test_webhook", 2355 url="https://example.com/webhook", 2356 events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)], 2357 ) 2358 2359 with pytest.raises(MlflowException, match="Webhook events must be a non-empty list"): 2360 store.update_webhook(webhook_id=webhook.webhook_id, events=[]) 2361 2362 # Test non-list events 2363 with pytest.raises(MlflowException, match="Webhook events must be a non-empty list"): 2364 store.update_webhook(webhook_id=webhook.webhook_id, events=()) 2365 2366 # Test list with non-WebhookEvent items 2367 with pytest.raises(MlflowException, match="Webhook events must be a non-empty list"): 2368 store.update_webhook(webhook_id=webhook.webhook_id, events=[1, 2, 3]) 2369 2370 2371 @pytest.mark.parametrize(("invalid_name", "expected_match"), INVALID_WEBHOOK_NAMES) 2372 def test_update_webhook_invalid_names(store, invalid_name, expected_match): 2373 # Create a valid webhook first 2374 webhook = store.create_webhook( 2375 name="test_webhook", 2376 url="https://example.com/webhook", 2377 events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)], 2378 ) 2379 2380 with pytest.raises(MlflowException, match=expected_match): 2381 store.update_webhook(webhook_id=webhook.webhook_id, name=invalid_name) 2382 2383 2384 @pytest.mark.parametrize( 2385 ("invalid_url", "expected_match"), 2386 [ 2387 (" ", r"Webhook URL cannot be empty or just whitespace"), 2388 ("ftp://example.com", r"Invalid webhook URL scheme"), 2389 ("http://[invalid", r"Invalid webhook URL"), 2390 ], 2391 ) 2392 def test_update_webhook_invalid_urls(store, invalid_url, expected_match): 2393 # Create a valid webhook first 2394 webhook = store.create_webhook( 2395 name="test_webhook", 2396 url="https://example.com/webhook", 2397 events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)], 2398 ) 2399 2400 with pytest.raises(MlflowException, match=expected_match): 2401 store.update_webhook(webhook_id=webhook.webhook_id, url=invalid_url) 2402 2403 2404 def test_delete_webhook(store): 2405 events = [WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)] 2406 webhook = store.create_webhook( 2407 name="test_webhook", 2408 url="https://example.com/webhook", 2409 events=events, 2410 ) 2411 2412 store.delete_webhook(webhook.webhook_id) 2413 2414 with pytest.raises(MlflowException, match=r"Webhook with ID .* not found"): 2415 store.get_webhook(webhook.webhook_id) 2416 2417 webhooks_page = store.list_webhooks() 2418 webhook_ids = {w.webhook_id for w in webhooks_page} 2419 assert webhook.webhook_id not in webhook_ids 2420 2421 2422 def test_delete_webhook_not_found(store): 2423 with pytest.raises(MlflowException, match="Webhook with ID nonexistent not found"): 2424 store.delete_webhook("nonexistent") 2425 2426 2427 def test_webhook_status_transitions(store): 2428 events = [WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)] 2429 2430 webhook = store.create_webhook( 2431 name="test_webhook", 2432 url="https://example.com/webhook", 2433 events=events, 2434 status=WebhookStatus.ACTIVE, 2435 ) 2436 assert webhook.status == WebhookStatus.ACTIVE 2437 2438 # Update to inactive 2439 updated_webhook = store.update_webhook( 2440 webhook_id=webhook.webhook_id, status=WebhookStatus.DISABLED 2441 ) 2442 assert updated_webhook.status == WebhookStatus.DISABLED 2443 2444 # Update back to active 2445 updated_webhook = store.update_webhook( 2446 webhook_id=webhook.webhook_id, status=WebhookStatus.ACTIVE 2447 ) 2448 assert updated_webhook.status == WebhookStatus.ACTIVE 2449 2450 2451 def test_webhook_secret_encryption(store): 2452 store.create_webhook( 2453 name="test_webhook", 2454 url="https://example.com/webhook", 2455 events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)], 2456 secret="my_secret", 2457 ) 2458 engine = create_engine(store.db_uri) 2459 with engine.connect() as conn: 2460 (raw_secret,) = conn.execute(text("SELECT secret FROM webhooks")).fetchone() 2461 assert raw_secret is not None 2462 assert raw_secret != "my_secret" # Should be encrypted 2463 2464 2465 def test_create_model_version_with_model_id_and_no_run_id(store): 2466 name = "test_model_with_model_id" 2467 _rm_maker(store, name) 2468 2469 mock_run_id = "mock-run-id-123" 2470 mock_logged_model = mock.MagicMock() 2471 mock_logged_model.source_run_id = mock_run_id 2472 2473 with mock.patch( 2474 "mlflow.store.model_registry.sqlalchemy_store.MlflowClient" 2475 ) as mock_client_class: 2476 mock_client = mock.MagicMock() 2477 mock_client_class.return_value = mock_client 2478 mock_client.get_logged_model.return_value = mock_logged_model 2479 2480 mv = store.create_model_version( 2481 name=name, 2482 source="path/to/source", 2483 run_id=None, 2484 model_id="test-model-id-456", 2485 ) 2486 2487 mock_client.get_logged_model.assert_called_once_with("test-model-id-456") 2488 2489 assert mv.run_id == mock_run_id 2490 2491 mvd = store.get_model_version(name=mv.name, version=mv.version) 2492 assert mvd.run_id == mock_run_id 2493 2494 2495 def test_create_model_version_concurrent(store): 2496 name = "test_concurrent_mv" 2497 _rm_maker(store, name) 2498 2499 num_threads = 4 2500 versions_per_thread = 5 2501 results = [] 2502 2503 def create_versions(): 2504 return [ 2505 store.create_model_version(name, "path/to/source", uuid.uuid4().hex).version 2506 for _ in range(versions_per_thread) 2507 ] 2508 2509 with concurrent.futures.ThreadPoolExecutor( 2510 max_workers=num_threads, thread_name_prefix="create_model_version" 2511 ) as executor: 2512 futures = [executor.submit(create_versions) for _ in range(num_threads)] 2513 for f in concurrent.futures.as_completed(futures): 2514 results.extend(f.result()) 2515 2516 # All versions should be unique 2517 assert len(results) == len(set(results)) 2518 assert len(results) == num_threads * versions_per_thread