test_file_store.py
1 import pytest 2 3 pytestmark = pytest.mark.skip(reason="FileStore is no longer supported") 4 5 import time 6 import uuid 7 from typing import NamedTuple 8 from unittest import mock 9 10 import pytest 11 12 import mlflow 13 from mlflow.entities.model_registry import ( 14 ModelVersion, 15 ModelVersionTag, 16 RegisteredModelTag, 17 ) 18 from mlflow.entities.model_registry.prompt_version import IS_PROMPT_TAG_KEY 19 from mlflow.exceptions import MlflowException 20 from mlflow.protos.databricks_pb2 import ( 21 INVALID_PARAMETER_VALUE, 22 RESOURCE_DOES_NOT_EXIST, 23 ErrorCode, 24 ) 25 from mlflow.pyfunc import PythonModel 26 from mlflow.store.model_registry.file_store import FileStore 27 from mlflow.utils.file_utils import path_to_local_file_uri 28 from mlflow.utils.time import get_current_time_millis 29 from mlflow.utils.yaml_utils import write_yaml 30 31 from tests.helper_functions import random_int, random_str 32 33 34 @pytest.fixture 35 def store(tmp_path): 36 return FileStore(str(tmp_path)) 37 38 39 @pytest.fixture 40 def registered_model_names(): 41 return [random_str() for _ in range(3)] 42 43 44 @pytest.fixture 45 def rm_data(registered_model_names, tmp_path): 46 rm_data = {} 47 for name in registered_model_names: 48 # create registered model 49 rm_folder = tmp_path.joinpath(FileStore.MODELS_FOLDER_NAME, name) 50 rm_folder.mkdir(parents=True, exist_ok=True) 51 creation_time = get_current_time_millis() 52 d = { 53 "name": name, 54 "creation_timestamp": creation_time, 55 "last_updated_timestamp": creation_time, 56 "description": None, 57 "latest_versions": [], 58 "tags": {}, 59 } 60 rm_data[name] = d 61 write_yaml(rm_folder, FileStore.META_DATA_FILE_NAME, d) 62 tags_dir = rm_folder.joinpath(FileStore.TAGS_FOLDER_NAME) 63 tags_dir.mkdir(parents=True, exist_ok=True) 64 return rm_data 65 66 67 def test_file_store_deprecation_warning(tmp_path): 68 with pytest.warns(FutureWarning, match="filesystem model registry backend.*is deprecated"): 69 FileStore(str(tmp_path / "model_registry")) 70 71 72 def test_create_registered_model(store): 73 # Error cases 74 with pytest.raises(MlflowException, match=r"Missing value for required parameter 'name'\."): 75 store.create_registered_model(None) 76 with pytest.raises(MlflowException, match=r"Missing value for required parameter 'name'\."): 77 store.create_registered_model("") 78 79 name = random_str() 80 model = store.create_registered_model(name) 81 assert model.name == name 82 assert model.latest_versions == [] 83 assert model.creation_timestamp == model.last_updated_timestamp 84 assert model.tags == {} 85 86 87 def test_create_registered_model_with_name_that_looks_like_path(store, tmp_path): 88 name = str(tmp_path.joinpath("test")) 89 with pytest.raises(MlflowException, match=r"Names cannot contain '/' or ':'"): 90 store.get_registered_model(name) 91 92 93 def test_create_registered_model_with_percent_in_name(store, tmp_path): 94 with pytest.raises( 95 MlflowException, match=r"Registered model name cannot contain '%' character" 96 ): 97 store.get_registered_model("m%6fdel") 98 99 100 def _verify_registered_model(fs, name, rm_data): 101 rm = fs.get_registered_model(name) 102 assert rm.name == name 103 assert rm.creation_timestamp == rm_data[name]["creation_timestamp"] 104 assert rm.last_updated_timestamp == rm_data[name]["last_updated_timestamp"] 105 assert rm.description == rm_data[name]["description"] 106 assert rm.latest_versions == rm_data[name]["latest_versions"] 107 assert rm.tags == rm_data[name]["tags"] 108 109 110 def test_get_registered_model(store, registered_model_names, rm_data): 111 for name in registered_model_names: 112 _verify_registered_model(store, name, rm_data) 113 114 # test that fake registered models dont exist. 115 name = random_str() 116 with pytest.raises(MlflowException, match=f"Registered Model with name={name} not found"): 117 store.get_registered_model(name) 118 119 name = "../../path" 120 with pytest.raises(MlflowException, match="Names cannot contain '/' or ':'"): 121 store.get_registered_model(name) 122 123 124 def test_list_registered_model(store, registered_model_names, rm_data): 125 for rm in store.list_registered_models(max_results=10, page_token=None): 126 name = rm.name 127 assert name in registered_model_names 128 assert name == rm_data[name]["name"] 129 130 131 @pytest.mark.usefixtures(rm_data.__name__) 132 def test_rename_registered_model(store, registered_model_names): 133 # Error cases 134 model_name = registered_model_names[0] 135 with pytest.raises(MlflowException, match=r"Missing value for required parameter 'name'\."): 136 store.rename_registered_model(model_name, None) 137 138 # test that names of existing registered models are checked before renaming 139 other_model_name = registered_model_names[1] 140 with pytest.raises( 141 MlflowException, 142 match=rf"Registered Model \(name={other_model_name}\) already exists\.", 143 ): 144 store.rename_registered_model(model_name, other_model_name) 145 146 new_name = model_name + "!!!" 147 store.rename_registered_model(model_name, new_name) 148 assert store.get_registered_model(new_name).name == new_name 149 150 151 def _extract_names(registered_models): 152 return [rm.name for rm in registered_models] 153 154 155 @pytest.mark.usefixtures(rm_data.__name__) 156 def test_delete_registered_model(store, registered_model_names): 157 model_name = registered_model_names[random_int(0, len(registered_model_names) - 1)] 158 159 # Error cases 160 with pytest.raises( 161 MlflowException, match=f"Registered Model with name={model_name}!!! not found" 162 ): 163 store.delete_registered_model(model_name + "!!!") 164 165 store.delete_registered_model(model_name) 166 assert model_name not in _extract_names( 167 store.list_registered_models(max_results=10, page_token=None) 168 ) 169 # Cannot delete a deleted model 170 with pytest.raises(MlflowException, match=f"Registered Model with name={model_name} not found"): 171 store.delete_registered_model(model_name) 172 173 174 def test_list_registered_model_paginated(store): 175 for _ in range(10): 176 store.create_registered_model(random_str()) 177 rms1 = store.list_registered_models(max_results=4, page_token=None) 178 assert len(rms1) == 4 179 assert rms1.token is not None 180 rms2 = store.list_registered_models(max_results=4, page_token=None) 181 assert len(rms2) == 4 182 assert rms2.token is not None 183 assert rms1 == rms2 184 rms3 = store.list_registered_models(max_results=500, page_token=rms2.token) 185 assert len(rms3) == 6 186 assert rms3.token is None 187 188 189 def test_list_registered_model_paginated_returns_in_correct_order(store): 190 rms = [store.create_registered_model(f"RM{i:03}").name for i in range(50)] 191 192 # test that pagination will return all valid results in sorted order 193 # by name ascending 194 result = store.list_registered_models(max_results=5, page_token=None) 195 assert result.token is not None 196 assert _extract_names(result) == rms[0:5] 197 198 result = store.list_registered_models(page_token=result.token, max_results=10) 199 assert result.token is not None 200 assert _extract_names(result) == rms[5:15] 201 202 result = store.list_registered_models(page_token=result.token, max_results=20) 203 assert result.token is not None 204 assert _extract_names(result) == rms[15:35] 205 206 result = store.list_registered_models(page_token=result.token, max_results=100) 207 assert result.token is None 208 assert _extract_names(result) == rms[35:] 209 210 211 def test_list_registered_model_paginated_errors(store): 212 rms = [store.create_registered_model(f"RM{i:03}").name for i in range(50)] 213 # test that providing a completely invalid page token throws 214 with pytest.raises( 215 MlflowException, match=r"Invalid page token, could not base64-decode" 216 ) as exception_context: 217 store.list_registered_models(page_token="evilhax", max_results=20) 218 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 219 220 # test that providing too large of a max_results throws 221 with pytest.raises( 222 MlflowException, match=r"Invalid value for max_results" 223 ) as exception_context: 224 store.list_registered_models(page_token="evilhax", max_results=1e15) 225 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 226 # list should not return deleted models 227 store.delete_registered_model(name="RM000") 228 assert set( 229 _extract_names(store.list_registered_models(max_results=100, page_token=None)) 230 ) == set(rms[1:]) 231 232 233 def _create_model_version( 234 fs, 235 name, 236 source="path/to/source", 237 run_id=uuid.uuid4().hex, 238 tags=None, 239 run_link=None, 240 description=None, 241 ): 242 time.sleep(0.001) 243 return fs.create_model_version( 244 name, source, run_id, tags, run_link=run_link, description=description 245 ) 246 247 248 def _stage_to_version_map(latest_versions): 249 return {mvd.current_stage: mvd.version for mvd in latest_versions} 250 251 252 def test_get_latest_versions(store): 253 name = "test_for_latest_versions" 254 rmd1 = store.create_registered_model(name) 255 assert rmd1.latest_versions == [] 256 257 mv1 = _create_model_version(store, name) 258 assert mv1.version == 1 259 rmd2 = store.get_registered_model(name) 260 assert _stage_to_version_map(rmd2.latest_versions) == {"None": 1} 261 262 # add a bunch more 263 mv2 = _create_model_version(store, name) 264 assert mv2.version == 2 265 store.transition_model_version_stage( 266 name=mv2.name, 267 version=mv2.version, 268 stage="Production", 269 archive_existing_versions=False, 270 ) 271 272 mv3 = _create_model_version(store, name) 273 assert mv3.version == 3 274 store.transition_model_version_stage( 275 name=mv3.name, 276 version=mv3.version, 277 stage="Production", 278 archive_existing_versions=False, 279 ) 280 mv4 = _create_model_version(store, name) 281 assert mv4.version == 4 282 store.transition_model_version_stage( 283 name=mv4.name, 284 version=mv4.version, 285 stage="Staging", 286 archive_existing_versions=False, 287 ) 288 289 # test that correct latest versions are returned for each stage 290 rmd4 = store.get_registered_model(name) 291 assert _stage_to_version_map(rmd4.latest_versions) == { 292 "None": 1, 293 "Production": 3, 294 "Staging": 4, 295 } 296 assert _stage_to_version_map(store.get_latest_versions(name=name, stages=None)) == { 297 "None": 1, 298 "Production": 3, 299 "Staging": 4, 300 } 301 assert _stage_to_version_map(store.get_latest_versions(name=name, stages=[])) == { 302 "None": 1, 303 "Production": 3, 304 "Staging": 4, 305 } 306 assert _stage_to_version_map(store.get_latest_versions(name=name, stages=["Production"])) == { 307 "Production": 3 308 } 309 assert _stage_to_version_map(store.get_latest_versions(name=name, stages=["production"])) == { 310 "Production": 3 311 } # The stages are case insensitive. 312 assert _stage_to_version_map(store.get_latest_versions(name=name, stages=["pROduction"])) == { 313 "Production": 3 314 } # The stages are case insensitive. 315 assert _stage_to_version_map( 316 store.get_latest_versions(name=name, stages=["None", "Production"]) 317 ) == {"None": 1, "Production": 3} 318 319 # delete latest Production, and should point to previous one 320 store.delete_model_version(name=mv3.name, version=mv3.version) 321 rmd5 = store.get_registered_model(name=name) 322 assert _stage_to_version_map(rmd5.latest_versions) == { 323 "None": 1, 324 "Production": 2, 325 "Staging": 4, 326 } 327 assert _stage_to_version_map(store.get_latest_versions(name=name, stages=None)) == { 328 "None": 1, 329 "Production": 2, 330 "Staging": 4, 331 } 332 assert _stage_to_version_map(store.get_latest_versions(name=name, stages=["Production"])) == { 333 "Production": 2 334 } 335 336 337 def test_set_registered_model_tag(store): 338 name1 = "SetRegisteredModelTag_TestMod" 339 name2 = "SetRegisteredModelTag_TestMod 2" 340 initial_tags = [ 341 RegisteredModelTag("key", "value"), 342 RegisteredModelTag("anotherKey", "some other value"), 343 ] 344 store.create_registered_model(name1, initial_tags) 345 store.create_registered_model(name2, initial_tags) 346 new_tag = RegisteredModelTag("randomTag", "not a random value") 347 store.set_registered_model_tag(name1, new_tag) 348 rm1 = store.get_registered_model(name=name1) 349 all_tags = [*initial_tags, new_tag] 350 assert rm1.tags == {tag.key: tag.value for tag in all_tags} 351 352 # test overriding a tag with the same key 353 overriding_tag = RegisteredModelTag("key", "overriding") 354 store.set_registered_model_tag(name1, overriding_tag) 355 all_tags = [tag for tag in all_tags if tag.key != "key"] + [overriding_tag] 356 rm1 = store.get_registered_model(name=name1) 357 assert rm1.tags == {tag.key: tag.value for tag in all_tags} 358 # does not affect other models with the same key 359 rm2 = store.get_registered_model(name=name2) 360 assert rm2.tags == {tag.key: tag.value for tag in initial_tags} 361 362 # can not set tag on deleted (non-existed) registered model 363 store.delete_registered_model(name1) 364 with pytest.raises( 365 MlflowException, match=f"Registered Model with name={name1} not found" 366 ) as exception_context: 367 store.set_registered_model_tag(name1, overriding_tag) 368 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 369 # test cannot set tags that are too long 370 long_tag = RegisteredModelTag("longTagKey", "a" * 100_001) 371 with pytest.raises( 372 MlflowException, 373 match=r"'value' exceeds the maximum length of \d+ characters", 374 ) as exception_context: 375 store.set_registered_model_tag(name2, long_tag) 376 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 377 # test can set tags that are somewhat long 378 long_tag = RegisteredModelTag("longTagKey", "a" * 4999) 379 store.set_registered_model_tag(name2, long_tag) 380 # can not set invalid tag 381 with pytest.raises( 382 MlflowException, match=r"Missing value for required parameter 'key'" 383 ) as exception_context: 384 store.set_registered_model_tag(name2, RegisteredModelTag(key=None, value="")) 385 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 386 # can not use invalid model name 387 with pytest.raises( 388 MlflowException, match=r"Missing value for required parameter 'name'\." 389 ) as exception_context: 390 store.set_registered_model_tag(None, RegisteredModelTag(key="key", value="value")) 391 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 392 393 394 def test_delete_registered_model_tag(store): 395 name1 = "DeleteRegisteredModelTag_TestMod" 396 name2 = "DeleteRegisteredModelTag_TestMod 2" 397 initial_tags = [ 398 RegisteredModelTag("key", "value"), 399 RegisteredModelTag("anotherKey", "some other value"), 400 ] 401 store.create_registered_model(name1, initial_tags) 402 store.create_registered_model(name2, initial_tags) 403 new_tag = RegisteredModelTag("randomTag", "not a random value") 404 store.set_registered_model_tag(name1, new_tag) 405 store.delete_registered_model_tag(name1, "randomTag") 406 rm1 = store.get_registered_model(name=name1) 407 assert rm1.tags == {tag.key: tag.value for tag in initial_tags} 408 409 # testing deleting a key does not affect other models with the same key 410 store.delete_registered_model_tag(name1, "key") 411 rm1 = store.get_registered_model(name=name1) 412 rm2 = store.get_registered_model(name=name2) 413 assert rm1.tags == {"anotherKey": "some other value"} 414 assert rm2.tags == {tag.key: tag.value for tag in initial_tags} 415 416 # delete tag that is already deleted does nothing 417 store.delete_registered_model_tag(name1, "key") 418 rm1 = store.get_registered_model(name=name1) 419 assert rm1.tags == {"anotherKey": "some other value"} 420 421 # can not delete tag on deleted (non-existed) registered model 422 store.delete_registered_model(name1) 423 with pytest.raises( 424 MlflowException, match=f"Registered Model with name={name1} not found" 425 ) as exception_context: 426 store.delete_registered_model_tag(name1, "anotherKey") 427 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 428 # can not delete tag with invalid key 429 with pytest.raises( 430 MlflowException, match=r"Missing value for required parameter 'key'" 431 ) as exception_context: 432 store.delete_registered_model_tag(name2, None) 433 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 434 # can not use invalid model name 435 with pytest.raises( 436 MlflowException, match=r"Missing value for required parameter 'name'\." 437 ) as exception_context: 438 store.delete_registered_model_tag(None, "key") 439 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 440 441 442 def test_create_model_version(store): 443 name = "test_for_create_MV" 444 store.create_registered_model(name) 445 run_id = uuid.uuid4().hex 446 with mock.patch("time.time", return_value=456778): 447 mv1 = _create_model_version(store, name, "a/b/CD", run_id) 448 assert mv1.name == name 449 assert mv1.version == 1 450 451 mvd1 = store.get_model_version(mv1.name, mv1.version) 452 assert mvd1.name == name 453 assert mvd1.version == 1 454 assert mvd1.current_stage == "None" 455 assert mvd1.creation_timestamp == 456778000 456 assert mvd1.last_updated_timestamp == 456778000 457 assert mvd1.description is None 458 assert mvd1.source == "a/b/CD" 459 assert mvd1.run_id == run_id 460 assert mvd1.status == "READY" 461 assert mvd1.status_message is None 462 assert mvd1.tags == {} 463 464 # new model versions for same name autoincrement versions 465 mv2 = _create_model_version(store, name) 466 mvd2 = store.get_model_version(name=mv2.name, version=mv2.version) 467 assert mv2.version == 2 468 assert mvd2.version == 2 469 470 # create model version with tags return model version entity with tags 471 tags = [ 472 ModelVersionTag("key", "value"), 473 ModelVersionTag("anotherKey", "some other value"), 474 ] 475 mv3 = _create_model_version(store, name, tags=tags) 476 mvd3 = store.get_model_version(name=mv3.name, version=mv3.version) 477 assert mv3.version == 3 478 assert mv3.tags == {tag.key: tag.value for tag in tags} 479 assert mvd3.version == 3 480 assert mvd3.tags == {tag.key: tag.value for tag in tags} 481 482 # create model versions with runLink 483 run_link = "http://localhost:3000/path/to/run/" 484 mv4 = _create_model_version(store, name, run_link=run_link) 485 mvd4 = store.get_model_version(name, mv4.version) 486 assert mv4.version == 4 487 assert mv4.run_link == run_link 488 assert mvd4.version == 4 489 assert mvd4.run_link == run_link 490 491 # create model version with description 492 description = "the best model ever" 493 mv5 = _create_model_version(store, name, description=description) 494 mvd5 = store.get_model_version(name, mv5.version) 495 assert mv5.version == 5 496 assert mv5.description == description 497 assert mvd5.version == 5 498 assert mvd5.description == description 499 500 # create model version without runId 501 mv6 = _create_model_version(store, name, run_id=None) 502 mvd6 = store.get_model_version(name, mv6.version) 503 assert mv6.version == 6 504 assert mv6.run_id is None 505 assert mvd6.version == 6 506 assert mvd6.run_id is None 507 508 509 def test_update_model_version(store): 510 name = "test_for_update_MV" 511 store.create_registered_model(name) 512 mv1 = _create_model_version(store, name) 513 mvd1 = store.get_model_version(name=mv1.name, version=mv1.version) 514 assert mvd1.name == name 515 assert mvd1.version == 1 516 assert mvd1.current_stage == "None" 517 518 # update stage 519 store.transition_model_version_stage( 520 name=mv1.name, 521 version=mv1.version, 522 stage="Production", 523 archive_existing_versions=False, 524 ) 525 mvd2 = store.get_model_version(name=mv1.name, version=mv1.version) 526 assert mvd2.name == name 527 assert mvd2.version == 1 528 assert mvd2.current_stage == "Production" 529 assert mvd2.description is None 530 531 # update description 532 store.update_model_version(name=mv1.name, version=mv1.version, description="test model version") 533 mvd3 = store.get_model_version(name=mv1.name, version=mv1.version) 534 assert mvd3.name == name 535 assert mvd3.version == 1 536 assert mvd3.current_stage == "Production" 537 assert mvd3.description == "test model version" 538 539 # only valid stages can be set 540 with pytest.raises( 541 MlflowException, 542 match=( 543 r"Invalid Model Version stage: unknown\. " 544 r"Value must be one of None, Staging, Production, Archived\." 545 ), 546 ) as exception_context: 547 store.transition_model_version_stage( 548 mv1.name, mv1.version, stage="unknown", archive_existing_versions=False 549 ) 550 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 551 552 # stages are case-insensitive and auto-corrected to system stage names 553 for stage_name in ["STAGING", "staging", "StAgInG"]: 554 store.transition_model_version_stage( 555 name=mv1.name, 556 version=mv1.version, 557 stage=stage_name, 558 archive_existing_versions=False, 559 ) 560 mvd5 = store.get_model_version(name=mv1.name, version=mv1.version) 561 assert mvd5.current_stage == "Staging" 562 563 564 def test_transition_model_version_stage_when_archive_existing_versions_is_false(store): 565 name = "model" 566 store.create_registered_model(name) 567 mv1 = _create_model_version(store, name) 568 mv2 = _create_model_version(store, name) 569 mv3 = _create_model_version(store, name) 570 571 # test that when `archive_existing_versions` is False, transitioning a model version 572 # to the inactive stages ("Archived" and "None") does not throw. 573 for stage in ["Archived", "None"]: 574 store.transition_model_version_stage(name, mv1.version, stage, False) 575 576 store.transition_model_version_stage(name, mv1.version, "Staging", False) 577 store.transition_model_version_stage(name, mv2.version, "Production", False) 578 store.transition_model_version_stage(name, mv3.version, "Staging", False) 579 580 mvd1 = store.get_model_version(name=name, version=mv1.version) 581 mvd2 = store.get_model_version(name=name, version=mv2.version) 582 mvd3 = store.get_model_version(name=name, version=mv3.version) 583 584 assert mvd1.current_stage == "Staging" 585 assert mvd2.current_stage == "Production" 586 assert mvd3.current_stage == "Staging" 587 588 store.transition_model_version_stage(name, mv3.version, "Production", False) 589 590 mvd1 = store.get_model_version(name=name, version=mv1.version) 591 mvd2 = store.get_model_version(name=name, version=mv2.version) 592 mvd3 = store.get_model_version(name=name, version=mv3.version) 593 594 assert mvd1.current_stage == "Staging" 595 assert mvd2.current_stage == "Production" 596 assert mvd3.current_stage == "Production" 597 598 599 def test_transition_model_version_stage_when_archive_existing_versions_is_true(store): 600 name = "model" 601 store.create_registered_model(name) 602 mv1 = _create_model_version(store, name) 603 mv2 = _create_model_version(store, name) 604 mv3 = _create_model_version(store, name) 605 606 msg = ( 607 r"Model version transition cannot archive existing model versions " 608 r"because .+ is not an Active stage" 609 ) 610 611 # test that when `archive_existing_versions` is True, transitioning a model version 612 # to the inactive stages ("Archived" and "None") throws. 613 for stage in ["Archived", "None"]: 614 with pytest.raises(MlflowException, match=msg): 615 store.transition_model_version_stage(name, mv1.version, stage, True) 616 617 store.transition_model_version_stage(name, mv1.version, "Staging", False) 618 store.transition_model_version_stage(name, mv2.version, "Production", False) 619 store.transition_model_version_stage(name, mv3.version, "Staging", True) 620 621 mvd1 = store.get_model_version(name=name, version=mv1.version) 622 mvd2 = store.get_model_version(name=name, version=mv2.version) 623 mvd3 = store.get_model_version(name=name, version=mv3.version) 624 625 assert mvd1.current_stage == "Archived" 626 assert mvd2.current_stage == "Production" 627 assert mvd3.current_stage == "Staging" 628 assert mvd1.last_updated_timestamp == mvd3.last_updated_timestamp 629 630 store.transition_model_version_stage(name, mv3.version, "Production", True) 631 632 mvd1 = store.get_model_version(name=name, version=mv1.version) 633 mvd2 = store.get_model_version(name=name, version=mv2.version) 634 mvd3 = store.get_model_version(name=name, version=mv3.version) 635 636 assert mvd1.current_stage == "Archived" 637 assert mvd2.current_stage == "Archived" 638 assert mvd3.current_stage == "Production" 639 assert mvd2.last_updated_timestamp == mvd3.last_updated_timestamp 640 641 for uncanonical_stage_name in ["STAGING", "staging", "StAgInG"]: 642 store.transition_model_version_stage(mv1.name, mv1.version, "Staging", False) 643 store.transition_model_version_stage(mv2.name, mv2.version, "None", False) 644 645 # stage names are case-insensitive and auto-corrected to system stage names 646 store.transition_model_version_stage(mv2.name, mv2.version, uncanonical_stage_name, True) 647 648 mvd1 = store.get_model_version(name=mv1.name, version=mv1.version) 649 mvd2 = store.get_model_version(name=mv2.name, version=mv2.version) 650 assert mvd1.current_stage == "Archived" 651 assert mvd2.current_stage == "Staging" 652 653 654 def test_delete_model_version(store): 655 name = "test_for_delete_MV" 656 initial_tags = [ 657 ModelVersionTag("key", "value"), 658 ModelVersionTag("anotherKey", "some other value"), 659 ] 660 store.create_registered_model(name) 661 mv = _create_model_version(store, name, tags=initial_tags) 662 mvd = store.get_model_version(name=mv.name, version=mv.version) 663 assert mvd.name == name 664 665 store.delete_model_version(name=mv.name, version=mv.version) 666 667 # cannot get a deleted model version 668 with pytest.raises( 669 MlflowException, 670 match=rf"Model Version \(name={mv.name}, version={mv.version}\) not found", 671 ) as exception_context: 672 store.get_model_version(name=mv.name, version=mv.version) 673 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 674 675 # cannot update a delete 676 with pytest.raises( 677 MlflowException, 678 match=rf"Model Version \(name={mv.name}, version={mv.version}\) not found", 679 ) as exception_context: 680 store.update_model_version(mv.name, mv.version, description="deleted!") 681 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 682 683 # cannot delete it again 684 with pytest.raises( 685 MlflowException, 686 match=rf"Model Version \(name={mv.name}, version={mv.version}\) not found", 687 ) as exception_context: 688 store.delete_model_version(name=mv.name, version=mv.version) 689 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 690 691 692 def _search_model_versions(fs, filter_string=None, max_results=10, order_by=None, page_token=None): 693 return fs.search_model_versions( 694 filter_string=filter_string, 695 max_results=max_results, 696 order_by=order_by, 697 page_token=page_token, 698 ) 699 700 701 def test_search_model_versions(store): 702 # create some model versions 703 name = "test_for_search_MV" 704 store.create_registered_model(name) 705 run_id_1 = uuid.uuid4().hex 706 run_id_2 = uuid.uuid4().hex 707 run_id_3 = uuid.uuid4().hex 708 mv1 = _create_model_version(store, name=name, source="A/B", run_id=run_id_1) 709 assert mv1.version == 1 710 mv2 = _create_model_version(store, name=name, source="A/C", run_id=run_id_2) 711 assert mv2.version == 2 712 mv3 = _create_model_version(store, name=name, source="A/D", run_id=run_id_2) 713 assert mv3.version == 3 714 mv4 = _create_model_version(store, name=name, source="A/D", run_id=run_id_3) 715 assert mv4.version == 4 716 717 def search_versions(filter_string): 718 return [mvd.version for mvd in _search_model_versions(store, filter_string)] 719 720 # search using name should return all 4 versions 721 assert set(search_versions(f"name='{name}'")) == {1, 2, 3, 4} 722 723 # search using version 724 assert set(search_versions("version_number=2")) == {2} 725 assert set(search_versions("version_number<=3")) == {1, 2, 3} 726 727 # search using run_id_1 should return version 1 728 assert set(search_versions(f"run_id='{run_id_1}'")) == {1} 729 730 # search using run_id_2 should return versions 2 and 3 731 assert set(search_versions(f"run_id='{run_id_2}'")) == {2, 3} 732 733 # search using the IN operator should return all versions 734 assert set(search_versions(f"run_id IN ('{run_id_1}','{run_id_2}')")) == {1, 2, 3} 735 736 # search IN operator is case sensitive 737 assert set(search_versions(f"run_id IN ('{run_id_1.upper()}','{run_id_2}')")) == { 738 2, 739 3, 740 } 741 742 # search IN operator with right-hand side value containing whitespaces 743 assert set(search_versions(f"run_id IN ('{run_id_1}', '{run_id_2}')")) == {1, 2, 3} 744 745 # search IN operator with other conditions 746 assert set( 747 search_versions(f"version_number=2 AND run_id IN ('{run_id_1.upper()}','{run_id_2}')") 748 ) == {2} 749 750 # search using the IN operator with bad lists should return exceptions 751 with pytest.raises( 752 MlflowException, 753 match=( 754 r"While parsing a list in the query, " 755 r"expected string value, punctuation, or whitespace, " 756 r"but got different type in list" 757 ), 758 ) as exception_context: 759 search_versions("run_id IN (1,2,3)") 760 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 761 762 assert set(search_versions(f"run_id LIKE '{run_id_2[:30]}%'")) == {2, 3} 763 764 assert set(search_versions(f"run_id ILIKE '{run_id_2[:30].upper()}%'")) == {2, 3} 765 766 # search using the IN operator with empty lists should return exceptions 767 with pytest.raises( 768 MlflowException, 769 match=( 770 r"While parsing a list in the query, " 771 r"expected a non-empty list of string values, " 772 r"but got empty list" 773 ), 774 ) as exception_context: 775 search_versions("run_id IN ()") 776 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 777 778 # search using an ill-formed IN operator correctly throws exception 779 with pytest.raises( 780 MlflowException, match=r"Invalid clause\(s\) in filter string" 781 ) as exception_context: 782 search_versions("run_id IN (") 783 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 784 785 with pytest.raises( 786 MlflowException, match=r"Invalid clause\(s\) in filter string" 787 ) as exception_context: 788 search_versions("run_id IN") 789 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 790 791 with pytest.raises( 792 MlflowException, match=r"Invalid clause\(s\) in filter string" 793 ) as exception_context: 794 search_versions("name LIKE") 795 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 796 797 with pytest.raises( 798 MlflowException, 799 match=( 800 r"While parsing a list in the query, " 801 r"expected a non-empty list of string values, " 802 r"but got ill-formed list" 803 ), 804 ) as exception_context: 805 search_versions("run_id IN (,)") 806 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 807 808 with pytest.raises( 809 MlflowException, 810 match=( 811 r"While parsing a list in the query, " 812 r"expected a non-empty list of string values, " 813 r"but got ill-formed list" 814 ), 815 ) as exception_context: 816 search_versions("run_id IN ('runid1',,'runid2')") 817 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 818 819 # search using source_path "A/D" should return version 3 and 4 820 assert set(search_versions("source_path = 'A/D'")) == {3, 4} 821 822 # search using source_path "A" should not return anything 823 assert len(search_versions("source_path = 'A'")) == 0 824 assert len(search_versions("source_path = 'A/'")) == 0 825 assert len(search_versions("source_path = ''")) == 0 826 827 # delete mv4. search should not return version 4 828 store.delete_model_version(name=mv4.name, version=mv4.version) 829 assert set(search_versions("")) == {1, 2, 3} 830 831 assert set(search_versions(None)) == {1, 2, 3} 832 833 assert set(search_versions(f"name='{name}'")) == {1, 2, 3} 834 835 store.transition_model_version_stage( 836 name=mv1.name, 837 version=mv1.version, 838 stage="production", 839 archive_existing_versions=False, 840 ) 841 842 store.update_model_version( 843 name=mv1.name, version=mv1.version, description="Online prediction model!" 844 ) 845 846 mvds = store.search_model_versions(f"run_id = '{run_id_1}'", max_results=10) 847 assert len(mvds) == 1 848 assert isinstance(mvds[0], ModelVersion) 849 assert mvds[0].current_stage == "Production" 850 assert mvds[0].run_id == run_id_1 851 assert mvds[0].source == "A/B" 852 assert mvds[0].description == "Online prediction model!" 853 854 855 def test_search_model_versions_order_by_simple(store): 856 # create some model versions 857 names = ["RM1", "RM2", "RM3", "RM4", "RM1", "RM4"] 858 sources = ["A"] * 3 + ["B"] * 3 859 run_ids = [uuid.uuid4().hex for _ in range(6)] 860 for name in set(names): 861 store.create_registered_model(name) 862 for i in range(6): 863 _create_model_version(store, name=names[i], source=sources[i], run_id=run_ids[i]) 864 time.sleep(0.001) # sleep for windows fs timestamp precision issues 865 866 # by default order by last_updated_timestamp DESC 867 mvs = _search_model_versions(store).to_list() 868 assert [mv.name for mv in mvs] == names[::-1] 869 assert [mv.version for mv in mvs] == [2, 2, 1, 1, 1, 1] 870 871 # order by name DESC 872 mvs = _search_model_versions(store, order_by=["name DESC"]) 873 assert [mv.name for mv in mvs] == sorted(names)[::-1] 874 assert [mv.version for mv in mvs] == [2, 1, 1, 1, 2, 1] 875 876 # order by version DESC 877 mvs = _search_model_versions(store, order_by=["version_number DESC"]) 878 assert [mv.name for mv in mvs] == ["RM1", "RM4", "RM1", "RM2", "RM3", "RM4"] 879 assert [mv.version for mv in mvs] == [2, 2, 1, 1, 1, 1] 880 881 # order by creation_timestamp DESC 882 mvs = _search_model_versions(store, order_by=["creation_timestamp DESC"]) 883 assert [mv.name for mv in mvs] == names[::-1] 884 assert [mv.version for mv in mvs] == [2, 2, 1, 1, 1, 1] 885 886 # order by last_updated_timestamp ASC 887 store.update_model_version(names[0], 1, "latest updated") 888 mvs = _search_model_versions(store, order_by=["last_updated_timestamp ASC"]) 889 assert mvs[-1].name == names[0] 890 assert mvs[-1].version == 1 891 892 893 def test_search_model_versions_pagination(store): 894 def search_versions(filter_string, page_token=None, max_results=10): 895 result = _search_model_versions( 896 store, 897 filter_string=filter_string, 898 page_token=page_token, 899 max_results=max_results, 900 ) 901 return result.to_list(), result.token 902 903 name = "test_for_search_MV_pagination" 904 store.create_registered_model(name) 905 mvs = [_create_model_version(store, name) for _ in range(50)][::-1] 906 907 # test flow with fixed max_results 908 returned_mvs = [] 909 query = "name LIKE 'test_for_search_MV_pagination%'" 910 result, token = search_versions(query, page_token=None, max_results=5) 911 returned_mvs.extend(result) 912 while token: 913 result, token = search_versions(query, page_token=token, max_results=5) 914 returned_mvs.extend(result) 915 assert mvs == returned_mvs 916 917 # test that pagination will return all valid results in sorted order 918 # by name ascending 919 result, token1 = search_versions(query, max_results=5) 920 assert token1 is not None 921 assert result == mvs[0:5] 922 923 result, token2 = search_versions(query, page_token=token1, max_results=10) 924 assert token2 is not None 925 assert result == mvs[5:15] 926 927 result, token3 = search_versions(query, page_token=token2, max_results=20) 928 assert token3 is not None 929 assert result == mvs[15:35] 930 931 result, token4 = search_versions(query, page_token=token3, max_results=100) 932 # assert that page token is None 933 assert token4 is None 934 assert result == mvs[35:] 935 936 # test that providing a completely invalid page token throws 937 with pytest.raises( 938 MlflowException, match=r"Invalid page token, could not base64-decode" 939 ) as exception_context: 940 search_versions(query, page_token="evilhax", max_results=20) 941 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 942 943 # test that providing too large of a max_results throws 944 with pytest.raises( 945 MlflowException, match=r"Invalid value for max_results." 946 ) as exception_context: 947 search_versions(query, page_token="evilhax", max_results=1e15) 948 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 949 950 951 def test_search_model_versions_by_tag(store): 952 # create some model versions 953 name = "test_for_search_MV_by_tag" 954 store.create_registered_model(name) 955 run_id_1 = uuid.uuid4().hex 956 run_id_2 = uuid.uuid4().hex 957 958 mv1 = _create_model_version( 959 store, 960 name=name, 961 source="A/B", 962 run_id=run_id_1, 963 tags=[ModelVersionTag("t1", "abc"), ModelVersionTag("t2", "xyz")], 964 ) 965 assert mv1.version == 1 966 mv2 = _create_model_version( 967 store, 968 name=name, 969 source="A/C", 970 run_id=run_id_2, 971 tags=[ModelVersionTag("t1", "abc"), ModelVersionTag("t2", "x123")], 972 ) 973 assert mv2.version == 2 974 975 def search_versions(filter_string): 976 return [mvd.version for mvd in _search_model_versions(store, filter_string)] 977 978 assert search_versions(f"name = '{name}' and tag.t2 = 'xyz'") == [1] 979 assert search_versions("name = 'wrong_name' and tag.t2 = 'xyz'") == [] 980 assert search_versions("tag.`t2` = 'xyz'") == [1] 981 assert search_versions("tag.t3 = 'xyz'") == [] 982 assert set(search_versions("tag.t2 != 'xy'")) == {2, 1} 983 assert search_versions("tag.t2 LIKE 'xy%'") == [1] 984 assert search_versions("tag.t2 LIKE 'xY%'") == [] 985 assert search_versions("tag.t2 ILIKE 'xY%'") == [1] 986 assert set(search_versions("tag.t2 LIKE 'x%'")) == {2, 1} 987 assert search_versions("tag.T2 = 'xyz'") == [] 988 assert search_versions("tag.t1 = 'abc' and tag.t2 = 'xyz'") == [1] 989 assert set(search_versions("tag.t1 = 'abc' and tag.t2 LIKE 'x%'")) == {2, 1} 990 assert search_versions("tag.t1 = 'abc' and tag.t2 LIKE 'y%'") == [] 991 # test filter with duplicated keys 992 assert search_versions("tag.t2 like 'x%' and tag.t2 != 'xyz'") == [2] 993 994 995 class SearchRegisteredModelsResult(NamedTuple): 996 names: list[str] 997 token: str 998 999 1000 def _search_registered_models( 1001 store, filter_string=None, max_results=10, order_by=None, page_token=None 1002 ): 1003 result = store.search_registered_models( 1004 filter_string=filter_string, 1005 max_results=max_results, 1006 order_by=order_by, 1007 page_token=page_token, 1008 ) 1009 return SearchRegisteredModelsResult( 1010 names=[registered_model.name for registered_model in result], 1011 token=result.token, 1012 ) 1013 1014 1015 def test_search_registered_models(store): 1016 # create some registered models 1017 prefix = "test_for_search_" 1018 names = [prefix + name for name in ["RM1", "RM2", "RM3", "RM4", "RM4A", "RM4ab"]] 1019 for name in names: 1020 store.create_registered_model(name) 1021 1022 # search with no filter should return all registered models 1023 res = _search_registered_models(store, None) 1024 assert res.names == names 1025 1026 # equality search using name should return exactly the 1 name 1027 res = _search_registered_models(store, f"name='{names[0]}'") 1028 assert res.names == [names[0]] 1029 1030 # equality search using name that is not valid should return nothing 1031 res = _search_registered_models(store, f"name='{names[0]}cats'") 1032 assert res.names == [] 1033 1034 # case-sensitive prefix search using LIKE should return all the RMs 1035 res = _search_registered_models(store, f"name LIKE '{prefix}%'") 1036 assert res.names == names 1037 1038 # case-sensitive prefix search using LIKE with surrounding % should return all the RMs 1039 res = _search_registered_models(store, "name LIKE '%RM%'") 1040 assert res.names == names 1041 1042 # case-sensitive prefix search using LIKE with surrounding % should return all the RMs 1043 # _e% matches test_for_search_ , so all RMs should match 1044 res = _search_registered_models(store, "name LIKE '_e%'") 1045 assert res.names == names 1046 1047 # case-sensitive prefix search using LIKE should return just rm4 1048 res = _search_registered_models(store, f"name LIKE '{prefix}RM4A%'") 1049 assert res.names == [names[4]] 1050 1051 # case-sensitive prefix search using LIKE should return no models if no match 1052 res = _search_registered_models(store, f"name LIKE '{prefix}cats%'") 1053 assert res.names == [] 1054 1055 # confirm that LIKE is not case-sensitive 1056 res = _search_registered_models(store, "name lIkE '%blah%'") 1057 assert res.names == [] 1058 1059 res = _search_registered_models(store, f"name like '{prefix}RM4A%'") 1060 assert res.names == [names[4]] 1061 1062 # case-insensitive prefix search using ILIKE should return both rm5 and rm6 1063 res = _search_registered_models(store, f"name ILIKE '{prefix}RM4A%'") 1064 assert res.names == names[4:] 1065 1066 # case-insensitive postfix search with ILIKE 1067 res = _search_registered_models(store, "name ILIKE '%RM4a%'") 1068 assert res.names == names[4:] 1069 1070 # case-insensitive prefix search using ILIKE should return both rm5 and rm6 1071 res = _search_registered_models(store, f"name ILIKE '{prefix}cats%'") 1072 assert res.names == [] 1073 1074 # confirm that ILIKE is not case-sensitive 1075 res = _search_registered_models(store, "name iLike '%blah%'") 1076 assert res.names == [] 1077 1078 # confirm that ILIKE works for empty query 1079 res = _search_registered_models(store, "name iLike '%%'") 1080 assert res.names == names 1081 1082 res = _search_registered_models(store, "name ilike '%RM4a%'") 1083 assert res.names == names[4:] 1084 1085 # cannot search by invalid comparator types 1086 with pytest.raises( 1087 MlflowException, 1088 match="Parameter value is either not quoted or unidentified quote types used for " 1089 "string value something", 1090 ) as exception_context: 1091 _search_registered_models(store, "name!=something") 1092 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1093 1094 # cannot search by run_id 1095 with pytest.raises( 1096 MlflowException, 1097 match=r"Invalid attribute key 'run_id' specified. Valid keys are '{'name'}'", 1098 ) as exception_context: 1099 _search_registered_models(store, "run_id='somerunID'") 1100 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1101 1102 # cannot search by source_path 1103 with pytest.raises( 1104 MlflowException, 1105 match=r"Invalid attribute key 'source_path' specified\. Valid keys are '{'name'}'", 1106 ) as exception_context: 1107 _search_registered_models(store, "source_path = 'A/D'") 1108 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1109 1110 # cannot search by other params 1111 with pytest.raises( 1112 MlflowException, match=r"Invalid clause\(s\) in filter string" 1113 ) as exception_context: 1114 _search_registered_models(store, "evilhax = true") 1115 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1116 1117 # delete last registered model. search should not return the first 5 1118 store.delete_registered_model(name=names[-1]) 1119 res = _search_registered_models(store, None, max_results=1000) 1120 assert res.names == names[:-1] 1121 assert res.token is None 1122 1123 # equality search using name should return no names 1124 assert _search_registered_models(store, f"name='{names[-1]}'") == ([], None) 1125 1126 # case-sensitive prefix search using LIKE should return all the RMs 1127 res = _search_registered_models(store, f"name LIKE '{prefix}%'") 1128 assert res.names == names[0:5] 1129 assert res.token is None 1130 1131 # case-insensitive prefix search using ILIKE should return both rm5 and rm6 1132 res = _search_registered_models(store, f"name ILIKE '{prefix}RM4A%'") 1133 assert res.names == [names[4]] 1134 assert res.token is None 1135 1136 1137 def test_search_registered_models_by_tag(store): 1138 name1 = "test_for_search_RM_by_tag1" 1139 name2 = "test_for_search_RM_by_tag2" 1140 tags1 = [ 1141 RegisteredModelTag("t1", "abc"), 1142 RegisteredModelTag("t2", "xyz"), 1143 ] 1144 tags2 = [ 1145 RegisteredModelTag("t1", "abcd"), 1146 RegisteredModelTag("t2", "xyz123"), 1147 RegisteredModelTag("t3", "XYZ"), 1148 ] 1149 store.create_registered_model(name1, tags1) 1150 store.create_registered_model(name2, tags2) 1151 1152 res = _search_registered_models(store, "tag.t3 = 'XYZ'") 1153 assert res.names == [name2] 1154 1155 res = _search_registered_models(store, f"name = '{name1}' and tag.t1 = 'abc'") 1156 assert res.names == [name1] 1157 1158 res = _search_registered_models(store, "tag.t1 LIKE 'ab%'") 1159 assert res.names == [name1, name2] 1160 1161 res = _search_registered_models(store, "tag.t1 ILIKE 'aB%'") 1162 assert res.names == [name1, name2] 1163 1164 res = _search_registered_models(store, "tag.t1 LIKE 'ab%' AND tag.t2 LIKE 'xy%'") 1165 assert res.names == [name1, name2] 1166 1167 res = _search_registered_models(store, "tag.t3 = 'XYz'") 1168 assert res.names == [] 1169 1170 res = _search_registered_models(store, "tag.T3 = 'XYZ'") 1171 assert res.names == [] 1172 1173 res = _search_registered_models(store, "tag.t1 != 'abc'") 1174 assert res.names == [name2] 1175 1176 # test filter with duplicated keys 1177 res = _search_registered_models(store, "tag.t1 != 'abcd' and tag.t1 LIKE 'ab%'") 1178 assert res.names == [name1] 1179 1180 1181 def test_search_registered_models_order_by_simple(store): 1182 # create some registered models 1183 names = ["RM1", "RM2", "RM3", "RM4", "RM4A", "RM4ab"] 1184 for name in names: 1185 store.create_registered_model(name) 1186 time.sleep(0.001) # sleep for windows store timestamp precision issues 1187 1188 # by default order by name ASC 1189 res = _search_registered_models(store) 1190 assert res.names == names 1191 1192 # order by name DESC 1193 res = _search_registered_models(store, order_by=["name DESC"]) 1194 assert res.names == names[::-1] 1195 1196 # order by last_updated_timestamp ASC 1197 store.update_registered_model(names[0], "latest updated") 1198 res = _search_registered_models(store, order_by=["last_updated_timestamp ASC"]) 1199 assert res.names[-1] == names[0] 1200 1201 1202 def test_search_registered_model_pagination(store): 1203 rms = [store.create_registered_model(f"RM{i:03}").name for i in range(50)] 1204 1205 # test flow with fixed max_results 1206 returned_rms = [] 1207 query = "name LIKE 'RM%'" 1208 res = _search_registered_models(store, query, page_token=None, max_results=5) 1209 returned_rms.extend(res.names) 1210 while res.token: 1211 res = _search_registered_models(store, query, page_token=res.token, max_results=5) 1212 returned_rms.extend(res.names) 1213 assert returned_rms == rms 1214 1215 # test that pagination will return all valid results in sorted order 1216 # by name ascending 1217 res = _search_registered_models(store, query, max_results=5) 1218 assert res.token is not None 1219 assert res.names == rms[0:5] 1220 1221 res = _search_registered_models(store, query, page_token=res.token, max_results=10) 1222 assert res.token is not None 1223 assert res.names == rms[5:15] 1224 1225 res = _search_registered_models(store, query, page_token=res.token, max_results=20) 1226 assert res.token is not None 1227 assert res.names == rms[15:35] 1228 1229 res = _search_registered_models(store, query, page_token=res.token, max_results=100) 1230 # assert that page token is None 1231 assert res.token is None 1232 assert res.names == rms[35:] 1233 1234 # test that providing a completely invalid page token throws 1235 with pytest.raises( 1236 MlflowException, match=r"Invalid page token, could not base64-decode" 1237 ) as exception_context: 1238 _search_registered_models(store, query, page_token="evilhax", max_results=20) 1239 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1240 1241 # test that providing too large of a max_results throws 1242 with pytest.raises( 1243 MlflowException, match=r"Invalid value for max_results." 1244 ) as exception_context: 1245 _search_registered_models(store, query, page_token="evilhax", max_results=1e15) 1246 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1247 1248 1249 def test_search_registered_model_order_by(store): 1250 rms = [] 1251 # explicitly mock the creation_timestamps because timestamps seem to be unstable in Windows 1252 for i in range(50): 1253 rms.append(store.create_registered_model(f"RM{i:03}").name) 1254 time.sleep(0.01) 1255 1256 # test flow with fixed max_results and order_by (test stable order across pages) 1257 returned_rms = [] 1258 query = "name LIKE 'RM%'" 1259 result, token = _search_registered_models( 1260 store, query, page_token=None, order_by=["name DESC"], max_results=5 1261 ) 1262 returned_rms.extend(result) 1263 while token: 1264 result, token = _search_registered_models( 1265 store, query, page_token=token, order_by=["name DESC"], max_results=5 1266 ) 1267 returned_rms.extend(result) 1268 # name descending should be the opposite order of the current order 1269 assert returned_rms == rms[::-1] 1270 # last_updated_timestamp descending should have the newest RMs first 1271 res = _search_registered_models( 1272 store, 1273 query, 1274 page_token=None, 1275 order_by=["last_updated_timestamp DESC"], 1276 max_results=100, 1277 ) 1278 assert res.names == rms[::-1] 1279 # last_updated_timestamp ascending should have the oldest RMs first 1280 res = _search_registered_models( 1281 store, 1282 query, 1283 page_token=None, 1284 order_by=["last_updated_timestamp ASC"], 1285 max_results=100, 1286 ) 1287 assert res.names == rms 1288 # name ascending should have the original order 1289 res = _search_registered_models( 1290 store, query, page_token=None, order_by=["name ASC"], max_results=100 1291 ) 1292 assert res.names == rms 1293 # test that no ASC/DESC defaults to ASC 1294 res = _search_registered_models( 1295 store, 1296 query, 1297 page_token=None, 1298 order_by=["last_updated_timestamp"], 1299 max_results=100, 1300 ) 1301 assert res.names == rms 1302 with mock.patch( 1303 "mlflow.store.model_registry.file_store.get_current_time_millis", return_value=1 1304 ): 1305 rm1 = store.create_registered_model("MR1").name 1306 rm2 = store.create_registered_model("MR2").name 1307 with mock.patch( 1308 "mlflow.store.model_registry.file_store.get_current_time_millis", return_value=2 1309 ): 1310 rm3 = store.create_registered_model("MR3").name 1311 rm4 = store.create_registered_model("MR4").name 1312 query = "name LIKE 'MR%'" 1313 # test with multiple clauses 1314 res = _search_registered_models( 1315 store, 1316 query, 1317 page_token=None, 1318 order_by=["last_updated_timestamp ASC", "name DESC"], 1319 max_results=100, 1320 ) 1321 assert res.names == [rm2, rm1, rm4, rm3] 1322 # confirm that name ascending is the default, even if ties exist on other fields 1323 res = _search_registered_models(store, query, page_token=None, order_by=[], max_results=100) 1324 assert res.names == [rm1, rm2, rm3, rm4] 1325 # test default tiebreak with descending timestamps 1326 res = _search_registered_models( 1327 store, 1328 query, 1329 page_token=None, 1330 order_by=["last_updated_timestamp DESC"], 1331 max_results=100, 1332 ) 1333 assert res.names == [rm3, rm4, rm1, rm2] 1334 1335 1336 def test_search_registered_model_order_by_errors(store): 1337 store.create_registered_model("dummy") 1338 query = "name LIKE 'RM%'" 1339 # test that invalid columns throw even if they come after valid columns 1340 with pytest.raises( 1341 MlflowException, match="Invalid attribute key 'description' specified." 1342 ) as exception_context: 1343 _search_registered_models( 1344 store, 1345 query, 1346 page_token=None, 1347 order_by=["name ASC", "description DESC"], 1348 max_results=5, 1349 ) 1350 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1351 # test that invalid columns with random text throw even if they come after valid columns 1352 with pytest.raises(MlflowException, match=r"Invalid order_by clause '.+'") as exception_context: 1353 _search_registered_models( 1354 store, 1355 query, 1356 page_token=None, 1357 order_by=["name ASC", "last_updated_timestamp DESC blah"], 1358 max_results=5, 1359 ) 1360 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1361 1362 1363 def test_set_model_version_tag(store): 1364 name1 = "SetModelVersionTag_TestMod" 1365 name2 = "SetModelVersionTag_TestMod 2" 1366 initial_tags = [ 1367 ModelVersionTag("key", "value"), 1368 ModelVersionTag("anotherKey", "some other value"), 1369 ] 1370 store.create_registered_model(name1) 1371 store.create_registered_model(name2) 1372 run_id_1 = uuid.uuid4().hex 1373 run_id_2 = uuid.uuid4().hex 1374 run_id_3 = uuid.uuid4().hex 1375 store.create_model_version(name1, "A/B", run_id_1, initial_tags) 1376 store.create_model_version(name1, "A/C", run_id_2, initial_tags) 1377 store.create_model_version(name2, "A/D", run_id_3, initial_tags) 1378 new_tag = ModelVersionTag("randomTag", "not a random value") 1379 store.set_model_version_tag(name1, 1, new_tag) 1380 all_tags = [*initial_tags, new_tag] 1381 rm1mv1 = store.get_model_version(name1, 1) 1382 assert rm1mv1.tags == {tag.key: tag.value for tag in all_tags} 1383 1384 # test overriding a tag with the same key 1385 overriding_tag = ModelVersionTag("key", "overriding") 1386 store.set_model_version_tag(name1, 1, overriding_tag) 1387 all_tags = [tag for tag in all_tags if tag.key != "key"] + [overriding_tag] 1388 rm1mv1 = store.get_model_version(name1, 1) 1389 assert rm1mv1.tags == {tag.key: tag.value for tag in all_tags} 1390 # does not affect other model versions with the same key 1391 rm1mv2 = store.get_model_version(name1, 2) 1392 rm2mv1 = store.get_model_version(name2, 1) 1393 assert rm1mv2.tags == {tag.key: tag.value for tag in initial_tags} 1394 assert rm2mv1.tags == {tag.key: tag.value for tag in initial_tags} 1395 1396 # can not set tag on deleted (non-existed) model version 1397 store.delete_model_version(name1, 2) 1398 with pytest.raises( 1399 MlflowException, match=rf"Model Version \(name={name1}, version=2\) not found" 1400 ) as exception_context: 1401 store.set_model_version_tag(name1, 2, overriding_tag) 1402 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 1403 # test cannot set tags that are too long 1404 long_tag = ModelVersionTag("longTagKey", "a" * 100_001) 1405 with pytest.raises( 1406 MlflowException, 1407 match=r"'value' exceeds the maximum length of \d+ characters", 1408 ) as exception_context: 1409 store.set_model_version_tag(name1, 1, long_tag) 1410 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1411 # test can set tags that are somewhat long 1412 long_tag = ModelVersionTag("longTagKey", "a" * 4999) 1413 store.set_model_version_tag(name1, 1, long_tag) 1414 # can not set invalid tag 1415 with pytest.raises( 1416 MlflowException, match=r"Missing value for required parameter 'key'" 1417 ) as exception_context: 1418 store.set_model_version_tag(name2, 1, ModelVersionTag(key=None, value="")) 1419 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1420 # can not use invalid model name or version 1421 with pytest.raises( 1422 MlflowException, match=r"Missing value for required parameter 'name'\." 1423 ) as exception_context: 1424 store.set_model_version_tag(None, 1, ModelVersionTag(key="key", value="value")) 1425 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1426 with pytest.raises( 1427 MlflowException, 1428 match=r"Parameter 'version' must be an integer, got 'I am not a version'.", 1429 ) as exception_context: 1430 store.set_model_version_tag( 1431 name2, "I am not a version", ModelVersionTag(key="key", value="value") 1432 ) 1433 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1434 1435 1436 def test_delete_model_version_tag(store): 1437 name1 = "DeleteModelVersionTag_TestMod" 1438 name2 = "DeleteModelVersionTag_TestMod 2" 1439 initial_tags = [ 1440 ModelVersionTag("key", "value"), 1441 ModelVersionTag("anotherKey", "some other value"), 1442 ] 1443 store.create_registered_model(name1) 1444 store.create_registered_model(name2) 1445 run_id_1 = uuid.uuid4().hex 1446 run_id_2 = uuid.uuid4().hex 1447 run_id_3 = uuid.uuid4().hex 1448 store.create_model_version(name1, "A/B", run_id_1, initial_tags) 1449 store.create_model_version(name1, "A/C", run_id_2, initial_tags) 1450 store.create_model_version(name2, "A/D", run_id_3, initial_tags) 1451 new_tag = ModelVersionTag("randomTag", "not a random value") 1452 store.set_model_version_tag(name1, 1, new_tag) 1453 store.delete_model_version_tag(name1, 1, "randomTag") 1454 rm1mv1 = store.get_model_version(name1, 1) 1455 assert rm1mv1.tags == {tag.key: tag.value for tag in initial_tags} 1456 1457 # testing deleting a key does not affect other model versions with the same key 1458 store.delete_model_version_tag(name1, 1, "key") 1459 rm1mv1 = store.get_model_version(name1, 1) 1460 rm1mv2 = store.get_model_version(name1, 2) 1461 rm2mv1 = store.get_model_version(name2, 1) 1462 assert rm1mv1.tags == {"anotherKey": "some other value"} 1463 assert rm1mv2.tags == {tag.key: tag.value for tag in initial_tags} 1464 assert rm2mv1.tags == {tag.key: tag.value for tag in initial_tags} 1465 1466 # delete tag that is already deleted does nothing 1467 store.delete_model_version_tag(name1, 1, "key") 1468 rm1mv1 = store.get_model_version(name1, 1) 1469 assert rm1mv1.tags == {"anotherKey": "some other value"} 1470 1471 # can not delete tag on deleted (non-existed) model version 1472 store.delete_model_version(name2, 1) 1473 with pytest.raises( 1474 MlflowException, match=rf"Model Version \(name={name2}, version=1\) not found" 1475 ) as exception_context: 1476 store.delete_model_version_tag(name2, 1, "key") 1477 assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 1478 # can not delete tag with invalid key 1479 with pytest.raises( 1480 MlflowException, match=r"Missing value for required parameter 'key'" 1481 ) as exception_context: 1482 store.delete_model_version_tag(name1, 2, None) 1483 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1484 # can not use invalid model name or version 1485 with pytest.raises( 1486 MlflowException, match=r"Missing value for required parameter 'name'\." 1487 ) as exception_context: 1488 store.delete_model_version_tag(None, 2, "key") 1489 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1490 with pytest.raises( 1491 MlflowException, match=r"Parameter 'version' must be an integer, got 'I am not a version'\." 1492 ) as exception_context: 1493 store.delete_model_version_tag(name1, "I am not a version", "key") 1494 assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1495 1496 1497 def _setup_and_test_aliases(store, model_name): 1498 store.create_registered_model(model_name) 1499 run_id_1 = uuid.uuid4().hex 1500 run_id_2 = uuid.uuid4().hex 1501 store.create_model_version(model_name, "v1", run_id_1) 1502 store.create_model_version(model_name, "v2", run_id_2) 1503 store.set_registered_model_alias(model_name, "test_alias", "2") 1504 model = store.get_registered_model(model_name) 1505 assert model.aliases == {"test_alias": "2"} 1506 mv1 = store.get_model_version(model_name, 1) 1507 mv2 = store.get_model_version(model_name, 2) 1508 assert mv1.aliases == [] 1509 assert mv2.aliases == ["test_alias"] 1510 1511 1512 def test_set_registered_model_alias(store): 1513 _setup_and_test_aliases(store, "SetRegisteredModelAlias_TestMod") 1514 1515 1516 def test_delete_registered_model_alias(store): 1517 model_name = "DeleteRegisteredModelAlias_TestMod" 1518 _setup_and_test_aliases(store, model_name) 1519 store.delete_registered_model_alias(model_name, "test_alias") 1520 model = store.get_registered_model(model_name) 1521 assert model.aliases == {} 1522 mv2 = store.get_model_version(model_name, 2) 1523 assert mv2.aliases == [] 1524 with pytest.raises(MlflowException, match=r"Registered model alias test_alias not found."): 1525 store.get_model_version_by_alias(model_name, "test_alias") 1526 1527 1528 def test_get_model_version_by_alias(store): 1529 model_name = "GetModelVersionByAlias_TestMod" 1530 _setup_and_test_aliases(store, model_name) 1531 mv = store.get_model_version_by_alias(model_name, "test_alias") 1532 assert mv.aliases == ["test_alias"] 1533 1534 1535 def test_delete_model_version_deletes_alias(store): 1536 model_name = "DeleteModelVersionDeletesAlias_TestMod" 1537 _setup_and_test_aliases(store, model_name) 1538 store.delete_model_version(model_name, 2) 1539 model = store.get_registered_model(model_name) 1540 assert model.aliases == {} 1541 with pytest.raises(MlflowException, match=r"Registered model alias test_alias not found."): 1542 store.get_model_version_by_alias(model_name, "test_alias") 1543 1544 1545 def test_delete_model_deletes_alias(store): 1546 model_name = "DeleteModelDeletesAlias_TestMod" 1547 _setup_and_test_aliases(store, model_name) 1548 store.delete_registered_model(model_name) 1549 with pytest.raises( 1550 MlflowException, 1551 match=r"Registered Model with name=DeleteModelDeletesAlias_TestMod not found", 1552 ): 1553 store.get_model_version_by_alias(model_name, "test_alias") 1554 1555 1556 def test_pyfunc_model_registry_with_file_store(store): 1557 import mlflow 1558 from mlflow.pyfunc import PythonModel 1559 1560 class MyModel(PythonModel): 1561 def predict(self, context, model_input, params=None): 1562 return 7 1563 1564 mlflow.set_registry_uri(path_to_local_file_uri(store.root_directory)) 1565 with mlflow.start_run(): 1566 mlflow.pyfunc.log_model(name="foo", python_model=MyModel(), registered_model_name="model1") 1567 mlflow.pyfunc.log_model(name="foo", python_model=MyModel(), registered_model_name="model2") 1568 mlflow.pyfunc.log_model( 1569 name="model", python_model=MyModel(), registered_model_name="model1" 1570 ) 1571 1572 with mlflow.start_run(): 1573 mlflow.log_param("A", "B") 1574 1575 models = store.search_registered_models(max_results=10) 1576 assert len(models) == 2 1577 assert models[0].name == "model1" 1578 assert models[1].name == "model2" 1579 mv1 = store.search_model_versions("name = 'model1'", max_results=10) 1580 assert len(mv1) == 2 1581 assert mv1[0].name == "model1" 1582 mv2 = store.search_model_versions("name = 'model2'", max_results=10) 1583 assert len(mv2) == 1 1584 assert mv2[0].name == "model2" 1585 1586 1587 @pytest.mark.parametrize("copy_to_same_model", [False, True]) 1588 def test_copy_model_version(store, copy_to_same_model): 1589 name1 = "test_for_copy_MV1" 1590 store.create_registered_model(name1) 1591 src_tags = [ 1592 ModelVersionTag("key", "value"), 1593 ModelVersionTag("anotherKey", "some other value"), 1594 ] 1595 src_mv = _create_model_version( 1596 store, 1597 name1, 1598 tags=src_tags, 1599 run_link="dummylink", 1600 description="test description", 1601 ) 1602 1603 # Make some changes to the src MV that won't be copied over 1604 store.transition_model_version_stage( 1605 name1, src_mv.version, "Production", archive_existing_versions=False 1606 ) 1607 1608 copy_rm_name = name1 if copy_to_same_model else "test_for_copy_MV2" 1609 copy_mv_version = 2 if copy_to_same_model else 1 1610 timestamp = time.time() 1611 dst_mv = store.copy_model_version(src_mv, copy_rm_name) 1612 assert dst_mv.name == copy_rm_name 1613 assert dst_mv.version == copy_mv_version 1614 1615 copied_mv = store.get_model_version(dst_mv.name, dst_mv.version) 1616 assert copied_mv.name == copy_rm_name 1617 assert copied_mv.version == copy_mv_version 1618 assert copied_mv.current_stage == "None" 1619 assert copied_mv.creation_timestamp >= timestamp 1620 assert copied_mv.last_updated_timestamp >= timestamp 1621 assert copied_mv.description == "test description" 1622 assert copied_mv.source == f"models:/{src_mv.name}/{src_mv.version}" 1623 assert store.get_model_version_download_uri(dst_mv.name, dst_mv.version) == src_mv.source 1624 assert copied_mv.run_link == "dummylink" 1625 assert copied_mv.run_id == src_mv.run_id 1626 assert copied_mv.status == "READY" 1627 assert copied_mv.status_message is None 1628 assert copied_mv.tags == {"key": "value", "anotherKey": "some other value"} 1629 1630 # Copy a model version copy 1631 double_copy_mv = store.copy_model_version(copied_mv, "test_for_copy_MV3") 1632 assert double_copy_mv.source == f"models:/{copied_mv.name}/{copied_mv.version}" 1633 assert store.get_model_version_download_uri(dst_mv.name, dst_mv.version) == src_mv.source 1634 1635 1636 def test_writing_model_version_preserves_storage_location(store): 1637 name = "test_storage_location_MV1" 1638 source = "/special/source" 1639 store.create_registered_model(name) 1640 _create_model_version(store, name, source=source) 1641 _create_model_version(store, name, source=source) 1642 1643 # Run through all the operations that modify model versions and make sure that the 1644 # `storage_location` property is not dropped. 1645 store.transition_model_version_stage(name, 1, "Production", archive_existing_versions=False) 1646 assert store._fetch_file_model_version_if_exists(name, 1).storage_location == source 1647 store.update_model_version(name, 1, description="test description") 1648 assert store._fetch_file_model_version_if_exists(name, 1).storage_location == source 1649 store.transition_model_version_stage(name, 1, "Production", archive_existing_versions=True) 1650 assert store._fetch_file_model_version_if_exists(name, 1).storage_location == source 1651 store.rename_registered_model(name, "test_storage_location_new") 1652 assert ( 1653 store._fetch_file_model_version_if_exists("test_storage_location_new", 1).storage_location 1654 == source 1655 ) 1656 1657 1658 def test_search_prompts(store): 1659 store.create_registered_model("model", tags=[RegisteredModelTag(key="fruit", value="apple")]) 1660 1661 store.create_registered_model( 1662 "prompt_1", tags=[RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")] 1663 ) 1664 store.create_registered_model( 1665 "prompt_2", 1666 tags=[ 1667 RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true"), 1668 RegisteredModelTag(key="fruit", value="apple"), 1669 ], 1670 ) 1671 1672 # By default, should not return prompts 1673 rms = store.search_registered_models(max_results=10) 1674 assert len(rms) == 1 1675 assert rms[0].name == "model" 1676 1677 rms = store.search_registered_models(filter_string="tags.fruit = 'apple'", max_results=10) 1678 assert len(rms) == 1 1679 assert rms[0].name == "model" 1680 1681 rms = store.search_registered_models(filter_string="name = 'prompt_1'", max_results=10) 1682 assert len(rms) == 0 1683 1684 rms = store.search_registered_models( 1685 filter_string="tags.`mlflow.prompt.is_prompt` = 'false'", max_results=10 1686 ) 1687 assert len(rms) == 1 1688 assert rms[0].name == "model" 1689 1690 rms = store.search_registered_models( 1691 filter_string="tags.`mlflow.prompt.is_prompt` != 'true'", max_results=10 1692 ) 1693 assert len(rms) == 1 1694 assert rms[0].name == "model" 1695 1696 # Search for prompts 1697 rms = store.search_registered_models( 1698 filter_string="tags.`mlflow.prompt.is_prompt` = 'true'", max_results=10 1699 ) 1700 assert len(rms) == 2 1701 assert {rm.name for rm in rms} == {"prompt_1", "prompt_2"} 1702 1703 rms = store.search_registered_models( 1704 filter_string="name = 'prompt_1' and tags.`mlflow.prompt.is_prompt` = 'true'", 1705 max_results=10, 1706 ) 1707 assert len(rms) == 1 1708 assert rms[0].name == "prompt_1" 1709 1710 rms = store.search_registered_models( 1711 filter_string="tags.`mlflow.prompt.is_prompt` = 'true' and tags.fruit = 'apple'", 1712 max_results=10, 1713 ) 1714 assert len(rms) == 1 1715 assert rms[0].name == "prompt_2" 1716 1717 1718 def test_search_prompts_versions(store): 1719 # A Model 1720 store.create_registered_model("model") 1721 store.create_model_version( 1722 "model", "1", "dummy_source", tags=[ModelVersionTag(key="fruit", value="apple")] 1723 ) 1724 1725 # A Prompt with 1 version 1726 store.create_registered_model( 1727 "prompt_1", tags=[RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")] 1728 ) 1729 store.create_model_version( 1730 "prompt_1", "1", "dummy_source", tags=[ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true")] 1731 ) 1732 1733 # A Prompt with 2 versions 1734 store.create_registered_model( 1735 "prompt_2", 1736 tags=[RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")], 1737 ) 1738 store.create_model_version( 1739 "prompt_2", 1740 "1", 1741 "dummy_source", 1742 tags=[ 1743 ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true"), 1744 ModelVersionTag(key="fruit", value="apple"), 1745 ], 1746 ) 1747 store.create_model_version( 1748 "prompt_2", 1749 "2", 1750 "dummy_source", 1751 tags=[ 1752 ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true"), 1753 ModelVersionTag(key="fruit", value="orange"), 1754 ], 1755 ) 1756 1757 # Searching model versions should not return prompts by default either 1758 mvs = store.search_model_versions(max_results=10) 1759 assert len(mvs) == 1 1760 assert mvs[0].name == "model" 1761 1762 mvs = store.search_model_versions(filter_string="tags.fruit = 'apple'", max_results=10) 1763 assert len(mvs) == 1 1764 assert mvs[0].name == "model" 1765 1766 mvs = store.search_model_versions( 1767 filter_string="tags.`mlflow.prompt.is_prompt` = 'false'", max_results=10 1768 ) 1769 assert len(mvs) == 1 1770 assert mvs[0].name == "model" 1771 1772 mvs = store.search_model_versions( 1773 filter_string="tags.`mlflow.prompt.is_prompt` != 'true'", max_results=10 1774 ) 1775 assert len(mvs) == 1 1776 assert mvs[0].name == "model" 1777 1778 # Search for prompts via search_model_versions 1779 mvs = store.search_model_versions( 1780 filter_string="tags.`mlflow.prompt.is_prompt` = 'true'", max_results=10 1781 ) 1782 assert len(mvs) == 3 1783 1784 mvs = store.search_model_versions( 1785 filter_string="tags.`mlflow.prompt.is_prompt` = 'true' and name = 'prompt_2'", 1786 max_results=10, 1787 ) 1788 assert len(mvs) == 2 1789 1790 mvs = store.search_model_versions( 1791 filter_string="tags.`mlflow.prompt.is_prompt` = 'true' and tags.fruit = 'apple'", 1792 max_results=10, 1793 ) 1794 assert len(mvs) == 1 1795 assert mvs[0].name == "prompt_2" 1796 1797 1798 def test_create_registered_model_handle_prompt_properly(store): 1799 prompt_tags = [RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")] 1800 1801 store.create_registered_model("model") 1802 1803 store.create_registered_model("prompt", tags=prompt_tags) 1804 1805 with pytest.raises(MlflowException, match=r"Registered Model \(name=model\) already exists"): 1806 store.create_registered_model("model") 1807 1808 with pytest.raises(MlflowException, match=r"Prompt \(name=prompt\) already exists"): 1809 store.create_registered_model("prompt", tags=prompt_tags) 1810 1811 with pytest.raises( 1812 MlflowException, 1813 match=r"Tried to create a prompt with name 'model', " 1814 r"but the name is already taken by a registered model.", 1815 ): 1816 store.create_registered_model("model", tags=prompt_tags) 1817 1818 with pytest.raises( 1819 MlflowException, 1820 match=r"Tried to create a registered model with name 'prompt', " 1821 r"but the name is already taken by a prompt.", 1822 ): 1823 store.create_registered_model("prompt") 1824 1825 1826 def test_create_model_version_with_model_id_and_no_run_id(store: FileStore): 1827 class SimpleModel(PythonModel): 1828 def predict(self, context, model_input, params=None): 1829 return model_input 1830 1831 name = "test_model_with_model_id" 1832 store.create_registered_model(name) 1833 1834 with mlflow.start_run() as run: 1835 model_info = mlflow.pyfunc.log_model( 1836 name="model", 1837 python_model=SimpleModel(), 1838 ) 1839 run_id = run.info.run_id 1840 model_id = model_info.model_id 1841 1842 mv = store.create_model_version( 1843 name=name, 1844 source="/absolute/path/to/source", 1845 run_id=None, 1846 model_id=model_id, 1847 ) 1848 1849 assert mv.run_id == run_id 1850 1851 mvd = store.get_model_version(name=mv.name, version=mv.version) 1852 assert mvd.run_id == run_id 1853 1854 1855 def test_update_model_version_with_model_id_and_metrics(store: FileStore): 1856 class SimpleModel(PythonModel): 1857 def predict(self, context, model_input, params=None): 1858 return model_input 1859 1860 name = "test_model_with_model_id_and_metrics" 1861 store.create_registered_model(name) 1862 1863 with mlflow.start_run() as run: 1864 mlflow.log_param("learning_rate", "0.001") 1865 model_info = mlflow.pyfunc.log_model( 1866 name="model", 1867 python_model=SimpleModel(), 1868 ) 1869 mlflow.log_metric("execution_time", 33.74682, step=0, model_id=model_info.model_id) 1870 run_id = run.info.run_id 1871 model_id = model_info.model_id 1872 1873 mv = store.create_model_version( 1874 name=name, 1875 source="/absolute/path/to/source", 1876 run_id=None, 1877 model_id=model_id, 1878 ) 1879 1880 assert mv.run_id == run_id 1881 1882 mvd = store.get_model_version(name=mv.name, version=mv.version) 1883 assert mvd.run_id == run_id 1884 assert mvd.metrics is not None 1885 assert len(mvd.metrics) == 1 1886 assert mvd.metrics[0].key == "execution_time" 1887 assert mvd.metrics[0].value == 33.74682 1888 assert mvd.params == {"learning_rate": "0.001"} 1889 1890 updated_mvd = store.update_model_version( 1891 name=mv.name, 1892 version=mv.version, 1893 description="Test description with metrics", 1894 ) 1895 assert updated_mvd.description == "Test description with metrics" 1896 1897 retrieved_mvd = store.get_model_version(name=mv.name, version=mv.version) 1898 assert retrieved_mvd.description == "Test description with metrics" 1899 assert retrieved_mvd.metrics is not None 1900 assert len(retrieved_mvd.metrics) == 1 1901 assert retrieved_mvd.metrics[0].key == "execution_time" 1902 assert retrieved_mvd.params == {"learning_rate": "0.001"}