/ tests / store / model_registry / test_file_store.py
test_file_store.py
   1  import pytest
   2  
   3  pytestmark = pytest.mark.skip(reason="FileStore is no longer supported")
   4  
   5  import time
   6  import uuid
   7  from typing import NamedTuple
   8  from unittest import mock
   9  
  10  import pytest
  11  
  12  import mlflow
  13  from mlflow.entities.model_registry import (
  14      ModelVersion,
  15      ModelVersionTag,
  16      RegisteredModelTag,
  17  )
  18  from mlflow.entities.model_registry.prompt_version import IS_PROMPT_TAG_KEY
  19  from mlflow.exceptions import MlflowException
  20  from mlflow.protos.databricks_pb2 import (
  21      INVALID_PARAMETER_VALUE,
  22      RESOURCE_DOES_NOT_EXIST,
  23      ErrorCode,
  24  )
  25  from mlflow.pyfunc import PythonModel
  26  from mlflow.store.model_registry.file_store import FileStore
  27  from mlflow.utils.file_utils import path_to_local_file_uri
  28  from mlflow.utils.time import get_current_time_millis
  29  from mlflow.utils.yaml_utils import write_yaml
  30  
  31  from tests.helper_functions import random_int, random_str
  32  
  33  
  34  @pytest.fixture
  35  def store(tmp_path):
  36      return FileStore(str(tmp_path))
  37  
  38  
  39  @pytest.fixture
  40  def registered_model_names():
  41      return [random_str() for _ in range(3)]
  42  
  43  
  44  @pytest.fixture
  45  def rm_data(registered_model_names, tmp_path):
  46      rm_data = {}
  47      for name in registered_model_names:
  48          # create registered model
  49          rm_folder = tmp_path.joinpath(FileStore.MODELS_FOLDER_NAME, name)
  50          rm_folder.mkdir(parents=True, exist_ok=True)
  51          creation_time = get_current_time_millis()
  52          d = {
  53              "name": name,
  54              "creation_timestamp": creation_time,
  55              "last_updated_timestamp": creation_time,
  56              "description": None,
  57              "latest_versions": [],
  58              "tags": {},
  59          }
  60          rm_data[name] = d
  61          write_yaml(rm_folder, FileStore.META_DATA_FILE_NAME, d)
  62          tags_dir = rm_folder.joinpath(FileStore.TAGS_FOLDER_NAME)
  63          tags_dir.mkdir(parents=True, exist_ok=True)
  64      return rm_data
  65  
  66  
  67  def test_file_store_deprecation_warning(tmp_path):
  68      with pytest.warns(FutureWarning, match="filesystem model registry backend.*is deprecated"):
  69          FileStore(str(tmp_path / "model_registry"))
  70  
  71  
  72  def test_create_registered_model(store):
  73      # Error cases
  74      with pytest.raises(MlflowException, match=r"Missing value for required parameter 'name'\."):
  75          store.create_registered_model(None)
  76      with pytest.raises(MlflowException, match=r"Missing value for required parameter 'name'\."):
  77          store.create_registered_model("")
  78  
  79      name = random_str()
  80      model = store.create_registered_model(name)
  81      assert model.name == name
  82      assert model.latest_versions == []
  83      assert model.creation_timestamp == model.last_updated_timestamp
  84      assert model.tags == {}
  85  
  86  
  87  def test_create_registered_model_with_name_that_looks_like_path(store, tmp_path):
  88      name = str(tmp_path.joinpath("test"))
  89      with pytest.raises(MlflowException, match=r"Names cannot contain '/' or ':'"):
  90          store.get_registered_model(name)
  91  
  92  
  93  def test_create_registered_model_with_percent_in_name(store, tmp_path):
  94      with pytest.raises(
  95          MlflowException, match=r"Registered model name cannot contain '%' character"
  96      ):
  97          store.get_registered_model("m%6fdel")
  98  
  99  
 100  def _verify_registered_model(fs, name, rm_data):
 101      rm = fs.get_registered_model(name)
 102      assert rm.name == name
 103      assert rm.creation_timestamp == rm_data[name]["creation_timestamp"]
 104      assert rm.last_updated_timestamp == rm_data[name]["last_updated_timestamp"]
 105      assert rm.description == rm_data[name]["description"]
 106      assert rm.latest_versions == rm_data[name]["latest_versions"]
 107      assert rm.tags == rm_data[name]["tags"]
 108  
 109  
 110  def test_get_registered_model(store, registered_model_names, rm_data):
 111      for name in registered_model_names:
 112          _verify_registered_model(store, name, rm_data)
 113  
 114      # test that fake registered models dont exist.
 115      name = random_str()
 116      with pytest.raises(MlflowException, match=f"Registered Model with name={name} not found"):
 117          store.get_registered_model(name)
 118  
 119      name = "../../path"
 120      with pytest.raises(MlflowException, match="Names cannot contain '/' or ':'"):
 121          store.get_registered_model(name)
 122  
 123  
 124  def test_list_registered_model(store, registered_model_names, rm_data):
 125      for rm in store.list_registered_models(max_results=10, page_token=None):
 126          name = rm.name
 127          assert name in registered_model_names
 128          assert name == rm_data[name]["name"]
 129  
 130  
 131  @pytest.mark.usefixtures(rm_data.__name__)
 132  def test_rename_registered_model(store, registered_model_names):
 133      # Error cases
 134      model_name = registered_model_names[0]
 135      with pytest.raises(MlflowException, match=r"Missing value for required parameter 'name'\."):
 136          store.rename_registered_model(model_name, None)
 137  
 138      # test that names of existing registered models are checked before renaming
 139      other_model_name = registered_model_names[1]
 140      with pytest.raises(
 141          MlflowException,
 142          match=rf"Registered Model \(name={other_model_name}\) already exists\.",
 143      ):
 144          store.rename_registered_model(model_name, other_model_name)
 145  
 146      new_name = model_name + "!!!"
 147      store.rename_registered_model(model_name, new_name)
 148      assert store.get_registered_model(new_name).name == new_name
 149  
 150  
 151  def _extract_names(registered_models):
 152      return [rm.name for rm in registered_models]
 153  
 154  
 155  @pytest.mark.usefixtures(rm_data.__name__)
 156  def test_delete_registered_model(store, registered_model_names):
 157      model_name = registered_model_names[random_int(0, len(registered_model_names) - 1)]
 158  
 159      # Error cases
 160      with pytest.raises(
 161          MlflowException, match=f"Registered Model with name={model_name}!!! not found"
 162      ):
 163          store.delete_registered_model(model_name + "!!!")
 164  
 165      store.delete_registered_model(model_name)
 166      assert model_name not in _extract_names(
 167          store.list_registered_models(max_results=10, page_token=None)
 168      )
 169      # Cannot delete a deleted model
 170      with pytest.raises(MlflowException, match=f"Registered Model with name={model_name} not found"):
 171          store.delete_registered_model(model_name)
 172  
 173  
 174  def test_list_registered_model_paginated(store):
 175      for _ in range(10):
 176          store.create_registered_model(random_str())
 177      rms1 = store.list_registered_models(max_results=4, page_token=None)
 178      assert len(rms1) == 4
 179      assert rms1.token is not None
 180      rms2 = store.list_registered_models(max_results=4, page_token=None)
 181      assert len(rms2) == 4
 182      assert rms2.token is not None
 183      assert rms1 == rms2
 184      rms3 = store.list_registered_models(max_results=500, page_token=rms2.token)
 185      assert len(rms3) == 6
 186      assert rms3.token is None
 187  
 188  
 189  def test_list_registered_model_paginated_returns_in_correct_order(store):
 190      rms = [store.create_registered_model(f"RM{i:03}").name for i in range(50)]
 191  
 192      # test that pagination will return all valid results in sorted order
 193      # by name ascending
 194      result = store.list_registered_models(max_results=5, page_token=None)
 195      assert result.token is not None
 196      assert _extract_names(result) == rms[0:5]
 197  
 198      result = store.list_registered_models(page_token=result.token, max_results=10)
 199      assert result.token is not None
 200      assert _extract_names(result) == rms[5:15]
 201  
 202      result = store.list_registered_models(page_token=result.token, max_results=20)
 203      assert result.token is not None
 204      assert _extract_names(result) == rms[15:35]
 205  
 206      result = store.list_registered_models(page_token=result.token, max_results=100)
 207      assert result.token is None
 208      assert _extract_names(result) == rms[35:]
 209  
 210  
 211  def test_list_registered_model_paginated_errors(store):
 212      rms = [store.create_registered_model(f"RM{i:03}").name for i in range(50)]
 213      # test that providing a completely invalid page token throws
 214      with pytest.raises(
 215          MlflowException, match=r"Invalid page token, could not base64-decode"
 216      ) as exception_context:
 217          store.list_registered_models(page_token="evilhax", max_results=20)
 218      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 219  
 220      # test that providing too large of a max_results throws
 221      with pytest.raises(
 222          MlflowException, match=r"Invalid value for max_results"
 223      ) as exception_context:
 224          store.list_registered_models(page_token="evilhax", max_results=1e15)
 225      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 226      # list should not return deleted models
 227      store.delete_registered_model(name="RM000")
 228      assert set(
 229          _extract_names(store.list_registered_models(max_results=100, page_token=None))
 230      ) == set(rms[1:])
 231  
 232  
 233  def _create_model_version(
 234      fs,
 235      name,
 236      source="path/to/source",
 237      run_id=uuid.uuid4().hex,
 238      tags=None,
 239      run_link=None,
 240      description=None,
 241  ):
 242      time.sleep(0.001)
 243      return fs.create_model_version(
 244          name, source, run_id, tags, run_link=run_link, description=description
 245      )
 246  
 247  
 248  def _stage_to_version_map(latest_versions):
 249      return {mvd.current_stage: mvd.version for mvd in latest_versions}
 250  
 251  
 252  def test_get_latest_versions(store):
 253      name = "test_for_latest_versions"
 254      rmd1 = store.create_registered_model(name)
 255      assert rmd1.latest_versions == []
 256  
 257      mv1 = _create_model_version(store, name)
 258      assert mv1.version == 1
 259      rmd2 = store.get_registered_model(name)
 260      assert _stage_to_version_map(rmd2.latest_versions) == {"None": 1}
 261  
 262      # add a bunch more
 263      mv2 = _create_model_version(store, name)
 264      assert mv2.version == 2
 265      store.transition_model_version_stage(
 266          name=mv2.name,
 267          version=mv2.version,
 268          stage="Production",
 269          archive_existing_versions=False,
 270      )
 271  
 272      mv3 = _create_model_version(store, name)
 273      assert mv3.version == 3
 274      store.transition_model_version_stage(
 275          name=mv3.name,
 276          version=mv3.version,
 277          stage="Production",
 278          archive_existing_versions=False,
 279      )
 280      mv4 = _create_model_version(store, name)
 281      assert mv4.version == 4
 282      store.transition_model_version_stage(
 283          name=mv4.name,
 284          version=mv4.version,
 285          stage="Staging",
 286          archive_existing_versions=False,
 287      )
 288  
 289      # test that correct latest versions are returned for each stage
 290      rmd4 = store.get_registered_model(name)
 291      assert _stage_to_version_map(rmd4.latest_versions) == {
 292          "None": 1,
 293          "Production": 3,
 294          "Staging": 4,
 295      }
 296      assert _stage_to_version_map(store.get_latest_versions(name=name, stages=None)) == {
 297          "None": 1,
 298          "Production": 3,
 299          "Staging": 4,
 300      }
 301      assert _stage_to_version_map(store.get_latest_versions(name=name, stages=[])) == {
 302          "None": 1,
 303          "Production": 3,
 304          "Staging": 4,
 305      }
 306      assert _stage_to_version_map(store.get_latest_versions(name=name, stages=["Production"])) == {
 307          "Production": 3
 308      }
 309      assert _stage_to_version_map(store.get_latest_versions(name=name, stages=["production"])) == {
 310          "Production": 3
 311      }  # The stages are case insensitive.
 312      assert _stage_to_version_map(store.get_latest_versions(name=name, stages=["pROduction"])) == {
 313          "Production": 3
 314      }  # The stages are case insensitive.
 315      assert _stage_to_version_map(
 316          store.get_latest_versions(name=name, stages=["None", "Production"])
 317      ) == {"None": 1, "Production": 3}
 318  
 319      # delete latest Production, and should point to previous one
 320      store.delete_model_version(name=mv3.name, version=mv3.version)
 321      rmd5 = store.get_registered_model(name=name)
 322      assert _stage_to_version_map(rmd5.latest_versions) == {
 323          "None": 1,
 324          "Production": 2,
 325          "Staging": 4,
 326      }
 327      assert _stage_to_version_map(store.get_latest_versions(name=name, stages=None)) == {
 328          "None": 1,
 329          "Production": 2,
 330          "Staging": 4,
 331      }
 332      assert _stage_to_version_map(store.get_latest_versions(name=name, stages=["Production"])) == {
 333          "Production": 2
 334      }
 335  
 336  
 337  def test_set_registered_model_tag(store):
 338      name1 = "SetRegisteredModelTag_TestMod"
 339      name2 = "SetRegisteredModelTag_TestMod 2"
 340      initial_tags = [
 341          RegisteredModelTag("key", "value"),
 342          RegisteredModelTag("anotherKey", "some other value"),
 343      ]
 344      store.create_registered_model(name1, initial_tags)
 345      store.create_registered_model(name2, initial_tags)
 346      new_tag = RegisteredModelTag("randomTag", "not a random value")
 347      store.set_registered_model_tag(name1, new_tag)
 348      rm1 = store.get_registered_model(name=name1)
 349      all_tags = [*initial_tags, new_tag]
 350      assert rm1.tags == {tag.key: tag.value for tag in all_tags}
 351  
 352      # test overriding a tag with the same key
 353      overriding_tag = RegisteredModelTag("key", "overriding")
 354      store.set_registered_model_tag(name1, overriding_tag)
 355      all_tags = [tag for tag in all_tags if tag.key != "key"] + [overriding_tag]
 356      rm1 = store.get_registered_model(name=name1)
 357      assert rm1.tags == {tag.key: tag.value for tag in all_tags}
 358      # does not affect other models with the same key
 359      rm2 = store.get_registered_model(name=name2)
 360      assert rm2.tags == {tag.key: tag.value for tag in initial_tags}
 361  
 362      # can not set tag on deleted (non-existed) registered model
 363      store.delete_registered_model(name1)
 364      with pytest.raises(
 365          MlflowException, match=f"Registered Model with name={name1} not found"
 366      ) as exception_context:
 367          store.set_registered_model_tag(name1, overriding_tag)
 368      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
 369      # test cannot set tags that are too long
 370      long_tag = RegisteredModelTag("longTagKey", "a" * 100_001)
 371      with pytest.raises(
 372          MlflowException,
 373          match=r"'value' exceeds the maximum length of \d+ characters",
 374      ) as exception_context:
 375          store.set_registered_model_tag(name2, long_tag)
 376      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 377      # test can set tags that are somewhat long
 378      long_tag = RegisteredModelTag("longTagKey", "a" * 4999)
 379      store.set_registered_model_tag(name2, long_tag)
 380      # can not set invalid tag
 381      with pytest.raises(
 382          MlflowException, match=r"Missing value for required parameter 'key'"
 383      ) as exception_context:
 384          store.set_registered_model_tag(name2, RegisteredModelTag(key=None, value=""))
 385      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 386      # can not use invalid model name
 387      with pytest.raises(
 388          MlflowException, match=r"Missing value for required parameter 'name'\."
 389      ) as exception_context:
 390          store.set_registered_model_tag(None, RegisteredModelTag(key="key", value="value"))
 391      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 392  
 393  
 394  def test_delete_registered_model_tag(store):
 395      name1 = "DeleteRegisteredModelTag_TestMod"
 396      name2 = "DeleteRegisteredModelTag_TestMod 2"
 397      initial_tags = [
 398          RegisteredModelTag("key", "value"),
 399          RegisteredModelTag("anotherKey", "some other value"),
 400      ]
 401      store.create_registered_model(name1, initial_tags)
 402      store.create_registered_model(name2, initial_tags)
 403      new_tag = RegisteredModelTag("randomTag", "not a random value")
 404      store.set_registered_model_tag(name1, new_tag)
 405      store.delete_registered_model_tag(name1, "randomTag")
 406      rm1 = store.get_registered_model(name=name1)
 407      assert rm1.tags == {tag.key: tag.value for tag in initial_tags}
 408  
 409      # testing deleting a key does not affect other models with the same key
 410      store.delete_registered_model_tag(name1, "key")
 411      rm1 = store.get_registered_model(name=name1)
 412      rm2 = store.get_registered_model(name=name2)
 413      assert rm1.tags == {"anotherKey": "some other value"}
 414      assert rm2.tags == {tag.key: tag.value for tag in initial_tags}
 415  
 416      # delete tag that is already deleted does nothing
 417      store.delete_registered_model_tag(name1, "key")
 418      rm1 = store.get_registered_model(name=name1)
 419      assert rm1.tags == {"anotherKey": "some other value"}
 420  
 421      # can not delete tag on deleted (non-existed) registered model
 422      store.delete_registered_model(name1)
 423      with pytest.raises(
 424          MlflowException, match=f"Registered Model with name={name1} not found"
 425      ) as exception_context:
 426          store.delete_registered_model_tag(name1, "anotherKey")
 427      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
 428      # can not delete tag with invalid key
 429      with pytest.raises(
 430          MlflowException, match=r"Missing value for required parameter 'key'"
 431      ) as exception_context:
 432          store.delete_registered_model_tag(name2, None)
 433      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 434      # can not use invalid model name
 435      with pytest.raises(
 436          MlflowException, match=r"Missing value for required parameter 'name'\."
 437      ) as exception_context:
 438          store.delete_registered_model_tag(None, "key")
 439      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 440  
 441  
 442  def test_create_model_version(store):
 443      name = "test_for_create_MV"
 444      store.create_registered_model(name)
 445      run_id = uuid.uuid4().hex
 446      with mock.patch("time.time", return_value=456778):
 447          mv1 = _create_model_version(store, name, "a/b/CD", run_id)
 448          assert mv1.name == name
 449          assert mv1.version == 1
 450  
 451      mvd1 = store.get_model_version(mv1.name, mv1.version)
 452      assert mvd1.name == name
 453      assert mvd1.version == 1
 454      assert mvd1.current_stage == "None"
 455      assert mvd1.creation_timestamp == 456778000
 456      assert mvd1.last_updated_timestamp == 456778000
 457      assert mvd1.description is None
 458      assert mvd1.source == "a/b/CD"
 459      assert mvd1.run_id == run_id
 460      assert mvd1.status == "READY"
 461      assert mvd1.status_message is None
 462      assert mvd1.tags == {}
 463  
 464      # new model versions for same name autoincrement versions
 465      mv2 = _create_model_version(store, name)
 466      mvd2 = store.get_model_version(name=mv2.name, version=mv2.version)
 467      assert mv2.version == 2
 468      assert mvd2.version == 2
 469  
 470      # create model version with tags return model version entity with tags
 471      tags = [
 472          ModelVersionTag("key", "value"),
 473          ModelVersionTag("anotherKey", "some other value"),
 474      ]
 475      mv3 = _create_model_version(store, name, tags=tags)
 476      mvd3 = store.get_model_version(name=mv3.name, version=mv3.version)
 477      assert mv3.version == 3
 478      assert mv3.tags == {tag.key: tag.value for tag in tags}
 479      assert mvd3.version == 3
 480      assert mvd3.tags == {tag.key: tag.value for tag in tags}
 481  
 482      # create model versions with runLink
 483      run_link = "http://localhost:3000/path/to/run/"
 484      mv4 = _create_model_version(store, name, run_link=run_link)
 485      mvd4 = store.get_model_version(name, mv4.version)
 486      assert mv4.version == 4
 487      assert mv4.run_link == run_link
 488      assert mvd4.version == 4
 489      assert mvd4.run_link == run_link
 490  
 491      # create model version with description
 492      description = "the best model ever"
 493      mv5 = _create_model_version(store, name, description=description)
 494      mvd5 = store.get_model_version(name, mv5.version)
 495      assert mv5.version == 5
 496      assert mv5.description == description
 497      assert mvd5.version == 5
 498      assert mvd5.description == description
 499  
 500      # create model version without runId
 501      mv6 = _create_model_version(store, name, run_id=None)
 502      mvd6 = store.get_model_version(name, mv6.version)
 503      assert mv6.version == 6
 504      assert mv6.run_id is None
 505      assert mvd6.version == 6
 506      assert mvd6.run_id is None
 507  
 508  
 509  def test_update_model_version(store):
 510      name = "test_for_update_MV"
 511      store.create_registered_model(name)
 512      mv1 = _create_model_version(store, name)
 513      mvd1 = store.get_model_version(name=mv1.name, version=mv1.version)
 514      assert mvd1.name == name
 515      assert mvd1.version == 1
 516      assert mvd1.current_stage == "None"
 517  
 518      # update stage
 519      store.transition_model_version_stage(
 520          name=mv1.name,
 521          version=mv1.version,
 522          stage="Production",
 523          archive_existing_versions=False,
 524      )
 525      mvd2 = store.get_model_version(name=mv1.name, version=mv1.version)
 526      assert mvd2.name == name
 527      assert mvd2.version == 1
 528      assert mvd2.current_stage == "Production"
 529      assert mvd2.description is None
 530  
 531      # update description
 532      store.update_model_version(name=mv1.name, version=mv1.version, description="test model version")
 533      mvd3 = store.get_model_version(name=mv1.name, version=mv1.version)
 534      assert mvd3.name == name
 535      assert mvd3.version == 1
 536      assert mvd3.current_stage == "Production"
 537      assert mvd3.description == "test model version"
 538  
 539      # only valid stages can be set
 540      with pytest.raises(
 541          MlflowException,
 542          match=(
 543              r"Invalid Model Version stage: unknown\. "
 544              r"Value must be one of None, Staging, Production, Archived\."
 545          ),
 546      ) as exception_context:
 547          store.transition_model_version_stage(
 548              mv1.name, mv1.version, stage="unknown", archive_existing_versions=False
 549          )
 550      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 551  
 552      # stages are case-insensitive and auto-corrected to system stage names
 553      for stage_name in ["STAGING", "staging", "StAgInG"]:
 554          store.transition_model_version_stage(
 555              name=mv1.name,
 556              version=mv1.version,
 557              stage=stage_name,
 558              archive_existing_versions=False,
 559          )
 560          mvd5 = store.get_model_version(name=mv1.name, version=mv1.version)
 561          assert mvd5.current_stage == "Staging"
 562  
 563  
 564  def test_transition_model_version_stage_when_archive_existing_versions_is_false(store):
 565      name = "model"
 566      store.create_registered_model(name)
 567      mv1 = _create_model_version(store, name)
 568      mv2 = _create_model_version(store, name)
 569      mv3 = _create_model_version(store, name)
 570  
 571      # test that when `archive_existing_versions` is False, transitioning a model version
 572      # to the inactive stages ("Archived" and "None") does not throw.
 573      for stage in ["Archived", "None"]:
 574          store.transition_model_version_stage(name, mv1.version, stage, False)
 575  
 576      store.transition_model_version_stage(name, mv1.version, "Staging", False)
 577      store.transition_model_version_stage(name, mv2.version, "Production", False)
 578      store.transition_model_version_stage(name, mv3.version, "Staging", False)
 579  
 580      mvd1 = store.get_model_version(name=name, version=mv1.version)
 581      mvd2 = store.get_model_version(name=name, version=mv2.version)
 582      mvd3 = store.get_model_version(name=name, version=mv3.version)
 583  
 584      assert mvd1.current_stage == "Staging"
 585      assert mvd2.current_stage == "Production"
 586      assert mvd3.current_stage == "Staging"
 587  
 588      store.transition_model_version_stage(name, mv3.version, "Production", False)
 589  
 590      mvd1 = store.get_model_version(name=name, version=mv1.version)
 591      mvd2 = store.get_model_version(name=name, version=mv2.version)
 592      mvd3 = store.get_model_version(name=name, version=mv3.version)
 593  
 594      assert mvd1.current_stage == "Staging"
 595      assert mvd2.current_stage == "Production"
 596      assert mvd3.current_stage == "Production"
 597  
 598  
 599  def test_transition_model_version_stage_when_archive_existing_versions_is_true(store):
 600      name = "model"
 601      store.create_registered_model(name)
 602      mv1 = _create_model_version(store, name)
 603      mv2 = _create_model_version(store, name)
 604      mv3 = _create_model_version(store, name)
 605  
 606      msg = (
 607          r"Model version transition cannot archive existing model versions "
 608          r"because .+ is not an Active stage"
 609      )
 610  
 611      # test that when `archive_existing_versions` is True, transitioning a model version
 612      # to the inactive stages ("Archived" and "None") throws.
 613      for stage in ["Archived", "None"]:
 614          with pytest.raises(MlflowException, match=msg):
 615              store.transition_model_version_stage(name, mv1.version, stage, True)
 616  
 617      store.transition_model_version_stage(name, mv1.version, "Staging", False)
 618      store.transition_model_version_stage(name, mv2.version, "Production", False)
 619      store.transition_model_version_stage(name, mv3.version, "Staging", True)
 620  
 621      mvd1 = store.get_model_version(name=name, version=mv1.version)
 622      mvd2 = store.get_model_version(name=name, version=mv2.version)
 623      mvd3 = store.get_model_version(name=name, version=mv3.version)
 624  
 625      assert mvd1.current_stage == "Archived"
 626      assert mvd2.current_stage == "Production"
 627      assert mvd3.current_stage == "Staging"
 628      assert mvd1.last_updated_timestamp == mvd3.last_updated_timestamp
 629  
 630      store.transition_model_version_stage(name, mv3.version, "Production", True)
 631  
 632      mvd1 = store.get_model_version(name=name, version=mv1.version)
 633      mvd2 = store.get_model_version(name=name, version=mv2.version)
 634      mvd3 = store.get_model_version(name=name, version=mv3.version)
 635  
 636      assert mvd1.current_stage == "Archived"
 637      assert mvd2.current_stage == "Archived"
 638      assert mvd3.current_stage == "Production"
 639      assert mvd2.last_updated_timestamp == mvd3.last_updated_timestamp
 640  
 641      for uncanonical_stage_name in ["STAGING", "staging", "StAgInG"]:
 642          store.transition_model_version_stage(mv1.name, mv1.version, "Staging", False)
 643          store.transition_model_version_stage(mv2.name, mv2.version, "None", False)
 644  
 645          # stage names are case-insensitive and auto-corrected to system stage names
 646          store.transition_model_version_stage(mv2.name, mv2.version, uncanonical_stage_name, True)
 647  
 648          mvd1 = store.get_model_version(name=mv1.name, version=mv1.version)
 649          mvd2 = store.get_model_version(name=mv2.name, version=mv2.version)
 650          assert mvd1.current_stage == "Archived"
 651          assert mvd2.current_stage == "Staging"
 652  
 653  
 654  def test_delete_model_version(store):
 655      name = "test_for_delete_MV"
 656      initial_tags = [
 657          ModelVersionTag("key", "value"),
 658          ModelVersionTag("anotherKey", "some other value"),
 659      ]
 660      store.create_registered_model(name)
 661      mv = _create_model_version(store, name, tags=initial_tags)
 662      mvd = store.get_model_version(name=mv.name, version=mv.version)
 663      assert mvd.name == name
 664  
 665      store.delete_model_version(name=mv.name, version=mv.version)
 666  
 667      # cannot get a deleted model version
 668      with pytest.raises(
 669          MlflowException,
 670          match=rf"Model Version \(name={mv.name}, version={mv.version}\) not found",
 671      ) as exception_context:
 672          store.get_model_version(name=mv.name, version=mv.version)
 673      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
 674  
 675      # cannot update a delete
 676      with pytest.raises(
 677          MlflowException,
 678          match=rf"Model Version \(name={mv.name}, version={mv.version}\) not found",
 679      ) as exception_context:
 680          store.update_model_version(mv.name, mv.version, description="deleted!")
 681      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
 682  
 683      # cannot delete it again
 684      with pytest.raises(
 685          MlflowException,
 686          match=rf"Model Version \(name={mv.name}, version={mv.version}\) not found",
 687      ) as exception_context:
 688          store.delete_model_version(name=mv.name, version=mv.version)
 689      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
 690  
 691  
 692  def _search_model_versions(fs, filter_string=None, max_results=10, order_by=None, page_token=None):
 693      return fs.search_model_versions(
 694          filter_string=filter_string,
 695          max_results=max_results,
 696          order_by=order_by,
 697          page_token=page_token,
 698      )
 699  
 700  
 701  def test_search_model_versions(store):
 702      # create some model versions
 703      name = "test_for_search_MV"
 704      store.create_registered_model(name)
 705      run_id_1 = uuid.uuid4().hex
 706      run_id_2 = uuid.uuid4().hex
 707      run_id_3 = uuid.uuid4().hex
 708      mv1 = _create_model_version(store, name=name, source="A/B", run_id=run_id_1)
 709      assert mv1.version == 1
 710      mv2 = _create_model_version(store, name=name, source="A/C", run_id=run_id_2)
 711      assert mv2.version == 2
 712      mv3 = _create_model_version(store, name=name, source="A/D", run_id=run_id_2)
 713      assert mv3.version == 3
 714      mv4 = _create_model_version(store, name=name, source="A/D", run_id=run_id_3)
 715      assert mv4.version == 4
 716  
 717      def search_versions(filter_string):
 718          return [mvd.version for mvd in _search_model_versions(store, filter_string)]
 719  
 720      # search using name should return all 4 versions
 721      assert set(search_versions(f"name='{name}'")) == {1, 2, 3, 4}
 722  
 723      # search using version
 724      assert set(search_versions("version_number=2")) == {2}
 725      assert set(search_versions("version_number<=3")) == {1, 2, 3}
 726  
 727      # search using run_id_1 should return version 1
 728      assert set(search_versions(f"run_id='{run_id_1}'")) == {1}
 729  
 730      # search using run_id_2 should return versions 2 and 3
 731      assert set(search_versions(f"run_id='{run_id_2}'")) == {2, 3}
 732  
 733      # search using the IN operator should return all versions
 734      assert set(search_versions(f"run_id IN ('{run_id_1}','{run_id_2}')")) == {1, 2, 3}
 735  
 736      # search IN operator is case sensitive
 737      assert set(search_versions(f"run_id IN ('{run_id_1.upper()}','{run_id_2}')")) == {
 738          2,
 739          3,
 740      }
 741  
 742      # search IN operator with right-hand side value containing whitespaces
 743      assert set(search_versions(f"run_id IN ('{run_id_1}', '{run_id_2}')")) == {1, 2, 3}
 744  
 745      # search IN operator with other conditions
 746      assert set(
 747          search_versions(f"version_number=2 AND run_id IN ('{run_id_1.upper()}','{run_id_2}')")
 748      ) == {2}
 749  
 750      # search using the IN operator with bad lists should return exceptions
 751      with pytest.raises(
 752          MlflowException,
 753          match=(
 754              r"While parsing a list in the query, "
 755              r"expected string value, punctuation, or whitespace, "
 756              r"but got different type in list"
 757          ),
 758      ) as exception_context:
 759          search_versions("run_id IN (1,2,3)")
 760      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 761  
 762      assert set(search_versions(f"run_id LIKE '{run_id_2[:30]}%'")) == {2, 3}
 763  
 764      assert set(search_versions(f"run_id ILIKE '{run_id_2[:30].upper()}%'")) == {2, 3}
 765  
 766      # search using the IN operator with empty lists should return exceptions
 767      with pytest.raises(
 768          MlflowException,
 769          match=(
 770              r"While parsing a list in the query, "
 771              r"expected a non-empty list of string values, "
 772              r"but got empty list"
 773          ),
 774      ) as exception_context:
 775          search_versions("run_id IN ()")
 776      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 777  
 778      # search using an ill-formed IN operator correctly throws exception
 779      with pytest.raises(
 780          MlflowException, match=r"Invalid clause\(s\) in filter string"
 781      ) as exception_context:
 782          search_versions("run_id IN (")
 783      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 784  
 785      with pytest.raises(
 786          MlflowException, match=r"Invalid clause\(s\) in filter string"
 787      ) as exception_context:
 788          search_versions("run_id IN")
 789      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 790  
 791      with pytest.raises(
 792          MlflowException, match=r"Invalid clause\(s\) in filter string"
 793      ) as exception_context:
 794          search_versions("name LIKE")
 795      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 796  
 797      with pytest.raises(
 798          MlflowException,
 799          match=(
 800              r"While parsing a list in the query, "
 801              r"expected a non-empty list of string values, "
 802              r"but got ill-formed list"
 803          ),
 804      ) as exception_context:
 805          search_versions("run_id IN (,)")
 806      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 807  
 808      with pytest.raises(
 809          MlflowException,
 810          match=(
 811              r"While parsing a list in the query, "
 812              r"expected a non-empty list of string values, "
 813              r"but got ill-formed list"
 814          ),
 815      ) as exception_context:
 816          search_versions("run_id IN ('runid1',,'runid2')")
 817      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 818  
 819      # search using source_path "A/D" should return version 3 and 4
 820      assert set(search_versions("source_path = 'A/D'")) == {3, 4}
 821  
 822      # search using source_path "A" should not return anything
 823      assert len(search_versions("source_path = 'A'")) == 0
 824      assert len(search_versions("source_path = 'A/'")) == 0
 825      assert len(search_versions("source_path = ''")) == 0
 826  
 827      # delete mv4. search should not return version 4
 828      store.delete_model_version(name=mv4.name, version=mv4.version)
 829      assert set(search_versions("")) == {1, 2, 3}
 830  
 831      assert set(search_versions(None)) == {1, 2, 3}
 832  
 833      assert set(search_versions(f"name='{name}'")) == {1, 2, 3}
 834  
 835      store.transition_model_version_stage(
 836          name=mv1.name,
 837          version=mv1.version,
 838          stage="production",
 839          archive_existing_versions=False,
 840      )
 841  
 842      store.update_model_version(
 843          name=mv1.name, version=mv1.version, description="Online prediction model!"
 844      )
 845  
 846      mvds = store.search_model_versions(f"run_id = '{run_id_1}'", max_results=10)
 847      assert len(mvds) == 1
 848      assert isinstance(mvds[0], ModelVersion)
 849      assert mvds[0].current_stage == "Production"
 850      assert mvds[0].run_id == run_id_1
 851      assert mvds[0].source == "A/B"
 852      assert mvds[0].description == "Online prediction model!"
 853  
 854  
 855  def test_search_model_versions_order_by_simple(store):
 856      # create some model versions
 857      names = ["RM1", "RM2", "RM3", "RM4", "RM1", "RM4"]
 858      sources = ["A"] * 3 + ["B"] * 3
 859      run_ids = [uuid.uuid4().hex for _ in range(6)]
 860      for name in set(names):
 861          store.create_registered_model(name)
 862      for i in range(6):
 863          _create_model_version(store, name=names[i], source=sources[i], run_id=run_ids[i])
 864          time.sleep(0.001)  # sleep for windows fs timestamp precision issues
 865  
 866      # by default order by last_updated_timestamp DESC
 867      mvs = _search_model_versions(store).to_list()
 868      assert [mv.name for mv in mvs] == names[::-1]
 869      assert [mv.version for mv in mvs] == [2, 2, 1, 1, 1, 1]
 870  
 871      # order by name DESC
 872      mvs = _search_model_versions(store, order_by=["name DESC"])
 873      assert [mv.name for mv in mvs] == sorted(names)[::-1]
 874      assert [mv.version for mv in mvs] == [2, 1, 1, 1, 2, 1]
 875  
 876      # order by version DESC
 877      mvs = _search_model_versions(store, order_by=["version_number DESC"])
 878      assert [mv.name for mv in mvs] == ["RM1", "RM4", "RM1", "RM2", "RM3", "RM4"]
 879      assert [mv.version for mv in mvs] == [2, 2, 1, 1, 1, 1]
 880  
 881      # order by creation_timestamp DESC
 882      mvs = _search_model_versions(store, order_by=["creation_timestamp DESC"])
 883      assert [mv.name for mv in mvs] == names[::-1]
 884      assert [mv.version for mv in mvs] == [2, 2, 1, 1, 1, 1]
 885  
 886      # order by last_updated_timestamp ASC
 887      store.update_model_version(names[0], 1, "latest updated")
 888      mvs = _search_model_versions(store, order_by=["last_updated_timestamp ASC"])
 889      assert mvs[-1].name == names[0]
 890      assert mvs[-1].version == 1
 891  
 892  
 893  def test_search_model_versions_pagination(store):
 894      def search_versions(filter_string, page_token=None, max_results=10):
 895          result = _search_model_versions(
 896              store,
 897              filter_string=filter_string,
 898              page_token=page_token,
 899              max_results=max_results,
 900          )
 901          return result.to_list(), result.token
 902  
 903      name = "test_for_search_MV_pagination"
 904      store.create_registered_model(name)
 905      mvs = [_create_model_version(store, name) for _ in range(50)][::-1]
 906  
 907      # test flow with fixed max_results
 908      returned_mvs = []
 909      query = "name LIKE 'test_for_search_MV_pagination%'"
 910      result, token = search_versions(query, page_token=None, max_results=5)
 911      returned_mvs.extend(result)
 912      while token:
 913          result, token = search_versions(query, page_token=token, max_results=5)
 914          returned_mvs.extend(result)
 915      assert mvs == returned_mvs
 916  
 917      # test that pagination will return all valid results in sorted order
 918      # by name ascending
 919      result, token1 = search_versions(query, max_results=5)
 920      assert token1 is not None
 921      assert result == mvs[0:5]
 922  
 923      result, token2 = search_versions(query, page_token=token1, max_results=10)
 924      assert token2 is not None
 925      assert result == mvs[5:15]
 926  
 927      result, token3 = search_versions(query, page_token=token2, max_results=20)
 928      assert token3 is not None
 929      assert result == mvs[15:35]
 930  
 931      result, token4 = search_versions(query, page_token=token3, max_results=100)
 932      # assert that page token is None
 933      assert token4 is None
 934      assert result == mvs[35:]
 935  
 936      # test that providing a completely invalid page token throws
 937      with pytest.raises(
 938          MlflowException, match=r"Invalid page token, could not base64-decode"
 939      ) as exception_context:
 940          search_versions(query, page_token="evilhax", max_results=20)
 941      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 942  
 943      # test that providing too large of a max_results throws
 944      with pytest.raises(
 945          MlflowException, match=r"Invalid value for max_results."
 946      ) as exception_context:
 947          search_versions(query, page_token="evilhax", max_results=1e15)
 948      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 949  
 950  
 951  def test_search_model_versions_by_tag(store):
 952      # create some model versions
 953      name = "test_for_search_MV_by_tag"
 954      store.create_registered_model(name)
 955      run_id_1 = uuid.uuid4().hex
 956      run_id_2 = uuid.uuid4().hex
 957  
 958      mv1 = _create_model_version(
 959          store,
 960          name=name,
 961          source="A/B",
 962          run_id=run_id_1,
 963          tags=[ModelVersionTag("t1", "abc"), ModelVersionTag("t2", "xyz")],
 964      )
 965      assert mv1.version == 1
 966      mv2 = _create_model_version(
 967          store,
 968          name=name,
 969          source="A/C",
 970          run_id=run_id_2,
 971          tags=[ModelVersionTag("t1", "abc"), ModelVersionTag("t2", "x123")],
 972      )
 973      assert mv2.version == 2
 974  
 975      def search_versions(filter_string):
 976          return [mvd.version for mvd in _search_model_versions(store, filter_string)]
 977  
 978      assert search_versions(f"name = '{name}' and tag.t2 = 'xyz'") == [1]
 979      assert search_versions("name = 'wrong_name' and tag.t2 = 'xyz'") == []
 980      assert search_versions("tag.`t2` = 'xyz'") == [1]
 981      assert search_versions("tag.t3 = 'xyz'") == []
 982      assert set(search_versions("tag.t2 != 'xy'")) == {2, 1}
 983      assert search_versions("tag.t2 LIKE 'xy%'") == [1]
 984      assert search_versions("tag.t2 LIKE 'xY%'") == []
 985      assert search_versions("tag.t2 ILIKE 'xY%'") == [1]
 986      assert set(search_versions("tag.t2 LIKE 'x%'")) == {2, 1}
 987      assert search_versions("tag.T2 = 'xyz'") == []
 988      assert search_versions("tag.t1 = 'abc' and tag.t2 = 'xyz'") == [1]
 989      assert set(search_versions("tag.t1 = 'abc' and tag.t2 LIKE 'x%'")) == {2, 1}
 990      assert search_versions("tag.t1 = 'abc' and tag.t2 LIKE 'y%'") == []
 991      # test filter with duplicated keys
 992      assert search_versions("tag.t2 like 'x%' and tag.t2 != 'xyz'") == [2]
 993  
 994  
 995  class SearchRegisteredModelsResult(NamedTuple):
 996      names: list[str]
 997      token: str
 998  
 999  
1000  def _search_registered_models(
1001      store, filter_string=None, max_results=10, order_by=None, page_token=None
1002  ):
1003      result = store.search_registered_models(
1004          filter_string=filter_string,
1005          max_results=max_results,
1006          order_by=order_by,
1007          page_token=page_token,
1008      )
1009      return SearchRegisteredModelsResult(
1010          names=[registered_model.name for registered_model in result],
1011          token=result.token,
1012      )
1013  
1014  
1015  def test_search_registered_models(store):
1016      # create some registered models
1017      prefix = "test_for_search_"
1018      names = [prefix + name for name in ["RM1", "RM2", "RM3", "RM4", "RM4A", "RM4ab"]]
1019      for name in names:
1020          store.create_registered_model(name)
1021  
1022      # search with no filter should return all registered models
1023      res = _search_registered_models(store, None)
1024      assert res.names == names
1025  
1026      # equality search using name should return exactly the 1 name
1027      res = _search_registered_models(store, f"name='{names[0]}'")
1028      assert res.names == [names[0]]
1029  
1030      # equality search using name that is not valid should return nothing
1031      res = _search_registered_models(store, f"name='{names[0]}cats'")
1032      assert res.names == []
1033  
1034      # case-sensitive prefix search using LIKE should return all the RMs
1035      res = _search_registered_models(store, f"name LIKE '{prefix}%'")
1036      assert res.names == names
1037  
1038      # case-sensitive prefix search using LIKE with surrounding % should return all the RMs
1039      res = _search_registered_models(store, "name LIKE '%RM%'")
1040      assert res.names == names
1041  
1042      # case-sensitive prefix search using LIKE with surrounding % should return all the RMs
1043      # _e% matches test_for_search_ , so all RMs should match
1044      res = _search_registered_models(store, "name LIKE '_e%'")
1045      assert res.names == names
1046  
1047      # case-sensitive prefix search using LIKE should return just rm4
1048      res = _search_registered_models(store, f"name LIKE '{prefix}RM4A%'")
1049      assert res.names == [names[4]]
1050  
1051      # case-sensitive prefix search using LIKE should return no models if no match
1052      res = _search_registered_models(store, f"name LIKE '{prefix}cats%'")
1053      assert res.names == []
1054  
1055      # confirm that LIKE is not case-sensitive
1056      res = _search_registered_models(store, "name lIkE '%blah%'")
1057      assert res.names == []
1058  
1059      res = _search_registered_models(store, f"name like '{prefix}RM4A%'")
1060      assert res.names == [names[4]]
1061  
1062      # case-insensitive prefix search using ILIKE should return both rm5 and rm6
1063      res = _search_registered_models(store, f"name ILIKE '{prefix}RM4A%'")
1064      assert res.names == names[4:]
1065  
1066      # case-insensitive postfix search with ILIKE
1067      res = _search_registered_models(store, "name ILIKE '%RM4a%'")
1068      assert res.names == names[4:]
1069  
1070      # case-insensitive prefix search using ILIKE should return both rm5 and rm6
1071      res = _search_registered_models(store, f"name ILIKE '{prefix}cats%'")
1072      assert res.names == []
1073  
1074      # confirm that ILIKE is not case-sensitive
1075      res = _search_registered_models(store, "name iLike '%blah%'")
1076      assert res.names == []
1077  
1078      # confirm that ILIKE works for empty query
1079      res = _search_registered_models(store, "name iLike '%%'")
1080      assert res.names == names
1081  
1082      res = _search_registered_models(store, "name ilike '%RM4a%'")
1083      assert res.names == names[4:]
1084  
1085      # cannot search by invalid comparator types
1086      with pytest.raises(
1087          MlflowException,
1088          match="Parameter value is either not quoted or unidentified quote types used for "
1089          "string value something",
1090      ) as exception_context:
1091          _search_registered_models(store, "name!=something")
1092      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1093  
1094      # cannot search by run_id
1095      with pytest.raises(
1096          MlflowException,
1097          match=r"Invalid attribute key 'run_id' specified. Valid keys are '{'name'}'",
1098      ) as exception_context:
1099          _search_registered_models(store, "run_id='somerunID'")
1100      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1101  
1102      # cannot search by source_path
1103      with pytest.raises(
1104          MlflowException,
1105          match=r"Invalid attribute key 'source_path' specified\. Valid keys are '{'name'}'",
1106      ) as exception_context:
1107          _search_registered_models(store, "source_path = 'A/D'")
1108      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1109  
1110      # cannot search by other params
1111      with pytest.raises(
1112          MlflowException, match=r"Invalid clause\(s\) in filter string"
1113      ) as exception_context:
1114          _search_registered_models(store, "evilhax = true")
1115      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1116  
1117      # delete last registered model. search should not return the first 5
1118      store.delete_registered_model(name=names[-1])
1119      res = _search_registered_models(store, None, max_results=1000)
1120      assert res.names == names[:-1]
1121      assert res.token is None
1122  
1123      # equality search using name should return no names
1124      assert _search_registered_models(store, f"name='{names[-1]}'") == ([], None)
1125  
1126      # case-sensitive prefix search using LIKE should return all the RMs
1127      res = _search_registered_models(store, f"name LIKE '{prefix}%'")
1128      assert res.names == names[0:5]
1129      assert res.token is None
1130  
1131      # case-insensitive prefix search using ILIKE should return both rm5 and rm6
1132      res = _search_registered_models(store, f"name ILIKE '{prefix}RM4A%'")
1133      assert res.names == [names[4]]
1134      assert res.token is None
1135  
1136  
1137  def test_search_registered_models_by_tag(store):
1138      name1 = "test_for_search_RM_by_tag1"
1139      name2 = "test_for_search_RM_by_tag2"
1140      tags1 = [
1141          RegisteredModelTag("t1", "abc"),
1142          RegisteredModelTag("t2", "xyz"),
1143      ]
1144      tags2 = [
1145          RegisteredModelTag("t1", "abcd"),
1146          RegisteredModelTag("t2", "xyz123"),
1147          RegisteredModelTag("t3", "XYZ"),
1148      ]
1149      store.create_registered_model(name1, tags1)
1150      store.create_registered_model(name2, tags2)
1151  
1152      res = _search_registered_models(store, "tag.t3 = 'XYZ'")
1153      assert res.names == [name2]
1154  
1155      res = _search_registered_models(store, f"name = '{name1}' and tag.t1 = 'abc'")
1156      assert res.names == [name1]
1157  
1158      res = _search_registered_models(store, "tag.t1 LIKE 'ab%'")
1159      assert res.names == [name1, name2]
1160  
1161      res = _search_registered_models(store, "tag.t1 ILIKE 'aB%'")
1162      assert res.names == [name1, name2]
1163  
1164      res = _search_registered_models(store, "tag.t1 LIKE 'ab%' AND tag.t2 LIKE 'xy%'")
1165      assert res.names == [name1, name2]
1166  
1167      res = _search_registered_models(store, "tag.t3 = 'XYz'")
1168      assert res.names == []
1169  
1170      res = _search_registered_models(store, "tag.T3 = 'XYZ'")
1171      assert res.names == []
1172  
1173      res = _search_registered_models(store, "tag.t1 != 'abc'")
1174      assert res.names == [name2]
1175  
1176      # test filter with duplicated keys
1177      res = _search_registered_models(store, "tag.t1 != 'abcd' and tag.t1 LIKE 'ab%'")
1178      assert res.names == [name1]
1179  
1180  
1181  def test_search_registered_models_order_by_simple(store):
1182      # create some registered models
1183      names = ["RM1", "RM2", "RM3", "RM4", "RM4A", "RM4ab"]
1184      for name in names:
1185          store.create_registered_model(name)
1186          time.sleep(0.001)  # sleep for windows store timestamp precision issues
1187  
1188      # by default order by name ASC
1189      res = _search_registered_models(store)
1190      assert res.names == names
1191  
1192      # order by name DESC
1193      res = _search_registered_models(store, order_by=["name DESC"])
1194      assert res.names == names[::-1]
1195  
1196      # order by last_updated_timestamp ASC
1197      store.update_registered_model(names[0], "latest updated")
1198      res = _search_registered_models(store, order_by=["last_updated_timestamp ASC"])
1199      assert res.names[-1] == names[0]
1200  
1201  
1202  def test_search_registered_model_pagination(store):
1203      rms = [store.create_registered_model(f"RM{i:03}").name for i in range(50)]
1204  
1205      # test flow with fixed max_results
1206      returned_rms = []
1207      query = "name LIKE 'RM%'"
1208      res = _search_registered_models(store, query, page_token=None, max_results=5)
1209      returned_rms.extend(res.names)
1210      while res.token:
1211          res = _search_registered_models(store, query, page_token=res.token, max_results=5)
1212          returned_rms.extend(res.names)
1213      assert returned_rms == rms
1214  
1215      # test that pagination will return all valid results in sorted order
1216      # by name ascending
1217      res = _search_registered_models(store, query, max_results=5)
1218      assert res.token is not None
1219      assert res.names == rms[0:5]
1220  
1221      res = _search_registered_models(store, query, page_token=res.token, max_results=10)
1222      assert res.token is not None
1223      assert res.names == rms[5:15]
1224  
1225      res = _search_registered_models(store, query, page_token=res.token, max_results=20)
1226      assert res.token is not None
1227      assert res.names == rms[15:35]
1228  
1229      res = _search_registered_models(store, query, page_token=res.token, max_results=100)
1230      # assert that page token is None
1231      assert res.token is None
1232      assert res.names == rms[35:]
1233  
1234      # test that providing a completely invalid page token throws
1235      with pytest.raises(
1236          MlflowException, match=r"Invalid page token, could not base64-decode"
1237      ) as exception_context:
1238          _search_registered_models(store, query, page_token="evilhax", max_results=20)
1239      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1240  
1241      # test that providing too large of a max_results throws
1242      with pytest.raises(
1243          MlflowException, match=r"Invalid value for max_results."
1244      ) as exception_context:
1245          _search_registered_models(store, query, page_token="evilhax", max_results=1e15)
1246      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1247  
1248  
1249  def test_search_registered_model_order_by(store):
1250      rms = []
1251      # explicitly mock the creation_timestamps because timestamps seem to be unstable in Windows
1252      for i in range(50):
1253          rms.append(store.create_registered_model(f"RM{i:03}").name)
1254          time.sleep(0.01)
1255  
1256      # test flow with fixed max_results and order_by (test stable order across pages)
1257      returned_rms = []
1258      query = "name LIKE 'RM%'"
1259      result, token = _search_registered_models(
1260          store, query, page_token=None, order_by=["name DESC"], max_results=5
1261      )
1262      returned_rms.extend(result)
1263      while token:
1264          result, token = _search_registered_models(
1265              store, query, page_token=token, order_by=["name DESC"], max_results=5
1266          )
1267          returned_rms.extend(result)
1268      # name descending should be the opposite order of the current order
1269      assert returned_rms == rms[::-1]
1270      # last_updated_timestamp descending should have the newest RMs first
1271      res = _search_registered_models(
1272          store,
1273          query,
1274          page_token=None,
1275          order_by=["last_updated_timestamp DESC"],
1276          max_results=100,
1277      )
1278      assert res.names == rms[::-1]
1279      # last_updated_timestamp ascending should have the oldest RMs first
1280      res = _search_registered_models(
1281          store,
1282          query,
1283          page_token=None,
1284          order_by=["last_updated_timestamp ASC"],
1285          max_results=100,
1286      )
1287      assert res.names == rms
1288      # name ascending should have the original order
1289      res = _search_registered_models(
1290          store, query, page_token=None, order_by=["name ASC"], max_results=100
1291      )
1292      assert res.names == rms
1293      # test that no ASC/DESC defaults to ASC
1294      res = _search_registered_models(
1295          store,
1296          query,
1297          page_token=None,
1298          order_by=["last_updated_timestamp"],
1299          max_results=100,
1300      )
1301      assert res.names == rms
1302      with mock.patch(
1303          "mlflow.store.model_registry.file_store.get_current_time_millis", return_value=1
1304      ):
1305          rm1 = store.create_registered_model("MR1").name
1306          rm2 = store.create_registered_model("MR2").name
1307      with mock.patch(
1308          "mlflow.store.model_registry.file_store.get_current_time_millis", return_value=2
1309      ):
1310          rm3 = store.create_registered_model("MR3").name
1311          rm4 = store.create_registered_model("MR4").name
1312      query = "name LIKE 'MR%'"
1313      # test with multiple clauses
1314      res = _search_registered_models(
1315          store,
1316          query,
1317          page_token=None,
1318          order_by=["last_updated_timestamp ASC", "name DESC"],
1319          max_results=100,
1320      )
1321      assert res.names == [rm2, rm1, rm4, rm3]
1322      # confirm that name ascending is the default, even if ties exist on other fields
1323      res = _search_registered_models(store, query, page_token=None, order_by=[], max_results=100)
1324      assert res.names == [rm1, rm2, rm3, rm4]
1325      # test default tiebreak with descending timestamps
1326      res = _search_registered_models(
1327          store,
1328          query,
1329          page_token=None,
1330          order_by=["last_updated_timestamp DESC"],
1331          max_results=100,
1332      )
1333      assert res.names == [rm3, rm4, rm1, rm2]
1334  
1335  
1336  def test_search_registered_model_order_by_errors(store):
1337      store.create_registered_model("dummy")
1338      query = "name LIKE 'RM%'"
1339      # test that invalid columns throw even if they come after valid columns
1340      with pytest.raises(
1341          MlflowException, match="Invalid attribute key 'description' specified."
1342      ) as exception_context:
1343          _search_registered_models(
1344              store,
1345              query,
1346              page_token=None,
1347              order_by=["name ASC", "description DESC"],
1348              max_results=5,
1349          )
1350      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1351      # test that invalid columns with random text throw even if they come after valid columns
1352      with pytest.raises(MlflowException, match=r"Invalid order_by clause '.+'") as exception_context:
1353          _search_registered_models(
1354              store,
1355              query,
1356              page_token=None,
1357              order_by=["name ASC", "last_updated_timestamp DESC blah"],
1358              max_results=5,
1359          )
1360      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1361  
1362  
1363  def test_set_model_version_tag(store):
1364      name1 = "SetModelVersionTag_TestMod"
1365      name2 = "SetModelVersionTag_TestMod 2"
1366      initial_tags = [
1367          ModelVersionTag("key", "value"),
1368          ModelVersionTag("anotherKey", "some other value"),
1369      ]
1370      store.create_registered_model(name1)
1371      store.create_registered_model(name2)
1372      run_id_1 = uuid.uuid4().hex
1373      run_id_2 = uuid.uuid4().hex
1374      run_id_3 = uuid.uuid4().hex
1375      store.create_model_version(name1, "A/B", run_id_1, initial_tags)
1376      store.create_model_version(name1, "A/C", run_id_2, initial_tags)
1377      store.create_model_version(name2, "A/D", run_id_3, initial_tags)
1378      new_tag = ModelVersionTag("randomTag", "not a random value")
1379      store.set_model_version_tag(name1, 1, new_tag)
1380      all_tags = [*initial_tags, new_tag]
1381      rm1mv1 = store.get_model_version(name1, 1)
1382      assert rm1mv1.tags == {tag.key: tag.value for tag in all_tags}
1383  
1384      # test overriding a tag with the same key
1385      overriding_tag = ModelVersionTag("key", "overriding")
1386      store.set_model_version_tag(name1, 1, overriding_tag)
1387      all_tags = [tag for tag in all_tags if tag.key != "key"] + [overriding_tag]
1388      rm1mv1 = store.get_model_version(name1, 1)
1389      assert rm1mv1.tags == {tag.key: tag.value for tag in all_tags}
1390      # does not affect other model versions with the same key
1391      rm1mv2 = store.get_model_version(name1, 2)
1392      rm2mv1 = store.get_model_version(name2, 1)
1393      assert rm1mv2.tags == {tag.key: tag.value for tag in initial_tags}
1394      assert rm2mv1.tags == {tag.key: tag.value for tag in initial_tags}
1395  
1396      # can not set tag on deleted (non-existed) model version
1397      store.delete_model_version(name1, 2)
1398      with pytest.raises(
1399          MlflowException, match=rf"Model Version \(name={name1}, version=2\) not found"
1400      ) as exception_context:
1401          store.set_model_version_tag(name1, 2, overriding_tag)
1402      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
1403      # test cannot set tags that are too long
1404      long_tag = ModelVersionTag("longTagKey", "a" * 100_001)
1405      with pytest.raises(
1406          MlflowException,
1407          match=r"'value' exceeds the maximum length of \d+ characters",
1408      ) as exception_context:
1409          store.set_model_version_tag(name1, 1, long_tag)
1410      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1411      # test can set tags that are somewhat long
1412      long_tag = ModelVersionTag("longTagKey", "a" * 4999)
1413      store.set_model_version_tag(name1, 1, long_tag)
1414      # can not set invalid tag
1415      with pytest.raises(
1416          MlflowException, match=r"Missing value for required parameter 'key'"
1417      ) as exception_context:
1418          store.set_model_version_tag(name2, 1, ModelVersionTag(key=None, value=""))
1419      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1420      # can not use invalid model name or version
1421      with pytest.raises(
1422          MlflowException, match=r"Missing value for required parameter 'name'\."
1423      ) as exception_context:
1424          store.set_model_version_tag(None, 1, ModelVersionTag(key="key", value="value"))
1425      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1426      with pytest.raises(
1427          MlflowException,
1428          match=r"Parameter 'version' must be an integer, got 'I am not a version'.",
1429      ) as exception_context:
1430          store.set_model_version_tag(
1431              name2, "I am not a version", ModelVersionTag(key="key", value="value")
1432          )
1433      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1434  
1435  
1436  def test_delete_model_version_tag(store):
1437      name1 = "DeleteModelVersionTag_TestMod"
1438      name2 = "DeleteModelVersionTag_TestMod 2"
1439      initial_tags = [
1440          ModelVersionTag("key", "value"),
1441          ModelVersionTag("anotherKey", "some other value"),
1442      ]
1443      store.create_registered_model(name1)
1444      store.create_registered_model(name2)
1445      run_id_1 = uuid.uuid4().hex
1446      run_id_2 = uuid.uuid4().hex
1447      run_id_3 = uuid.uuid4().hex
1448      store.create_model_version(name1, "A/B", run_id_1, initial_tags)
1449      store.create_model_version(name1, "A/C", run_id_2, initial_tags)
1450      store.create_model_version(name2, "A/D", run_id_3, initial_tags)
1451      new_tag = ModelVersionTag("randomTag", "not a random value")
1452      store.set_model_version_tag(name1, 1, new_tag)
1453      store.delete_model_version_tag(name1, 1, "randomTag")
1454      rm1mv1 = store.get_model_version(name1, 1)
1455      assert rm1mv1.tags == {tag.key: tag.value for tag in initial_tags}
1456  
1457      # testing deleting a key does not affect other model versions with the same key
1458      store.delete_model_version_tag(name1, 1, "key")
1459      rm1mv1 = store.get_model_version(name1, 1)
1460      rm1mv2 = store.get_model_version(name1, 2)
1461      rm2mv1 = store.get_model_version(name2, 1)
1462      assert rm1mv1.tags == {"anotherKey": "some other value"}
1463      assert rm1mv2.tags == {tag.key: tag.value for tag in initial_tags}
1464      assert rm2mv1.tags == {tag.key: tag.value for tag in initial_tags}
1465  
1466      # delete tag that is already deleted does nothing
1467      store.delete_model_version_tag(name1, 1, "key")
1468      rm1mv1 = store.get_model_version(name1, 1)
1469      assert rm1mv1.tags == {"anotherKey": "some other value"}
1470  
1471      # can not delete tag on deleted (non-existed) model version
1472      store.delete_model_version(name2, 1)
1473      with pytest.raises(
1474          MlflowException, match=rf"Model Version \(name={name2}, version=1\) not found"
1475      ) as exception_context:
1476          store.delete_model_version_tag(name2, 1, "key")
1477      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
1478      # can not delete tag with invalid key
1479      with pytest.raises(
1480          MlflowException, match=r"Missing value for required parameter 'key'"
1481      ) as exception_context:
1482          store.delete_model_version_tag(name1, 2, None)
1483      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1484      # can not use invalid model name or version
1485      with pytest.raises(
1486          MlflowException, match=r"Missing value for required parameter 'name'\."
1487      ) as exception_context:
1488          store.delete_model_version_tag(None, 2, "key")
1489      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1490      with pytest.raises(
1491          MlflowException, match=r"Parameter 'version' must be an integer, got 'I am not a version'\."
1492      ) as exception_context:
1493          store.delete_model_version_tag(name1, "I am not a version", "key")
1494      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1495  
1496  
1497  def _setup_and_test_aliases(store, model_name):
1498      store.create_registered_model(model_name)
1499      run_id_1 = uuid.uuid4().hex
1500      run_id_2 = uuid.uuid4().hex
1501      store.create_model_version(model_name, "v1", run_id_1)
1502      store.create_model_version(model_name, "v2", run_id_2)
1503      store.set_registered_model_alias(model_name, "test_alias", "2")
1504      model = store.get_registered_model(model_name)
1505      assert model.aliases == {"test_alias": "2"}
1506      mv1 = store.get_model_version(model_name, 1)
1507      mv2 = store.get_model_version(model_name, 2)
1508      assert mv1.aliases == []
1509      assert mv2.aliases == ["test_alias"]
1510  
1511  
1512  def test_set_registered_model_alias(store):
1513      _setup_and_test_aliases(store, "SetRegisteredModelAlias_TestMod")
1514  
1515  
1516  def test_delete_registered_model_alias(store):
1517      model_name = "DeleteRegisteredModelAlias_TestMod"
1518      _setup_and_test_aliases(store, model_name)
1519      store.delete_registered_model_alias(model_name, "test_alias")
1520      model = store.get_registered_model(model_name)
1521      assert model.aliases == {}
1522      mv2 = store.get_model_version(model_name, 2)
1523      assert mv2.aliases == []
1524      with pytest.raises(MlflowException, match=r"Registered model alias test_alias not found."):
1525          store.get_model_version_by_alias(model_name, "test_alias")
1526  
1527  
1528  def test_get_model_version_by_alias(store):
1529      model_name = "GetModelVersionByAlias_TestMod"
1530      _setup_and_test_aliases(store, model_name)
1531      mv = store.get_model_version_by_alias(model_name, "test_alias")
1532      assert mv.aliases == ["test_alias"]
1533  
1534  
1535  def test_delete_model_version_deletes_alias(store):
1536      model_name = "DeleteModelVersionDeletesAlias_TestMod"
1537      _setup_and_test_aliases(store, model_name)
1538      store.delete_model_version(model_name, 2)
1539      model = store.get_registered_model(model_name)
1540      assert model.aliases == {}
1541      with pytest.raises(MlflowException, match=r"Registered model alias test_alias not found."):
1542          store.get_model_version_by_alias(model_name, "test_alias")
1543  
1544  
1545  def test_delete_model_deletes_alias(store):
1546      model_name = "DeleteModelDeletesAlias_TestMod"
1547      _setup_and_test_aliases(store, model_name)
1548      store.delete_registered_model(model_name)
1549      with pytest.raises(
1550          MlflowException,
1551          match=r"Registered Model with name=DeleteModelDeletesAlias_TestMod not found",
1552      ):
1553          store.get_model_version_by_alias(model_name, "test_alias")
1554  
1555  
1556  def test_pyfunc_model_registry_with_file_store(store):
1557      import mlflow
1558      from mlflow.pyfunc import PythonModel
1559  
1560      class MyModel(PythonModel):
1561          def predict(self, context, model_input, params=None):
1562              return 7
1563  
1564      mlflow.set_registry_uri(path_to_local_file_uri(store.root_directory))
1565      with mlflow.start_run():
1566          mlflow.pyfunc.log_model(name="foo", python_model=MyModel(), registered_model_name="model1")
1567          mlflow.pyfunc.log_model(name="foo", python_model=MyModel(), registered_model_name="model2")
1568          mlflow.pyfunc.log_model(
1569              name="model", python_model=MyModel(), registered_model_name="model1"
1570          )
1571  
1572      with mlflow.start_run():
1573          mlflow.log_param("A", "B")
1574  
1575          models = store.search_registered_models(max_results=10)
1576          assert len(models) == 2
1577          assert models[0].name == "model1"
1578          assert models[1].name == "model2"
1579          mv1 = store.search_model_versions("name = 'model1'", max_results=10)
1580          assert len(mv1) == 2
1581          assert mv1[0].name == "model1"
1582          mv2 = store.search_model_versions("name = 'model2'", max_results=10)
1583          assert len(mv2) == 1
1584          assert mv2[0].name == "model2"
1585  
1586  
1587  @pytest.mark.parametrize("copy_to_same_model", [False, True])
1588  def test_copy_model_version(store, copy_to_same_model):
1589      name1 = "test_for_copy_MV1"
1590      store.create_registered_model(name1)
1591      src_tags = [
1592          ModelVersionTag("key", "value"),
1593          ModelVersionTag("anotherKey", "some other value"),
1594      ]
1595      src_mv = _create_model_version(
1596          store,
1597          name1,
1598          tags=src_tags,
1599          run_link="dummylink",
1600          description="test description",
1601      )
1602  
1603      # Make some changes to the src MV that won't be copied over
1604      store.transition_model_version_stage(
1605          name1, src_mv.version, "Production", archive_existing_versions=False
1606      )
1607  
1608      copy_rm_name = name1 if copy_to_same_model else "test_for_copy_MV2"
1609      copy_mv_version = 2 if copy_to_same_model else 1
1610      timestamp = time.time()
1611      dst_mv = store.copy_model_version(src_mv, copy_rm_name)
1612      assert dst_mv.name == copy_rm_name
1613      assert dst_mv.version == copy_mv_version
1614  
1615      copied_mv = store.get_model_version(dst_mv.name, dst_mv.version)
1616      assert copied_mv.name == copy_rm_name
1617      assert copied_mv.version == copy_mv_version
1618      assert copied_mv.current_stage == "None"
1619      assert copied_mv.creation_timestamp >= timestamp
1620      assert copied_mv.last_updated_timestamp >= timestamp
1621      assert copied_mv.description == "test description"
1622      assert copied_mv.source == f"models:/{src_mv.name}/{src_mv.version}"
1623      assert store.get_model_version_download_uri(dst_mv.name, dst_mv.version) == src_mv.source
1624      assert copied_mv.run_link == "dummylink"
1625      assert copied_mv.run_id == src_mv.run_id
1626      assert copied_mv.status == "READY"
1627      assert copied_mv.status_message is None
1628      assert copied_mv.tags == {"key": "value", "anotherKey": "some other value"}
1629  
1630      # Copy a model version copy
1631      double_copy_mv = store.copy_model_version(copied_mv, "test_for_copy_MV3")
1632      assert double_copy_mv.source == f"models:/{copied_mv.name}/{copied_mv.version}"
1633      assert store.get_model_version_download_uri(dst_mv.name, dst_mv.version) == src_mv.source
1634  
1635  
1636  def test_writing_model_version_preserves_storage_location(store):
1637      name = "test_storage_location_MV1"
1638      source = "/special/source"
1639      store.create_registered_model(name)
1640      _create_model_version(store, name, source=source)
1641      _create_model_version(store, name, source=source)
1642  
1643      # Run through all the operations that modify model versions and make sure that the
1644      # `storage_location` property is not dropped.
1645      store.transition_model_version_stage(name, 1, "Production", archive_existing_versions=False)
1646      assert store._fetch_file_model_version_if_exists(name, 1).storage_location == source
1647      store.update_model_version(name, 1, description="test description")
1648      assert store._fetch_file_model_version_if_exists(name, 1).storage_location == source
1649      store.transition_model_version_stage(name, 1, "Production", archive_existing_versions=True)
1650      assert store._fetch_file_model_version_if_exists(name, 1).storage_location == source
1651      store.rename_registered_model(name, "test_storage_location_new")
1652      assert (
1653          store._fetch_file_model_version_if_exists("test_storage_location_new", 1).storage_location
1654          == source
1655      )
1656  
1657  
1658  def test_search_prompts(store):
1659      store.create_registered_model("model", tags=[RegisteredModelTag(key="fruit", value="apple")])
1660  
1661      store.create_registered_model(
1662          "prompt_1", tags=[RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")]
1663      )
1664      store.create_registered_model(
1665          "prompt_2",
1666          tags=[
1667              RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true"),
1668              RegisteredModelTag(key="fruit", value="apple"),
1669          ],
1670      )
1671  
1672      # By default, should not return prompts
1673      rms = store.search_registered_models(max_results=10)
1674      assert len(rms) == 1
1675      assert rms[0].name == "model"
1676  
1677      rms = store.search_registered_models(filter_string="tags.fruit = 'apple'", max_results=10)
1678      assert len(rms) == 1
1679      assert rms[0].name == "model"
1680  
1681      rms = store.search_registered_models(filter_string="name = 'prompt_1'", max_results=10)
1682      assert len(rms) == 0
1683  
1684      rms = store.search_registered_models(
1685          filter_string="tags.`mlflow.prompt.is_prompt` = 'false'", max_results=10
1686      )
1687      assert len(rms) == 1
1688      assert rms[0].name == "model"
1689  
1690      rms = store.search_registered_models(
1691          filter_string="tags.`mlflow.prompt.is_prompt` != 'true'", max_results=10
1692      )
1693      assert len(rms) == 1
1694      assert rms[0].name == "model"
1695  
1696      # Search for prompts
1697      rms = store.search_registered_models(
1698          filter_string="tags.`mlflow.prompt.is_prompt` = 'true'", max_results=10
1699      )
1700      assert len(rms) == 2
1701      assert {rm.name for rm in rms} == {"prompt_1", "prompt_2"}
1702  
1703      rms = store.search_registered_models(
1704          filter_string="name = 'prompt_1' and tags.`mlflow.prompt.is_prompt` = 'true'",
1705          max_results=10,
1706      )
1707      assert len(rms) == 1
1708      assert rms[0].name == "prompt_1"
1709  
1710      rms = store.search_registered_models(
1711          filter_string="tags.`mlflow.prompt.is_prompt` = 'true' and tags.fruit = 'apple'",
1712          max_results=10,
1713      )
1714      assert len(rms) == 1
1715      assert rms[0].name == "prompt_2"
1716  
1717  
1718  def test_search_prompts_versions(store):
1719      # A Model
1720      store.create_registered_model("model")
1721      store.create_model_version(
1722          "model", "1", "dummy_source", tags=[ModelVersionTag(key="fruit", value="apple")]
1723      )
1724  
1725      # A Prompt with 1 version
1726      store.create_registered_model(
1727          "prompt_1", tags=[RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")]
1728      )
1729      store.create_model_version(
1730          "prompt_1", "1", "dummy_source", tags=[ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true")]
1731      )
1732  
1733      # A Prompt with 2 versions
1734      store.create_registered_model(
1735          "prompt_2",
1736          tags=[RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")],
1737      )
1738      store.create_model_version(
1739          "prompt_2",
1740          "1",
1741          "dummy_source",
1742          tags=[
1743              ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true"),
1744              ModelVersionTag(key="fruit", value="apple"),
1745          ],
1746      )
1747      store.create_model_version(
1748          "prompt_2",
1749          "2",
1750          "dummy_source",
1751          tags=[
1752              ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true"),
1753              ModelVersionTag(key="fruit", value="orange"),
1754          ],
1755      )
1756  
1757      # Searching model versions should not return prompts by default either
1758      mvs = store.search_model_versions(max_results=10)
1759      assert len(mvs) == 1
1760      assert mvs[0].name == "model"
1761  
1762      mvs = store.search_model_versions(filter_string="tags.fruit = 'apple'", max_results=10)
1763      assert len(mvs) == 1
1764      assert mvs[0].name == "model"
1765  
1766      mvs = store.search_model_versions(
1767          filter_string="tags.`mlflow.prompt.is_prompt` = 'false'", max_results=10
1768      )
1769      assert len(mvs) == 1
1770      assert mvs[0].name == "model"
1771  
1772      mvs = store.search_model_versions(
1773          filter_string="tags.`mlflow.prompt.is_prompt` != 'true'", max_results=10
1774      )
1775      assert len(mvs) == 1
1776      assert mvs[0].name == "model"
1777  
1778      # Search for prompts via search_model_versions
1779      mvs = store.search_model_versions(
1780          filter_string="tags.`mlflow.prompt.is_prompt` = 'true'", max_results=10
1781      )
1782      assert len(mvs) == 3
1783  
1784      mvs = store.search_model_versions(
1785          filter_string="tags.`mlflow.prompt.is_prompt` = 'true' and name = 'prompt_2'",
1786          max_results=10,
1787      )
1788      assert len(mvs) == 2
1789  
1790      mvs = store.search_model_versions(
1791          filter_string="tags.`mlflow.prompt.is_prompt` = 'true' and tags.fruit = 'apple'",
1792          max_results=10,
1793      )
1794      assert len(mvs) == 1
1795      assert mvs[0].name == "prompt_2"
1796  
1797  
1798  def test_create_registered_model_handle_prompt_properly(store):
1799      prompt_tags = [RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")]
1800  
1801      store.create_registered_model("model")
1802  
1803      store.create_registered_model("prompt", tags=prompt_tags)
1804  
1805      with pytest.raises(MlflowException, match=r"Registered Model \(name=model\) already exists"):
1806          store.create_registered_model("model")
1807  
1808      with pytest.raises(MlflowException, match=r"Prompt \(name=prompt\) already exists"):
1809          store.create_registered_model("prompt", tags=prompt_tags)
1810  
1811      with pytest.raises(
1812          MlflowException,
1813          match=r"Tried to create a prompt with name 'model', "
1814          r"but the name is already taken by a registered model.",
1815      ):
1816          store.create_registered_model("model", tags=prompt_tags)
1817  
1818      with pytest.raises(
1819          MlflowException,
1820          match=r"Tried to create a registered model with name 'prompt', "
1821          r"but the name is already taken by a prompt.",
1822      ):
1823          store.create_registered_model("prompt")
1824  
1825  
1826  def test_create_model_version_with_model_id_and_no_run_id(store: FileStore):
1827      class SimpleModel(PythonModel):
1828          def predict(self, context, model_input, params=None):
1829              return model_input
1830  
1831      name = "test_model_with_model_id"
1832      store.create_registered_model(name)
1833  
1834      with mlflow.start_run() as run:
1835          model_info = mlflow.pyfunc.log_model(
1836              name="model",
1837              python_model=SimpleModel(),
1838          )
1839          run_id = run.info.run_id
1840          model_id = model_info.model_id
1841  
1842      mv = store.create_model_version(
1843          name=name,
1844          source="/absolute/path/to/source",
1845          run_id=None,
1846          model_id=model_id,
1847      )
1848  
1849      assert mv.run_id == run_id
1850  
1851      mvd = store.get_model_version(name=mv.name, version=mv.version)
1852      assert mvd.run_id == run_id
1853  
1854  
1855  def test_update_model_version_with_model_id_and_metrics(store: FileStore):
1856      class SimpleModel(PythonModel):
1857          def predict(self, context, model_input, params=None):
1858              return model_input
1859  
1860      name = "test_model_with_model_id_and_metrics"
1861      store.create_registered_model(name)
1862  
1863      with mlflow.start_run() as run:
1864          mlflow.log_param("learning_rate", "0.001")
1865          model_info = mlflow.pyfunc.log_model(
1866              name="model",
1867              python_model=SimpleModel(),
1868          )
1869          mlflow.log_metric("execution_time", 33.74682, step=0, model_id=model_info.model_id)
1870          run_id = run.info.run_id
1871          model_id = model_info.model_id
1872  
1873      mv = store.create_model_version(
1874          name=name,
1875          source="/absolute/path/to/source",
1876          run_id=None,
1877          model_id=model_id,
1878      )
1879  
1880      assert mv.run_id == run_id
1881  
1882      mvd = store.get_model_version(name=mv.name, version=mv.version)
1883      assert mvd.run_id == run_id
1884      assert mvd.metrics is not None
1885      assert len(mvd.metrics) == 1
1886      assert mvd.metrics[0].key == "execution_time"
1887      assert mvd.metrics[0].value == 33.74682
1888      assert mvd.params == {"learning_rate": "0.001"}
1889  
1890      updated_mvd = store.update_model_version(
1891          name=mv.name,
1892          version=mv.version,
1893          description="Test description with metrics",
1894      )
1895      assert updated_mvd.description == "Test description with metrics"
1896  
1897      retrieved_mvd = store.get_model_version(name=mv.name, version=mv.version)
1898      assert retrieved_mvd.description == "Test description with metrics"
1899      assert retrieved_mvd.metrics is not None
1900      assert len(retrieved_mvd.metrics) == 1
1901      assert retrieved_mvd.metrics[0].key == "execution_time"
1902      assert retrieved_mvd.params == {"learning_rate": "0.001"}