/ tests / entities / model_registry / test_registered_model.py
test_registered_model.py
  1  from mlflow.entities.model_registry import RegisteredModelAlias
  2  from mlflow.entities.model_registry.model_version import ModelVersion
  3  from mlflow.entities.model_registry.registered_model import RegisteredModel
  4  from mlflow.entities.model_registry.registered_model_tag import RegisteredModelTag
  5  from mlflow.utils.workspace_utils import DEFAULT_WORKSPACE_NAME
  6  
  7  from tests.helper_functions import random_str
  8  
  9  
 10  def _check(
 11      registered_model,
 12      name,
 13      creation_timestamp,
 14      last_updated_timestamp,
 15      description,
 16      latest_versions,
 17      tags,
 18      aliases,
 19      workspace=DEFAULT_WORKSPACE_NAME,
 20  ):
 21      assert isinstance(registered_model, RegisteredModel)
 22      assert registered_model.name == name
 23      assert registered_model.creation_timestamp == creation_timestamp
 24      assert registered_model.last_updated_timestamp == last_updated_timestamp
 25      assert registered_model.description == description
 26      assert registered_model.last_updated_timestamp == last_updated_timestamp
 27      assert registered_model.latest_versions == latest_versions
 28      assert registered_model.tags == tags
 29      assert registered_model.aliases == aliases
 30      assert registered_model.workspace == workspace
 31  
 32  
 33  def test_creation_and_hydration():
 34      name = random_str()
 35      description = random_str()
 36      rmd_1 = RegisteredModel(name, 1, 2, description, [], [])
 37      _check(rmd_1, name, 1, 2, description, [], {}, {})
 38  
 39      as_dict = {
 40          "name": name,
 41          "creation_timestamp": 1,
 42          "last_updated_timestamp": 2,
 43          "description": description,
 44          "latest_versions": [],
 45          "tags": {},
 46          "aliases": {},
 47          "deployment_job_id": None,
 48          "deployment_job_state": None,
 49          "workspace": DEFAULT_WORKSPACE_NAME,
 50      }
 51      assert dict(rmd_1) == as_dict
 52  
 53      proto = rmd_1.to_proto()
 54      assert proto.name == name
 55      assert proto.creation_timestamp == 1
 56      assert proto.last_updated_timestamp == 2
 57      assert proto.description == description
 58      rmd_2 = RegisteredModel.from_proto(proto)
 59      _check(rmd_2, name, 1, 2, description, [], {}, {})
 60      as_dict["tags"] = []
 61      rmd_3 = RegisteredModel.from_dictionary(as_dict)
 62      _check(rmd_3, name, 1, 2, description, [], {}, {})
 63  
 64  
 65  def test_with_latest_model_versions():
 66      name = random_str()
 67      mvd_1 = ModelVersion(
 68          name,
 69          "1",
 70          1000,
 71          2000,
 72          "version 1",
 73          "user 1",
 74          "Production",
 75          "source 1",
 76          "run ID 1",
 77          "PENDING_REGISTRATION",
 78          "Model version is in production!",
 79      )
 80      mvd_2 = ModelVersion(
 81          name,
 82          "4",
 83          1300,
 84          2002,
 85          "version 4",
 86          "user 2",
 87          "Staging",
 88          "source 4",
 89          "run ID 12",
 90          "READY",
 91          "Model copied over!",
 92      )
 93      as_dict = {
 94          "name": name,
 95          "creation_timestamp": 1,
 96          "last_updated_timestamp": 4000,
 97          "description": random_str(),
 98          "latest_versions": [mvd_1, mvd_2],
 99          "tags": [],
100          "aliases": {},
101          "deployment_job_id": None,
102          "deployment_job_state": None,
103          "workspace": DEFAULT_WORKSPACE_NAME,
104      }
105      rmd_1 = RegisteredModel.from_dictionary(as_dict)
106      as_dict["tags"] = {}
107      assert dict(rmd_1) == as_dict
108  
109      proto = rmd_1.to_proto()
110      assert proto.creation_timestamp == 1
111      assert proto.last_updated_timestamp == 4000
112      assert {mvd.version for mvd in proto.latest_versions} == {"1", "4"}
113      assert {mvd.name for mvd in proto.latest_versions} == {name}
114      assert {mvd.current_stage for mvd in proto.latest_versions} == {"Production", "Staging"}
115      assert {mvd.last_updated_timestamp for mvd in proto.latest_versions} == {2000, 2002}
116  
117      assert {mvd.creation_timestamp for mvd in proto.latest_versions} == {1300, 1000}
118  
119  
120  def test_with_tags():
121      name = random_str()
122      tag1 = RegisteredModelTag("key", "value")
123      tag2 = RegisteredModelTag("randomKey", "not a random value")
124      tags = [tag1, tag2]
125      as_dict = {
126          "name": name,
127          "creation_timestamp": 1,
128          "last_updated_timestamp": 4000,
129          "description": random_str(),
130          "latest_versions": [],
131          "tags": tags,
132          "aliases": {},
133          "deployment_job_id": None,
134          "deployment_job_state": None,
135          "workspace": DEFAULT_WORKSPACE_NAME,
136      }
137      rmd_1 = RegisteredModel.from_dictionary(as_dict)
138      as_dict["tags"] = {tag.key: tag.value for tag in (tags or [])}
139      assert dict(rmd_1) == as_dict
140      proto = rmd_1.to_proto()
141      assert proto.creation_timestamp == 1
142      assert proto.last_updated_timestamp == 4000
143      assert {tag.key for tag in proto.tags} == {"key", "randomKey"}
144      assert {tag.value for tag in proto.tags} == {"value", "not a random value"}
145  
146  
147  def test_with_aliases():
148      name = random_str()
149      alias1 = RegisteredModelAlias("test_alias", "1")
150      alias2 = RegisteredModelAlias("other_alias", "2")
151      aliases = [alias1, alias2]
152      as_dict = {
153          "name": name,
154          "creation_timestamp": 1,
155          "last_updated_timestamp": 4000,
156          "description": random_str(),
157          "latest_versions": [],
158          "tags": {},
159          "aliases": aliases,
160          "deployment_job_id": None,
161          "deployment_job_state": None,
162          "workspace": DEFAULT_WORKSPACE_NAME,
163      }
164      rmd_1 = RegisteredModel.from_dictionary(as_dict)
165      as_dict["aliases"] = {alias.alias: alias.version for alias in (aliases or [])}
166      assert dict(rmd_1) == as_dict
167      proto = rmd_1.to_proto()
168      assert proto.creation_timestamp == 1
169      assert proto.last_updated_timestamp == 4000
170      assert {alias.alias for alias in proto.aliases} == {"test_alias", "other_alias"}
171      assert {alias.version for alias in proto.aliases} == {"1", "2"}
172  
173  
174  def test_string_repr():
175      rmd = RegisteredModel(
176          name="myname",
177          creation_timestamp=1000,
178          last_updated_timestamp=2002,
179          description="something about a model",
180          latest_versions=["1", "2", "3"],
181          tags=[],
182          aliases=[],
183      )
184      assert (
185          str(rmd) == "<RegisteredModel: aliases={}, creation_timestamp=1000, "
186          "deployment_job_id=None, deployment_job_state=None, description='something about a model',"
187          " last_updated_timestamp=2002, "
188          "latest_versions=['1', '2', '3'], name='myname', tags={}, workspace='default'>"
189      )
190  
191  
192  def test_registered_model_non_default_workspace_round_trip():
193      name = random_str()
194      workspace = f"team-{random_str()}"
195      description = "custom model"
196      rmd = RegisteredModel(
197          name=name,
198          creation_timestamp=1,
199          last_updated_timestamp=2,
200          description=description,
201          latest_versions=[],
202          tags=[],
203          aliases=[],
204          workspace=workspace,
205      )
206  
207      as_dict = dict(rmd)
208      assert as_dict["workspace"] == workspace
209  
210      hydrated = RegisteredModel.from_dictionary(as_dict)
211      assert hydrated.workspace == workspace
212      assert workspace in str(hydrated)