/ tests / utils / test_unity_catalog_utils.py
test_unity_catalog_utils.py
  1  import pytest
  2  
  3  from mlflow.entities.model_registry import (
  4      ModelVersion,
  5      ModelVersionDeploymentJobState,
  6      ModelVersionTag,
  7      RegisteredModel,
  8      RegisteredModelAlias,
  9      RegisteredModelTag,
 10  )
 11  from mlflow.entities.model_registry.model_version_search import ModelVersionSearch
 12  from mlflow.entities.model_registry.registered_model_search import RegisteredModelSearch
 13  from mlflow.protos.databricks_uc_registry_messages_pb2 import (
 14      EncryptionDetails,
 15      SseEncryptionAlgorithm,
 16      SseEncryptionDetails,
 17      TemporaryCredentials,
 18  )
 19  from mlflow.protos.databricks_uc_registry_messages_pb2 import ModelVersion as ProtoModelVersion
 20  from mlflow.protos.databricks_uc_registry_messages_pb2 import (
 21      ModelVersionStatus as ProtoModelVersionStatus,
 22  )
 23  from mlflow.protos.databricks_uc_registry_messages_pb2 import (
 24      ModelVersionTag as ProtoModelVersionTag,
 25  )
 26  from mlflow.protos.databricks_uc_registry_messages_pb2 import (
 27      RegisteredModel as ProtoRegisteredModel,
 28  )
 29  from mlflow.protos.databricks_uc_registry_messages_pb2 import (
 30      RegisteredModelAlias as ProtoRegisteredModelAlias,
 31  )
 32  from mlflow.protos.databricks_uc_registry_messages_pb2 import (
 33      RegisteredModelTag as ProtoRegisteredModelTag,
 34  )
 35  from mlflow.utils._unity_catalog_utils import (
 36      _parse_aws_sse_credential,
 37      model_version_from_uc_proto,
 38      model_version_search_from_uc_proto,
 39      registered_model_from_uc_proto,
 40      registered_model_search_from_uc_proto,
 41  )
 42  
 43  
 44  @pytest.mark.parametrize(
 45      "run_state",
 46      [
 47          "DEPLOYMENT_JOB_RUN_STATE_UNSPECIFIED",
 48          "NO_VALID_DEPLOYMENT_JOB_FOUND",
 49          "RUNNING",
 50          "SUCCEEDED",
 51          "FAILED",
 52          "PENDING",
 53          "APPROVAL",
 54      ],
 55  )
 56  def test_model_version_from_uc_proto(run_state):
 57      from mlflow.protos.databricks_uc_registry_messages_pb2 import (
 58          ModelVersionDeploymentJobState as ProtoModelVersionDeploymentJobState,
 59      )
 60  
 61      expected_model_version = ModelVersion(
 62          name="name",
 63          version="1",
 64          creation_timestamp=1,
 65          last_updated_timestamp=2,
 66          description="description",
 67          user_id="user_id",
 68          source="source",
 69          run_id="run_id",
 70          status="READY",
 71          status_message="status_message",
 72          aliases=["alias1", "alias2"],
 73          tags=[
 74              ModelVersionTag(key="key1", value="value"),
 75              ModelVersionTag(key="key2", value=""),
 76          ],
 77          metrics=[],
 78          model_id="",
 79          params=[],
 80          deployment_job_state=ModelVersionDeploymentJobState(
 81              "job_123",
 82              "run_456",
 83              "DEPLOYMENT_JOB_CONNECTION_STATE_UNSPECIFIED",
 84              run_state,
 85              "task_name",
 86          ),
 87      )
 88  
 89      # Create protobuf with deployment job state
 90      deployment_job_state_proto = ProtoModelVersionDeploymentJobState(
 91          job_id="job_123",
 92          run_id="run_456",
 93          job_state=0,  # DEPLOYMENT_JOB_CONNECTION_STATE_UNSPECIFIED
 94          run_state=ProtoModelVersionDeploymentJobState.DeploymentJobRunState.Value(run_state),
 95          current_task_name="task_name",
 96      )
 97  
 98      uc_proto = ProtoModelVersion(
 99          name="name",
100          version="1",
101          creation_timestamp=1,
102          last_updated_timestamp=2,
103          description="description",
104          user_id="user_id",
105          source="source",
106          run_id="run_id",
107          status=ProtoModelVersionStatus.Value("READY"),
108          status_message="status_message",
109          aliases=[
110              ProtoRegisteredModelAlias(alias="alias1", version="1"),
111              ProtoRegisteredModelAlias(alias="alias2", version="2"),
112          ],
113          tags=[
114              ProtoModelVersionTag(key="key1", value="value"),
115              ProtoModelVersionTag(key="key2", value=""),
116          ],
117          deployment_job_state=deployment_job_state_proto,
118      )
119      actual_model_version = model_version_from_uc_proto(uc_proto)
120      assert actual_model_version == expected_model_version
121  
122  
123  def test_model_version_search_from_uc_proto():
124      expected_model_version = ModelVersionSearch(
125          name="name",
126          version="1",
127          creation_timestamp=1,
128          last_updated_timestamp=2,
129          description="description",
130          user_id="user_id",
131          source="source",
132          run_id="run_id",
133          status="READY",
134          status_message="status_message",
135          aliases=[],
136          tags=[],
137          deployment_job_state=ModelVersionDeploymentJobState(
138              "",
139              "",
140              "DEPLOYMENT_JOB_CONNECTION_STATE_UNSPECIFIED",
141              "DEPLOYMENT_JOB_RUN_STATE_UNSPECIFIED",
142              "",
143          ),
144      )
145      uc_proto = ProtoModelVersion(
146          name="name",
147          version="1",
148          creation_timestamp=1,
149          last_updated_timestamp=2,
150          description="description",
151          user_id="user_id",
152          source="source",
153          run_id="run_id",
154          status=ProtoModelVersionStatus.Value("READY"),
155          status_message="status_message",
156          aliases=[
157              ProtoRegisteredModelAlias(alias="alias1", version="1"),
158              ProtoRegisteredModelAlias(alias="alias2", version="2"),
159          ],
160          tags=[
161              ProtoModelVersionTag(key="key1", value="value"),
162              ProtoModelVersionTag(key="key2", value=""),
163          ],
164      )
165      actual_model_version = model_version_search_from_uc_proto(uc_proto)
166      assert actual_model_version == expected_model_version
167  
168      with pytest.raises(Exception):  # noqa: PT011
169          actual_model_version.tags()
170  
171      with pytest.raises(Exception):  # noqa: PT011
172          actual_model_version.aliases()
173  
174  
175  def test_model_version_and_model_version_search_equality():
176      kwargs = {
177          "name": "name",
178          "version": "1",
179          "creation_timestamp": 1,
180          "last_updated_timestamp": 2,
181          "description": "description",
182          "user_id": "user_id",
183          "source": "source",
184          "run_id": "run_id",
185          "status": "READY",
186          "status_message": "status_message",
187          "aliases": ["alias1", "alias2"],
188          "tags": [
189              ModelVersionTag(key="key1", value="value"),
190              ModelVersionTag(key="key2", value=""),
191          ],
192      }
193      model_version = ModelVersion(**kwargs)
194      model_version_search = ModelVersionSearch(**kwargs)
195  
196      assert model_version != model_version_search
197  
198      kwargs["tags"] = []
199      kwargs["aliases"] = []
200  
201      model_version_2 = ModelVersion(**kwargs)
202      model_version_search_2 = ModelVersionSearch(**kwargs)
203  
204      assert model_version_2 == model_version_search_2
205  
206  
207  def test_registered_model_from_uc_proto():
208      expected_registered_model = RegisteredModel(
209          name="name",
210          creation_timestamp=1,
211          last_updated_timestamp=2,
212          description="description",
213          aliases=[
214              RegisteredModelAlias(alias="alias1", version="1"),
215              RegisteredModelAlias(alias="alias2", version="2"),
216          ],
217          tags=[
218              RegisteredModelTag(key="key1", value="value"),
219              RegisteredModelTag(key="key2", value=""),
220          ],
221          deployment_job_id="",
222          deployment_job_state="DEPLOYMENT_JOB_CONNECTION_STATE_UNSPECIFIED",
223      )
224      uc_proto = ProtoRegisteredModel(
225          name="name",
226          creation_timestamp=1,
227          last_updated_timestamp=2,
228          description="description",
229          aliases=[
230              ProtoRegisteredModelAlias(alias="alias1", version="1"),
231              ProtoRegisteredModelAlias(alias="alias2", version="2"),
232          ],
233          tags=[
234              ProtoRegisteredModelTag(key="key1", value="value"),
235              ProtoRegisteredModelTag(key="key2", value=""),
236          ],
237      )
238      actual_registered_model = registered_model_from_uc_proto(uc_proto)
239      assert actual_registered_model == expected_registered_model
240  
241  
242  def test_registered_model_search_from_uc_proto():
243      expected_registered_model = RegisteredModelSearch(
244          name="name",
245          creation_timestamp=1,
246          last_updated_timestamp=2,
247          description="description",
248          aliases=[],
249          tags=[],
250      )
251      uc_proto = ProtoRegisteredModel(
252          name="name",
253          creation_timestamp=1,
254          last_updated_timestamp=2,
255          description="description",
256          aliases=[
257              ProtoRegisteredModelAlias(alias="alias1", version="1"),
258              ProtoRegisteredModelAlias(alias="alias2", version="2"),
259          ],
260          tags=[
261              ProtoRegisteredModelTag(key="key1", value="value"),
262              ProtoRegisteredModelTag(key="key2", value=""),
263          ],
264      )
265      actual_registered_model = registered_model_search_from_uc_proto(uc_proto)
266      assert actual_registered_model == expected_registered_model
267  
268      with pytest.raises(Exception):  # noqa: PT011
269          actual_registered_model.tags()
270  
271      with pytest.raises(Exception):  # noqa: PT011
272          actual_registered_model.aliases()
273  
274  
275  def test_registered_model_and_registered_model_search_equality():
276      kwargs = {
277          "name": "name",
278          "creation_timestamp": 1,
279          "last_updated_timestamp": 2,
280          "description": "description",
281          "aliases": [
282              RegisteredModelAlias(alias="alias1", version="1"),
283              RegisteredModelAlias(alias="alias2", version="2"),
284          ],
285          "tags": [
286              RegisteredModelTag(key="key1", value="value"),
287              RegisteredModelTag(key="key2", value=""),
288          ],
289      }
290      registered_model = RegisteredModel(**kwargs)
291      registered_model_search = RegisteredModelSearch(**kwargs)
292  
293      assert registered_model != registered_model_search
294  
295      kwargs["tags"] = []
296      kwargs["aliases"] = []
297  
298      registered_model_2 = RegisteredModel(**kwargs)
299      registered_model_search_2 = RegisteredModelSearch(**kwargs)
300  
301      assert registered_model_2 == registered_model_search_2
302  
303  
304  @pytest.mark.parametrize(
305      ("temp_credentials", "parsed"),
306      [
307          (TemporaryCredentials(), {}),
308          (
309              TemporaryCredentials(
310                  encryption_details=EncryptionDetails(
311                      sse_encryption_details=SseEncryptionDetails(
312                          algorithm=SseEncryptionAlgorithm.SSE_ENCRYPTION_ALGORITHM_UNSPECIFIED
313                      )
314                  )
315              ),
316              {},
317          ),
318          (
319              TemporaryCredentials(
320                  encryption_details=EncryptionDetails(
321                      sse_encryption_details=SseEncryptionDetails(
322                          algorithm=SseEncryptionAlgorithm.AWS_SSE_KMS,
323                          aws_kms_key_arn="arn:aws:kms:us-west-2:111111111111:key/test-key-id",
324                      )
325                  )
326              ),
327              {
328                  "ServerSideEncryption": "aws:kms",
329                  "SSEKMSKeyId": "arn:aws:kms:us-west-2:111111111111:key/test-key-id",
330              },
331          ),
332          (
333              TemporaryCredentials(
334                  encryption_details=EncryptionDetails(
335                      sse_encryption_details=SseEncryptionDetails(
336                          algorithm=SseEncryptionAlgorithm.AWS_SSE_S3,
337                      )
338                  )
339              ),
340              {
341                  "ServerSideEncryption": "AES256",
342              },
343          ),
344      ],
345  )
346  def test_parse_aws_sse_credential(temp_credentials, parsed):
347      assert _parse_aws_sse_credential(temp_credentials) == parsed