test_unity_catalog_utils.py
1 import pytest 2 3 from mlflow.entities.model_registry import ( 4 ModelVersion, 5 ModelVersionDeploymentJobState, 6 ModelVersionTag, 7 RegisteredModel, 8 RegisteredModelAlias, 9 RegisteredModelTag, 10 ) 11 from mlflow.entities.model_registry.model_version_search import ModelVersionSearch 12 from mlflow.entities.model_registry.registered_model_search import RegisteredModelSearch 13 from mlflow.protos.databricks_uc_registry_messages_pb2 import ( 14 EncryptionDetails, 15 SseEncryptionAlgorithm, 16 SseEncryptionDetails, 17 TemporaryCredentials, 18 ) 19 from mlflow.protos.databricks_uc_registry_messages_pb2 import ModelVersion as ProtoModelVersion 20 from mlflow.protos.databricks_uc_registry_messages_pb2 import ( 21 ModelVersionStatus as ProtoModelVersionStatus, 22 ) 23 from mlflow.protos.databricks_uc_registry_messages_pb2 import ( 24 ModelVersionTag as ProtoModelVersionTag, 25 ) 26 from mlflow.protos.databricks_uc_registry_messages_pb2 import ( 27 RegisteredModel as ProtoRegisteredModel, 28 ) 29 from mlflow.protos.databricks_uc_registry_messages_pb2 import ( 30 RegisteredModelAlias as ProtoRegisteredModelAlias, 31 ) 32 from mlflow.protos.databricks_uc_registry_messages_pb2 import ( 33 RegisteredModelTag as ProtoRegisteredModelTag, 34 ) 35 from mlflow.utils._unity_catalog_utils import ( 36 _parse_aws_sse_credential, 37 model_version_from_uc_proto, 38 model_version_search_from_uc_proto, 39 registered_model_from_uc_proto, 40 registered_model_search_from_uc_proto, 41 ) 42 43 44 @pytest.mark.parametrize( 45 "run_state", 46 [ 47 "DEPLOYMENT_JOB_RUN_STATE_UNSPECIFIED", 48 "NO_VALID_DEPLOYMENT_JOB_FOUND", 49 "RUNNING", 50 "SUCCEEDED", 51 "FAILED", 52 "PENDING", 53 "APPROVAL", 54 ], 55 ) 56 def test_model_version_from_uc_proto(run_state): 57 from mlflow.protos.databricks_uc_registry_messages_pb2 import ( 58 ModelVersionDeploymentJobState as ProtoModelVersionDeploymentJobState, 59 ) 60 61 expected_model_version = ModelVersion( 62 name="name", 63 version="1", 64 creation_timestamp=1, 65 last_updated_timestamp=2, 66 description="description", 67 user_id="user_id", 68 source="source", 69 run_id="run_id", 70 status="READY", 71 status_message="status_message", 72 aliases=["alias1", "alias2"], 73 tags=[ 74 ModelVersionTag(key="key1", value="value"), 75 ModelVersionTag(key="key2", value=""), 76 ], 77 metrics=[], 78 model_id="", 79 params=[], 80 deployment_job_state=ModelVersionDeploymentJobState( 81 "job_123", 82 "run_456", 83 "DEPLOYMENT_JOB_CONNECTION_STATE_UNSPECIFIED", 84 run_state, 85 "task_name", 86 ), 87 ) 88 89 # Create protobuf with deployment job state 90 deployment_job_state_proto = ProtoModelVersionDeploymentJobState( 91 job_id="job_123", 92 run_id="run_456", 93 job_state=0, # DEPLOYMENT_JOB_CONNECTION_STATE_UNSPECIFIED 94 run_state=ProtoModelVersionDeploymentJobState.DeploymentJobRunState.Value(run_state), 95 current_task_name="task_name", 96 ) 97 98 uc_proto = ProtoModelVersion( 99 name="name", 100 version="1", 101 creation_timestamp=1, 102 last_updated_timestamp=2, 103 description="description", 104 user_id="user_id", 105 source="source", 106 run_id="run_id", 107 status=ProtoModelVersionStatus.Value("READY"), 108 status_message="status_message", 109 aliases=[ 110 ProtoRegisteredModelAlias(alias="alias1", version="1"), 111 ProtoRegisteredModelAlias(alias="alias2", version="2"), 112 ], 113 tags=[ 114 ProtoModelVersionTag(key="key1", value="value"), 115 ProtoModelVersionTag(key="key2", value=""), 116 ], 117 deployment_job_state=deployment_job_state_proto, 118 ) 119 actual_model_version = model_version_from_uc_proto(uc_proto) 120 assert actual_model_version == expected_model_version 121 122 123 def test_model_version_search_from_uc_proto(): 124 expected_model_version = ModelVersionSearch( 125 name="name", 126 version="1", 127 creation_timestamp=1, 128 last_updated_timestamp=2, 129 description="description", 130 user_id="user_id", 131 source="source", 132 run_id="run_id", 133 status="READY", 134 status_message="status_message", 135 aliases=[], 136 tags=[], 137 deployment_job_state=ModelVersionDeploymentJobState( 138 "", 139 "", 140 "DEPLOYMENT_JOB_CONNECTION_STATE_UNSPECIFIED", 141 "DEPLOYMENT_JOB_RUN_STATE_UNSPECIFIED", 142 "", 143 ), 144 ) 145 uc_proto = ProtoModelVersion( 146 name="name", 147 version="1", 148 creation_timestamp=1, 149 last_updated_timestamp=2, 150 description="description", 151 user_id="user_id", 152 source="source", 153 run_id="run_id", 154 status=ProtoModelVersionStatus.Value("READY"), 155 status_message="status_message", 156 aliases=[ 157 ProtoRegisteredModelAlias(alias="alias1", version="1"), 158 ProtoRegisteredModelAlias(alias="alias2", version="2"), 159 ], 160 tags=[ 161 ProtoModelVersionTag(key="key1", value="value"), 162 ProtoModelVersionTag(key="key2", value=""), 163 ], 164 ) 165 actual_model_version = model_version_search_from_uc_proto(uc_proto) 166 assert actual_model_version == expected_model_version 167 168 with pytest.raises(Exception): # noqa: PT011 169 actual_model_version.tags() 170 171 with pytest.raises(Exception): # noqa: PT011 172 actual_model_version.aliases() 173 174 175 def test_model_version_and_model_version_search_equality(): 176 kwargs = { 177 "name": "name", 178 "version": "1", 179 "creation_timestamp": 1, 180 "last_updated_timestamp": 2, 181 "description": "description", 182 "user_id": "user_id", 183 "source": "source", 184 "run_id": "run_id", 185 "status": "READY", 186 "status_message": "status_message", 187 "aliases": ["alias1", "alias2"], 188 "tags": [ 189 ModelVersionTag(key="key1", value="value"), 190 ModelVersionTag(key="key2", value=""), 191 ], 192 } 193 model_version = ModelVersion(**kwargs) 194 model_version_search = ModelVersionSearch(**kwargs) 195 196 assert model_version != model_version_search 197 198 kwargs["tags"] = [] 199 kwargs["aliases"] = [] 200 201 model_version_2 = ModelVersion(**kwargs) 202 model_version_search_2 = ModelVersionSearch(**kwargs) 203 204 assert model_version_2 == model_version_search_2 205 206 207 def test_registered_model_from_uc_proto(): 208 expected_registered_model = RegisteredModel( 209 name="name", 210 creation_timestamp=1, 211 last_updated_timestamp=2, 212 description="description", 213 aliases=[ 214 RegisteredModelAlias(alias="alias1", version="1"), 215 RegisteredModelAlias(alias="alias2", version="2"), 216 ], 217 tags=[ 218 RegisteredModelTag(key="key1", value="value"), 219 RegisteredModelTag(key="key2", value=""), 220 ], 221 deployment_job_id="", 222 deployment_job_state="DEPLOYMENT_JOB_CONNECTION_STATE_UNSPECIFIED", 223 ) 224 uc_proto = ProtoRegisteredModel( 225 name="name", 226 creation_timestamp=1, 227 last_updated_timestamp=2, 228 description="description", 229 aliases=[ 230 ProtoRegisteredModelAlias(alias="alias1", version="1"), 231 ProtoRegisteredModelAlias(alias="alias2", version="2"), 232 ], 233 tags=[ 234 ProtoRegisteredModelTag(key="key1", value="value"), 235 ProtoRegisteredModelTag(key="key2", value=""), 236 ], 237 ) 238 actual_registered_model = registered_model_from_uc_proto(uc_proto) 239 assert actual_registered_model == expected_registered_model 240 241 242 def test_registered_model_search_from_uc_proto(): 243 expected_registered_model = RegisteredModelSearch( 244 name="name", 245 creation_timestamp=1, 246 last_updated_timestamp=2, 247 description="description", 248 aliases=[], 249 tags=[], 250 ) 251 uc_proto = ProtoRegisteredModel( 252 name="name", 253 creation_timestamp=1, 254 last_updated_timestamp=2, 255 description="description", 256 aliases=[ 257 ProtoRegisteredModelAlias(alias="alias1", version="1"), 258 ProtoRegisteredModelAlias(alias="alias2", version="2"), 259 ], 260 tags=[ 261 ProtoRegisteredModelTag(key="key1", value="value"), 262 ProtoRegisteredModelTag(key="key2", value=""), 263 ], 264 ) 265 actual_registered_model = registered_model_search_from_uc_proto(uc_proto) 266 assert actual_registered_model == expected_registered_model 267 268 with pytest.raises(Exception): # noqa: PT011 269 actual_registered_model.tags() 270 271 with pytest.raises(Exception): # noqa: PT011 272 actual_registered_model.aliases() 273 274 275 def test_registered_model_and_registered_model_search_equality(): 276 kwargs = { 277 "name": "name", 278 "creation_timestamp": 1, 279 "last_updated_timestamp": 2, 280 "description": "description", 281 "aliases": [ 282 RegisteredModelAlias(alias="alias1", version="1"), 283 RegisteredModelAlias(alias="alias2", version="2"), 284 ], 285 "tags": [ 286 RegisteredModelTag(key="key1", value="value"), 287 RegisteredModelTag(key="key2", value=""), 288 ], 289 } 290 registered_model = RegisteredModel(**kwargs) 291 registered_model_search = RegisteredModelSearch(**kwargs) 292 293 assert registered_model != registered_model_search 294 295 kwargs["tags"] = [] 296 kwargs["aliases"] = [] 297 298 registered_model_2 = RegisteredModel(**kwargs) 299 registered_model_search_2 = RegisteredModelSearch(**kwargs) 300 301 assert registered_model_2 == registered_model_search_2 302 303 304 @pytest.mark.parametrize( 305 ("temp_credentials", "parsed"), 306 [ 307 (TemporaryCredentials(), {}), 308 ( 309 TemporaryCredentials( 310 encryption_details=EncryptionDetails( 311 sse_encryption_details=SseEncryptionDetails( 312 algorithm=SseEncryptionAlgorithm.SSE_ENCRYPTION_ALGORITHM_UNSPECIFIED 313 ) 314 ) 315 ), 316 {}, 317 ), 318 ( 319 TemporaryCredentials( 320 encryption_details=EncryptionDetails( 321 sse_encryption_details=SseEncryptionDetails( 322 algorithm=SseEncryptionAlgorithm.AWS_SSE_KMS, 323 aws_kms_key_arn="arn:aws:kms:us-west-2:111111111111:key/test-key-id", 324 ) 325 ) 326 ), 327 { 328 "ServerSideEncryption": "aws:kms", 329 "SSEKMSKeyId": "arn:aws:kms:us-west-2:111111111111:key/test-key-id", 330 }, 331 ), 332 ( 333 TemporaryCredentials( 334 encryption_details=EncryptionDetails( 335 sse_encryption_details=SseEncryptionDetails( 336 algorithm=SseEncryptionAlgorithm.AWS_SSE_S3, 337 ) 338 ) 339 ), 340 { 341 "ServerSideEncryption": "AES256", 342 }, 343 ), 344 ], 345 ) 346 def test_parse_aws_sse_credential(temp_credentials, parsed): 347 assert _parse_aws_sse_credential(temp_credentials) == parsed