/ tests / store / model_registry / test_sqlalchemy_store.py
test_sqlalchemy_store.py
   1  import concurrent.futures
   2  import shutil
   3  import time
   4  import uuid
   5  from pathlib import Path
   6  from unittest import mock
   7  
   8  import pytest
   9  from sqlalchemy import create_engine, text
  10  
  11  from mlflow.entities.model_registry import (
  12      ModelVersion,
  13      ModelVersionTag,
  14      RegisteredModelTag,
  15  )
  16  from mlflow.entities.model_registry.prompt_version import IS_PROMPT_TAG_KEY
  17  from mlflow.entities.webhook import WebhookAction, WebhookEntity, WebhookEvent, WebhookStatus
  18  from mlflow.environment_variables import (
  19      _MLFLOW_GO_STORE_TESTING,
  20      MLFLOW_ENABLE_WORKSPACES,
  21      MLFLOW_TRACKING_URI,
  22  )
  23  from mlflow.exceptions import MlflowException
  24  from mlflow.prompt.constants import PROMPT_TEXT_TAG_KEY
  25  from mlflow.protos.databricks_pb2 import (
  26      INVALID_PARAMETER_VALUE,
  27      RESOURCE_ALREADY_EXISTS,
  28      RESOURCE_DOES_NOT_EXIST,
  29      ErrorCode,
  30  )
  31  from mlflow.store.model_registry.dbmodels.models import (
  32      SqlModelVersion,
  33      SqlModelVersionTag,
  34      SqlRegisteredModel,
  35      SqlRegisteredModelTag,
  36      SqlWebhook,
  37  )
  38  from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore
  39  from mlflow.store.model_registry.sqlalchemy_workspace_store import (
  40      WorkspaceAwareSqlAlchemyStore,
  41  )
  42  from mlflow.utils.workspace_context import WorkspaceContext
  43  from mlflow.utils.workspace_utils import DEFAULT_WORKSPACE_NAME
  44  
  45  from tests.helper_functions import random_str
  46  
  47  pytestmark = pytest.mark.notrackingurimock
  48  
  49  GO_MOCK_TIME_TAG = "mock.time.go.testing.tag"
  50  
  51  
  52  @pytest.fixture(autouse=True, params=[False, True], ids=["workspace-disabled", "workspace-enabled"])
  53  def workspaces_enabled(request, monkeypatch, disable_workspace_mode_by_default):
  54      """
  55      Run every test in this module with workspaces disabled and enabled to cover both code paths.
  56      """
  57      enabled = request.param
  58      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true" if enabled else "false")
  59      if enabled:
  60          with WorkspaceContext(DEFAULT_WORKSPACE_NAME):
  61              yield enabled
  62      else:
  63          yield enabled
  64  
  65  
  66  @pytest.fixture
  67  def store(tmp_path: Path, cached_db: Path, workspaces_enabled):
  68      store_cls = WorkspaceAwareSqlAlchemyStore if workspaces_enabled else SqlAlchemyStore
  69      if db_uri_env := MLFLOW_TRACKING_URI.get():
  70          s = store_cls(db_uri_env)
  71          yield s
  72          _cleanup_database(s)
  73      else:
  74          db_path = tmp_path / "mlflow.db"
  75          shutil.copy(cached_db, db_path)
  76          db_uri = f"sqlite:///{db_path}"
  77          s = store_cls(db_uri)
  78          yield s
  79  
  80      # Dispose the engine to close all pooled connections
  81      s.engine.dispose()
  82  
  83  
  84  def _cleanup_database(store: SqlAlchemyStore):
  85      with store.ManagedSessionMaker() as session:
  86          # Delete all rows in all tables
  87          for model in (
  88              SqlModelVersionTag,
  89              SqlRegisteredModelTag,
  90              SqlModelVersion,
  91              SqlRegisteredModel,
  92              SqlWebhook,
  93          ):
  94              session.query(model).delete()
  95  
  96  
  97  def _rm_maker(store, name, tags=None, description=None):
  98      return store.create_registered_model(name, tags, description)
  99  
 100  
 101  def _add_go_test_tags(tags, val):
 102      if _MLFLOW_GO_STORE_TESTING.get():
 103          return tags + [RegisteredModelTag(GO_MOCK_TIME_TAG, val)]
 104      return tags
 105  
 106  
 107  def _mv_maker(
 108      store,
 109      name,
 110      source="path/to/source",
 111      run_id=uuid.uuid4().hex,
 112      tags=None,
 113      run_link=None,
 114      description=None,
 115  ):
 116      return store.create_model_version(
 117          name, source, run_id, tags, run_link=run_link, description=description
 118      )
 119  
 120  
 121  def _extract_latest_by_stage(latest_versions):
 122      return {mvd.current_stage: str(mvd.version) for mvd in latest_versions}
 123  
 124  
 125  def test_create_registered_model(store):
 126      name = random_str() + "abCD"
 127      rm1 = _rm_maker(store, name)
 128      assert rm1.name == name
 129      assert rm1.description is None
 130  
 131      # error on duplicate
 132      with pytest.raises(
 133          MlflowException, match=rf"Registered Model \(name={name}\) already exists"
 134      ) as exception_context:
 135          _rm_maker(store, name)
 136      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_ALREADY_EXISTS)
 137  
 138      # slightly different name is ok
 139      for name2 in [name + "extra", name + name]:
 140          rm2 = _rm_maker(store, name2)
 141          assert rm2.name == name2
 142  
 143      # test create model with tags
 144      name2 = random_str() + "tags"
 145      tags = [
 146          RegisteredModelTag("key", "value"),
 147          RegisteredModelTag("anotherKey", "some other value"),
 148      ]
 149      rm2 = _rm_maker(store, name2, tags)
 150      rmd2 = store.get_registered_model(name2)
 151      assert rm2.name == name2
 152      assert rm2.tags == {tag.key: tag.value for tag in tags}
 153      assert rmd2.name == name2
 154      assert rmd2.tags == {tag.key: tag.value for tag in tags}
 155  
 156      # create with description
 157      name3 = random_str() + "-description"
 158      description = "the best model ever"
 159      rm3 = _rm_maker(store, name3, description=description)
 160      rmd3 = store.get_registered_model(name3)
 161      assert rm3.name == name3
 162      assert rm3.description == description
 163      assert rmd3.name == name3
 164      assert rmd3.description == description
 165  
 166      # invalid model name will fail
 167      with pytest.raises(
 168          MlflowException, match=r"Missing value for required parameter 'name'"
 169      ) as exception_context:
 170          _rm_maker(store, None)
 171      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 172      with pytest.raises(
 173          MlflowException, match=r"Missing value for required parameter 'name'"
 174      ) as exception_context:
 175          _rm_maker(store, "")
 176      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 177  
 178  
 179  def test_get_registered_model(store):
 180      name = "model_1"
 181      tags = [
 182          RegisteredModelTag("key", "value"),
 183          RegisteredModelTag("anotherKey", "some other value"),
 184      ]
 185      # use fake clock
 186      with mock.patch("time.time", return_value=1234):
 187          rm = _rm_maker(store, name, _add_go_test_tags(tags, "1234000"))
 188          assert rm.name == name
 189  
 190      rmd = store.get_registered_model(name=name)
 191      assert rmd.name == name
 192      assert rmd.creation_timestamp == 1234000
 193      assert rmd.last_updated_timestamp == 1234000
 194      assert rmd.description is None
 195      assert rmd.latest_versions == []
 196      assert rmd.tags == {tag.key: tag.value for tag in tags}
 197  
 198  
 199  def test_update_registered_model(store):
 200      name = "model_for_update_RM"
 201      rm1 = _rm_maker(store, name)
 202      rmd1 = store.get_registered_model(name=name)
 203      assert rm1.name == name
 204      assert rmd1.description is None
 205  
 206      # update description
 207      rm2 = store.update_registered_model(name=name, description="test model")
 208      rmd2 = store.get_registered_model(name=name)
 209      assert rm2.name == "model_for_update_RM"
 210      assert rmd2.name == "model_for_update_RM"
 211      assert rmd2.description == "test model"
 212  
 213  
 214  def test_rename_registered_model(store):
 215      original_name = "original name"
 216      new_name = "new name"
 217      _rm_maker(store, original_name)
 218      _mv_maker(store, original_name)
 219      _mv_maker(store, original_name)
 220      rm = store.get_registered_model(original_name)
 221      mv1 = store.get_model_version(original_name, 1)
 222      mv2 = store.get_model_version(original_name, 2)
 223      assert rm.name == original_name
 224      assert mv1.name == original_name
 225      assert mv2.name == original_name
 226  
 227      # test renaming registered model also updates its model versions
 228      store.rename_registered_model(original_name, new_name)
 229      rm = store.get_registered_model(new_name)
 230      mv1 = store.get_model_version(new_name, 1)
 231      mv2 = store.get_model_version(new_name, 2)
 232      assert rm.name == new_name
 233      assert mv1.name == new_name
 234      assert mv2.name == new_name
 235  
 236      # test accessing the model with the old name will fail
 237      with pytest.raises(
 238          MlflowException, match=rf"Registered Model with name={original_name} not found"
 239      ) as exception_context:
 240          store.get_registered_model(original_name)
 241      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
 242  
 243      # test name another model with the replaced name is ok
 244      _rm_maker(store, original_name)
 245      # cannot rename model to conflict with an existing model
 246      with pytest.raises(
 247          MlflowException,
 248          match=rf"Registered Model \(name={original_name}\) already exists",
 249      ) as exception_context:
 250          store.rename_registered_model(new_name, original_name)
 251      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_ALREADY_EXISTS)
 252      # invalid model name will fail
 253      with pytest.raises(
 254          MlflowException, match=r"Missing value for required parameter 'new_name'"
 255      ) as exception_context:
 256          store.rename_registered_model(original_name, None)
 257      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 258      with pytest.raises(
 259          MlflowException, match=r"Missing value for required parameter 'new_name'"
 260      ) as exception_context:
 261          store.rename_registered_model(original_name, "")
 262      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 263  
 264  
 265  def test_delete_registered_model(store):
 266      name = "model_for_delete_RM"
 267      _rm_maker(store, name)
 268      _mv_maker(store, name)
 269      rm1 = store.get_registered_model(name=name)
 270      mv1 = store.get_model_version(name, 1)
 271      assert rm1.name == name
 272      assert mv1.name == name
 273  
 274      # delete model
 275      store.delete_registered_model(name=name)
 276  
 277      # cannot get model
 278      with pytest.raises(
 279          MlflowException, match=rf"Registered Model with name={name} not found"
 280      ) as exception_context:
 281          store.get_registered_model(name=name)
 282      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
 283  
 284      # cannot update a delete model
 285      with pytest.raises(
 286          MlflowException, match=rf"Registered Model with name={name} not found"
 287      ) as exception_context:
 288          store.update_registered_model(name=name, description="deleted")
 289      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
 290  
 291      # cannot delete it again
 292      with pytest.raises(
 293          MlflowException, match=rf"Registered Model with name={name} not found"
 294      ) as exception_context:
 295          store.delete_registered_model(name=name)
 296      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
 297  
 298      # model versions are cascade deleted with the registered model
 299      with pytest.raises(
 300          MlflowException, match=rf"Model Version \(name={name}, version=1\) not found"
 301      ) as exception_context:
 302          store.get_model_version(name, 1)
 303      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
 304  
 305  
 306  def test_get_latest_versions(store):
 307      name = "test_for_latest_versions"
 308      _rm_maker(store, name)
 309      rmd1 = store.get_registered_model(name=name)
 310      assert rmd1.latest_versions == []
 311  
 312      mv1 = _mv_maker(store, name)
 313      assert mv1.version == 1
 314      rmd2 = store.get_registered_model(name=name)
 315      assert _extract_latest_by_stage(rmd2.latest_versions) == {"None": "1"}
 316  
 317      # add a bunch more
 318      mv2 = _mv_maker(store, name)
 319      assert mv2.version == 2
 320      store.transition_model_version_stage(
 321          name=mv2.name,
 322          version=mv2.version,
 323          stage="Production",
 324          archive_existing_versions=False,
 325      )
 326  
 327      mv3 = _mv_maker(store, name)
 328      assert mv3.version == 3
 329      store.transition_model_version_stage(
 330          name=mv3.name,
 331          version=mv3.version,
 332          stage="Production",
 333          archive_existing_versions=False,
 334      )
 335      mv4 = _mv_maker(store, name)
 336      assert mv4.version == 4
 337      store.transition_model_version_stage(
 338          name=mv4.name,
 339          version=mv4.version,
 340          stage="Staging",
 341          archive_existing_versions=False,
 342      )
 343  
 344      # test that correct latest versions are returned for each stage
 345      rmd4 = store.get_registered_model(name=name)
 346      assert _extract_latest_by_stage(rmd4.latest_versions) == {
 347          "None": "1",
 348          "Production": "3",
 349          "Staging": "4",
 350      }
 351      assert _extract_latest_by_stage(store.get_latest_versions(name=name, stages=None)) == {
 352          "None": "1",
 353          "Production": "3",
 354          "Staging": "4",
 355      }
 356      assert _extract_latest_by_stage(store.get_latest_versions(name=name, stages=[])) == {
 357          "None": "1",
 358          "Production": "3",
 359          "Staging": "4",
 360      }
 361      assert _extract_latest_by_stage(
 362          store.get_latest_versions(name=name, stages=["Production"])
 363      ) == {"Production": "3"}
 364      assert _extract_latest_by_stage(
 365          store.get_latest_versions(name=name, stages=["production"])
 366      ) == {"Production": "3"}  # The stages are case insensitive.
 367      assert _extract_latest_by_stage(
 368          store.get_latest_versions(name=name, stages=["pROduction"])
 369      ) == {"Production": "3"}  # The stages are case insensitive.
 370      assert _extract_latest_by_stage(
 371          store.get_latest_versions(name=name, stages=["None", "Production"])
 372      ) == {"None": "1", "Production": "3"}
 373  
 374      # delete latest Production, and should point to previous one
 375      store.delete_model_version(name=mv3.name, version=mv3.version)
 376      rmd5 = store.get_registered_model(name=name)
 377      assert _extract_latest_by_stage(rmd5.latest_versions) == {
 378          "None": "1",
 379          "Production": "2",
 380          "Staging": "4",
 381      }
 382      assert _extract_latest_by_stage(store.get_latest_versions(name=name, stages=None)) == {
 383          "None": "1",
 384          "Production": "2",
 385          "Staging": "4",
 386      }
 387      assert _extract_latest_by_stage(
 388          store.get_latest_versions(name=name, stages=["Production"])
 389      ) == {"Production": "2"}
 390  
 391  
 392  def test_set_registered_model_tag(store):
 393      name1 = "SetRegisteredModelTag_TestMod"
 394      name2 = "SetRegisteredModelTag_TestMod 2"
 395      initial_tags = [
 396          RegisteredModelTag("key", "value"),
 397          RegisteredModelTag("anotherKey", "some other value"),
 398      ]
 399      _rm_maker(store, name1, initial_tags)
 400      _rm_maker(store, name2, initial_tags)
 401      new_tag = RegisteredModelTag("randomTag", "not a random value")
 402      store.set_registered_model_tag(name1, new_tag)
 403      rm1 = store.get_registered_model(name=name1)
 404      all_tags = initial_tags + [new_tag]
 405      assert rm1.tags == {tag.key: tag.value for tag in all_tags}
 406  
 407      # test overriding a tag with the same key
 408      overriding_tag = RegisteredModelTag("key", "overriding")
 409      store.set_registered_model_tag(name1, overriding_tag)
 410      all_tags = [tag for tag in all_tags if tag.key != "key"] + [overriding_tag]
 411      rm1 = store.get_registered_model(name=name1)
 412      assert rm1.tags == {tag.key: tag.value for tag in all_tags}
 413      # does not affect other models with the same key
 414      rm2 = store.get_registered_model(name=name2)
 415      assert rm2.tags == {tag.key: tag.value for tag in initial_tags}
 416  
 417      # can not set tag on deleted (non-existed) registered model
 418      store.delete_registered_model(name1)
 419      with pytest.raises(
 420          MlflowException, match=rf"Registered Model with name={name1} not found"
 421      ) as exception_context:
 422          store.set_registered_model_tag(name1, overriding_tag)
 423      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
 424      # test cannot set tags that are too long
 425      long_tag = RegisteredModelTag("longTagKey", "a" * 100_001)
 426      with pytest.raises(
 427          MlflowException,
 428          match=r"'value' exceeds the maximum length of \d+ characters",
 429      ) as exception_context:
 430          store.set_registered_model_tag(name2, long_tag)
 431      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 432      # test can set tags that are somewhat long
 433      long_tag = RegisteredModelTag("longTagKey", "a" * 4999)
 434      store.set_registered_model_tag(name2, long_tag)
 435      # can not set invalid tag
 436      with pytest.raises(
 437          MlflowException, match=r"Missing value for required parameter 'key'"
 438      ) as exception_context:
 439          store.set_registered_model_tag(name2, RegisteredModelTag(key=None, value=""))
 440      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 441      # can not use invalid model name
 442      with pytest.raises(
 443          MlflowException, match=r"Missing value for required parameter 'name'"
 444      ) as exception_context:
 445          store.set_registered_model_tag(None, RegisteredModelTag(key="key", value="value"))
 446      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 447  
 448  
 449  def test_delete_registered_model_tag(store):
 450      name1 = "DeleteRegisteredModelTag_TestMod"
 451      name2 = "DeleteRegisteredModelTag_TestMod 2"
 452      initial_tags = [
 453          RegisteredModelTag("key", "value"),
 454          RegisteredModelTag("anotherKey", "some other value"),
 455      ]
 456      _rm_maker(store, name1, initial_tags)
 457      _rm_maker(store, name2, initial_tags)
 458      new_tag = RegisteredModelTag("randomTag", "not a random value")
 459      store.set_registered_model_tag(name1, new_tag)
 460      store.delete_registered_model_tag(name1, "randomTag")
 461      rm1 = store.get_registered_model(name=name1)
 462      assert rm1.tags == {tag.key: tag.value for tag in initial_tags}
 463  
 464      # testing deleting a key does not affect other models with the same key
 465      store.delete_registered_model_tag(name1, "key")
 466      rm1 = store.get_registered_model(name=name1)
 467      rm2 = store.get_registered_model(name=name2)
 468      assert rm1.tags == {"anotherKey": "some other value"}
 469      assert rm2.tags == {tag.key: tag.value for tag in initial_tags}
 470  
 471      # delete tag that is already deleted does nothing
 472      store.delete_registered_model_tag(name1, "key")
 473      rm1 = store.get_registered_model(name=name1)
 474      assert rm1.tags == {"anotherKey": "some other value"}
 475  
 476      # can not delete tag on deleted (non-existed) registered model
 477      store.delete_registered_model(name1)
 478      with pytest.raises(
 479          MlflowException, match=rf"Registered Model with name={name1} not found"
 480      ) as exception_context:
 481          store.delete_registered_model_tag(name1, "anotherKey")
 482      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
 483      # can not delete tag with invalid key
 484      with pytest.raises(
 485          MlflowException, match=r"Missing value for required parameter 'key'"
 486      ) as exception_context:
 487          store.delete_registered_model_tag(name2, None)
 488      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 489      # can not use invalid model name
 490      with pytest.raises(
 491          MlflowException, match=r"Missing value for required parameter 'name'"
 492      ) as exception_context:
 493          store.delete_registered_model_tag(None, "key")
 494      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 495  
 496  
 497  def test_create_model_version(store):
 498      name = "test_for_update_MV"
 499      _rm_maker(store, name)
 500      run_id = uuid.uuid4().hex
 501      with mock.patch("time.time", return_value=456778):
 502          mv1 = _mv_maker(store, name, "a/b/CD", run_id)
 503          assert mv1.name == name
 504          assert mv1.version == 1
 505  
 506      mvd1 = store.get_model_version(mv1.name, mv1.version)
 507      assert mvd1.name == name
 508      assert int(mvd1.version) == 1
 509      assert mvd1.current_stage == "None"
 510      assert mvd1.creation_timestamp == 456778000
 511      assert mvd1.last_updated_timestamp == 456778000
 512      assert mvd1.description is None
 513      assert mvd1.source == "a/b/CD"
 514      assert mvd1.run_id == run_id
 515      assert mvd1.status == "READY"
 516      assert mvd1.status_message is None
 517      assert mvd1.tags == {}
 518  
 519      # new model versions for same name autoincrement versions
 520      mv2 = _mv_maker(store, name)
 521      mvd2 = store.get_model_version(name=mv2.name, version=mv2.version)
 522      assert mv2.version == 2
 523      assert int(mvd2.version) == 2
 524  
 525      # create model version with tags return model version entity with tags
 526      tags = [
 527          ModelVersionTag("key", "value"),
 528          ModelVersionTag("anotherKey", "some other value"),
 529      ]
 530      mv3 = _mv_maker(store, name, tags=tags)
 531      mvd3 = store.get_model_version(name=mv3.name, version=mv3.version)
 532      assert mv3.version == 3
 533      assert mv3.tags == {tag.key: tag.value for tag in tags}
 534      assert int(mvd3.version) == 3
 535      assert mvd3.tags == {tag.key: tag.value for tag in tags}
 536  
 537      # create model versions with runLink
 538      run_link = "http://localhost:3000/path/to/run/"
 539      mv4 = _mv_maker(store, name, run_link=run_link)
 540      mvd4 = store.get_model_version(name, mv4.version)
 541      assert mv4.version == 4
 542      assert mv4.run_link == run_link
 543      assert int(mvd4.version) == 4
 544      assert mvd4.run_link == run_link
 545  
 546      # create model version with description
 547      description = "the best model ever"
 548      mv5 = _mv_maker(store, name, description=description)
 549      mvd5 = store.get_model_version(name, mv5.version)
 550      assert mv5.version == 5
 551      assert mv5.description == description
 552      assert int(mvd5.version) == 5
 553      assert mvd5.description == description
 554  
 555      # create model version without runId
 556      mv6 = _mv_maker(store, name, run_id=None)
 557      mvd6 = store.get_model_version(name, mv6.version)
 558      assert mv6.version == 6
 559      assert mv6.run_id is None
 560      assert int(mvd6.version) == 6
 561      assert mvd6.run_id is None
 562  
 563  
 564  def test_update_model_version(store):
 565      name = "test_for_update_MV"
 566      _rm_maker(store, name)
 567      mv1 = _mv_maker(store, name)
 568      mvd1 = store.get_model_version(name=mv1.name, version=mv1.version)
 569      assert mvd1.name == name
 570      assert int(mvd1.version) == 1
 571      assert mvd1.current_stage == "None"
 572  
 573      # update stage
 574      store.transition_model_version_stage(
 575          name=mv1.name,
 576          version=mv1.version,
 577          stage="Production",
 578          archive_existing_versions=False,
 579      )
 580      mvd2 = store.get_model_version(name=mv1.name, version=mv1.version)
 581      assert mvd2.name == name
 582      assert int(mvd2.version) == 1
 583      assert mvd2.current_stage == "Production"
 584      assert mvd2.description is None
 585  
 586      # update description
 587      store.update_model_version(name=mv1.name, version=mv1.version, description="test model version")
 588      mvd3 = store.get_model_version(name=mv1.name, version=mv1.version)
 589      assert mvd3.name == name
 590      assert int(mvd3.version) == 1
 591      assert mvd3.current_stage == "Production"
 592      assert mvd3.description == "test model version"
 593  
 594      # only valid stages can be set
 595      with pytest.raises(
 596          MlflowException,
 597          match=(
 598              "Invalid Model Version stage: unknown. "
 599              "Value must be one of None, Staging, Production, Archived."
 600          ),
 601      ) as exception_context:
 602          store.transition_model_version_stage(
 603              mv1.name, mv1.version, stage="unknown", archive_existing_versions=False
 604          )
 605      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 606  
 607      # stages are case-insensitive and auto-corrected to system stage names
 608      for stage_name in ["STAGING", "staging", "StAgInG"]:
 609          store.transition_model_version_stage(
 610              name=mv1.name,
 611              version=mv1.version,
 612              stage=stage_name,
 613              archive_existing_versions=False,
 614          )
 615          mvd5 = store.get_model_version(name=mv1.name, version=mv1.version)
 616          assert mvd5.current_stage == "Staging"
 617  
 618  
 619  def test_transition_model_version_stage_when_archive_existing_versions_is_false(store):
 620      name = "model"
 621      _rm_maker(store, name)
 622      mv1 = _mv_maker(store, name)
 623      mv2 = _mv_maker(store, name)
 624      mv3 = _mv_maker(store, name)
 625  
 626      # test that when `archive_existing_versions` is False, transitioning a model version
 627      # to the inactive stages ("Archived" and "None") does not throw.
 628      for stage in ["Archived", "None"]:
 629          store.transition_model_version_stage(name, mv1.version, stage, False)
 630  
 631      store.transition_model_version_stage(name, mv1.version, "Staging", False)
 632      store.transition_model_version_stage(name, mv2.version, "Production", False)
 633      store.transition_model_version_stage(name, mv3.version, "Staging", False)
 634  
 635      mvd1 = store.get_model_version(name=name, version=mv1.version)
 636      mvd2 = store.get_model_version(name=name, version=mv2.version)
 637      mvd3 = store.get_model_version(name=name, version=mv3.version)
 638  
 639      assert mvd1.current_stage == "Staging"
 640      assert mvd2.current_stage == "Production"
 641      assert mvd3.current_stage == "Staging"
 642  
 643      store.transition_model_version_stage(name, mv3.version, "Production", False)
 644  
 645      mvd1 = store.get_model_version(name=name, version=mv1.version)
 646      mvd2 = store.get_model_version(name=name, version=mv2.version)
 647      mvd3 = store.get_model_version(name=name, version=mv3.version)
 648  
 649      assert mvd1.current_stage == "Staging"
 650      assert mvd2.current_stage == "Production"
 651      assert mvd3.current_stage == "Production"
 652  
 653  
 654  def test_transition_model_version_stage_when_archive_existing_versions_is_true(store):
 655      name = "model"
 656      _rm_maker(store, name)
 657      mv1 = _mv_maker(store, name)
 658      mv2 = _mv_maker(store, name)
 659      mv3 = _mv_maker(store, name)
 660  
 661      msg = (
 662          r"Model version transition cannot archive existing model versions "
 663          r"because .+ is not an Active stage"
 664      )
 665  
 666      # test that when `archive_existing_versions` is True, transitioning a model version
 667      # to the inactive stages ("Archived" and "None") throws.
 668      for stage in ["Archived", "None"]:
 669          with pytest.raises(MlflowException, match=msg):
 670              store.transition_model_version_stage(name, mv1.version, stage, True)
 671  
 672      store.transition_model_version_stage(name, mv1.version, "Staging", False)
 673      store.transition_model_version_stage(name, mv2.version, "Production", False)
 674      store.transition_model_version_stage(name, mv3.version, "Staging", True)
 675  
 676      mvd1 = store.get_model_version(name=name, version=mv1.version)
 677      mvd2 = store.get_model_version(name=name, version=mv2.version)
 678      mvd3 = store.get_model_version(name=name, version=mv3.version)
 679  
 680      assert mvd1.current_stage == "Archived"
 681      assert mvd2.current_stage == "Production"
 682      assert mvd3.current_stage == "Staging"
 683      assert mvd1.last_updated_timestamp == mvd3.last_updated_timestamp
 684  
 685      store.transition_model_version_stage(name, mv3.version, "Production", True)
 686  
 687      mvd1 = store.get_model_version(name=name, version=mv1.version)
 688      mvd2 = store.get_model_version(name=name, version=mv2.version)
 689      mvd3 = store.get_model_version(name=name, version=mv3.version)
 690  
 691      assert mvd1.current_stage == "Archived"
 692      assert mvd2.current_stage == "Archived"
 693      assert mvd3.current_stage == "Production"
 694      assert mvd2.last_updated_timestamp == mvd3.last_updated_timestamp
 695  
 696      for uncanonical_stage_name in ["STAGING", "staging", "StAgInG"]:
 697          store.transition_model_version_stage(mv1.name, mv1.version, "Staging", False)
 698          store.transition_model_version_stage(mv2.name, mv2.version, "None", False)
 699  
 700          # stage names are case-insensitive and auto-corrected to system stage names
 701          store.transition_model_version_stage(mv2.name, mv2.version, uncanonical_stage_name, True)
 702  
 703          mvd1 = store.get_model_version(name=mv1.name, version=mv1.version)
 704          mvd2 = store.get_model_version(name=mv2.name, version=mv2.version)
 705          assert mvd1.current_stage == "Archived"
 706          assert mvd2.current_stage == "Staging"
 707  
 708  
 709  def test_delete_model_version(store):
 710      name = "test_for_delete_MV"
 711      initial_tags = [
 712          ModelVersionTag("key", "value"),
 713          ModelVersionTag("anotherKey", "some other value"),
 714      ]
 715      _rm_maker(store, name)
 716      mv = _mv_maker(store, name, tags=initial_tags)
 717      mvd = store.get_model_version(name=mv.name, version=mv.version)
 718      assert mvd.name == name
 719  
 720      store.delete_model_version(name=mv.name, version=mv.version)
 721  
 722      # cannot get a deleted model version
 723      with pytest.raises(
 724          MlflowException,
 725          match=rf"Model Version \(name={mv.name}, version={mv.version}\) not found",
 726      ) as exception_context:
 727          store.get_model_version(name=mv.name, version=mv.version)
 728      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
 729  
 730      # cannot update a delete
 731      with pytest.raises(
 732          MlflowException,
 733          match=rf"Model Version \(name={mv.name}, version={mv.version}\) not found",
 734      ) as exception_context:
 735          store.update_model_version(mv.name, mv.version, description="deleted!")
 736      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
 737  
 738      # cannot delete it again
 739      with pytest.raises(
 740          MlflowException,
 741          match=rf"Model Version \(name={mv.name}, version={mv.version}\) not found",
 742      ) as exception_context:
 743          store.delete_model_version(name=mv.name, version=mv.version)
 744      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
 745  
 746  
 747  def test_delete_model_version_redaction(store):
 748      name = "test_for_delete_MV_redaction"
 749      run_link = "http://localhost:5000/path/to/run"
 750      run_id = "12345"
 751      source = "path/to/source"
 752      _rm_maker(store, name)
 753      mv = _mv_maker(store, name, source=source, run_id=run_id, run_link=run_link)
 754      mvd = store.get_model_version(name=name, version=mv.version)
 755      assert mvd.run_link == run_link
 756      assert mvd.run_id == run_id
 757      assert mvd.source == source
 758      # delete the MV now
 759      store.delete_model_version(name, mv.version)
 760      # verify that the relevant fields are redacted
 761      mvd_deleted = store._get_sql_model_version_including_deleted(name=name, version=mv.version)
 762      assert "REDACTED" in mvd_deleted.run_link
 763      assert "REDACTED" in mvd_deleted.source
 764      assert "REDACTED" in mvd_deleted.run_id
 765  
 766  
 767  def test_get_model_version_download_uri(store):
 768      name = "test_for_update_MV"
 769      _rm_maker(store, name)
 770      source_path = "path/to/source"
 771      mv = _mv_maker(store, name, source=source_path, run_id=uuid.uuid4().hex)
 772      mvd1 = store.get_model_version(name=mv.name, version=mv.version)
 773      assert mvd1.name == name
 774      assert mvd1.source == source_path
 775  
 776      # download location points to source
 777      assert store.get_model_version_download_uri(name=mv.name, version=mv.version) == source_path
 778  
 779      # download URI does not change even if model version is updated
 780      store.transition_model_version_stage(
 781          name=mv.name,
 782          version=mv.version,
 783          stage="Production",
 784          archive_existing_versions=False,
 785      )
 786      store.update_model_version(name=mv.name, version=mv.version, description="Test for Path")
 787      mvd2 = store.get_model_version(name=mv.name, version=mv.version)
 788      assert mvd2.source == source_path
 789      assert store.get_model_version_download_uri(name=mv.name, version=mv.version) == source_path
 790  
 791      # cannot retrieve download URI for deleted model versions
 792      store.delete_model_version(name=mv.name, version=mv.version)
 793      with pytest.raises(
 794          MlflowException,
 795          match=rf"Model Version \(name={mv.name}, version={mv.version}\) not found",
 796      ) as exception_context:
 797          store.get_model_version_download_uri(name=mv.name, version=mv.version)
 798      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
 799  
 800  
 801  def test_search_model_versions(store):
 802      # create some model versions
 803      name = "test_for_search_MV"
 804      _rm_maker(store, name)
 805      run_id_1 = uuid.uuid4().hex
 806      run_id_2 = uuid.uuid4().hex
 807      run_id_3 = uuid.uuid4().hex
 808      mv1 = _mv_maker(store, name=name, source="A/B", run_id=run_id_1)
 809      assert mv1.version == 1
 810      mv2 = _mv_maker(store, name=name, source="A/C", run_id=run_id_2)
 811      assert mv2.version == 2
 812      mv3 = _mv_maker(store, name=name, source="A/D", run_id=run_id_2)
 813      assert mv3.version == 3
 814      mv4 = _mv_maker(store, name=name, source="A/D", run_id=run_id_3)
 815      assert mv4.version == 4
 816  
 817      def search_versions(filter_string, max_results=10, order_by=None, page_token=None):
 818          return [
 819              mvd.version
 820              for mvd in store.search_model_versions(filter_string, max_results, order_by, page_token)
 821          ]
 822  
 823      # search using name should return all 4 versions
 824      assert set(search_versions(f"name='{name}'")) == {1, 2, 3, 4}
 825  
 826      # search using version
 827      assert set(search_versions("version_number=2")) == {2}
 828      assert set(search_versions("version_number<=3")) == {1, 2, 3}
 829  
 830      # search using run_id_1 should return version 1
 831      assert set(search_versions(f"run_id='{run_id_1}'")) == {1}
 832  
 833      # search using run_id_2 should return versions 2 and 3
 834      assert set(search_versions(f"run_id='{run_id_2}'")) == {2, 3}
 835  
 836      # search using the IN operator should return all versions
 837      assert set(search_versions(f"run_id IN ('{run_id_1}','{run_id_2}')")) == {1, 2, 3}
 838  
 839      # search IN operator is case sensitive
 840      assert set(search_versions(f"run_id IN ('{run_id_1.upper()}','{run_id_2}')")) == {
 841          2,
 842          3,
 843      }
 844  
 845      # search IN operator with other conditions
 846      assert set(
 847          search_versions(f"version_number=2 AND run_id IN ('{run_id_1.upper()}','{run_id_2}')")
 848      ) == {2}
 849  
 850      # search IN operator with right-hand side value containing whitespaces
 851      assert set(search_versions(f"run_id IN ('{run_id_1}', '{run_id_2}')")) == {1, 2, 3}
 852  
 853      # search using the IN operator with bad lists should return exceptions
 854      with pytest.raises(
 855          MlflowException,
 856          match=(
 857              r"While parsing a list in the query, "
 858              r"expected string value, punctuation, or whitespace, "
 859              r"but got different type in list"
 860          ),
 861      ) as exception_context:
 862          search_versions("run_id IN (1,2,3)")
 863      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 864  
 865      assert set(search_versions(f"run_id LIKE '{run_id_2[:30]}%'")) == {2, 3}
 866  
 867      assert set(search_versions(f"run_id ILIKE '{run_id_2[:30].upper()}%'")) == {2, 3}
 868  
 869      # search using the IN operator with empty lists should return exceptions
 870      with pytest.raises(
 871          MlflowException,
 872          match=(
 873              r"While parsing a list in the query, "
 874              r"expected a non-empty list of string values, "
 875              r"but got empty list"
 876          ),
 877      ) as exception_context:
 878          search_versions("run_id IN ()")
 879      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 880  
 881      # search using an ill-formed IN operator correctly throws exception
 882      with pytest.raises(
 883          MlflowException, match=r"Invalid clause\(s\) in filter string"
 884      ) as exception_context:
 885          search_versions("run_id IN (")
 886      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 887  
 888      with pytest.raises(
 889          MlflowException, match=r"Invalid clause\(s\) in filter string"
 890      ) as exception_context:
 891          search_versions("run_id IN")
 892      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 893  
 894      with pytest.raises(
 895          MlflowException, match=r"Invalid clause\(s\) in filter string"
 896      ) as exception_context:
 897          search_versions("name LIKE")
 898      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 899  
 900      with pytest.raises(
 901          MlflowException,
 902          match=(
 903              r"While parsing a list in the query, "
 904              r"expected a non-empty list of string values, "
 905              r"but got ill-formed list"
 906          ),
 907      ) as exception_context:
 908          search_versions("run_id IN (,)")
 909      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 910  
 911      with pytest.raises(
 912          MlflowException,
 913          match=(
 914              r"While parsing a list in the query, "
 915              r"expected a non-empty list of string values, "
 916              r"but got ill-formed list"
 917          ),
 918      ) as exception_context:
 919          search_versions("run_id IN ('runid1',,'runid2')")
 920      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 921  
 922      # search using source_path "A/D" should return version 3 and 4
 923      assert set(search_versions("source_path = 'A/D'")) == {3, 4}
 924  
 925      # search using source_path "A" should not return anything
 926      assert len(search_versions("source_path = 'A'")) == 0
 927      assert len(search_versions("source_path = 'A/'")) == 0
 928      assert len(search_versions("source_path = ''")) == 0
 929  
 930      # delete mv4. search should not return version 4
 931      store.delete_model_version(name=mv4.name, version=mv4.version)
 932      assert set(search_versions("")) == {1, 2, 3}
 933  
 934      assert set(search_versions(None)) == {1, 2, 3}
 935  
 936      assert set(search_versions(f"name='{name}'")) == {1, 2, 3}
 937  
 938      assert set(search_versions("source_path = 'A/D'")) == {3}
 939  
 940      store.transition_model_version_stage(
 941          name=mv1.name,
 942          version=mv1.version,
 943          stage="production",
 944          archive_existing_versions=False,
 945      )
 946  
 947      store.update_model_version(
 948          name=mv1.name, version=mv1.version, description="Online prediction model!"
 949      )
 950  
 951      mvds = store.search_model_versions(f"run_id = '{run_id_1}'", max_results=10)
 952      assert len(mvds) == 1
 953      assert isinstance(mvds[0], ModelVersion)
 954      assert mvds[0].current_stage == "Production"
 955      assert mvds[0].run_id == run_id_1
 956      assert mvds[0].source == "A/B"
 957      assert mvds[0].description == "Online prediction model!"
 958  
 959  
 960  def test_search_model_versions_order_by_simple(store):
 961      # create some model versions
 962      names = ["RM1", "RM2", "RM3", "RM4", "RM1", "RM4"]
 963      sources = ["A"] * 3 + ["B"] * 3
 964      run_ids = [uuid.uuid4().hex for _ in range(6)]
 965      for name in set(names):
 966          _rm_maker(store, name)
 967      for i in range(6):
 968          time.sleep(0.001)  # sleep to ensure each model version has a different creation_time
 969          _mv_maker(store, name=names[i], source=sources[i], run_id=run_ids[i])
 970  
 971      # by default order by last_updated_timestamp DESC
 972      mvs = store.search_model_versions(filter_string=None)
 973      assert [mv.name for mv in mvs] == names[::-1]
 974      assert [mv.version for mv in mvs] == [2, 2, 1, 1, 1, 1]
 975  
 976      # order by name DESC
 977      mvs = store.search_model_versions(filter_string=None, order_by=["name DESC"])
 978      assert [mv.name for mv in mvs] == sorted(names)[::-1]
 979      assert [mv.version for mv in mvs] == [2, 1, 1, 1, 2, 1]
 980  
 981      # order by version DESC
 982      mvs = store.search_model_versions(filter_string=None, order_by=["version_number DESC"])
 983      assert [mv.name for mv in mvs] == ["RM1", "RM4", "RM1", "RM2", "RM3", "RM4"]
 984      assert [mv.version for mv in mvs] == [2, 2, 1, 1, 1, 1]
 985  
 986      # order by creation_timestamp DESC
 987      mvs = store.search_model_versions(filter_string=None, order_by=["creation_timestamp DESC"])
 988      assert [mv.name for mv in mvs] == names[::-1]
 989      assert [mv.version for mv in mvs] == [2, 2, 1, 1, 1, 1]
 990  
 991      # order by last_updated_timestamp ASC
 992      store.update_model_version(names[0], 1, "latest updated")
 993      mvs = store.search_model_versions(filter_string=None, order_by=["last_updated_timestamp ASC"])
 994      assert mvs[-1].name == names[0]
 995      assert mvs[-1].version == 1
 996  
 997  
 998  def test_search_model_versions_order_by_errors(store):
 999      # create some model versions
1000      name = "RM1"
1001      _rm_maker(store, name)
1002      for _ in range(6):
1003          _mv_maker(store, name=name)
1004      query = "name LIKE 'RM%'"
1005      # test that invalid columns throw even if they come after valid columns
1006      with pytest.raises(
1007          MlflowException, match=r"Invalid attribute key '.+' specified"
1008      ) as exception_context:
1009          store.search_model_versions(
1010              query,
1011              page_token=None,
1012              order_by=["name ASC", "run_id DESC"],
1013              max_results=5,
1014          )
1015      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1016      # test that invalid columns with random text throw even if they come after valid columns
1017      with pytest.raises(MlflowException, match=r"Invalid order_by clause '.+'") as exception_context:
1018          store.search_model_versions(
1019              query,
1020              page_token=None,
1021              order_by=["name ASC", "last_updated_timestamp DESC blah"],
1022              max_results=5,
1023          )
1024      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1025  
1026  
1027  def test_search_model_versions_pagination(store):
1028      def search_versions(filter_string, page_token=None, max_results=10):
1029          result = store.search_model_versions(
1030              filter_string=filter_string, page_token=page_token, max_results=max_results
1031          )
1032          return result.to_list(), result.token
1033  
1034      name = "test_for_search_MV_pagination"
1035      _rm_maker(store, name)
1036      mvs = [_mv_maker(store, name) for _ in range(50)][::-1]
1037  
1038      # test flow with fixed max_results
1039      returned_mvs = []
1040      query = "name LIKE 'test_for_search_MV_pagination%'"
1041      result, token = search_versions(query, page_token=None, max_results=5)
1042      returned_mvs.extend(result)
1043      while token:
1044          result, token = search_versions(query, page_token=token, max_results=5)
1045          returned_mvs.extend(result)
1046      assert mvs == returned_mvs
1047  
1048      # test that pagination will return all valid results in sorted order
1049      # by name ascending
1050      result, token1 = search_versions(query, max_results=5)
1051      assert token1 is not None
1052      assert result == mvs[0:5]
1053  
1054      result, token2 = search_versions(query, page_token=token1, max_results=10)
1055      assert token2 is not None
1056      assert result == mvs[5:15]
1057  
1058      result, token3 = search_versions(query, page_token=token2, max_results=20)
1059      assert token3 is not None
1060      assert result == mvs[15:35]
1061  
1062      result, token4 = search_versions(query, page_token=token3, max_results=100)
1063      # assert that page token is None
1064      assert token4 is None
1065      assert result == mvs[35:]
1066  
1067      # test that providing a completely invalid page token throws
1068      with pytest.raises(
1069          MlflowException, match=r"Invalid page token, could not base64-decode"
1070      ) as exception_context:
1071          search_versions(query, page_token="evilhax", max_results=20)
1072      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1073  
1074      # test that providing too large of a max_results throws
1075      with pytest.raises(
1076          MlflowException, match=r"Invalid value for max_results\."
1077      ) as exception_context:
1078          search_versions(query, page_token="evilhax", max_results=1e15)
1079      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1080  
1081  
1082  def test_search_model_versions_by_tag(store):
1083      # create some model versions
1084      name = "test_for_search_MV_by_tag"
1085      _rm_maker(store, name)
1086      run_id_1 = uuid.uuid4().hex
1087      run_id_2 = uuid.uuid4().hex
1088  
1089      mv1 = _mv_maker(
1090          store,
1091          name=name,
1092          source="A/B",
1093          run_id=run_id_1,
1094          tags=[ModelVersionTag("t1", "abc"), ModelVersionTag("t2", "xyz")],
1095      )
1096      assert mv1.version == 1
1097      mv2 = _mv_maker(
1098          store,
1099          name=name,
1100          source="A/C",
1101          run_id=run_id_2,
1102          tags=[ModelVersionTag("t1", "abc"), ModelVersionTag("t2", "x123")],
1103      )
1104      assert mv2.version == 2
1105  
1106      def search_versions(filter_string):
1107          return [mvd.version for mvd in store.search_model_versions(filter_string)]
1108  
1109      assert search_versions(f"name = '{name}' and tag.t2 = 'xyz'") == [1]
1110      assert search_versions("name = 'wrong_name' and tag.t2 = 'xyz'") == []
1111      assert search_versions("tag.`t2` = 'xyz'") == [1]
1112      assert search_versions("tag.t3 = 'xyz'") == []
1113      assert search_versions("tag.t2 != 'xy'") == [2, 1]
1114      assert search_versions("tag.t2 LIKE 'xy%'") == [1]
1115      assert search_versions("tag.t2 LIKE 'xY%'") == []
1116      assert search_versions("tag.t2 ILIKE 'xY%'") == [1]
1117      assert search_versions("tag.t2 LIKE 'x%'") == [2, 1]
1118      assert search_versions("tag.T2 = 'xyz'") == []
1119      assert search_versions("tag.t1 = 'abc' and tag.t2 = 'xyz'") == [1]
1120      assert search_versions("tag.t1 = 'abc' and tag.t2 LIKE 'x%'") == [2, 1]
1121      assert search_versions("tag.t1 = 'abc' and tag.t2 LIKE 'y%'") == []
1122      # test filter with duplicated keys
1123      assert search_versions("tag.t2 like 'x%' and tag.t2 != 'xyz'") == [2]
1124  
1125  
1126  def _assert_workspace_in_main_queries(captured_sql, table_name, context_label):
1127      """Assert that every main search query (ORDER BY + LIMIT) includes workspace in WHERE."""
1128      main_queries = [s for s in captured_sql if "ORDER BY" in s and table_name in s and "LIMIT" in s]
1129      assert main_queries, f"No main search query found for {context_label}"
1130      for sql in main_queries:
1131          where_clause = sql.split("WHERE", 1)[-1] if "WHERE" in sql else ""
1132          assert "workspace" in where_clause.lower(), (
1133              f"{context_label} missing workspace predicate in WHERE — "
1134              f"causes full table scan on (workspace, ...) PK:\n{sql}"
1135          )
1136  
1137  
1138  def test_search_model_versions_includes_workspace_predicate(store):
1139      from sqlalchemy import event
1140  
1141      name = "test_ws_predicate_mv"
1142      _rm_maker(store, name)
1143      _mv_maker(store, name=name, source="A/B", tags=[ModelVersionTag("t1", "abc")])
1144  
1145      for filter_string in [
1146          f"name = '{name}'",
1147          "tag.t1 = 'abc'",
1148          f"name = '{name}' AND tag.t1 = 'abc'",
1149      ]:
1150          captured_sql: list[str] = []
1151  
1152          def _capture(conn, cursor, statement, parameters, context, executemany):
1153              captured_sql.append(statement)
1154  
1155          event.listen(store.engine, "before_cursor_execute", _capture)
1156          try:
1157              store.search_model_versions(filter_string)
1158          finally:
1159              event.remove(store.engine, "before_cursor_execute", _capture)
1160  
1161          _assert_workspace_in_main_queries(
1162              captured_sql, "model_versions", f"search_model_versions('{filter_string}')"
1163          )
1164  
1165  
1166  def test_search_registered_models_includes_workspace_predicate(store):
1167      from sqlalchemy import event
1168  
1169      name = "test_ws_predicate_rm"
1170      _rm_maker(store, name, tags=[RegisteredModelTag("t1", "abc")])
1171  
1172      for filter_string in [
1173          f"name = '{name}'",
1174          "tag.t1 = 'abc'",
1175          f"name = '{name}' AND tag.t1 = 'abc'",
1176      ]:
1177          captured_sql: list[str] = []
1178  
1179          def _capture(conn, cursor, statement, parameters, context, executemany):
1180              captured_sql.append(statement)
1181  
1182          event.listen(store.engine, "before_cursor_execute", _capture)
1183          try:
1184              store.search_registered_models(filter_string)
1185          finally:
1186              event.remove(store.engine, "before_cursor_execute", _capture)
1187  
1188          _assert_workspace_in_main_queries(
1189              captured_sql,
1190              "registered_models",
1191              f"search_registered_models('{filter_string}')",
1192          )
1193  
1194  
1195  def _search_registered_models(store, filter_string, max_results=10, order_by=None, page_token=None):
1196      result = store.search_registered_models(
1197          filter_string=filter_string,
1198          max_results=max_results,
1199          order_by=order_by,
1200          page_token=page_token,
1201      )
1202      return [registered_model.name for registered_model in result], result.token
1203  
1204  
1205  def test_search_registered_models(store):
1206      # create some registered models
1207      prefix = "test_for_search_"
1208      names = [prefix + name for name in ["RM1", "RM2", "RM3", "RM4", "RM4A", "RM4ab"]]
1209      for name in names:
1210          _rm_maker(store, name)
1211  
1212      # search with no filter should return all registered models
1213      rms, _ = _search_registered_models(store, None)
1214      assert rms == names
1215  
1216      # equality search using name should return exactly the 1 name
1217      rms, _ = _search_registered_models(store, f"name='{names[0]}'")
1218      assert rms == [names[0]]
1219  
1220      # equality search using name that is not valid should return nothing
1221      rms, _ = _search_registered_models(store, "name='{}'".format(names[0] + "cats"))
1222      assert rms == []
1223  
1224      # case-sensitive prefix search using LIKE should return all the RMs
1225      rms, _ = _search_registered_models(store, f"name LIKE '{prefix}%'")
1226      assert rms == names
1227  
1228      # case-sensitive prefix search using LIKE with surrounding % should return all the RMs
1229      rms, _ = _search_registered_models(store, "name LIKE '%RM%'")
1230      assert rms == names
1231  
1232      # case-sensitive prefix search using LIKE with surrounding % should return all the RMs
1233      # _e% matches test_for_search_ , so all RMs should match
1234      rms, _ = _search_registered_models(store, "name LIKE '_e%'")
1235      assert rms == names
1236  
1237      # case-sensitive prefix search using LIKE should return just rm4
1238      rms, _ = _search_registered_models(store, "name LIKE '{}%'".format(prefix + "RM4A"))
1239      assert rms == [names[4]]
1240  
1241      # case-sensitive prefix search using LIKE should return no models if no match
1242      rms, _ = _search_registered_models(store, "name LIKE '{}%'".format(prefix + "cats"))
1243      assert rms == []
1244  
1245      # confirm that LIKE is not case-sensitive
1246      rms, _ = _search_registered_models(store, "name lIkE '%blah%'")
1247      assert rms == []
1248  
1249      rms, _ = _search_registered_models(store, "name like '{}%'".format(prefix + "RM4A"))
1250      assert rms == [names[4]]
1251  
1252      # case-insensitive prefix search using ILIKE should return both rm5 and rm6
1253      rms, _ = _search_registered_models(store, "name ILIKE '{}%'".format(prefix + "RM4A"))
1254      assert rms == names[4:]
1255  
1256      # case-insensitive postfix search with ILIKE
1257      rms, _ = _search_registered_models(store, "name ILIKE '%RM4a%'")
1258      assert rms == names[4:]
1259  
1260      # case-insensitive prefix search using ILIKE should return both rm5 and rm6
1261      rms, _ = _search_registered_models(store, "name ILIKE '{}%'".format(prefix + "cats"))
1262      assert rms == []
1263  
1264      # confirm that ILIKE is not case-sensitive
1265      rms, _ = _search_registered_models(store, "name iLike '%blah%'")
1266      assert rms == []
1267  
1268      # confirm that ILIKE works for empty query
1269      rms, _ = _search_registered_models(store, "name iLike '%%'")
1270      assert rms == names
1271  
1272      rms, _ = _search_registered_models(store, "name ilike '%RM4a%'")
1273      assert rms == names[4:]
1274  
1275      # cannot search by invalid comparator types
1276      with pytest.raises(
1277          MlflowException,
1278          match="Parameter value is either not quoted or unidentified quote types used for "
1279          "string value something",
1280      ) as exception_context:
1281          _search_registered_models(store, "name!=something")
1282      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1283  
1284      # cannot search by run_id
1285      with pytest.raises(
1286          MlflowException, match=r"Invalid attribute key 'run_id' specified."
1287      ) as exception_context:
1288          _search_registered_models(store, "run_id='somerunID'")
1289      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1290  
1291      # cannot search by source_path
1292      with pytest.raises(
1293          MlflowException, match=r"Invalid attribute key 'source_path' specified."
1294      ) as exception_context:
1295          _search_registered_models(store, "source_path = 'A/D'")
1296      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1297  
1298      # cannot search by other params
1299      with pytest.raises(
1300          MlflowException, match=r"Invalid clause\(s\) in filter string"
1301      ) as exception_context:
1302          _search_registered_models(store, "evilhax = true")
1303      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1304  
1305      # delete last registered model. search should not return the first 5
1306      store.delete_registered_model(name=names[-1])
1307      assert _search_registered_models(store, None, max_results=1000) == (
1308          names[:-1],
1309          None,
1310      )
1311  
1312      # equality search using name should return no names
1313      assert _search_registered_models(store, f"name='{names[-1]}'") == ([], None)
1314  
1315      # case-sensitive prefix search using LIKE should return all the RMs
1316      assert _search_registered_models(store, f"name LIKE '{prefix}%'") == (
1317          names[0:5],
1318          None,
1319      )
1320  
1321      # case-insensitive prefix search using ILIKE should return both rm5 and rm6
1322      assert _search_registered_models(store, "name ILIKE '{}%'".format(prefix + "RM4A")) == (
1323          [names[4]],
1324          None,
1325      )
1326  
1327  
1328  def test_search_registered_models_by_tag(store):
1329      name1 = "test_for_search_RM_by_tag1"
1330      name2 = "test_for_search_RM_by_tag2"
1331      tags1 = [
1332          RegisteredModelTag("t1", "abc"),
1333          RegisteredModelTag("t2", "xyz"),
1334      ]
1335      tags2 = [
1336          RegisteredModelTag("t1", "abcd"),
1337          RegisteredModelTag("t2", "xyz123"),
1338          RegisteredModelTag("t3", "XYZ"),
1339      ]
1340      _rm_maker(store, name1, tags1)
1341      _rm_maker(store, name2, tags2)
1342  
1343      rms, _ = _search_registered_models(store, "tag.t3 = 'XYZ'")
1344      assert rms == [name2]
1345  
1346      rms, _ = _search_registered_models(store, f"name = '{name1}' and tag.t1 = 'abc'")
1347      assert rms == [name1]
1348  
1349      rms, _ = _search_registered_models(store, "tag.t1 LIKE 'ab%'")
1350      assert rms == [name1, name2]
1351  
1352      rms, _ = _search_registered_models(store, "tag.t1 ILIKE 'aB%'")
1353      assert rms == [name1, name2]
1354  
1355      rms, _ = _search_registered_models(store, "tag.t1 LIKE 'ab%' AND tag.t2 LIKE 'xy%'")
1356      assert rms == [name1, name2]
1357  
1358      rms, _ = _search_registered_models(store, "tag.t3 = 'XYz'")
1359      assert rms == []
1360  
1361      rms, _ = _search_registered_models(store, "tag.T3 = 'XYZ'")
1362      assert rms == []
1363  
1364      rms, _ = _search_registered_models(store, "tag.t1 != 'abc'")
1365      assert rms == [name2]
1366  
1367      # test filter with duplicated keys
1368      rms, _ = _search_registered_models(store, "tag.t1 != 'abcd' and tag.t1 LIKE 'ab%'")
1369      assert rms == [name1]
1370  
1371  
1372  def test_parse_search_registered_models_order_by():
1373      # test that "registered_models.name ASC" is returned by default
1374      parsed = SqlAlchemyStore._parse_search_registered_models_order_by([])
1375      assert [str(x) for x in parsed] == ["registered_models.name ASC"]
1376  
1377      # test that the given 'name' replaces the default one ('registered_models.name ASC')
1378      parsed = SqlAlchemyStore._parse_search_registered_models_order_by(["name DESC"])
1379      assert [str(x) for x in parsed] == ["registered_models.name DESC"]
1380  
1381      # test that an exception is raised when order_by contains duplicate fields
1382      msg = "`order_by` contains duplicate fields:"
1383      with pytest.raises(MlflowException, match=msg):
1384          SqlAlchemyStore._parse_search_registered_models_order_by([
1385              "last_updated_timestamp",
1386              "last_updated_timestamp",
1387          ])
1388  
1389      with pytest.raises(MlflowException, match=msg):
1390          SqlAlchemyStore._parse_search_registered_models_order_by(["timestamp", "timestamp"])
1391  
1392      with pytest.raises(MlflowException, match=msg):
1393          SqlAlchemyStore._parse_search_registered_models_order_by(
1394              ["timestamp", "last_updated_timestamp"],
1395          )
1396  
1397      with pytest.raises(MlflowException, match=msg):
1398          SqlAlchemyStore._parse_search_registered_models_order_by(
1399              ["last_updated_timestamp ASC", "last_updated_timestamp DESC"],
1400          )
1401  
1402      with pytest.raises(MlflowException, match=msg):
1403          SqlAlchemyStore._parse_search_registered_models_order_by(
1404              ["last_updated_timestamp", "last_updated_timestamp DESC"],
1405          )
1406  
1407  
1408  def test_search_registered_model_pagination(store):
1409      rms = [_rm_maker(store, f"RM{i:03}").name for i in range(50)]
1410  
1411      # test flow with fixed max_results
1412      returned_rms = []
1413      query = "name LIKE 'RM%'"
1414      result, token = _search_registered_models(store, query, page_token=None, max_results=5)
1415      returned_rms.extend(result)
1416      while token:
1417          result, token = _search_registered_models(store, query, page_token=token, max_results=5)
1418          returned_rms.extend(result)
1419      assert rms == returned_rms
1420  
1421      # test that pagination will return all valid results in sorted order
1422      # by name ascending
1423      result, token1 = _search_registered_models(store, query, max_results=5)
1424      assert token1 is not None
1425      assert result == rms[0:5]
1426  
1427      result, token2 = _search_registered_models(store, query, page_token=token1, max_results=10)
1428      assert token2 is not None
1429      assert result == rms[5:15]
1430  
1431      result, token3 = _search_registered_models(store, query, page_token=token2, max_results=20)
1432      assert token3 is not None
1433      assert result == rms[15:35]
1434  
1435      result, token4 = _search_registered_models(store, query, page_token=token3, max_results=100)
1436      # assert that page token is None
1437      assert token4 is None
1438      assert result == rms[35:]
1439  
1440      # test that providing a completely invalid page token throws
1441      with pytest.raises(
1442          MlflowException, match=r"Invalid page token, could not base64-decode"
1443      ) as exception_context:
1444          _search_registered_models(store, query, page_token="evilhax", max_results=20)
1445      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1446  
1447      # test that providing too large of a max_results throws
1448      with pytest.raises(
1449          MlflowException, match=r"Invalid value for request parameter max_results"
1450      ) as exception_context:
1451          _search_registered_models(store, query, page_token="evilhax", max_results=1e15)
1452      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1453  
1454  
1455  def test_search_registered_model_order_by(store):
1456      rms = []
1457      # explicitly mock the creation_timestamps because timestamps seem to be unstable in Windows
1458      for i in range(50):
1459          with mock.patch(
1460              "mlflow.store.model_registry.sqlalchemy_store.get_current_time_millis",
1461              return_value=i,
1462          ):
1463              rms.append(_rm_maker(store, f"RM{i:03}", _add_go_test_tags([], f"{i}")).name)
1464  
1465      # test flow with fixed max_results and order_by (test stable order across pages)
1466      returned_rms = []
1467      query = "name LIKE 'RM%'"
1468      result, token = _search_registered_models(
1469          store, query, page_token=None, order_by=["name DESC"], max_results=5
1470      )
1471      returned_rms.extend(result)
1472      while token:
1473          result, token = _search_registered_models(
1474              store, query, page_token=token, order_by=["name DESC"], max_results=5
1475          )
1476          returned_rms.extend(result)
1477      # name descending should be the opposite order of the current order
1478      assert rms[::-1] == returned_rms
1479      # last_updated_timestamp descending should have the newest RMs first
1480      result, _ = _search_registered_models(
1481          store,
1482          query,
1483          page_token=None,
1484          order_by=["last_updated_timestamp DESC"],
1485          max_results=100,
1486      )
1487      assert rms[::-1] == result
1488      # timestamp returns same result as last_updated_timestamp
1489      result, _ = _search_registered_models(
1490          store, query, page_token=None, order_by=["timestamp DESC"], max_results=100
1491      )
1492      assert rms[::-1] == result
1493      # last_updated_timestamp ascending should have the oldest RMs first
1494      result, _ = _search_registered_models(
1495          store,
1496          query,
1497          page_token=None,
1498          order_by=["last_updated_timestamp ASC"],
1499          max_results=100,
1500      )
1501      assert rms == result
1502      # timestamp returns same result as last_updated_timestamp
1503      result, _ = _search_registered_models(
1504          store, query, page_token=None, order_by=["timestamp ASC"], max_results=100
1505      )
1506      assert rms == result
1507      # timestamp returns same result as last_updated_timestamp
1508      result, _ = _search_registered_models(
1509          store, query, page_token=None, order_by=["timestamp"], max_results=100
1510      )
1511      assert rms == result
1512      # name ascending should have the original order
1513      result, _ = _search_registered_models(
1514          store, query, page_token=None, order_by=["name ASC"], max_results=100
1515      )
1516      assert rms == result
1517      # test that no ASC/DESC defaults to ASC
1518      result, _ = _search_registered_models(
1519          store,
1520          query,
1521          page_token=None,
1522          order_by=["last_updated_timestamp"],
1523          max_results=100,
1524      )
1525      assert rms == result
1526      with mock.patch(
1527          "mlflow.store.model_registry.sqlalchemy_store.get_current_time_millis",
1528          return_value=1,
1529      ):
1530          rm1 = _rm_maker(store, "MR1", _add_go_test_tags([], "1")).name
1531          rm2 = _rm_maker(store, "MR2", _add_go_test_tags([], "1")).name
1532      with mock.patch(
1533          "mlflow.store.model_registry.sqlalchemy_store.get_current_time_millis",
1534          return_value=2,
1535      ):
1536          rm3 = _rm_maker(store, "MR3", _add_go_test_tags([], "2")).name
1537          rm4 = _rm_maker(store, "MR4", _add_go_test_tags([], "2")).name
1538      query = "name LIKE 'MR%'"
1539      # test with multiple clauses
1540      result, _ = _search_registered_models(
1541          store,
1542          query,
1543          page_token=None,
1544          order_by=["last_updated_timestamp ASC", "name DESC"],
1545          max_results=100,
1546      )
1547      assert result == [rm2, rm1, rm4, rm3]
1548      result, _ = _search_registered_models(
1549          store,
1550          query,
1551          page_token=None,
1552          order_by=["timestamp ASC", "name   DESC"],
1553          max_results=100,
1554      )
1555      assert result == [rm2, rm1, rm4, rm3]
1556      # confirm that name ascending is the default, even if ties exist on other fields
1557      result, _ = _search_registered_models(
1558          store, query, page_token=None, order_by=[], max_results=100
1559      )
1560      assert result == [rm1, rm2, rm3, rm4]
1561      # test default tiebreak with descending timestamps
1562      result, _ = _search_registered_models(
1563          store,
1564          query,
1565          page_token=None,
1566          order_by=["last_updated_timestamp DESC"],
1567          max_results=100,
1568      )
1569      assert result == [rm3, rm4, rm1, rm2]
1570      # test timestamp parsing
1571      result, _ = _search_registered_models(
1572          store, query, page_token=None, order_by=["timestamp\tASC"], max_results=100
1573      )
1574      assert result == [rm1, rm2, rm3, rm4]
1575      result, _ = _search_registered_models(
1576          store, query, page_token=None, order_by=["timestamp\r\rASC"], max_results=100
1577      )
1578      assert result == [rm1, rm2, rm3, rm4]
1579      result, _ = _search_registered_models(
1580          store, query, page_token=None, order_by=["timestamp\nASC"], max_results=100
1581      )
1582      assert result == [rm1, rm2, rm3, rm4]
1583      result, _ = _search_registered_models(
1584          store, query, page_token=None, order_by=["timestamp  ASC"], max_results=100
1585      )
1586      assert result == [rm1, rm2, rm3, rm4]
1587      # validate order by key is case-insensitive
1588      result, _ = _search_registered_models(
1589          store, query, page_token=None, order_by=["timestamp  asc"], max_results=100
1590      )
1591      assert result == [rm1, rm2, rm3, rm4]
1592      result, _ = _search_registered_models(
1593          store, query, page_token=None, order_by=["timestamp  aSC"], max_results=100
1594      )
1595      assert result == [rm1, rm2, rm3, rm4]
1596      result, _ = _search_registered_models(
1597          store,
1598          query,
1599          page_token=None,
1600          order_by=["timestamp  desc", "name desc"],
1601          max_results=100,
1602      )
1603      assert result == [rm4, rm3, rm2, rm1]
1604      result, _ = _search_registered_models(
1605          store,
1606          query,
1607          page_token=None,
1608          order_by=["timestamp  deSc", "name deSc"],
1609          max_results=100,
1610      )
1611      assert result == [rm4, rm3, rm2, rm1]
1612  
1613  
1614  def test_search_registered_model_order_by_errors(store):
1615      query = "name LIKE 'RM%'"
1616      # test that invalid columns throw even if they come after valid columns
1617      with pytest.raises(
1618          MlflowException, match=r"Invalid order by key '.+' specified"
1619      ) as exception_context:
1620          _search_registered_models(
1621              store,
1622              query,
1623              page_token=None,
1624              order_by=["name ASC", "creation_timestamp DESC"],
1625              max_results=5,
1626          )
1627      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1628      # test that invalid columns with random text throw even if they come after valid columns
1629      with pytest.raises(MlflowException, match=r"Invalid order_by clause '.+'") as exception_context:
1630          _search_registered_models(
1631              store,
1632              query,
1633              page_token=None,
1634              order_by=["name ASC", "last_updated_timestamp DESC blah"],
1635              max_results=5,
1636          )
1637      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1638  
1639  
1640  def test_set_model_version_tag(store):
1641      name1 = "SetModelVersionTag_TestMod"
1642      name2 = "SetModelVersionTag_TestMod 2"
1643      initial_tags = [
1644          ModelVersionTag("key", "value"),
1645          ModelVersionTag("anotherKey", "some other value"),
1646      ]
1647      _rm_maker(store, name1)
1648      _rm_maker(store, name2)
1649      run_id_1 = uuid.uuid4().hex
1650      run_id_2 = uuid.uuid4().hex
1651      run_id_3 = uuid.uuid4().hex
1652      _mv_maker(store, name1, "A/B", run_id_1, initial_tags)
1653      _mv_maker(store, name1, "A/C", run_id_2, initial_tags)
1654      _mv_maker(store, name2, "A/D", run_id_3, initial_tags)
1655      new_tag = ModelVersionTag("randomTag", "not a random value")
1656      store.set_model_version_tag(name1, 1, new_tag)
1657      all_tags = initial_tags + [new_tag]
1658      rm1mv1 = store.get_model_version(name1, 1)
1659      assert rm1mv1.tags == {tag.key: tag.value for tag in all_tags}
1660  
1661      # test overriding a tag with the same key
1662      overriding_tag = ModelVersionTag("key", "overriding")
1663      store.set_model_version_tag(name1, 1, overriding_tag)
1664      all_tags = [tag for tag in all_tags if tag.key != "key"] + [overriding_tag]
1665      rm1mv1 = store.get_model_version(name1, 1)
1666      assert rm1mv1.tags == {tag.key: tag.value for tag in all_tags}
1667      # does not affect other model versions with the same key
1668      rm1mv2 = store.get_model_version(name1, 2)
1669      rm2mv1 = store.get_model_version(name2, 1)
1670      assert rm1mv2.tags == {tag.key: tag.value for tag in initial_tags}
1671      assert rm2mv1.tags == {tag.key: tag.value for tag in initial_tags}
1672  
1673      # can not set tag on deleted (non-existed) model version
1674      store.delete_model_version(name1, 2)
1675      with pytest.raises(
1676          MlflowException, match=rf"Model Version \(name={name1}, version=2\) not found"
1677      ) as exception_context:
1678          store.set_model_version_tag(name1, 2, overriding_tag)
1679      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
1680      # test cannot set tags that are too long
1681      long_tag = ModelVersionTag("longTagKey", "a" * 100_001)
1682      with pytest.raises(
1683          MlflowException,
1684          match=r"'value' exceeds the maximum length of \d+ characters",
1685      ) as exception_context:
1686          store.set_model_version_tag(name1, 1, long_tag)
1687      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1688      # test can set tags that are somewhat long
1689      long_tag = ModelVersionTag("longTagKey", "a" * 4999)
1690      store.set_model_version_tag(name1, 1, long_tag)
1691      # can not set invalid tag
1692      with pytest.raises(
1693          MlflowException, match=r"Missing value for required parameter 'key'"
1694      ) as exception_context:
1695          store.set_model_version_tag(name2, 1, ModelVersionTag(key=None, value=""))
1696      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1697      # can not use invalid model name or version
1698      with pytest.raises(
1699          MlflowException, match=r"Missing value for required parameter 'name'"
1700      ) as exception_context:
1701          store.set_model_version_tag(None, 1, ModelVersionTag(key="key", value="value"))
1702      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1703      with pytest.raises(
1704          MlflowException, match=r"Parameter 'version' must be an integer, got 'I am not a version'"
1705      ) as exception_context:
1706          store.set_model_version_tag(
1707              name2, "I am not a version", ModelVersionTag(key="key", value="value")
1708          )
1709      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1710  
1711  
1712  def test_delete_model_version_tag(store):
1713      name1 = "DeleteModelVersionTag_TestMod"
1714      name2 = "DeleteModelVersionTag_TestMod 2"
1715      initial_tags = [
1716          ModelVersionTag("key", "value"),
1717          ModelVersionTag("anotherKey", "some other value"),
1718      ]
1719      _rm_maker(store, name1)
1720      _rm_maker(store, name2)
1721      run_id_1 = uuid.uuid4().hex
1722      run_id_2 = uuid.uuid4().hex
1723      run_id_3 = uuid.uuid4().hex
1724      _mv_maker(store, name1, "A/B", run_id_1, initial_tags)
1725      _mv_maker(store, name1, "A/C", run_id_2, initial_tags)
1726      _mv_maker(store, name2, "A/D", run_id_3, initial_tags)
1727      new_tag = ModelVersionTag("randomTag", "not a random value")
1728      store.set_model_version_tag(name1, 1, new_tag)
1729      store.delete_model_version_tag(name1, 1, "randomTag")
1730      rm1mv1 = store.get_model_version(name1, 1)
1731      assert rm1mv1.tags == {tag.key: tag.value for tag in initial_tags}
1732  
1733      # testing deleting a key does not affect other model versions with the same key
1734      store.delete_model_version_tag(name1, 1, "key")
1735      rm1mv1 = store.get_model_version(name1, 1)
1736      rm1mv2 = store.get_model_version(name1, 2)
1737      rm2mv1 = store.get_model_version(name2, 1)
1738      assert rm1mv1.tags == {"anotherKey": "some other value"}
1739      assert rm1mv2.tags == {tag.key: tag.value for tag in initial_tags}
1740      assert rm2mv1.tags == {tag.key: tag.value for tag in initial_tags}
1741  
1742      # delete tag that is already deleted does nothing
1743      store.delete_model_version_tag(name1, 1, "key")
1744      rm1mv1 = store.get_model_version(name1, 1)
1745      assert rm1mv1.tags == {"anotherKey": "some other value"}
1746  
1747      # can not delete tag on deleted (non-existed) model version
1748      store.delete_model_version(name2, 1)
1749      with pytest.raises(
1750          MlflowException, match=rf"Model Version \(name={name2}, version=1\) not found"
1751      ) as exception_context:
1752          store.delete_model_version_tag(name2, 1, "key")
1753      assert exception_context.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
1754      # can not delete tag with invalid key
1755      with pytest.raises(
1756          MlflowException, match=r"Missing value for required parameter 'key'"
1757      ) as exception_context:
1758          store.delete_model_version_tag(name1, 2, None)
1759      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1760      # can not use invalid model name or version
1761      with pytest.raises(
1762          MlflowException, match=r"Missing value for required parameter 'name'."
1763      ) as exception_context:
1764          store.delete_model_version_tag(None, 2, "key")
1765      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1766      with pytest.raises(
1767          MlflowException, match=r"Parameter 'version' must be an integer, got 'I am not a version'"
1768      ) as exception_context:
1769          store.delete_model_version_tag(name1, "I am not a version", "key")
1770      assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1771  
1772  
1773  def _setup_and_test_aliases(store, model_name):
1774      store.create_registered_model(model_name)
1775      run_id_1 = uuid.uuid4().hex
1776      run_id_2 = uuid.uuid4().hex
1777      store.create_model_version(model_name, "v1", run_id_1)
1778      store.create_model_version(model_name, "v2", run_id_2)
1779      store.set_registered_model_alias(model_name, "test_alias", "2")
1780      model = store.get_registered_model(model_name)
1781      assert model.aliases == {"test_alias": 2}
1782      mv1 = store.get_model_version(model_name, 1)
1783      mv2 = store.get_model_version(model_name, 2)
1784      assert mv1.aliases == []
1785      assert mv2.aliases == ["test_alias"]
1786  
1787  
1788  def test_set_registered_model_alias(store):
1789      _setup_and_test_aliases(store, "SetRegisteredModelAlias_TestMod")
1790  
1791  
1792  def test_delete_registered_model_alias(store):
1793      model_name = "DeleteRegisteredModelAlias_TestMod"
1794      _setup_and_test_aliases(store, model_name)
1795      store.delete_registered_model_alias(model_name, "test_alias")
1796      model = store.get_registered_model(model_name)
1797      assert model.aliases == {}
1798      mv2 = store.get_model_version(model_name, 2)
1799      assert mv2.aliases == []
1800  
1801  
1802  def test_get_model_version_by_alias(store):
1803      model_name = "GetModelVersionByAlias_TestMod"
1804      _setup_and_test_aliases(store, model_name)
1805      mv = store.get_model_version_by_alias(model_name, "test_alias")
1806      assert mv.aliases == ["test_alias"]
1807  
1808  
1809  def test_delete_model_version_deletes_alias(store):
1810      model_name = "DeleteModelVersionDeletesAlias_TestMod"
1811      _setup_and_test_aliases(store, model_name)
1812      store.delete_model_version(model_name, 2)
1813      model = store.get_registered_model(model_name)
1814      assert model.aliases == {}
1815      with pytest.raises(
1816          MlflowException,
1817          match=r"Registered model alias test_alias not found.",
1818      ):
1819          store.get_model_version_by_alias(model_name, "test_alias")
1820  
1821  
1822  def test_delete_model_deletes_alias(store):
1823      model_name = "DeleteModelDeletesAlias_TestMod"
1824      _setup_and_test_aliases(store, model_name)
1825      store.delete_registered_model(model_name)
1826      with pytest.raises(
1827          MlflowException,
1828          match=rf"Registered Model with name={model_name} not found",
1829      ):
1830          store.get_model_version_by_alias(model_name, "test_alias")
1831  
1832  
1833  @pytest.mark.parametrize("copy_to_same_model", [False, True])
1834  def test_copy_model_version(store, copy_to_same_model):
1835      name1 = "test_for_copy_MV1"
1836      store.create_registered_model(name1)
1837      src_tags = [
1838          ModelVersionTag("key", "value"),
1839          ModelVersionTag("anotherKey", "some other value"),
1840      ]
1841      src_mv = _mv_maker(
1842          store,
1843          name1,
1844          tags=src_tags,
1845          run_link="dummylink",
1846          description="test description",
1847      )
1848  
1849      # Make some changes to the src MV that won't be copied over
1850      store.transition_model_version_stage(
1851          name1, src_mv.version, "Production", archive_existing_versions=False
1852      )
1853  
1854      copy_rm_name = name1 if copy_to_same_model else "test_for_copy_MV2"
1855      copy_mv_version = 2 if copy_to_same_model else 1
1856      timestamp = time.time()
1857      dst_mv = store.copy_model_version(src_mv, copy_rm_name)
1858      assert dst_mv.name == copy_rm_name
1859      assert dst_mv.version == copy_mv_version
1860  
1861      copied_mv = store.get_model_version(dst_mv.name, dst_mv.version)
1862      assert copied_mv.name == copy_rm_name
1863      assert int(copied_mv.version) == copy_mv_version
1864      assert copied_mv.current_stage == "None"
1865      assert copied_mv.creation_timestamp >= timestamp
1866      assert copied_mv.last_updated_timestamp >= timestamp
1867      assert copied_mv.description == "test description"
1868      assert copied_mv.source == f"models:/{src_mv.name}/{src_mv.version}"
1869      assert store.get_model_version_download_uri(dst_mv.name, dst_mv.version) == src_mv.source
1870      assert copied_mv.run_link == "dummylink"
1871      assert copied_mv.run_id == src_mv.run_id
1872      assert copied_mv.status == "READY"
1873      assert copied_mv.status_message is None
1874      assert copied_mv.tags == {"key": "value", "anotherKey": "some other value"}
1875  
1876      # Copy a model version copy
1877      double_copy_mv = store.copy_model_version(copied_mv, "test_for_copy_MV3")
1878      assert double_copy_mv.source == f"models:/{copied_mv.name}/{copied_mv.version}"
1879      assert store.get_model_version_download_uri(dst_mv.name, dst_mv.version) == src_mv.source
1880  
1881  
1882  def test_search_prompts(store):
1883      store.create_registered_model("model", tags=[RegisteredModelTag(key="fruit", value="apple")])
1884  
1885      store.create_registered_model(
1886          "prompt_1", tags=[RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")]
1887      )
1888      store.create_registered_model(
1889          "prompt_2",
1890          tags=[
1891              RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true"),
1892              RegisteredModelTag(key="fruit", value="apple"),
1893          ],
1894      )
1895  
1896      # By default, should not return prompts
1897      rms = store.search_registered_models(max_results=10)
1898      assert len(rms) == 1
1899      assert rms[0].name == "model"
1900  
1901      rms = store.search_registered_models(filter_string="tags.fruit = 'apple'", max_results=10)
1902      assert len(rms) == 1
1903      assert rms[0].name == "model"
1904  
1905      rms = store.search_registered_models(filter_string="name = 'prompt_1'", max_results=10)
1906      assert len(rms) == 0
1907  
1908      rms = store.search_registered_models(
1909          filter_string="tags.`mlflow.prompt.is_prompt` = 'false'", max_results=10
1910      )
1911      assert len(rms) == 1
1912      assert rms[0].name == "model"
1913  
1914      rms = store.search_registered_models(
1915          filter_string="tags.`mlflow.prompt.is_prompt` != 'true'", max_results=10
1916      )
1917      assert len(rms) == 1
1918      assert rms[0].name == "model"
1919  
1920      # Search for prompts
1921      rms = store.search_registered_models(
1922          filter_string="tags.`mlflow.prompt.is_prompt` = 'true'", max_results=10
1923      )
1924      assert len(rms) == 2
1925      assert {rm.name for rm in rms} == {"prompt_1", "prompt_2"}
1926  
1927      rms = store.search_registered_models(
1928          filter_string="name = 'prompt_1' and tags.`mlflow.prompt.is_prompt` = 'true'",
1929          max_results=10,
1930      )
1931      assert len(rms) == 1
1932      assert rms[0].name == "prompt_1"
1933  
1934      rms = store.search_registered_models(
1935          filter_string="tags.`mlflow.prompt.is_prompt` = 'true' and tags.fruit = 'apple'",
1936          max_results=10,
1937      )
1938      assert len(rms) == 1
1939      assert rms[0].name == "prompt_2"
1940  
1941  
1942  def test_search_prompts_versions(store):
1943      # A Model
1944      store.create_registered_model("model")
1945      store.create_model_version(
1946          "model", "1", "dummy_source", tags=[ModelVersionTag(key="fruit", value="apple")]
1947      )
1948  
1949      # A Prompt with 1 version
1950      store.create_registered_model(
1951          "prompt_1", tags=[RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")]
1952      )
1953      store.create_model_version(
1954          "prompt_1", "1", "dummy_source", tags=[ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true")]
1955      )
1956  
1957      # A Prompt with 2 versions
1958      store.create_registered_model(
1959          "prompt_2",
1960          tags=[RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")],
1961      )
1962      store.create_model_version(
1963          "prompt_2",
1964          "1",
1965          "dummy_source",
1966          tags=[
1967              ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true"),
1968              ModelVersionTag(key="fruit", value="apple"),
1969          ],
1970      )
1971      store.create_model_version(
1972          "prompt_2",
1973          "2",
1974          "dummy_source",
1975          tags=[
1976              ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true"),
1977              ModelVersionTag(key="fruit", value="orange"),
1978          ],
1979      )
1980  
1981      # Searching model versions should not return prompts by default either
1982      mvs = store.search_model_versions(max_results=10)
1983      assert len(mvs) == 1
1984      assert mvs[0].name == "model"
1985  
1986      mvs = store.search_model_versions(filter_string="tags.fruit = 'apple'", max_results=10)
1987      assert len(mvs) == 1
1988      assert mvs[0].name == "model"
1989  
1990      mvs = store.search_model_versions(
1991          filter_string="tags.`mlflow.prompt.is_prompt` = 'false'", max_results=10
1992      )
1993      assert len(mvs) == 1
1994      assert mvs[0].name == "model"
1995  
1996      mvs = store.search_model_versions(
1997          filter_string="tags.`mlflow.prompt.is_prompt` != 'true'", max_results=10
1998      )
1999      assert len(mvs) == 1
2000      assert mvs[0].name == "model"
2001  
2002      # Search for prompts via search_model_versions
2003      mvs = store.search_model_versions(
2004          filter_string="tags.`mlflow.prompt.is_prompt` = 'true'", max_results=10
2005      )
2006      assert len(mvs) == 3
2007  
2008      mvs = store.search_model_versions(
2009          filter_string="tags.`mlflow.prompt.is_prompt` = 'true' and name = 'prompt_2'",
2010          max_results=10,
2011      )
2012      assert len(mvs) == 2
2013  
2014      mvs = store.search_model_versions(
2015          filter_string="tags.`mlflow.prompt.is_prompt` = 'true' and tags.fruit = 'apple'",
2016          max_results=10,
2017      )
2018      assert len(mvs) == 1
2019      assert mvs[0].name == "prompt_2"
2020  
2021  
2022  def test_search_prompt_versions(store):
2023      # Create a prompt with 3 versions
2024      store.create_registered_model(
2025          "my_prompt", tags=[RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")]
2026      )
2027      for i in range(1, 4):
2028          store.create_model_version(
2029              "my_prompt",
2030              str(i),
2031              "dummy_source",
2032              tags=[
2033                  ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true"),
2034                  ModelVersionTag(key=PROMPT_TEXT_TAG_KEY, value=f"Hello {{{{name}}}} v{i}"),
2035              ],
2036          )
2037  
2038      # Create a different prompt to verify filtering by name
2039      store.create_registered_model(
2040          "other_prompt", tags=[RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")]
2041      )
2042      store.create_model_version(
2043          "other_prompt",
2044          "1",
2045          "dummy_source",
2046          tags=[
2047              ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true"),
2048              ModelVersionTag(key=PROMPT_TEXT_TAG_KEY, value="Other prompt text"),
2049          ],
2050      )
2051  
2052      # Search all versions of my_prompt
2053      results = store.search_prompt_versions("my_prompt")
2054      assert len(results) == 3
2055      # Should be ordered by version descending
2056      assert [pv.version for pv in results] == [3, 2, 1]
2057      assert all(pv.name == "my_prompt" for pv in results)
2058  
2059      # Pagination with max_results
2060      page1 = store.search_prompt_versions("my_prompt", max_results=2)
2061      assert len(page1) == 2
2062      assert [pv.version for pv in page1] == [3, 2]
2063      assert page1.token is not None
2064  
2065      # Fetch next page
2066      page2 = store.search_prompt_versions("my_prompt", max_results=2, page_token=page1.token)
2067      assert len(page2) == 1
2068      assert page2[0].version == 1
2069      assert page2.token is None
2070  
2071      # Search other_prompt returns only its versions
2072      results = store.search_prompt_versions("other_prompt")
2073      assert len(results) == 1
2074      assert results[0].name == "other_prompt"
2075  
2076      # Searching a non-existent prompt raises
2077      with pytest.raises(MlflowException, match="not found"):
2078          store.search_prompt_versions("nonexistent_prompt")
2079  
2080      # Searching a model (not a prompt) raises
2081      store.create_registered_model("a_model")
2082      with pytest.raises(MlflowException, match="registered as a model, not a prompt"):
2083          store.search_prompt_versions("a_model")
2084  
2085  
2086  def test_create_registered_model_handle_prompt_properly(store):
2087      prompt_tags = [RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")]
2088  
2089      store.create_registered_model("model")
2090  
2091      store.create_registered_model("prompt", tags=prompt_tags)
2092  
2093      with pytest.raises(MlflowException, match=r"Registered Model \(name=model\) already exists"):
2094          store.create_registered_model("model")
2095  
2096      with pytest.raises(MlflowException, match=r"Prompt \(name=prompt\) already exists"):
2097          store.create_registered_model("prompt", tags=prompt_tags)
2098  
2099      with pytest.raises(
2100          MlflowException,
2101          match=r"Tried to create a prompt with name 'model', "
2102          r"but the name is already taken by a registered model.",
2103      ):
2104          store.create_registered_model("model", tags=prompt_tags)
2105  
2106      with pytest.raises(
2107          MlflowException,
2108          match=r"Tried to create a registered model with name 'prompt', "
2109          r"but the name is already taken by a prompt.",
2110      ):
2111          store.create_registered_model("prompt")
2112  
2113  
2114  def test_create_webhook(store):
2115      events = [
2116          WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED),
2117          WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED),
2118      ]
2119      webhook = store.create_webhook(
2120          name="test_webhook",
2121          url="https://example.com/webhook",
2122          events=events,
2123          description="Test webhook",
2124          secret="secret123",
2125          status=WebhookStatus.ACTIVE,
2126      )
2127  
2128      assert webhook.name == "test_webhook"
2129      assert webhook.url == "https://example.com/webhook"
2130      assert webhook.events == events
2131      assert webhook.description == "Test webhook"
2132      assert webhook.status == WebhookStatus.ACTIVE
2133      assert webhook.webhook_id is not None
2134      assert webhook.creation_timestamp is not None
2135      assert webhook.last_updated_timestamp is not None
2136      assert webhook.secret == "secret123"
2137  
2138  
2139  # Shared test data for invalid webhook names
2140  INVALID_WEBHOOK_NAMES = [
2141      ("", r"is invalid"),
2142      ("   ", r"is invalid"),
2143      ("webhook<script>", r"is invalid"),
2144      ("webhook@test", r"is invalid"),
2145      ("webhook#hash", r"is invalid"),
2146      ("webhook/slash", r"is invalid"),
2147      ("webhook\\backslash", r"is invalid"),
2148      ("-webhook", r"is invalid"),  # Must start with letter or digit
2149      ("webhook-", r"is invalid"),  # Must end with letter or digit
2150      ("_webhook", r"is invalid"),  # Must start with letter or digit
2151      ("webhook_", r"is invalid"),  # Must end with letter or digit
2152      (".webhook", r"is invalid"),  # Must start with letter or digit
2153      ("webhook.", r"is invalid"),  # Must end with letter or digit
2154      ("a" * 64, r"is invalid"),  # Too long (max 63 chars)
2155  ]
2156  
2157  
2158  # Shared test data for valid webhook names
2159  VALID_WEBHOOK_NAMES = [
2160      "a",  # Single character letter
2161      "1",  # Single character digit
2162      "a1",  # Two characters
2163      "1a",  # Start with digit, end with letter
2164      "webhook123",  # Alphanumeric
2165      "web_hook",  # With underscore
2166      "web-hook",  # With hyphen
2167      "web.hook",  # With dot
2168      "web_hook-123.test",  # Mixed special chars
2169      "A" * 63,  # Maximum length
2170      "1" + "a" * 61 + "1",  # Maximum length with digit start/end
2171      "WebHook123",  # Mixed case
2172  ]
2173  
2174  
2175  @pytest.mark.parametrize(("invalid_name", "expected_match"), INVALID_WEBHOOK_NAMES)
2176  def test_create_webhook_invalid_names(store, invalid_name, expected_match):
2177      with pytest.raises(MlflowException, match=expected_match):
2178          store.create_webhook(
2179              name=invalid_name,
2180              url="https://example.com",
2181              events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)],
2182          )
2183  
2184  
2185  @pytest.mark.parametrize("valid_name", VALID_WEBHOOK_NAMES)
2186  def test_create_webhook_valid_names(store, valid_name):
2187      webhook = store.create_webhook(
2188          name=valid_name,
2189          url="https://example.com",
2190          events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)],
2191      )
2192      assert webhook.name == valid_name
2193  
2194  
2195  @pytest.mark.parametrize(
2196      ("invalid_url", "expected_match"),
2197      [
2198          ("", r"Webhook URL cannot be empty or just whitespace"),
2199          ("   ", r"Webhook URL cannot be empty or just whitespace"),
2200          ("example.com/webhook", r"Invalid webhook URL"),
2201          ("ftp://example.com/webhook", r"Invalid webhook URL scheme"),
2202          ("http://[invalid-url", r"Invalid webhook URL"),
2203          ("invalid_url", r"Invalid webhook URL"),
2204      ],
2205  )
2206  def test_create_webhook_invalid_urls(store, invalid_url, expected_match):
2207      with pytest.raises(MlflowException, match=expected_match):
2208          store.create_webhook(
2209              name="test",
2210              url=invalid_url,
2211              events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)],
2212          )
2213  
2214  
2215  def test_create_webhook_invalid_events(store):
2216      # Test empty events
2217      with pytest.raises(MlflowException, match="Webhook events must be a non-empty list"):
2218          store.create_webhook(name="test", url="https://example.com", events=[])
2219  
2220      # Test non-list events
2221      with pytest.raises(MlflowException, match="Webhook events must be a non-empty list"):
2222          store.create_webhook(name="test", url="https://example.com", events=())
2223  
2224      # Test list with non-WebhookEvent items
2225      with pytest.raises(MlflowException, match="Webhook events must be a non-empty list"):
2226          store.create_webhook(name="test", url="https://example.com", events=[1, 2, 3])
2227  
2228  
2229  def test_get_webhook(store):
2230      events = [WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)]
2231      created_webhook = store.create_webhook(
2232          name="test_webhook", url="https://example.com/webhook", events=events
2233      )
2234  
2235      retrieved_webhook = store.get_webhook(created_webhook.webhook_id)
2236  
2237      assert retrieved_webhook.webhook_id == created_webhook.webhook_id
2238      assert retrieved_webhook.name == "test_webhook"
2239      assert retrieved_webhook.url == "https://example.com/webhook"
2240      assert retrieved_webhook.events == events
2241  
2242  
2243  def test_get_webhook_not_found(store):
2244      with pytest.raises(MlflowException, match="Webhook with ID nonexistent not found"):
2245          store.get_webhook("nonexistent")
2246  
2247  
2248  def test_list_webhooks(store):
2249      # Create multiple webhooks
2250      webhook1 = store.create_webhook(
2251          name="webhook1",
2252          url="https://example.com/1",
2253          events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)],
2254      )
2255      webhook2 = store.create_webhook(
2256          name="webhook2",
2257          url="https://example.com/2",
2258          events=[WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED)],
2259      )
2260  
2261      webhooks_page = store.list_webhooks()
2262  
2263      assert len(webhooks_page) == 2
2264      assert webhooks_page.token is None
2265      webhook_ids = {w.webhook_id for w in webhooks_page}
2266      assert webhook1.webhook_id in webhook_ids
2267      assert webhook2.webhook_id in webhook_ids
2268  
2269  
2270  def test_list_webhooks_pagination(store):
2271      # Create more webhooks than max_results
2272      for i in range(5):
2273          store.create_webhook(
2274              name=f"webhook{i}",
2275              url=f"https://example.com/{i}",
2276              events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)],
2277          )
2278  
2279      # Test pagination with max_results=2
2280      webhooks_page = store.list_webhooks(max_results=2)
2281      assert len(webhooks_page) == 2
2282      assert webhooks_page.token is not None
2283  
2284      # Get next page
2285      next_webhooks_page = store.list_webhooks(max_results=2, page_token=webhooks_page.token)
2286      assert len(next_webhooks_page) == 2
2287      assert next_webhooks_page.token is not None
2288  
2289      # Verify we don't get duplicates
2290      first_page_ids = {w.webhook_id for w in webhooks_page}
2291      second_page_ids = {w.webhook_id for w in next_webhooks_page}
2292      assert first_page_ids.isdisjoint(second_page_ids)
2293  
2294  
2295  def test_list_webhooks_invalid_max_results(store):
2296      with pytest.raises(MlflowException, match="max_results must be between 1 and 1000"):
2297          store.list_webhooks(max_results=1001)
2298  
2299  
2300  def test_update_webhook(store):
2301      events = [WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)]
2302      webhook = store.create_webhook(
2303          name="original_name", url="https://example.com/original", events=events
2304      )
2305  
2306      # Update webhook
2307      new_events = [
2308          WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED),
2309          WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED),
2310      ]
2311      updated_webhook = store.update_webhook(
2312          webhook_id=webhook.webhook_id,
2313          name="updated_name",
2314          url="https://example.com/updated",
2315          events=new_events,
2316          description="Updated description",
2317          secret="new_secret",
2318          status=WebhookStatus.DISABLED,
2319      )
2320  
2321      assert updated_webhook.webhook_id == webhook.webhook_id
2322      assert updated_webhook.name == "updated_name"
2323      assert updated_webhook.url == "https://example.com/updated"
2324      assert updated_webhook.events == new_events
2325      assert updated_webhook.description == "Updated description"
2326      assert updated_webhook.status == WebhookStatus.DISABLED
2327      assert updated_webhook.last_updated_timestamp > webhook.last_updated_timestamp
2328  
2329  
2330  def test_update_webhook_partial(store):
2331      events = [WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)]
2332      webhook = store.create_webhook(
2333          name="original_name", url="https://example.com/original", events=events
2334      )
2335  
2336      # Update only name
2337      updated_webhook = store.update_webhook(webhook_id=webhook.webhook_id, name="new_name")
2338  
2339      assert updated_webhook.name == "new_name"
2340      assert updated_webhook.url == "https://example.com/original"  # Should remain unchanged
2341      assert updated_webhook.events == events  # Should remain unchanged
2342  
2343  
2344  def test_update_webhook_not_found(store):
2345      with pytest.raises(MlflowException, match="Webhook with ID nonexistent not found"):
2346          store.update_webhook(
2347              webhook_id="nonexistent", name="new_name", url="https://example.com/new"
2348          )
2349  
2350  
2351  def test_update_webhook_invalid_events(store):
2352      # Create a valid webhook first
2353      webhook = store.create_webhook(
2354          name="test_webhook",
2355          url="https://example.com/webhook",
2356          events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)],
2357      )
2358  
2359      with pytest.raises(MlflowException, match="Webhook events must be a non-empty list"):
2360          store.update_webhook(webhook_id=webhook.webhook_id, events=[])
2361  
2362      # Test non-list events
2363      with pytest.raises(MlflowException, match="Webhook events must be a non-empty list"):
2364          store.update_webhook(webhook_id=webhook.webhook_id, events=())
2365  
2366      # Test list with non-WebhookEvent items
2367      with pytest.raises(MlflowException, match="Webhook events must be a non-empty list"):
2368          store.update_webhook(webhook_id=webhook.webhook_id, events=[1, 2, 3])
2369  
2370  
2371  @pytest.mark.parametrize(("invalid_name", "expected_match"), INVALID_WEBHOOK_NAMES)
2372  def test_update_webhook_invalid_names(store, invalid_name, expected_match):
2373      # Create a valid webhook first
2374      webhook = store.create_webhook(
2375          name="test_webhook",
2376          url="https://example.com/webhook",
2377          events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)],
2378      )
2379  
2380      with pytest.raises(MlflowException, match=expected_match):
2381          store.update_webhook(webhook_id=webhook.webhook_id, name=invalid_name)
2382  
2383  
2384  @pytest.mark.parametrize(
2385      ("invalid_url", "expected_match"),
2386      [
2387          ("   ", r"Webhook URL cannot be empty or just whitespace"),
2388          ("ftp://example.com", r"Invalid webhook URL scheme"),
2389          ("http://[invalid", r"Invalid webhook URL"),
2390      ],
2391  )
2392  def test_update_webhook_invalid_urls(store, invalid_url, expected_match):
2393      # Create a valid webhook first
2394      webhook = store.create_webhook(
2395          name="test_webhook",
2396          url="https://example.com/webhook",
2397          events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)],
2398      )
2399  
2400      with pytest.raises(MlflowException, match=expected_match):
2401          store.update_webhook(webhook_id=webhook.webhook_id, url=invalid_url)
2402  
2403  
2404  def test_delete_webhook(store):
2405      events = [WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)]
2406      webhook = store.create_webhook(
2407          name="test_webhook",
2408          url="https://example.com/webhook",
2409          events=events,
2410      )
2411  
2412      store.delete_webhook(webhook.webhook_id)
2413  
2414      with pytest.raises(MlflowException, match=r"Webhook with ID .* not found"):
2415          store.get_webhook(webhook.webhook_id)
2416  
2417      webhooks_page = store.list_webhooks()
2418      webhook_ids = {w.webhook_id for w in webhooks_page}
2419      assert webhook.webhook_id not in webhook_ids
2420  
2421  
2422  def test_delete_webhook_not_found(store):
2423      with pytest.raises(MlflowException, match="Webhook with ID nonexistent not found"):
2424          store.delete_webhook("nonexistent")
2425  
2426  
2427  def test_webhook_status_transitions(store):
2428      events = [WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)]
2429  
2430      webhook = store.create_webhook(
2431          name="test_webhook",
2432          url="https://example.com/webhook",
2433          events=events,
2434          status=WebhookStatus.ACTIVE,
2435      )
2436      assert webhook.status == WebhookStatus.ACTIVE
2437  
2438      # Update to inactive
2439      updated_webhook = store.update_webhook(
2440          webhook_id=webhook.webhook_id, status=WebhookStatus.DISABLED
2441      )
2442      assert updated_webhook.status == WebhookStatus.DISABLED
2443  
2444      # Update back to active
2445      updated_webhook = store.update_webhook(
2446          webhook_id=webhook.webhook_id, status=WebhookStatus.ACTIVE
2447      )
2448      assert updated_webhook.status == WebhookStatus.ACTIVE
2449  
2450  
2451  def test_webhook_secret_encryption(store):
2452      store.create_webhook(
2453          name="test_webhook",
2454          url="https://example.com/webhook",
2455          events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)],
2456          secret="my_secret",
2457      )
2458      engine = create_engine(store.db_uri)
2459      with engine.connect() as conn:
2460          (raw_secret,) = conn.execute(text("SELECT secret FROM webhooks")).fetchone()
2461          assert raw_secret is not None
2462          assert raw_secret != "my_secret"  # Should be encrypted
2463  
2464  
2465  def test_create_model_version_with_model_id_and_no_run_id(store):
2466      name = "test_model_with_model_id"
2467      _rm_maker(store, name)
2468  
2469      mock_run_id = "mock-run-id-123"
2470      mock_logged_model = mock.MagicMock()
2471      mock_logged_model.source_run_id = mock_run_id
2472  
2473      with mock.patch(
2474          "mlflow.store.model_registry.sqlalchemy_store.MlflowClient"
2475      ) as mock_client_class:
2476          mock_client = mock.MagicMock()
2477          mock_client_class.return_value = mock_client
2478          mock_client.get_logged_model.return_value = mock_logged_model
2479  
2480          mv = store.create_model_version(
2481              name=name,
2482              source="path/to/source",
2483              run_id=None,
2484              model_id="test-model-id-456",
2485          )
2486  
2487          mock_client.get_logged_model.assert_called_once_with("test-model-id-456")
2488  
2489          assert mv.run_id == mock_run_id
2490  
2491          mvd = store.get_model_version(name=mv.name, version=mv.version)
2492          assert mvd.run_id == mock_run_id
2493  
2494  
2495  def test_create_model_version_concurrent(store):
2496      name = "test_concurrent_mv"
2497      _rm_maker(store, name)
2498  
2499      num_threads = 4
2500      versions_per_thread = 5
2501      results = []
2502  
2503      def create_versions():
2504          return [
2505              store.create_model_version(name, "path/to/source", uuid.uuid4().hex).version
2506              for _ in range(versions_per_thread)
2507          ]
2508  
2509      with concurrent.futures.ThreadPoolExecutor(
2510          max_workers=num_threads, thread_name_prefix="create_model_version"
2511      ) as executor:
2512          futures = [executor.submit(create_versions) for _ in range(num_threads)]
2513          for f in concurrent.futures.as_completed(futures):
2514              results.extend(f.result())
2515  
2516      # All versions should be unique
2517      assert len(results) == len(set(results))
2518      assert len(results) == num_threads * versions_per_thread