test_model.py
1 import json 2 import os 3 import pathlib 4 import time 5 import uuid 6 from datetime import date 7 from unittest import mock 8 9 import numpy as np 10 import pandas as pd 11 import pydantic 12 import pytest 13 import sklearn.datasets 14 import sklearn.neighbors 15 from packaging.version import Version 16 from scipy.sparse import csc_matrix 17 18 import mlflow 19 from mlflow.exceptions import MlflowException 20 from mlflow.models import Model, ModelSignature, infer_signature, set_model, validate_schema 21 from mlflow.models.model import METADATA_FILES, SET_MODEL_ERROR 22 from mlflow.models.resources import DatabricksServingEndpoint, DatabricksVectorSearchIndex 23 from mlflow.models.utils import _read_example, _save_example 24 from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository 25 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 26 from mlflow.types.schema import ColSpec, DataType, ParamSchema, ParamSpec, Schema, TensorSpec 27 from mlflow.utils.databricks_utils import DatabricksRuntimeVersion 28 from mlflow.utils.file_utils import TempDir 29 from mlflow.utils.model_utils import _validate_and_prepare_target_save_path 30 from mlflow.utils.proto_json_utils import dataframe_from_raw_json 31 32 33 @pytest.fixture(scope="module") 34 def iris_data(): 35 iris = sklearn.datasets.load_iris() 36 x = iris.data[:, :2] 37 y = iris.target 38 return x, y 39 40 41 @pytest.fixture(scope="module") 42 def sklearn_knn_model(iris_data): 43 x, y = iris_data 44 knn_model = sklearn.neighbors.KNeighborsClassifier() 45 knn_model.fit(x, y) 46 return knn_model 47 48 49 def test_model_save_load(): 50 m = Model( 51 artifact_path="model", 52 run_id="123", 53 flavors={"flavor1": {"a": 1, "b": 2}, "flavor2": {"x": 1, "y": 2}}, 54 signature=ModelSignature( 55 inputs=Schema([ColSpec("integer", "x"), ColSpec("integer", "y")]), 56 outputs=Schema([ColSpec(name=None, type="double")]), 57 ), 58 saved_input_example_info={"x": 1, "y": 2}, 59 ) 60 assert m.get_input_schema() == m.signature.inputs 61 assert m.get_output_schema() == m.signature.outputs 62 x = Model(artifact_path="some/other/path", run_id="1234") 63 assert x.get_input_schema() is None 64 assert x.get_output_schema() is None 65 66 n = Model( 67 artifact_path="model", 68 run_id="123", 69 flavors={"flavor1": {"a": 1, "b": 2}, "flavor2": {"x": 1, "y": 2}}, 70 signature=ModelSignature( 71 inputs=Schema([ColSpec("integer", "x"), ColSpec("integer", "y")]), 72 outputs=Schema([ColSpec(name=None, type="double")]), 73 ), 74 saved_input_example_info={"x": 1, "y": 2}, 75 ) 76 n.utc_time_created = m.utc_time_created 77 n.model_uuid = m.model_uuid 78 assert m == n 79 n.signature = None 80 assert m != n 81 with TempDir() as tmp: 82 m.save(tmp.path("MLmodel")) 83 o = Model.load(tmp.path("MLmodel")) 84 assert m == o 85 assert m.to_json() == o.to_json() 86 assert m.to_yaml() == o.to_yaml() 87 88 89 def test_model_load_remote(tmp_path, mock_s3_bucket): 90 model = Model( 91 artifact_path="model", 92 run_id="123", 93 flavors={"flavor1": {"a": 1, "b": 2}, "flavor2": {"x": 1, "y": 2}}, 94 signature=ModelSignature( 95 inputs=Schema([ColSpec("integer", "x"), ColSpec("integer", "y")]), 96 outputs=Schema([ColSpec(name=None, type="double")]), 97 ), 98 saved_input_example_info={"x": 1, "y": 2}, 99 ) 100 model_path = tmp_path / "MLmodel" 101 model.save(model_path) 102 103 artifact_root = f"s3://{mock_s3_bucket}" 104 artifact_repo = S3ArtifactRepository(artifact_root) 105 artifact_repo.log_artifact(str(model_path)) 106 107 model_reloaded_1 = Model.load(f"{artifact_root}/MLmodel") 108 assert model_reloaded_1 == model 109 110 model_reloaded_2 = Model.load(artifact_root) 111 assert model_reloaded_2 == model 112 113 114 class TestFlavor: 115 @classmethod 116 def save_model(cls, path, mlflow_model, signature=None, input_example=None): 117 mlflow_model.flavors["flavor1"] = {"a": 1, "b": 2} 118 mlflow_model.flavors["flavor2"] = {"x": 1, "y": 2} 119 _validate_and_prepare_target_save_path(path) 120 if signature is not None: 121 mlflow_model.signature = signature 122 if input_example is not None: 123 _save_example(mlflow_model, input_example, path) 124 mlflow_model.save(os.path.join(path, "MLmodel")) 125 126 127 def _log_model_with_signature_and_example( 128 tmp_path, sig, input_example, metadata=None, resources=None 129 ): 130 experiment_id = mlflow.create_experiment("test") 131 132 with mlflow.start_run(experiment_id=experiment_id) as run: 133 model = Model.log( 134 "model", 135 TestFlavor, 136 signature=sig, 137 input_example=input_example, 138 metadata=metadata, 139 resources=resources, 140 ) 141 142 # TODO: remove this after replacing all `with TempDir(chdr=True) as tmp` 143 # with tmp_path fixture 144 output_path = tmp_path if isinstance(tmp_path, pathlib.PosixPath) else tmp_path.path("") 145 local_path = _download_artifact_from_uri(model.model_uri, output_path=output_path) 146 return local_path, run 147 148 149 def test_model_log(): 150 with TempDir(chdr=True) as tmp: 151 sig = ModelSignature( 152 inputs=Schema([ColSpec("integer", "x"), ColSpec("integer", "y")]), 153 outputs=Schema([ColSpec(name=None, type="double")]), 154 ) 155 input_example = {"x": 1, "y": 2} 156 local_path, r = _log_model_with_signature_and_example(tmp, sig, input_example) 157 158 loaded_model = Model.load(os.path.join(local_path, "MLmodel")) 159 assert loaded_model.run_id == r.info.run_id 160 assert loaded_model.flavors == { 161 "flavor1": {"a": 1, "b": 2}, 162 "flavor2": {"x": 1, "y": 2}, 163 } 164 assert loaded_model.signature == sig 165 x = _read_example( 166 Model(saved_input_example_info=loaded_model.saved_input_example_info), local_path 167 ) 168 assert x == input_example 169 assert not hasattr(loaded_model, "databricks_runtime") 170 171 loaded_example = loaded_model.load_input_example(local_path) 172 assert loaded_example == input_example 173 174 assert Version(loaded_model.mlflow_version) == Version(mlflow.version.VERSION) 175 176 177 def test_model_log_without_run(tmp_path): 178 model_info = Model.log("model", TestFlavor) 179 assert model_info.run_id is None 180 181 182 def test_model_log_with_active_run(tmp_path): 183 with mlflow.start_run() as run: 184 model_info = Model.log("model", TestFlavor) 185 assert model_info.run_id == run.info.run_id 186 187 188 def test_model_log_inactive_run_id(tmp_path): 189 experiment_id = mlflow.create_experiment("test", artifact_location=str(tmp_path)) 190 run = mlflow.MlflowClient().create_run(experiment_id=experiment_id) 191 model_info = Model.log("model", TestFlavor, run_id=run.info.run_id) 192 assert model_info.run_id == run.info.run_id 193 194 195 def test_model_log_calls_maybe_render_agent_eval_recipe(tmp_path): 196 sig = ModelSignature( 197 inputs=Schema([ColSpec("integer", "x"), ColSpec("integer", "y")]), 198 outputs=Schema([ColSpec(name=None, type="double")]), 199 ) 200 input_example = {"x": 1, "y": 2} 201 with mock.patch("mlflow.models.display_utils.maybe_render_agent_eval_recipe") as render_mock: 202 _log_model_with_signature_and_example(tmp_path, sig, input_example) 203 render_mock.assert_called_once() 204 205 206 def test_model_info(): 207 with TempDir(chdr=True) as tmp: 208 sig = ModelSignature( 209 inputs=Schema([ColSpec("integer", "x"), ColSpec("integer", "y")]), 210 outputs=Schema([ColSpec(name=None, type="double")]), 211 ) 212 input_example = {"x": 1, "y": 2} 213 214 experiment_id = mlflow.create_experiment("test") 215 with mlflow.start_run(experiment_id=experiment_id) as run: 216 model_info = Model.log("model", TestFlavor, signature=sig, input_example=input_example) 217 model_uri = f"models:/{model_info.model_id}" 218 219 model_info_fetched = mlflow.models.get_model_info(model_uri) 220 local_path = _download_artifact_from_uri(model_uri, output_path=tmp.path("")) 221 222 assert model_info.run_id == run.info.run_id 223 assert model_info_fetched.run_id == run.info.run_id 224 assert model_info.model_uri == model_uri 225 assert model_info_fetched.model_uri == model_uri 226 227 loaded_model = Model.load(os.path.join(local_path, "MLmodel")) 228 assert model_info.utc_time_created == loaded_model.utc_time_created 229 assert model_info_fetched.utc_time_created == loaded_model.utc_time_created 230 assert model_info.model_uuid == loaded_model.model_uuid 231 assert model_info_fetched.model_uuid == loaded_model.model_uuid 232 233 assert model_info.flavors == { 234 "flavor1": {"a": 1, "b": 2}, 235 "flavor2": {"x": 1, "y": 2}, 236 } 237 238 x = _read_example( 239 Model(saved_input_example_info=model_info.saved_input_example_info), local_path 240 ) 241 assert x == input_example 242 243 model_signature = model_info_fetched.signature 244 assert model_signature.to_dict() == sig.to_dict() 245 246 assert model_info.mlflow_version == loaded_model.mlflow_version 247 assert model_info_fetched.mlflow_version == loaded_model.mlflow_version 248 249 250 def test_model_info_with_model_version(tmp_path): 251 experiment_id = mlflow.create_experiment("test", artifact_location=str(tmp_path)) 252 with mlflow.start_run(experiment_id=experiment_id): 253 model_info = Model.log("model", TestFlavor, registered_model_name="model_abc") 254 assert model_info.registered_model_version == 1 255 model_info = Model.log("model", TestFlavor, registered_model_name="model_abc") 256 assert model_info.registered_model_version == 2 257 model_info = Model.log("model", TestFlavor) 258 assert model_info.registered_model_version is None 259 260 261 def test_model_metadata(): 262 with TempDir(chdr=True) as tmp: 263 metadata = {"metadata_key": "metadata_value"} 264 local_path, _ = _log_model_with_signature_and_example(tmp, None, None, metadata) 265 loaded_model = Model.load(os.path.join(local_path, "MLmodel")) 266 assert loaded_model.metadata["metadata_key"] == "metadata_value" 267 268 269 def test_load_model_without_mlflow_version(): 270 with TempDir(chdr=True) as tmp: 271 model = Model(artifact_path="model", run_id="1234", mlflow_version=None) 272 path = tmp.path("MLmodel") 273 with open(path, "w") as out: 274 model.to_yaml(out) 275 loaded_model = Model.load(path) 276 277 assert loaded_model.mlflow_version is None 278 279 280 def test_model_log_with_databricks_runtime(): 281 dbr_version = "8.3.x" 282 with mlflow.start_run(): 283 with mock.patch( 284 "mlflow.models.model.get_databricks_runtime_version", return_value=dbr_version 285 ) as mock_get_dbr_version: 286 model = Model.log("path", TestFlavor, signature=None, input_example=None) 287 mock_get_dbr_version.assert_called() 288 289 loaded_model = Model.load(model.model_uri) 290 assert loaded_model.databricks_runtime == dbr_version 291 292 293 def test_model_log_with_databricks_runtime_gpu(): 294 dbr_version = "client.8.1-gpu" 295 with mlflow.start_run(): 296 with mock.patch( 297 "mlflow.models.model.get_databricks_runtime_version", return_value=dbr_version 298 ) as mock_get_dbr_version: 299 model = Model.log("path", TestFlavor, signature=None, input_example=None) 300 mock_get_dbr_version.assert_called() 301 302 # Verify the GPU suffix is preserved in the MLmodel file 303 loaded_model = Model.load(model.model_uri) 304 assert loaded_model.databricks_runtime == dbr_version 305 306 # Verify that the version can be parsed correctly and is_gpu_image is True 307 parsed_version = DatabricksRuntimeVersion.parse(loaded_model.databricks_runtime) 308 assert parsed_version.is_client_image is True 309 assert parsed_version.major == 8 310 assert parsed_version.minor == 1 311 assert parsed_version.is_gpu_image is True 312 313 314 def test_model_log_with_input_example_succeeds(): 315 with TempDir(chdr=True) as tmp: 316 sig = ModelSignature( 317 inputs=Schema([ 318 ColSpec("integer", "a"), 319 ColSpec("string", "b"), 320 ColSpec("boolean", "c"), 321 ColSpec("string", "d"), 322 ColSpec("datetime", "e"), 323 ]), 324 outputs=Schema([ColSpec(name=None, type="double")]), 325 ) 326 input_example = pd.DataFrame( 327 { 328 "a": np.int32(1), 329 "b": "test string", 330 "c": True, 331 "d": date.today(), 332 "e": np.datetime64("2020-01-01T00:00:00"), 333 }, 334 index=[0], 335 ) 336 337 local_path, _ = _log_model_with_signature_and_example(tmp, sig, input_example) 338 loaded_model = Model.load(os.path.join(local_path, "MLmodel")) 339 path = os.path.join(local_path, loaded_model.saved_input_example_info["artifact_path"]) 340 x = dataframe_from_raw_json(path, schema=sig.inputs) 341 342 # date column will get deserialized into string 343 input_example["d"] = input_example["d"].apply(lambda x: x.isoformat()) 344 # datetime Datatype numpy type is [ns] 345 input_example["e"] = input_example["e"].astype(np.dtype("datetime64[ns]")) 346 pd.testing.assert_frame_equal(x, input_example) 347 348 loaded_example = loaded_model.load_input_example(local_path) 349 assert isinstance(loaded_example, pd.DataFrame) 350 pd.testing.assert_frame_equal(loaded_example, input_example) 351 352 353 def test_model_input_example_with_params_log_load_succeeds(tmp_path): 354 pdf = pd.DataFrame( 355 { 356 "a": np.int32(1), 357 "b": "test string", 358 "c": True, 359 "d": date.today(), 360 "e": np.datetime64("2020-01-01T00:00:00"), 361 }, 362 index=[0], 363 ) 364 input_example = (pdf, {"a": 1, "b": "string"}) 365 366 sig = ModelSignature( 367 inputs=Schema([ 368 ColSpec("integer", "a"), 369 ColSpec("string", "b"), 370 ColSpec("boolean", "c"), 371 ColSpec("string", "d"), 372 ColSpec("datetime", "e"), 373 ]), 374 outputs=Schema([ColSpec(name=None, type="double")]), 375 params=ParamSchema([ 376 ParamSpec("a", DataType.long, 1), 377 ParamSpec("b", DataType.string, "string"), 378 ]), 379 ) 380 381 local_path, _ = _log_model_with_signature_and_example(tmp_path, sig, input_example) 382 loaded_model = Model.load(os.path.join(local_path, "MLmodel")) 383 384 # date column will get deserialized into string 385 pdf["d"] = pdf["d"].apply(lambda x: x.isoformat()) 386 loaded_example = loaded_model.load_input_example(local_path) 387 assert isinstance(loaded_example, pd.DataFrame) 388 # datetime Datatype numpy type is [ns] 389 pdf["e"] = pdf["e"].astype(np.dtype("datetime64[ns]")) 390 pd.testing.assert_frame_equal(loaded_example, pdf) 391 392 params = loaded_model.load_input_example_params(local_path) 393 assert params == input_example[1] 394 395 396 def test_model_load_input_example_numpy(): 397 with TempDir(chdr=True) as tmp: 398 input_example = np.array([[3, 4, 5]], dtype=np.int32) 399 sig = ModelSignature( 400 inputs=Schema([TensorSpec(type=input_example.dtype, shape=input_example.shape)]), 401 outputs=Schema([ColSpec(name=None, type="double")]), 402 ) 403 404 local_path, _ = _log_model_with_signature_and_example(tmp, sig, input_example) 405 loaded_model = Model.load(os.path.join(local_path, "MLmodel")) 406 loaded_example = loaded_model.load_input_example(local_path) 407 408 assert isinstance(loaded_example, np.ndarray) 409 np.testing.assert_array_equal(input_example, loaded_example) 410 411 412 def test_model_load_input_example_scipy(): 413 with TempDir(chdr=True) as tmp: 414 input_example = csc_matrix(np.arange(0, 12, 0.5).reshape(3, 8)) 415 sig = ModelSignature( 416 inputs=Schema([TensorSpec(type=input_example.data.dtype, shape=input_example.shape)]), 417 outputs=Schema([ColSpec(name=None, type="double")]), 418 ) 419 420 local_path, _ = _log_model_with_signature_and_example(tmp, sig, input_example) 421 loaded_model = Model.load(os.path.join(local_path, "MLmodel")) 422 loaded_example = loaded_model.load_input_example(local_path) 423 424 assert isinstance(loaded_example, csc_matrix) 425 np.testing.assert_array_equal(input_example.data, loaded_example.data) 426 427 428 def test_model_load_input_example_failures(): 429 with TempDir(chdr=True) as tmp: 430 input_example = np.array([[3, 4, 5]], dtype=np.int32) 431 sig = ModelSignature( 432 inputs=Schema([TensorSpec(type=input_example.dtype, shape=input_example.shape)]), 433 outputs=Schema([ColSpec(name=None, type="double")]), 434 ) 435 436 local_path, _ = _log_model_with_signature_and_example(tmp, sig, input_example) 437 loaded_model = Model.load(os.path.join(local_path, "MLmodel")) 438 loaded_example = loaded_model.load_input_example(local_path) 439 assert loaded_example is not None 440 441 with pytest.raises(MlflowException, match="No such artifact"): 442 loaded_model.load_input_example(os.path.join(local_path, "folder_which_does_not_exist")) 443 444 path = os.path.join(local_path, loaded_model.saved_input_example_info["artifact_path"]) 445 os.remove(path) 446 with pytest.raises(MlflowException, match="No such artifact"): 447 loaded_model.load_input_example(local_path) 448 449 450 def test_model_load_input_example_no_signature(): 451 with TempDir(chdr=True) as tmp: 452 input_example = np.array([[3, 4, 5]], dtype=np.int32) 453 sig = ModelSignature( 454 inputs=Schema([TensorSpec(type=input_example.dtype, shape=input_example.shape)]), 455 outputs=Schema([ColSpec(name=None, type="double")]), 456 ) 457 458 local_path, _ = _log_model_with_signature_and_example(tmp, sig, input_example=None) 459 loaded_model = Model.load(os.path.join(local_path, "MLmodel")) 460 loaded_example = loaded_model.load_input_example(local_path) 461 assert loaded_example is None 462 463 464 def _is_valid_uuid(val): 465 try: 466 uuid.UUID(str(val)) 467 return True 468 except ValueError: 469 return False 470 471 472 def test_model_uuid(): 473 m = Model() 474 assert m.model_uuid is not None 475 assert _is_valid_uuid(m.model_uuid) 476 477 m2 = Model() 478 assert m.model_uuid != m2.model_uuid 479 480 m_dict = m.to_dict() 481 assert m_dict["model_uuid"] == m.model_uuid 482 m3 = Model.from_dict(m_dict) 483 assert m3.model_uuid == m.model_uuid 484 485 m_dict.pop("model_uuid") 486 m4 = Model.from_dict(m_dict) 487 assert m4.model_uuid is None 488 489 490 def test_validate_schema(sklearn_knn_model, iris_data, tmp_path): 491 sk_model_path = os.path.join(tmp_path, "sk_model") 492 X, y = iris_data 493 signature = infer_signature(X, y) 494 mlflow.sklearn.save_model( 495 sklearn_knn_model, 496 sk_model_path, 497 signature=signature, 498 ) 499 500 validate_schema(X, signature.inputs) 501 prediction = sklearn_knn_model.predict(X) 502 reloaded_model = mlflow.sklearn.load_model(sk_model_path) 503 np.testing.assert_array_equal(prediction, reloaded_model.predict(X)) 504 validate_schema(prediction, signature.outputs) 505 506 507 def test_save_load_input_example_without_conversion(tmp_path): 508 class MyModel(mlflow.pyfunc.PythonModel): 509 def predict(self, context, model_input, params=None): 510 return model_input 511 512 input_example = { 513 "messages": [ 514 {"role": "user", "content": "Hello!"}, 515 ] 516 } 517 with mlflow.start_run() as run: 518 mlflow.pyfunc.log_model( 519 name="test_model", 520 python_model=MyModel(), 521 input_example=input_example, 522 ) 523 local_path = _download_artifact_from_uri( 524 f"runs:/{run.info.run_id}/test_model", output_path=tmp_path 525 ) 526 loaded_model = Model.load(os.path.join(local_path, "MLmodel")) 527 assert loaded_model.saved_input_example_info["type"] == "json_object" 528 loaded_example = loaded_model.load_input_example(local_path) 529 assert loaded_example == input_example 530 531 532 def test_save_load_input_example_with_pydantic_model(tmp_path): 533 class Message(pydantic.BaseModel): 534 role: str 535 content: str 536 537 class MyModel(mlflow.pyfunc.PythonModel): 538 def predict(self, context, model_input: list[Message], params=None): 539 return model_input 540 541 with mlflow.start_run(): 542 model_info = mlflow.pyfunc.log_model( 543 name="test_model", 544 python_model=MyModel(), 545 input_example=[Message(role="user", content="Hello!")], 546 ) 547 local_path = _download_artifact_from_uri(model_info.model_uri, output_path=tmp_path) 548 loaded_model = Model.load(os.path.join(local_path, "MLmodel")) 549 assert loaded_model.saved_input_example_info["type"] == "json_object" 550 loaded_example = loaded_model.load_input_example(local_path) 551 assert loaded_example == [{"role": "user", "content": "Hello!"}] 552 553 554 def test_model_saved_by_save_model_can_be_loaded(tmp_path, sklearn_knn_model): 555 mlflow.sklearn.save_model(sklearn_knn_model, tmp_path) 556 info = Model.load(tmp_path).get_model_info() 557 assert info.run_id is None 558 assert info.artifact_path is None 559 560 561 def test_copy_metadata(mock_is_in_databricks, sklearn_knn_model): 562 with mlflow.start_run(): 563 model_info = mlflow.sklearn.log_model(sklearn_knn_model, name="model") 564 565 artifact_path = mlflow.artifacts.download_artifacts(model_info.model_uri) 566 metadata_path = os.path.join(artifact_path, "metadata") 567 # Metadata should be copied only in Databricks 568 if mock_is_in_databricks.return_value: 569 assert set(os.listdir(metadata_path)) == set(METADATA_FILES) 570 else: 571 assert not os.path.exists(metadata_path) 572 mock_is_in_databricks.assert_called_once() 573 574 575 class LegacyTestFlavor: 576 @classmethod 577 def save_model(cls, path, mlflow_model): 578 mlflow_model.flavors["flavor1"] = {"a": 1, "b": 2} 579 mlflow_model.flavors["flavor2"] = {"x": 1, "y": 2} 580 _validate_and_prepare_target_save_path(path) 581 mlflow_model.save(os.path.join(path, "MLmodel")) 582 583 584 def test_legacy_flavor(mock_is_in_databricks): 585 with mlflow.start_run(): 586 model_info = Model.log("model", LegacyTestFlavor) 587 588 artifact_path = _download_artifact_from_uri(model_info.model_uri) 589 metadata_path = os.path.join(artifact_path, "metadata") 590 # Metadata should be copied only in Databricks 591 if mock_is_in_databricks.return_value: 592 assert set(os.listdir(metadata_path)) == {"MLmodel"} 593 else: 594 assert not os.path.exists(metadata_path) 595 mock_is_in_databricks.assert_called_once() 596 597 598 def test_pyfunc_set_model(): 599 class MyModel(mlflow.pyfunc.PythonModel): 600 def predict(self, context, model_input): 601 return model_input 602 603 set_model(MyModel()) 604 assert isinstance(mlflow.models.model.__mlflow_model__, mlflow.pyfunc.PythonModel) 605 606 607 def test_langchain_set_model(): 608 from langchain_core.runnables import RunnableLambda 609 610 def create_runnable(): 611 def my_runnable(input): 612 return f"Input was: {input}" 613 614 runnable = RunnableLambda(my_runnable) 615 set_model(runnable) 616 617 create_runnable() 618 assert isinstance(mlflow.models.model.__mlflow_model__, RunnableLambda) 619 620 621 def test_error_set_model(sklearn_knn_model): 622 with pytest.raises(mlflow.MlflowException, match=SET_MODEL_ERROR): 623 set_model(sklearn_knn_model) 624 625 626 def test_model_resources(): 627 expected_resources = { 628 "api_version": "1", 629 "databricks": { 630 "serving_endpoint": [ 631 {"name": "databricks-mixtral-8x7b-instruct"}, 632 {"name": "databricks-bge-large-en"}, 633 {"name": "azure-eastus-model-serving-2_vs_endpoint"}, 634 ], 635 "vector_search_index": [{"name": "rag.studio_bugbash.databricks_docs_index"}], 636 }, 637 } 638 with TempDir(chdr=True) as tmp: 639 resources = [ 640 DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"), 641 DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"), 642 DatabricksServingEndpoint(endpoint_name="azure-eastus-model-serving-2_vs_endpoint"), 643 DatabricksVectorSearchIndex(index_name="rag.studio_bugbash.databricks_docs_index"), 644 ] 645 local_path, _ = _log_model_with_signature_and_example(tmp, None, None, resources=resources) 646 loaded_model = Model.load(os.path.join(local_path, "MLmodel")) 647 assert loaded_model.resources == expected_resources 648 649 650 def test_save_load_model_with_run_uri(): 651 class MyModel(mlflow.pyfunc.PythonModel): 652 def predict(self, context, model_input: list[str], params=None): 653 return model_input 654 655 with mlflow.start_run() as run: 656 mlflow.pyfunc.log_model( 657 name="test_model", 658 python_model=MyModel(), 659 input_example=["a", "b", "c"], 660 ) 661 mlflow_model = Model.load(f"runs:/{run.info.run_id}/test_model/MLmodel") 662 assert mlflow_model.load_input_example() == ["a", "b", "c"] 663 664 model = Model.load(f"runs:/{run.info.run_id}/test_model") 665 assert model == mlflow_model 666 667 model = Model.load(f"runs:/{run.info.run_id}/test_model/") 668 assert model == mlflow_model 669 670 671 def test_save_model_with_prompts(): 672 prompt_1 = mlflow.register_prompt("prompt-1", "Hello, {{title}} {{name}}!") 673 time.sleep(0.001) # To avoid timestamp precision issue in Windows 674 prompt_2 = mlflow.register_prompt("prompt-2", "Hello, {{title}} {{name}}!") 675 676 class MyModel(mlflow.pyfunc.PythonModel): 677 def predict(self, model_input: list[str]): 678 return model_input 679 680 with mlflow.start_run(): 681 model_info = mlflow.pyfunc.log_model( 682 name="test_model", 683 python_model=MyModel(), 684 # The 'prompts' parameter should accept both prompt object and URI 685 prompts=[prompt_1, prompt_2.uri], 686 ) 687 688 assert model_info.prompts == [prompt_1.uri, prompt_2.uri] 689 690 # Prompts should be recorded in the yaml file 691 model = Model.load(model_info.model_uri) 692 assert model.prompts == [prompt_1.uri, prompt_2.uri] 693 694 # Check that prompts were linked to the run via the linkedPrompts tag 695 from mlflow.tracing.constant import TraceTagKey 696 697 run = mlflow.MlflowClient().get_run(model_info.run_id) 698 linked_prompts_tag = run.data.tags.get(TraceTagKey.LINKED_PROMPTS) 699 assert linked_prompts_tag is not None 700 701 linked_prompts = json.loads(linked_prompts_tag) 702 assert len(linked_prompts) == 2 703 assert {p["name"] for p in linked_prompts} == {prompt_1.name, prompt_2.name} 704 705 706 def test_logged_model_status(): 707 def predict_fn(model_input: list[str]): 708 return model_input 709 710 model_info = mlflow.pyfunc.log_model( 711 name="test_model", 712 python_model=predict_fn, 713 input_example=["a", "b", "c"], 714 ) 715 logged_model = mlflow.get_logged_model(model_info.model_id) 716 assert logged_model.status == "READY" 717 718 with pytest.raises(Exception, match=r"mock exception"): 719 with mock.patch( 720 "mlflow.pyfunc.model._save_model_with_class_artifacts_params", 721 side_effect=Exception("mock exception"), 722 ): 723 mlflow.pyfunc.log_model( 724 name="test_model", 725 python_model=predict_fn, 726 input_example=["a", "b", "c"], 727 ) 728 logged_model = mlflow.last_logged_model() 729 assert logged_model.status == "FAILED" 730 731 732 def test_model_log_links_prompts_to_logged_model(): 733 client = mlflow.MlflowClient() 734 735 # Create actual prompts in the registry 736 client.create_prompt(name="test_prompt_1") 737 prompt_1 = client.create_prompt_version(name="test_prompt_1", template="Hello {{name}}") 738 client.create_prompt(name="test_prompt_2") 739 prompt_2 = client.create_prompt_version(name="test_prompt_2", template="Goodbye {{name}}") 740 741 with mlflow.start_run() as run: 742 model_info = Model.log("model", TestFlavor, prompts=[prompt_1, prompt_2]) 743 744 # Verify prompts were linked to the run 745 run_data = client.get_run(run.info.run_id) 746 linked_prompts_tag = run_data.data.tags.get("mlflow.linkedPrompts") 747 assert linked_prompts_tag is not None 748 linked_prompts = json.loads(linked_prompts_tag) 749 assert len(linked_prompts) == 2 750 assert {p["name"] for p in linked_prompts} == {"test_prompt_1", "test_prompt_2"} 751 752 # Verify prompts were linked to the LoggedModel 753 logged_model = client.get_logged_model(model_info.model_id) 754 model_linked_prompts_tag = logged_model.tags.get("mlflow.linkedPrompts") 755 assert model_linked_prompts_tag is not None 756 model_linked_prompts = json.loads(model_linked_prompts_tag) 757 assert len(model_linked_prompts) == 2 758 assert {p["name"] for p in model_linked_prompts} == {"test_prompt_1", "test_prompt_2"} 759 760 761 def test_get_model_info_with_logged_model(): 762 def model(model_input: list[str]) -> list[str]: 763 return model_input 764 765 model_info_log_model = mlflow.pyfunc.log_model( 766 name="test_model", python_model=model, input_example=["a", "b", "c"] 767 ) 768 model_info_get_model_info = mlflow.models.get_model_info(model_info_log_model.model_uri) 769 assert model_info_log_model.model_id == model_info_get_model_info.model_id 770 assert model_info_log_model.name == model_info_get_model_info.name