/ tests / store / model_registry / test_rest_store.py
test_rest_store.py
  1  import json
  2  import uuid
  3  from unittest import mock
  4  
  5  import pytest
  6  
  7  from mlflow.entities.model_registry import ModelVersion, ModelVersionTag, RegisteredModelTag
  8  from mlflow.entities.model_registry.model_version_status import ModelVersionStatus
  9  from mlflow.exceptions import MlflowException
 10  from mlflow.prompt.registry_utils import IS_PROMPT_TAG_KEY
 11  from mlflow.protos.model_registry_pb2 import (
 12      CreateModelVersion,
 13      CreateRegisteredModel,
 14      DeleteModelVersion,
 15      DeleteModelVersionTag,
 16      DeleteRegisteredModel,
 17      DeleteRegisteredModelAlias,
 18      DeleteRegisteredModelTag,
 19      GetLatestVersions,
 20      GetModelVersion,
 21      GetModelVersionByAlias,
 22      GetModelVersionDownloadUri,
 23      GetRegisteredModel,
 24      RenameRegisteredModel,
 25      SearchModelVersions,
 26      SearchRegisteredModels,
 27      SetModelVersionTag,
 28      SetRegisteredModelAlias,
 29      SetRegisteredModelTag,
 30      TransitionModelVersionStage,
 31      UpdateModelVersion,
 32      UpdateRegisteredModel,
 33  )
 34  from mlflow.store.model_registry.rest_store import RestStore
 35  from mlflow.utils.proto_json_utils import message_to_json
 36  from mlflow.utils.rest_utils import MlflowHostCreds
 37  from mlflow.utils.workspace_context import WorkspaceContext
 38  from mlflow.utils.workspace_utils import DEFAULT_WORKSPACE_NAME
 39  
 40  from tests.helper_functions import mock_http_request_200, mock_http_request_403_200
 41  
 42  
 43  @pytest.fixture(autouse=True, params=[False, True], ids=["workspace-disabled", "workspace-enabled"])
 44  def workspaces_enabled(request):
 45      """
 46      Run every test in this module with workspaces disabled and enabled to cover both code paths.
 47      """
 48  
 49      enabled = request.param
 50      if enabled:
 51          with (
 52              WorkspaceContext(DEFAULT_WORKSPACE_NAME),
 53              mock.patch(
 54                  "mlflow.store.workspace_rest_store_mixin.WorkspaceRestStoreMixin.supports_workspaces",
 55                  new_callable=mock.PropertyMock,
 56                  return_value=True,
 57              ),
 58          ):
 59              yield enabled
 60      else:
 61          yield enabled
 62  
 63  
 64  @pytest.fixture
 65  def creds():
 66      return MlflowHostCreds("https://hello")
 67  
 68  
 69  @pytest.fixture
 70  def store(creds):
 71      return RestStore(lambda: creds)
 72  
 73  
 74  def _args(host_creds, endpoint, method, json_body):
 75      res = {"host_creds": host_creds, "endpoint": f"/api/2.0/mlflow/{endpoint}", "method": method}
 76      if method == "GET":
 77          res["params"] = json.loads(json_body)
 78      else:
 79          res["json"] = json.loads(json_body)
 80      return res
 81  
 82  
 83  def _verify_requests(http_request, creds, endpoint, method, proto_message):
 84      json_body = message_to_json(proto_message)
 85      http_request.assert_any_call(**(_args(creds, endpoint, method, json_body)))
 86  
 87  
 88  def _verify_all_requests(http_request, creds, endpoints, proto_message):
 89      json_body = message_to_json(proto_message)
 90      http_request.assert_has_calls([
 91          mock.call(**(_args(creds, endpoint, method, json_body))) for endpoint, method in endpoints
 92      ])
 93  
 94  
 95  def test_create_registered_model(store, creds):
 96      tags = [
 97          RegisteredModelTag(key="key", value="value"),
 98          RegisteredModelTag(key="anotherKey", value="some other value"),
 99      ]
100      description = "best model ever"
101      with mock_http_request_200() as mock_http:
102          store.create_registered_model("model_1", tags, description)
103      _verify_requests(
104          mock_http,
105          creds,
106          "registered-models/create",
107          "POST",
108          CreateRegisteredModel(
109              name="model_1", tags=[tag.to_proto() for tag in tags], description=description
110          ),
111      )
112  
113  
114  def test_update_registered_model_name(store, creds):
115      name = "model_1"
116      new_name = "model_2"
117      with mock_http_request_200() as mock_http:
118          store.rename_registered_model(name=name, new_name=new_name)
119      _verify_requests(
120          mock_http,
121          creds,
122          "registered-models/rename",
123          "POST",
124          RenameRegisteredModel(name=name, new_name=new_name),
125      )
126  
127  
128  def test_update_registered_model_description(store, creds):
129      name = "model_1"
130      description = "test model"
131      with mock_http_request_200() as mock_http:
132          store.update_registered_model(name=name, description=description)
133      _verify_requests(
134          mock_http,
135          creds,
136          "registered-models/update",
137          "PATCH",
138          UpdateRegisteredModel(name=name, description=description),
139      )
140  
141  
142  def test_delete_registered_model(store, creds):
143      name = "model_1"
144      with mock_http_request_200() as mock_http:
145          store.delete_registered_model(name=name)
146      _verify_requests(
147          mock_http, creds, "registered-models/delete", "DELETE", DeleteRegisteredModel(name=name)
148      )
149  
150  
151  def test_search_registered_models(store, creds):
152      with mock_http_request_200() as mock_http:
153          store.search_registered_models()
154      _verify_requests(mock_http, creds, "registered-models/search", "GET", SearchRegisteredModels())
155  
156  
157  @pytest.mark.parametrize("filter_string", [None, "model = 'yo'"])
158  @pytest.mark.parametrize("max_results", [None, 400])
159  @pytest.mark.parametrize("page_token", [None, "blah"])
160  @pytest.mark.parametrize("order_by", [None, ["x", "Y"]])
161  def test_search_registered_models_params(
162      store, creds, filter_string, max_results, page_token, order_by
163  ):
164      params = {
165          "filter_string": filter_string,
166          "max_results": max_results,
167          "page_token": page_token,
168          "order_by": order_by,
169      }
170      params = {k: v for k, v in params.items() if v is not None}
171      with mock_http_request_200() as mock_http:
172          store.search_registered_models(**params)
173      if "filter_string" in params:
174          params["filter"] = params.pop("filter_string")
175      _verify_requests(
176          mock_http,
177          creds,
178          "registered-models/search",
179          "GET",
180          SearchRegisteredModels(**params),
181      )
182  
183  
184  def test_get_registered_model(store, creds):
185      name = "model_1"
186      with mock_http_request_200() as mock_http:
187          store.get_registered_model(name=name)
188      _verify_requests(
189          mock_http, creds, "registered-models/get", "GET", GetRegisteredModel(name=name)
190      )
191  
192  
193  def test_get_latest_versions(store, creds):
194      name = "model_1"
195      with mock_http_request_403_200() as mock_http:
196          store.get_latest_versions(name=name)
197      endpoint = "registered-models/get-latest-versions"
198      endpoints = [(endpoint, "POST"), (endpoint, "GET")]
199      _verify_all_requests(mock_http, creds, endpoints, GetLatestVersions(name=name))
200  
201  
202  def test_get_latest_versions_with_stages(store, creds):
203      name = "model_1"
204      with mock_http_request_403_200() as mock_http:
205          store.get_latest_versions(name=name, stages=["blaah"])
206      endpoint = "registered-models/get-latest-versions"
207      endpoints = [(endpoint, "POST"), (endpoint, "GET")]
208      _verify_all_requests(
209          mock_http, creds, endpoints, GetLatestVersions(name=name, stages=["blaah"])
210      )
211  
212  
213  def test_set_registered_model_tag(store, creds):
214      name = "model_1"
215      tag = RegisteredModelTag(key="key", value="value")
216      with mock_http_request_200() as mock_http:
217          store.set_registered_model_tag(name=name, tag=tag)
218      _verify_requests(
219          mock_http,
220          creds,
221          "registered-models/set-tag",
222          "POST",
223          SetRegisteredModelTag(name=name, key=tag.key, value=tag.value),
224      )
225  
226  
227  def test_delete_registered_model_tag(store, creds):
228      name = "model_1"
229      with mock_http_request_200() as mock_http:
230          store.delete_registered_model_tag(name=name, key="key")
231      _verify_requests(
232          mock_http,
233          creds,
234          "registered-models/delete-tag",
235          "DELETE",
236          DeleteRegisteredModelTag(name=name, key="key"),
237      )
238  
239  
240  def test_create_model_version(store, creds):
241      with mock_http_request_200() as mock_http:
242          store.create_model_version("model_1", "path/to/source")
243      _verify_requests(
244          mock_http,
245          creds,
246          "model-versions/create",
247          "POST",
248          CreateModelVersion(name="model_1", source="path/to/source"),
249      )
250      # test optional fields
251      run_id = uuid.uuid4().hex
252      tags = [
253          ModelVersionTag(key="key", value="value"),
254          ModelVersionTag(key="anotherKey", value="some other value"),
255      ]
256      run_link = "localhost:5000/path/to/run"
257      description = "version description"
258      with mock_http_request_200() as mock_http:
259          store.create_model_version(
260              "model_1",
261              "path/to/source",
262              run_id,
263              tags,
264              run_link=run_link,
265              description=description,
266          )
267      _verify_requests(
268          mock_http,
269          creds,
270          "model-versions/create",
271          "POST",
272          CreateModelVersion(
273              name="model_1",
274              source="path/to/source",
275              run_id=run_id,
276              run_link=run_link,
277              tags=[tag.to_proto() for tag in tags],
278              description=description,
279          ),
280      )
281  
282  
283  def test_transition_model_version_stage(store, creds):
284      name = "model_1"
285      version = "5"
286      with mock_http_request_200() as mock_http:
287          store.transition_model_version_stage(
288              name=name, version=version, stage="prod", archive_existing_versions=True
289          )
290      _verify_requests(
291          mock_http,
292          creds,
293          "model-versions/transition-stage",
294          "POST",
295          TransitionModelVersionStage(
296              name=name, version=version, stage="prod", archive_existing_versions=True
297          ),
298      )
299  
300  
301  def test_update_model_version_description(store, creds):
302      name = "model_1"
303      version = "5"
304      description = "test model version"
305      with mock_http_request_200() as mock_http:
306          store.update_model_version(name=name, version=version, description=description)
307      _verify_requests(
308          mock_http,
309          creds,
310          "model-versions/update",
311          "PATCH",
312          UpdateModelVersion(name=name, version=version, description="test model version"),
313      )
314  
315  
316  def test_delete_model_version(store, creds):
317      name = "model_1"
318      version = "12"
319      with mock_http_request_200() as mock_http:
320          store.delete_model_version(name=name, version=version)
321      _verify_requests(
322          mock_http,
323          creds,
324          "model-versions/delete",
325          "DELETE",
326          DeleteModelVersion(name=name, version=version),
327      )
328  
329  
330  def test_get_model_version_details(store, creds):
331      name = "model_11"
332      version = "8"
333      with mock_http_request_200() as mock_http:
334          store.get_model_version(name=name, version=version)
335      _verify_requests(
336          mock_http, creds, "model-versions/get", "GET", GetModelVersion(name=name, version=version)
337      )
338  
339  
340  def test_get_model_version_download_uri(store, creds):
341      name = "model_11"
342      version = "8"
343      with mock_http_request_200() as mock_http:
344          store.get_model_version_download_uri(name=name, version=version)
345      _verify_requests(
346          mock_http,
347          creds,
348          "model-versions/get-download-uri",
349          "GET",
350          GetModelVersionDownloadUri(name=name, version=version),
351      )
352  
353  
354  def test_search_model_versions(store, creds):
355      with mock_http_request_200() as mock_http:
356          store.search_model_versions()
357      _verify_requests(mock_http, creds, "model-versions/search", "GET", SearchModelVersions())
358  
359  
360  @pytest.mark.parametrize("filter_string", [None, "name = 'model_12'"])
361  @pytest.mark.parametrize("max_results", [None, 400])
362  @pytest.mark.parametrize("page_token", [None, "blah"])
363  @pytest.mark.parametrize("order_by", ["version DESC", "creation_time DESC"])
364  def test_search_model_versions_params(
365      store, creds, filter_string, max_results, page_token, order_by
366  ):
367      params = {
368          "filter_string": filter_string,
369          "max_results": max_results,
370          "page_token": page_token,
371          "order_by": order_by,
372      }
373      params = {k: v for k, v in params.items() if v is not None}
374      with mock_http_request_200() as mock_http:
375          store.search_model_versions(**params)
376      if "filter_string" in params:
377          params["filter"] = params.pop("filter_string")
378      _verify_requests(
379          mock_http,
380          creds,
381          "model-versions/search",
382          "GET",
383          SearchModelVersions(**params),
384      )
385  
386  
387  def test_set_model_version_tag(store, creds):
388      name = "model_1"
389      tag = ModelVersionTag(key="key", value="value")
390      with mock_http_request_200() as mock_http:
391          store.set_model_version_tag(name=name, version="1", tag=tag)
392      _verify_requests(
393          mock_http,
394          creds,
395          "model-versions/set-tag",
396          "POST",
397          SetModelVersionTag(name=name, version="1", key=tag.key, value=tag.value),
398      )
399  
400  
401  def test_delete_model_version_tag(store, creds):
402      name = "model_1"
403      with mock_http_request_200() as mock_http:
404          store.delete_model_version_tag(name=name, version="1", key="key")
405      _verify_requests(
406          mock_http,
407          creds,
408          "model-versions/delete-tag",
409          "DELETE",
410          DeleteModelVersionTag(name=name, version="1", key="key"),
411      )
412  
413  
414  def test_set_registered_model_alias(store, creds):
415      name = "model_1"
416      with mock_http_request_200() as mock_http:
417          store.set_registered_model_alias(name=name, alias="test_alias", version="1")
418      _verify_requests(
419          mock_http,
420          creds,
421          "registered-models/alias",
422          "POST",
423          SetRegisteredModelAlias(name=name, alias="test_alias", version="1"),
424      )
425  
426  
427  def test_delete_registered_model_alias(store, creds):
428      name = "model_1"
429      with mock_http_request_200() as mock_http:
430          store.delete_registered_model_alias(name=name, alias="test_alias")
431      _verify_requests(
432          mock_http,
433          creds,
434          "registered-models/alias",
435          "DELETE",
436          DeleteRegisteredModelAlias(name=name, alias="test_alias"),
437      )
438  
439  
440  def test_get_model_version_by_alias(store, creds):
441      name = "model_1"
442      with mock_http_request_200() as mock_http:
443          store.get_model_version_by_alias(name=name, alias="test_alias")
444      _verify_requests(
445          mock_http,
446          creds,
447          "registered-models/alias",
448          "GET",
449          GetModelVersionByAlias(name=name, alias="test_alias"),
450      )
451  
452  
453  def test_await_model_version_creation_pending(store):
454      pending_mv = ModelVersion(
455          name="Model 1",
456          version="1",
457          creation_timestamp=123,
458          status=ModelVersionStatus.to_string(ModelVersionStatus.PENDING_REGISTRATION),
459      )
460      with (
461          mock.patch(
462              "mlflow.store.model_registry.abstract_store.AWAIT_MODEL_VERSION_CREATE_SLEEP_INTERVAL_SECONDS",
463              1,
464          ),
465          mock.patch.object(store, "get_model_version", return_value=pending_mv),
466          pytest.raises(MlflowException, match="Exceeded max wait time"),
467      ):
468          store._await_model_version_creation(pending_mv, 0.5)
469  
470  
471  def test_await_model_version_creation_failed(store):
472      pending_mv = ModelVersion(
473          name="Model 1",
474          version="1",
475          creation_timestamp=123,
476          status=ModelVersionStatus.to_string(ModelVersionStatus.FAILED_REGISTRATION),
477      )
478      with (
479          mock.patch.object(store, "get_model_version", return_value=pending_mv),
480          pytest.raises(MlflowException, match="Model version creation failed for model name"),
481      ):
482          store._await_model_version_creation(pending_mv, 0.5)
483  
484  
485  @pytest.mark.parametrize("is_prompt", [True, False], ids=["prompt", "model"])
486  def test_await_model_version_creation_show_correct_message_for_prompt(store, is_prompt):
487      tags = [ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true")] if is_prompt else []
488      pending = ModelVersion(
489          name="test",
490          version="1",
491          creation_timestamp=123,
492          tags=tags,
493          status=ModelVersionStatus.to_string(ModelVersionStatus.PENDING_REGISTRATION),
494      )
495      completed = ModelVersion(
496          name="test",
497          version="1",
498          creation_timestamp=123,
499          tags=tags,
500          status=ModelVersionStatus.to_string(ModelVersionStatus.READY),
501      )
502  
503      with (
504          mock.patch("mlflow.store.model_registry.abstract_store._logger") as mock_logger,
505          mock.patch.object(store, "get_model_version", return_value=completed),
506      ):
507          store._await_model_version_creation(pending, 10)
508  
509      mock_logger.info.assert_called_once()
510      info_message = mock_logger.mock_calls[0][1][0]
511      if is_prompt:
512          assert "prompt" in info_message
513          assert "model" not in info_message
514      else:
515          assert "prompt" not in info_message
516          assert "model" in info_message