test_pytorch_model_export.py
1 import importlib 2 import json 3 import logging 4 import os 5 import pickle 6 import re 7 from pathlib import Path 8 from unittest import mock 9 10 import numpy as np 11 import pandas as pd 12 import pytest 13 import torch 14 import yaml 15 from packaging.version import Version 16 from sklearn import datasets 17 from torch import nn 18 from torch.utils.data import DataLoader 19 20 import mlflow.pyfunc.scoring_server as pyfunc_scoring_server 21 import mlflow.pytorch 22 from mlflow import pyfunc 23 from mlflow.exceptions import MlflowException 24 from mlflow.models import Model, ModelSignature 25 from mlflow.models.utils import _read_example, load_serving_example 26 from mlflow.pytorch import pickle_module as mlflow_pytorch_pickle_module 27 from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository 28 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 29 from mlflow.types.schema import DataType, Schema, TensorSpec 30 from mlflow.utils.environment import _mlflow_conda_env 31 from mlflow.utils.file_utils import TempDir 32 from mlflow.utils.model_utils import _get_flavor_configuration 33 34 from tests.helper_functions import ( 35 _assert_pip_requirements, 36 _compare_conda_env_requirements, 37 _compare_logged_code_paths, 38 _is_available_on_pypi, 39 _is_importable, 40 _mlflow_major_version_string, 41 assert_array_almost_equal, 42 assert_register_model_called_with_local_model_path, 43 ) 44 45 _logger = logging.getLogger(__name__) 46 47 # This test suite is included as a code dependency when testing PyTorch model scoring in new 48 # processes and docker containers. In these environments, the `tests` module is not available. 49 # Therefore, we attempt to import from `tests` and gracefully emit a warning if it's unavailable. 50 try: 51 from tests.helper_functions import pyfunc_serve_and_score_model 52 except ImportError: 53 _logger.warning( 54 "Failed to import test helper functions. Tests depending on these functions may fail!" 55 ) 56 57 EXTRA_PYFUNC_SERVING_TEST_ARGS = ( 58 [] if _is_available_on_pypi("torch") else ["--env-manager", "local"] 59 ) 60 61 # in pytorch >= 2.6.0, the `weights_only` kwarg default has been changed from 62 # `False` to `True`. this can cause pickle deserialization errors when loading 63 # models, unless the model classes have been explicitly marked as safe using 64 # `torch.serialization.add_safe_globals()` 65 ENABLE_LEGACY_DESERIALIZATION = Version(torch.__version__) >= Version("2.6.0") 66 67 68 @pytest.fixture(scope="module") 69 def data(): 70 iris = datasets.load_iris() 71 data = pd.DataFrame( 72 data=np.c_[iris["data"], iris["target"]], columns=iris["feature_names"] + ["target"] 73 ) 74 y = data["target"] 75 x = data.drop("target", axis=1) 76 return x, y 77 78 79 @pytest.fixture(scope="module") 80 def iris_tensor_spec(): 81 return ModelSignature( 82 inputs=Schema([TensorSpec(np.dtype("float32"), (-1, 4))]), 83 outputs=Schema([TensorSpec(np.dtype("float32"), (-1, 1))]), 84 ) 85 86 87 def get_dataset(data): 88 x, y = data 89 return [(xi.astype(np.float32), yi.astype(np.float32)) for xi, yi in zip(x.values, y.values)] 90 91 92 def train_model(model, data): 93 dataset = get_dataset(data) 94 criterion = nn.MSELoss() 95 optimizer = torch.optim.SGD(model.parameters(), lr=0.01) 96 batch_size = 16 97 num_workers = 4 98 dataloader = DataLoader( 99 dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, drop_last=False 100 ) 101 102 model.train() 103 for _ in range(5): 104 for batch in dataloader: 105 optimizer.zero_grad() 106 batch_size = batch[0].shape[0] 107 y_pred = model(batch[0]).squeeze(dim=1) 108 loss = criterion(y_pred, batch[1]) 109 loss.backward() 110 optimizer.step() 111 112 113 def get_sequential_model(): 114 return nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1)) 115 116 117 @pytest.fixture 118 def sequential_model(data, scripted_model): 119 model = get_sequential_model() 120 if scripted_model: 121 model = torch.jit.script(model) 122 123 train_model(model=model, data=data) 124 return model 125 126 127 def get_subclassed_model_definition(): 128 """ 129 Defines a PyTorch model class that inherits from ``torch.nn.Module``. This method can be invoked 130 within a pytest fixture to define the model class in the ``__main__`` scope. Alternatively, it 131 can be invoked within a module to define the class in the module's scope. 132 """ 133 134 class SubclassedModel(torch.nn.Module): 135 def __init__(self): 136 super().__init__() 137 self.linear = torch.nn.Linear(4, 1) 138 139 def forward(self, x): 140 return self.linear(x) 141 142 return SubclassedModel 143 144 145 @pytest.fixture(scope="module") 146 def main_scoped_subclassed_model(data): 147 """ 148 A custom PyTorch model inheriting from ``torch.nn.Module`` whose class is defined in the 149 "__main__" scope. 150 """ 151 model_class = get_subclassed_model_definition() 152 model = model_class() 153 train_model(model=model, data=data) 154 return model 155 156 157 class ModuleScopedSubclassedModel(get_subclassed_model_definition()): 158 """ 159 A custom PyTorch model class defined in the test module scope. This is a subclass of 160 ``torch.nn.Module``. 161 """ 162 163 164 @pytest.fixture(scope="module") 165 def module_scoped_subclassed_model(data): 166 """ 167 A custom PyTorch model inheriting from ``torch.nn.Module`` whose class is defined in the test 168 module scope. 169 """ 170 model = ModuleScopedSubclassedModel() 171 train_model(model=model, data=data) 172 return model 173 174 175 @pytest.fixture 176 def model_path(tmp_path): 177 return os.path.join(tmp_path, "model") 178 179 180 @pytest.fixture 181 def pytorch_custom_env(tmp_path): 182 conda_env = os.path.join(tmp_path, "conda_env.yml") 183 _mlflow_conda_env(conda_env, additional_pip_deps=["pytorch", "torchvision", "pytest"]) 184 return conda_env 185 186 187 def _predict(model, data): 188 from torch.fx import GraphModule 189 190 dataset = get_dataset(data) 191 batch_size = 16 192 num_workers = 4 193 dataloader = DataLoader( 194 dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, drop_last=False 195 ) 196 predictions = np.zeros((len(dataloader.sampler),)) 197 198 if not isinstance(model, GraphModule): 199 model.eval() 200 with torch.no_grad(): 201 for i, batch in enumerate(dataloader): 202 y_preds = model(batch[0]).squeeze(dim=1).numpy() 203 predictions[i * batch_size : (i + 1) * batch_size] = y_preds 204 return predictions 205 206 207 @pytest.fixture 208 def sequential_predicted(sequential_model, data): 209 return _predict(sequential_model, data) 210 211 212 @pytest.mark.parametrize("scripted_model", [True, False]) 213 def test_signature_and_examples_are_saved_correctly(sequential_model, data, iris_tensor_spec): 214 model = sequential_model 215 example_ = data[0].head(3).values.astype(np.float32) 216 for signature in (None, iris_tensor_spec): 217 for example in (None, example_): 218 with TempDir() as tmp: 219 path = tmp.path("model") 220 mlflow.pytorch.save_model( 221 model, path=path, signature=signature, input_example=example 222 ) 223 mlflow_model = Model.load(path) 224 if signature is None and example is None: 225 assert mlflow_model.signature is None 226 else: 227 assert mlflow_model.signature == iris_tensor_spec 228 if example is None: 229 assert mlflow_model.saved_input_example_info is None 230 else: 231 np.testing.assert_allclose(_read_example(mlflow_model, path), example) 232 233 234 @pytest.mark.parametrize("scripted_model", [True, False]) 235 def test_log_model(sequential_model, data, sequential_predicted): 236 try: 237 artifact_path = "pytorch" 238 model_info = mlflow.pytorch.log_model(sequential_model, name=artifact_path) 239 240 sequential_model_loaded = mlflow.pytorch.load_model(model_uri=model_info.model_uri) 241 test_predictions = _predict(sequential_model_loaded, data) 242 np.testing.assert_array_equal(test_predictions, sequential_predicted) 243 finally: 244 mlflow.end_run() 245 246 247 def test_log_model_calls_register_model(module_scoped_subclassed_model): 248 custom_pickle_module = pickle 249 artifact_path = "model" 250 register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model") 251 with mlflow.start_run(), register_model_patch: 252 model_info = mlflow.pytorch.log_model( 253 module_scoped_subclassed_model, 254 name=artifact_path, 255 pickle_module=custom_pickle_module, 256 registered_model_name="AdsModel1", 257 ) 258 assert_register_model_called_with_local_model_path( 259 register_model_mock=mlflow.tracking._model_registry.fluent._register_model, 260 model_uri=model_info.model_uri, 261 registered_model_name="AdsModel1", 262 ) 263 264 265 def test_log_model_no_registered_model_name(module_scoped_subclassed_model): 266 custom_pickle_module = pickle 267 artifact_path = "model" 268 register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model") 269 with mlflow.start_run(), register_model_patch: 270 mlflow.pytorch.log_model( 271 module_scoped_subclassed_model, 272 name=artifact_path, 273 pickle_module=custom_pickle_module, 274 ) 275 mlflow.tracking._model_registry.fluent._register_model.assert_not_called() 276 277 278 @pytest.mark.parametrize("scripted_model", [True, False]) 279 def test_raise_exception(sequential_model): 280 with TempDir(chdr=True, remove_on_exit=True) as tmp: 281 path = tmp.path("model") 282 with pytest.raises(MlflowException, match="No such artifact"): 283 mlflow.pytorch.load_model(path) 284 285 with pytest.raises(TypeError, match="Argument 'pytorch_model' should be a torch.nn.Module"): 286 mlflow.pytorch.save_model([1, 2, 3], path) 287 288 mlflow.pytorch.save_model(sequential_model, path) 289 with pytest.raises(MlflowException, match=f"Path '{os.path.abspath(path)}' already exists"): 290 mlflow.pytorch.save_model(sequential_model, path) 291 292 import sklearn.neighbors as knn 293 294 from mlflow import sklearn 295 296 path = tmp.path("knn.pkl") 297 knn = knn.KNeighborsClassifier() 298 with open(path, "wb") as f: 299 pickle.dump(knn, f) 300 path = tmp.path("knn") 301 sklearn.save_model(knn, path=path) 302 with pytest.raises(MlflowException, match='Model does not have the "pytorch" flavor'): 303 mlflow.pytorch.load_model(path) 304 305 306 @pytest.mark.parametrize("scripted_model", [True, False]) 307 def test_save_and_load_model(sequential_model, model_path, data, sequential_predicted): 308 mlflow.pytorch.save_model(sequential_model, model_path) 309 310 # Loading pytorch model 311 sequential_model_loaded = mlflow.pytorch.load_model(model_path) 312 np.testing.assert_array_equal(_predict(sequential_model_loaded, data), sequential_predicted) 313 314 # Loading pyfunc model 315 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 316 np.testing.assert_array_almost_equal( 317 pyfunc_loaded.predict(data[0]).values[:, 0], sequential_predicted, decimal=4 318 ) 319 320 321 @pytest.mark.parametrize("scripted_model", [True, False]) 322 def test_pyfunc_model_works_with_np_input_type( 323 sequential_model, model_path, data, sequential_predicted 324 ): 325 mlflow.pytorch.save_model(sequential_model, model_path) 326 327 # Loading pyfunc model 328 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 329 330 # predict works with dataframes 331 df_result = pyfunc_loaded.predict(data[0]) 332 assert type(df_result) == pd.DataFrame 333 np.testing.assert_array_almost_equal(df_result.values[:, 0], sequential_predicted, decimal=4) 334 335 # predict works with numpy ndarray 336 np_result = pyfunc_loaded.predict(data[0].values.astype(np.float32)) 337 assert type(np_result) == np.ndarray 338 np.testing.assert_array_almost_equal(np_result[:, 0], sequential_predicted, decimal=4) 339 340 # predict does not work with lists 341 with pytest.raises( 342 TypeError, match="The PyTorch flavor does not support List or Dict input types" 343 ): 344 pyfunc_loaded.predict([1, 2, 3, 4]) 345 346 # predict does not work with scalars 347 with pytest.raises(TypeError, match="Input data should be pandas.DataFrame or numpy.ndarray"): 348 pyfunc_loaded.predict(4) 349 350 351 @pytest.mark.parametrize("scripted_model", [True, False]) 352 def test_load_model_from_remote_uri_succeeds( 353 sequential_model, model_path, mock_s3_bucket, data, sequential_predicted 354 ): 355 mlflow.pytorch.save_model(sequential_model, model_path) 356 357 artifact_root = f"s3://{mock_s3_bucket}" 358 artifact_path = "model" 359 artifact_repo = S3ArtifactRepository(artifact_root) 360 artifact_repo.log_artifacts(model_path, artifact_path=artifact_path) 361 362 model_uri = artifact_root + "/" + artifact_path 363 sequential_model_loaded = mlflow.pytorch.load_model(model_uri=model_uri) 364 np.testing.assert_array_equal(_predict(sequential_model_loaded, data), sequential_predicted) 365 366 367 @pytest.mark.parametrize("scripted_model", [True, False]) 368 def test_model_save_persists_specified_conda_env_in_mlflow_model_directory( 369 sequential_model, model_path, pytorch_custom_env 370 ): 371 mlflow.pytorch.save_model( 372 pytorch_model=sequential_model, path=model_path, conda_env=pytorch_custom_env 373 ) 374 375 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 376 saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"]) 377 assert os.path.exists(saved_conda_env_path) 378 assert saved_conda_env_path != pytorch_custom_env 379 380 with open(pytorch_custom_env) as f: 381 pytorch_custom_env_text = f.read() 382 with open(saved_conda_env_path) as f: 383 saved_conda_env_text = f.read() 384 assert saved_conda_env_text == pytorch_custom_env_text 385 386 387 @pytest.mark.parametrize("scripted_model", [True, False]) 388 def test_model_save_persists_requirements_in_mlflow_model_directory( 389 sequential_model, model_path, pytorch_custom_env 390 ): 391 mlflow.pytorch.save_model( 392 pytorch_model=sequential_model, path=model_path, conda_env=pytorch_custom_env 393 ) 394 395 saved_pip_req_path = os.path.join(model_path, "requirements.txt") 396 _compare_conda_env_requirements(pytorch_custom_env, saved_pip_req_path) 397 398 399 @pytest.mark.parametrize("scripted_model", [False]) 400 def test_save_model_with_pip_requirements(sequential_model, tmp_path): 401 expected_mlflow_version = _mlflow_major_version_string() 402 # Path to a requirements file 403 tmpdir1 = tmp_path.joinpath("1") 404 req_file = tmp_path.joinpath("requirements.txt") 405 req_file.write_text("a") 406 mlflow.pytorch.save_model(sequential_model, tmpdir1, pip_requirements=str(req_file)) 407 _assert_pip_requirements(tmpdir1, [expected_mlflow_version, "a"], strict=True) 408 409 # List of requirements 410 tmpdir2 = tmp_path.joinpath("2") 411 mlflow.pytorch.save_model(sequential_model, tmpdir2, pip_requirements=[f"-r {req_file}", "b"]) 412 _assert_pip_requirements(tmpdir2, [expected_mlflow_version, "a", "b"], strict=True) 413 414 # Constraints file 415 tmpdir3 = tmp_path.joinpath("3") 416 mlflow.pytorch.save_model(sequential_model, tmpdir3, pip_requirements=[f"-c {req_file}", "b"]) 417 _assert_pip_requirements( 418 tmpdir3, [expected_mlflow_version, "b", "-c constraints.txt"], ["a"], strict=True 419 ) 420 421 422 @pytest.mark.parametrize("scripted_model", [False]) 423 def test_save_model_with_extra_pip_requirements(sequential_model, tmp_path): 424 expected_mlflow_version = _mlflow_major_version_string() 425 default_reqs = mlflow.pytorch.get_default_pip_requirements() 426 427 # Path to a requirements file 428 tmpdir1 = tmp_path.joinpath("1") 429 req_file = tmp_path.joinpath("requirements.txt") 430 req_file.write_text("a") 431 mlflow.pytorch.save_model(sequential_model, tmpdir1, extra_pip_requirements=str(req_file)) 432 _assert_pip_requirements(tmpdir1, [expected_mlflow_version, *default_reqs, "a"]) 433 434 # List of requirements 435 tmpdir2 = tmp_path.joinpath("2") 436 mlflow.pytorch.save_model( 437 sequential_model, tmpdir2, extra_pip_requirements=[f"-r {req_file}", "b"] 438 ) 439 _assert_pip_requirements(tmpdir2, [expected_mlflow_version, *default_reqs, "a", "b"]) 440 441 # Constraints file 442 tmpdir3 = tmp_path.joinpath("3") 443 mlflow.pytorch.save_model( 444 sequential_model, tmpdir3, extra_pip_requirements=[f"-c {req_file}", "b"] 445 ) 446 _assert_pip_requirements( 447 tmpdir3, [expected_mlflow_version, *default_reqs, "b", "-c constraints.txt"], ["a"] 448 ) 449 450 451 @pytest.mark.parametrize("scripted_model", [True, False]) 452 def test_model_save_accepts_conda_env_as_dict(sequential_model, model_path): 453 conda_env = dict(mlflow.pytorch.get_default_conda_env()) 454 conda_env["dependencies"].append("pytest") 455 mlflow.pytorch.save_model(pytorch_model=sequential_model, path=model_path, conda_env=conda_env) 456 457 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 458 saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"]) 459 assert os.path.exists(saved_conda_env_path) 460 461 with open(saved_conda_env_path) as f: 462 saved_conda_env_parsed = yaml.safe_load(f) 463 assert saved_conda_env_parsed == conda_env 464 465 466 @pytest.mark.parametrize("scripted_model", [True, False]) 467 def test_model_log_persists_specified_conda_env_in_mlflow_model_directory( 468 sequential_model, pytorch_custom_env 469 ): 470 artifact_path = "model" 471 with mlflow.start_run(): 472 model_info = mlflow.pytorch.log_model( 473 sequential_model, 474 name=artifact_path, 475 conda_env=pytorch_custom_env, 476 ) 477 model_path = _download_artifact_from_uri(model_info.model_uri) 478 479 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 480 saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"]) 481 assert os.path.exists(saved_conda_env_path) 482 assert saved_conda_env_path != pytorch_custom_env 483 484 with open(pytorch_custom_env) as f: 485 pytorch_custom_env_text = f.read() 486 with open(saved_conda_env_path) as f: 487 saved_conda_env_text = f.read() 488 assert saved_conda_env_text == pytorch_custom_env_text 489 490 491 @pytest.mark.parametrize("scripted_model", [True, False]) 492 def test_model_log_persists_requirements_in_mlflow_model_directory( 493 sequential_model, pytorch_custom_env 494 ): 495 artifact_path = "model" 496 with mlflow.start_run(): 497 model_info = mlflow.pytorch.log_model( 498 sequential_model, 499 name=artifact_path, 500 conda_env=pytorch_custom_env, 501 ) 502 model_path = _download_artifact_from_uri(model_info.model_uri) 503 504 saved_pip_req_path = os.path.join(model_path, "requirements.txt") 505 _compare_conda_env_requirements(pytorch_custom_env, saved_pip_req_path) 506 507 508 @pytest.mark.parametrize("scripted_model", [True, False]) 509 def test_model_save_without_specified_conda_env_uses_default_env_with_expected_dependencies( 510 sequential_model, model_path 511 ): 512 mlflow.pytorch.save_model(pytorch_model=sequential_model, path=model_path) 513 _assert_pip_requirements(model_path, mlflow.pytorch.get_default_pip_requirements()) 514 515 516 @pytest.mark.parametrize("scripted_model", [True, False]) 517 def test_model_log_without_specified_conda_env_uses_default_env_with_expected_dependencies( 518 sequential_model, 519 ): 520 with mlflow.start_run(): 521 model_info = mlflow.pytorch.log_model(sequential_model, name="model") 522 523 _assert_pip_requirements(model_info.model_uri, mlflow.pytorch.get_default_pip_requirements()) 524 525 526 @pytest.mark.parametrize("scripted_model", [True, False]) 527 def test_load_model_with_differing_pytorch_version_logs_warning(sequential_model, model_path): 528 mlflow.pytorch.save_model(pytorch_model=sequential_model, path=model_path) 529 saver_pytorch_version = "1.0" 530 model_config_path = os.path.join(model_path, "MLmodel") 531 model_config = Model.load(model_config_path) 532 model_config.flavors[mlflow.pytorch.FLAVOR_NAME]["pytorch_version"] = saver_pytorch_version 533 model_config.save(model_config_path) 534 535 log_messages = [] 536 537 def custom_warn(message_text, *args, **kwargs): 538 log_messages.append(message_text % args % kwargs) 539 540 loader_pytorch_version = "0.8.2" 541 with ( 542 mock.patch("mlflow.pytorch._logger.warning") as warn_mock, 543 mock.patch("torch.__version__", loader_pytorch_version), 544 ): 545 warn_mock.side_effect = custom_warn 546 mlflow.pytorch.load_model(model_uri=model_path) 547 548 assert any( 549 "does not match installed PyTorch version" in log_message 550 and saver_pytorch_version in log_message 551 and loader_pytorch_version in log_message 552 for log_message in log_messages 553 ) 554 555 556 def test_pyfunc_model_serving_with_module_scoped_subclassed_model_and_default_conda_env( 557 module_scoped_subclassed_model, data 558 ): 559 with mlflow.start_run(): 560 model_info = mlflow.pytorch.log_model( 561 module_scoped_subclassed_model, 562 name="pytorch_model", 563 code_paths=[__file__], 564 input_example=data[0], 565 ) 566 567 inference_payload = load_serving_example(model_info.model_uri) 568 scoring_response = pyfunc_serve_and_score_model( 569 model_uri=model_info.model_uri, 570 data=inference_payload, 571 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 572 extra_args=["--env-manager", "local"], 573 ) 574 assert scoring_response.status_code == 200 575 576 deployed_model_preds = pd.DataFrame(json.loads(scoring_response.content)["predictions"]) 577 np.testing.assert_array_almost_equal( 578 deployed_model_preds.values[:, 0], 579 _predict(model=module_scoped_subclassed_model, data=data), 580 decimal=4, 581 ) 582 583 584 def test_save_model_with_wrong_codepaths_fails_correctly( 585 module_scoped_subclassed_model, model_path, data 586 ): 587 with pytest.raises(TypeError, match="Argument code_paths should be a list, not <class 'str'>"): 588 mlflow.pytorch.save_model( 589 path=model_path, pytorch_model=module_scoped_subclassed_model, code_paths="some string" 590 ) 591 592 593 def test_pyfunc_model_serving_with_main_scoped_subclassed_model_and_custom_pickle_module( 594 main_scoped_subclassed_model, data 595 ): 596 with mlflow.start_run(): 597 model_info = mlflow.pytorch.log_model( 598 main_scoped_subclassed_model, 599 name="pytorch_model", 600 pickle_module=mlflow_pytorch_pickle_module, 601 input_example=data[0], 602 ) 603 604 inference_payload = load_serving_example(model_info.model_uri) 605 scoring_response = pyfunc_serve_and_score_model( 606 model_uri=model_info.model_uri, 607 data=inference_payload, 608 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 609 extra_args=["--env-manager", "local"], 610 ) 611 assert scoring_response.status_code == 200 612 613 deployed_model_preds = pd.DataFrame(json.loads(scoring_response.content)["predictions"]) 614 np.testing.assert_array_almost_equal( 615 deployed_model_preds.values[:, 0], 616 _predict(model=main_scoped_subclassed_model, data=data), 617 decimal=4, 618 ) 619 620 621 def test_load_model_succeeds_with_dependencies_specified_via_code_paths( 622 module_scoped_subclassed_model, model_path, data 623 ): 624 # Save a PyTorch model whose class is defined in the current test suite. Because the 625 # `tests` module is not available when the model is deployed for local scoring, we include 626 # the test suite file as a code dependency 627 mlflow.pytorch.save_model( 628 path=model_path, 629 pytorch_model=module_scoped_subclassed_model, 630 code_paths=[__file__], 631 ) 632 633 # Define a custom pyfunc model that loads a PyTorch model artifact using 634 # `mlflow.pytorch.load_model` 635 class TorchValidatorModel(pyfunc.PythonModel): 636 def load_context(self, context): 637 self.pytorch_model = mlflow.pytorch.load_model(context.artifacts["pytorch_model"]) 638 639 def predict(self, context, model_input, params=None): 640 with torch.no_grad(): 641 input_tensor = torch.from_numpy(model_input.values.astype(np.float32)) 642 output_tensor = self.pytorch_model(input_tensor) 643 return pd.DataFrame(output_tensor.numpy()) 644 645 pyfunc_artifact_path = "pyfunc_model" 646 with mlflow.start_run(): 647 model_info = pyfunc.log_model( 648 pyfunc_artifact_path, 649 python_model=TorchValidatorModel(), 650 artifacts={"pytorch_model": model_path}, 651 input_example=data[0], 652 # save file into code_paths, otherwise after first model loading (happens when 653 # validating input_example) then we can not load the model again 654 code_paths=[__file__], 655 ) 656 657 # Deploy the custom pyfunc model and ensure that it is able to successfully load its 658 # constituent PyTorch model via `mlflow.pytorch.load_model` 659 inference_payload = load_serving_example(model_info.model_uri) 660 scoring_response = pyfunc_serve_and_score_model( 661 model_uri=model_info.model_uri, 662 data=inference_payload, 663 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 664 extra_args=["--env-manager", "local"], 665 ) 666 assert scoring_response.status_code == 200 667 668 deployed_model_preds = pd.DataFrame(json.loads(scoring_response.content)["predictions"]) 669 np.testing.assert_array_almost_equal( 670 deployed_model_preds.values[:, 0], 671 _predict(model=module_scoped_subclassed_model, data=data), 672 decimal=4, 673 ) 674 675 676 def test_load_pyfunc_loads_torch_model_using_pickle_module_specified_at_save_time( 677 module_scoped_subclassed_model, model_path 678 ): 679 custom_pickle_module = pickle 680 681 mlflow.pytorch.save_model( 682 path=model_path, 683 pytorch_model=module_scoped_subclassed_model, 684 pickle_module=custom_pickle_module, 685 ) 686 687 import_module_fn = importlib.import_module 688 imported_modules = [] 689 690 def track_module_imports(module_name): 691 imported_modules.append(module_name) 692 return import_module_fn(module_name) 693 694 with ( 695 mock.patch("importlib.import_module") as import_mock, 696 mock.patch("torch.load") as torch_load_mock, 697 ): 698 import_mock.side_effect = track_module_imports 699 pyfunc.load_model(model_path) 700 701 expected_kwargs = {"pickle_module": custom_pickle_module} 702 if ENABLE_LEGACY_DESERIALIZATION: 703 expected_kwargs["weights_only"] = False 704 705 torch_load_mock.assert_called_with(mock.ANY, **expected_kwargs) 706 assert custom_pickle_module.__name__ in imported_modules 707 708 709 def test_load_model_loads_torch_model_using_pickle_module_specified_at_save_time( 710 module_scoped_subclassed_model, 711 ): 712 custom_pickle_module = pickle 713 714 artifact_path = "pytorch_model" 715 with mlflow.start_run(): 716 model_info = mlflow.pytorch.log_model( 717 module_scoped_subclassed_model, 718 name=artifact_path, 719 pickle_module=custom_pickle_module, 720 ) 721 model_uri = model_info.model_uri 722 723 import_module_fn = importlib.import_module 724 imported_modules = [] 725 726 def track_module_imports(module_name): 727 imported_modules.append(module_name) 728 return import_module_fn(module_name) 729 730 with ( 731 mock.patch("importlib.import_module") as import_mock, 732 mock.patch("torch.load") as torch_load_mock, 733 ): 734 import_mock.side_effect = track_module_imports 735 pyfunc.load_model(model_uri=model_uri) 736 737 expected_kwargs = {"pickle_module": custom_pickle_module} 738 if ENABLE_LEGACY_DESERIALIZATION: 739 expected_kwargs["weights_only"] = False 740 741 torch_load_mock.assert_called_with(mock.ANY, **expected_kwargs) 742 assert custom_pickle_module.__name__ in imported_modules 743 744 745 def test_load_pyfunc_succeeds_when_data_is_model_file_instead_of_directory( 746 module_scoped_subclassed_model, model_path, data 747 ): 748 """ 749 This test verifies that PyTorch models saved in older versions of MLflow are loaded successfully 750 by ``mlflow.pytorch.load_model``. The ``data`` path associated with these older models is 751 serialized PyTorch model file, as opposed to the current format: a directory containing a 752 serialized model file and pickle module information. 753 """ 754 mlflow.pytorch.save_model(path=model_path, pytorch_model=module_scoped_subclassed_model) 755 756 model_conf_path = os.path.join(model_path, "MLmodel") 757 model_conf = Model.load(model_conf_path) 758 pyfunc_conf = model_conf.flavors.get(pyfunc.FLAVOR_NAME) 759 assert pyfunc_conf is not None 760 model_data_path = os.path.join(model_path, pyfunc_conf[pyfunc.DATA]) 761 assert os.path.exists(model_data_path) 762 assert mlflow.pytorch._SERIALIZED_TORCH_MODEL_FILE_NAME in os.listdir(model_data_path) 763 pyfunc_conf[pyfunc.DATA] = os.path.join( 764 model_data_path, mlflow.pytorch._SERIALIZED_TORCH_MODEL_FILE_NAME 765 ) 766 model_conf.save(model_conf_path) 767 768 loaded_pyfunc = pyfunc.load_model(model_path) 769 770 np.testing.assert_array_almost_equal( 771 loaded_pyfunc.predict(data[0]), 772 pd.DataFrame(_predict(model=module_scoped_subclassed_model, data=data)), 773 decimal=4, 774 ) 775 776 777 def test_load_model_succeeds_when_data_is_model_file_instead_of_directory( 778 module_scoped_subclassed_model, model_path, data 779 ): 780 """ 781 This test verifies that PyTorch models saved in older versions of MLflow are loaded successfully 782 by ``mlflow.pytorch.load_model``. The ``data`` path associated with these older models is 783 serialized PyTorch model file, as opposed to the current format: a directory containing a 784 serialized model file and pickle module information. 785 """ 786 artifact_path = "pytorch_model" 787 with mlflow.start_run(): 788 model_info = mlflow.pytorch.log_model(module_scoped_subclassed_model, name=artifact_path) 789 model_path = _download_artifact_from_uri(model_info.model_uri) 790 791 model_conf_path = os.path.join(model_path, "MLmodel") 792 model_conf = Model.load(model_conf_path) 793 pyfunc_conf = model_conf.flavors.get(pyfunc.FLAVOR_NAME) 794 assert pyfunc_conf is not None 795 model_data_path = os.path.join(model_path, pyfunc_conf[pyfunc.DATA]) 796 assert os.path.exists(model_data_path) 797 assert mlflow.pytorch._SERIALIZED_TORCH_MODEL_FILE_NAME in os.listdir(model_data_path) 798 pyfunc_conf[pyfunc.DATA] = os.path.join( 799 model_data_path, mlflow.pytorch._SERIALIZED_TORCH_MODEL_FILE_NAME 800 ) 801 model_conf.save(model_conf_path) 802 803 loaded_pyfunc = pyfunc.load_model(model_path) 804 805 np.testing.assert_array_almost_equal( 806 loaded_pyfunc.predict(data[0]), 807 pd.DataFrame(_predict(model=module_scoped_subclassed_model, data=data)), 808 decimal=4, 809 ) 810 811 812 def test_load_model_allows_user_to_override_pickle_module_via_keyword_argument( 813 module_scoped_subclassed_model, model_path 814 ): 815 mlflow.pytorch.save_model( 816 path=model_path, pytorch_model=module_scoped_subclassed_model, pickle_module=pickle 817 ) 818 819 with ( 820 mock.patch("torch.load") as torch_load_mock, 821 mock.patch("mlflow.pytorch._logger.warning") as warn_mock, 822 ): 823 mlflow.pytorch.load_model(model_uri=model_path, pickle_module=mlflow_pytorch_pickle_module) 824 torch_load_mock.assert_called_with(mock.ANY, pickle_module=mlflow_pytorch_pickle_module) 825 warn_mock.assert_any_call(mock.ANY, mlflow_pytorch_pickle_module.__name__, pickle.__name__) 826 827 828 def test_load_model_raises_exception_when_pickle_module_cannot_be_imported( 829 main_scoped_subclassed_model, model_path 830 ): 831 mlflow.pytorch.save_model(path=model_path, pytorch_model=main_scoped_subclassed_model) 832 833 bad_pickle_module_name = "not.a.real.module" 834 835 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 836 model_data_path = os.path.join(model_path, pyfunc_conf[pyfunc.DATA]) 837 assert os.path.exists(model_data_path) 838 assert mlflow.pytorch._PICKLE_MODULE_INFO_FILE_NAME in os.listdir(model_data_path) 839 with open( 840 os.path.join(model_data_path, mlflow.pytorch._PICKLE_MODULE_INFO_FILE_NAME), "w" 841 ) as f: 842 f.write(bad_pickle_module_name) 843 844 with pytest.raises( 845 MlflowException, 846 match=r"Failed to import the pickle module.+" + re.escape(bad_pickle_module_name), 847 ): 848 mlflow.pytorch.load_model(model_uri=model_path) 849 850 851 def test_pyfunc_serve_and_score(data): 852 model = torch.nn.Linear(4, 1) 853 train_model(model=model, data=data) 854 855 with mlflow.start_run(): 856 model_info = mlflow.pytorch.log_model(model, name="model", input_example=data[0]) 857 858 inference_payload = load_serving_example(model_info.model_uri) 859 resp = pyfunc_serve_and_score_model( 860 model_info.model_uri, 861 inference_payload, 862 pyfunc_scoring_server.CONTENT_TYPE_JSON, 863 extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS, 864 ) 865 from mlflow.deployments import PredictionsResponse 866 867 scores = PredictionsResponse.from_json(resp.content).get_predictions() 868 np.testing.assert_array_almost_equal(scores.values[:, 0], _predict(model=model, data=data)) 869 870 871 @pytest.mark.skipif(not _is_importable("transformers"), reason="This test requires transformers") 872 def test_pyfunc_serve_and_score_transformers(): 873 from transformers import BertConfig, BertModel 874 875 from mlflow.deployments import PredictionsResponse 876 877 class MyBertModel(BertModel): 878 def forward(self, *args, **kwargs): 879 return super().forward(*args, **kwargs).last_hidden_state 880 881 model = MyBertModel( 882 BertConfig( 883 vocab_size=16, 884 hidden_size=2, 885 num_hidden_layers=2, 886 num_attention_heads=2, 887 intermediate_size=2, 888 ) 889 ) 890 model.eval() 891 892 input_ids = model.dummy_inputs["input_ids"] 893 894 with mlflow.start_run(): 895 model_info = mlflow.pytorch.log_model( 896 model, name="model", input_example=np.array(input_ids.tolist()) 897 ) 898 899 inference_payload = load_serving_example(model_info.model_uri) 900 resp = pyfunc_serve_and_score_model( 901 model_info.model_uri, 902 inference_payload, 903 pyfunc_scoring_server.CONTENT_TYPE_JSON, 904 extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS, 905 ) 906 907 scores = PredictionsResponse.from_json(resp.content.decode("utf-8")).get_predictions( 908 predictions_format="ndarray" 909 ) 910 assert_array_almost_equal(scores, model(input_ids).detach().numpy(), rtol=1e-6) 911 912 913 @pytest.fixture 914 def create_requirements_file(tmp_path): 915 requirement_file_name = "requirements.txt" 916 fp = tmp_path.joinpath(requirement_file_name) 917 test_string = "mlflow" 918 fp.write_text(test_string) 919 return str(fp), test_string 920 921 922 @pytest.fixture 923 def create_extra_files(tmp_path): 924 fp1 = tmp_path.joinpath("extra1.txt") 925 fp2 = tmp_path.joinpath("extra2.txt") 926 fp1.write_text("1") 927 fp2.write_text("2") 928 return [str(fp1), str(fp2)], ["1", "2"] 929 930 931 @pytest.mark.parametrize("scripted_model", [True, False]) 932 def test_extra_files_log_model(create_extra_files, sequential_model): 933 extra_files, contents_expected = create_extra_files 934 with mlflow.start_run(): 935 mlflow.pytorch.log_model(sequential_model, name="models", extra_files=extra_files) 936 937 model_uri = "runs:/{run_id}/{model_path}".format( 938 run_id=mlflow.active_run().info.run_id, model_path="models" 939 ) 940 with TempDir(remove_on_exit=True) as tmp: 941 model_path = _download_artifact_from_uri(model_uri, tmp.path()) 942 model_config_path = os.path.join(model_path, "MLmodel") 943 model_config = Model.load(model_config_path) 944 flavor_config = model_config.flavors["pytorch"] 945 946 assert "extra_files" in flavor_config 947 loaded_extra_files = flavor_config["extra_files"] 948 949 for loaded_extra_file, content_expected in zip(loaded_extra_files, contents_expected): 950 assert "path" in loaded_extra_file 951 extra_file_path = os.path.join(model_path, loaded_extra_file["path"]) 952 with open(extra_file_path) as fp: 953 assert fp.read() == content_expected 954 955 956 @pytest.mark.parametrize("scripted_model", [True, False]) 957 def test_extra_files_save_model(create_extra_files, sequential_model): 958 extra_files, contents_expected = create_extra_files 959 with TempDir(remove_on_exit=True) as tmp: 960 model_path = os.path.join(tmp.path(), "models") 961 mlflow.pytorch.save_model( 962 pytorch_model=sequential_model, path=model_path, extra_files=extra_files 963 ) 964 model_config_path = os.path.join(model_path, "MLmodel") 965 model_config = Model.load(model_config_path) 966 flavor_config = model_config.flavors["pytorch"] 967 968 assert "extra_files" in flavor_config 969 loaded_extra_files = flavor_config["extra_files"] 970 971 for loaded_extra_file, content_expected in zip(loaded_extra_files, contents_expected): 972 assert "path" in loaded_extra_file 973 extra_file_path = os.path.join(model_path, loaded_extra_file["path"]) 974 with open(extra_file_path) as fp: 975 assert fp.read() == content_expected 976 977 978 @pytest.mark.parametrize("scripted_model", [True, False]) 979 def test_log_model_invalid_extra_file_path(sequential_model): 980 with ( 981 mlflow.start_run(), 982 pytest.raises(MlflowException, match="No such artifact: 'non_existing_file.txt'"), 983 ): 984 mlflow.pytorch.log_model( 985 sequential_model, 986 name="models", 987 extra_files=["non_existing_file.txt"], 988 ) 989 990 991 @pytest.mark.parametrize("scripted_model", [True, False]) 992 def test_log_model_invalid_extra_file_type(sequential_model): 993 with ( 994 mlflow.start_run(), 995 pytest.raises(TypeError, match="Extra files argument should be a list"), 996 ): 997 mlflow.pytorch.log_model( 998 sequential_model, 999 name="models", 1000 extra_files="non_existing_file.txt", 1001 ) 1002 1003 1004 def state_dict_equal(state_dict1, state_dict2): 1005 for key1 in state_dict1: 1006 if key1 not in state_dict2: 1007 return False 1008 1009 value1 = state_dict1[key1] 1010 value2 = state_dict2[key1] 1011 1012 if type(value1) != type(value2): 1013 return False 1014 elif isinstance(value1, dict): 1015 if not state_dict_equal(value1, value2): 1016 return False 1017 elif isinstance(value1, torch.Tensor): 1018 if not torch.equal(value1, value2): 1019 return False 1020 elif value1 != value2: 1021 return False 1022 else: 1023 continue 1024 1025 return True 1026 1027 1028 @pytest.mark.parametrize("scripted_model", [True, False]) 1029 def test_save_state_dict(sequential_model, model_path, data): 1030 state_dict = sequential_model.state_dict() 1031 mlflow.pytorch.save_state_dict(state_dict, model_path) 1032 1033 loaded_state_dict = mlflow.pytorch.load_state_dict(model_path) 1034 assert state_dict_equal(loaded_state_dict, state_dict) 1035 model = get_sequential_model() 1036 model.load_state_dict(loaded_state_dict) 1037 np.testing.assert_array_almost_equal( 1038 _predict(model, data), 1039 _predict(sequential_model, data), 1040 decimal=4, 1041 ) 1042 1043 1044 def test_save_state_dict_can_save_nested_state_dict(model_path): 1045 """ 1046 This test ensures that `save_state_dict` supports a use case described in the page below 1047 where a user bundles multiple objects (e.g., model, optimizer, learning-rate scheduler) 1048 into a single nested state_dict and loads it back later for inference or re-training: 1049 https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html 1050 """ 1051 model = get_sequential_model() 1052 optim = torch.optim.Adam(model.parameters()) 1053 state_dict = {"model": model.state_dict(), "optim": optim.state_dict()} 1054 mlflow.pytorch.save_state_dict(state_dict, model_path) 1055 1056 loaded_state_dict = mlflow.pytorch.load_state_dict(model_path) 1057 assert state_dict_equal(loaded_state_dict, state_dict) 1058 model.load_state_dict(loaded_state_dict["model"]) 1059 optim.load_state_dict(loaded_state_dict["optim"]) 1060 1061 1062 def test_load_state_dict_disallows_pickle_deserialization(model_path, monkeypatch): 1063 model = get_sequential_model() 1064 mlflow.pytorch.save_state_dict(model.state_dict(), model_path) 1065 1066 monkeypatch.setenv("MLFLOW_ALLOW_PICKLE_DESERIALIZATION", "false") 1067 with pytest.raises(MlflowException, match="MLFLOW_ALLOW_PICKLE_DESERIALIZATION"): 1068 mlflow.pytorch.load_state_dict(model_path) 1069 1070 1071 @pytest.mark.parametrize("not_state_dict", [0, "", get_sequential_model()]) 1072 def test_save_state_dict_throws_for_invalid_object_type(not_state_dict, model_path): 1073 with pytest.raises(TypeError, match="Invalid object type for `state_dict`"): 1074 mlflow.pytorch.save_state_dict(not_state_dict, model_path) 1075 1076 1077 @pytest.mark.parametrize("scripted_model", [True, False]) 1078 def test_log_state_dict(sequential_model, data): 1079 artifact_path = "model" 1080 state_dict = sequential_model.state_dict() 1081 with mlflow.start_run(): 1082 mlflow.pytorch.log_state_dict(state_dict, artifact_path) 1083 state_dict_uri = mlflow.get_artifact_uri(artifact_path) 1084 1085 loaded_state_dict = mlflow.pytorch.load_state_dict(state_dict_uri) 1086 assert state_dict_equal(loaded_state_dict, state_dict) 1087 model = get_sequential_model() 1088 model.load_state_dict(loaded_state_dict) 1089 np.testing.assert_array_almost_equal( 1090 _predict(model, data), 1091 _predict(sequential_model, data), 1092 decimal=4, 1093 ) 1094 1095 1096 @pytest.mark.parametrize("scripted_model", [True, False]) 1097 def test_log_model_with_code_paths(sequential_model): 1098 artifact_path = "model" 1099 with ( 1100 mlflow.start_run(), 1101 mock.patch("mlflow.pytorch._add_code_from_conf_to_system_path") as add_mock, 1102 ): 1103 model_info = mlflow.pytorch.log_model( 1104 sequential_model, name=artifact_path, code_paths=[__file__] 1105 ) 1106 _compare_logged_code_paths(__file__, model_info.model_uri, mlflow.pytorch.FLAVOR_NAME) 1107 mlflow.pytorch.load_model(model_info.model_uri) 1108 add_mock.assert_called() 1109 1110 1111 def test_virtualenv_subfield_points_to_correct_path(model_path): 1112 model = get_sequential_model() 1113 mlflow.pytorch.save_model(model, path=model_path) 1114 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 1115 python_env_path = Path(model_path, pyfunc_conf[pyfunc.ENV]["virtualenv"]) 1116 assert python_env_path.exists() 1117 assert python_env_path.is_file() 1118 1119 1120 @pytest.mark.parametrize("scripted_model", [True, False]) 1121 def test_model_save_load_with_metadata(sequential_model, model_path): 1122 mlflow.pytorch.save_model( 1123 sequential_model, path=model_path, metadata={"metadata_key": "metadata_value"} 1124 ) 1125 1126 reloaded_model = mlflow.pyfunc.load_model(model_uri=model_path) 1127 assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value" 1128 1129 1130 @pytest.mark.parametrize("scripted_model", [True, False]) 1131 def test_model_log_with_metadata(sequential_model): 1132 artifact_path = "model" 1133 1134 with mlflow.start_run(): 1135 model_info = mlflow.pytorch.log_model( 1136 sequential_model, 1137 name=artifact_path, 1138 metadata={"metadata_key": "metadata_value"}, 1139 ) 1140 1141 reloaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri) 1142 assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value" 1143 1144 1145 @pytest.mark.parametrize("scripted_model", [True, False]) 1146 def test_model_log_with_signature_inference(sequential_model, data): 1147 artifact_path = "model" 1148 example_ = data[0].head(3).values.astype(np.float32) 1149 1150 with mlflow.start_run(): 1151 model_info = mlflow.pytorch.log_model( 1152 sequential_model, name=artifact_path, input_example=example_ 1153 ) 1154 1155 assert model_info.signature == ModelSignature( 1156 inputs=Schema([TensorSpec(np.dtype("float32"), (-1, 4))]), 1157 outputs=Schema([TensorSpec(np.dtype("float32"), (-1, 1))]), 1158 ) 1159 inference_payload = load_serving_example(model_info.model_uri) 1160 response = pyfunc_serve_and_score_model( 1161 model_info.model_uri, 1162 inference_payload, 1163 pyfunc_scoring_server.CONTENT_TYPE_JSON, 1164 extra_args=["--env-manager", "local"], 1165 ) 1166 assert response.status_code == 200 1167 deployed_model_preds = pd.DataFrame(json.loads(response.content)["predictions"]) 1168 np.testing.assert_array_almost_equal( 1169 deployed_model_preds.values[:, 0], 1170 _predict(model=sequential_model, data=(data[0].head(3), data[1].head(3))), 1171 decimal=4, 1172 ) 1173 1174 1175 @pytest.mark.parametrize("scripted_model", [False]) 1176 def test_load_model_to_device(sequential_model): 1177 with mock.patch("mlflow.pytorch._load_model") as load_model_mock: 1178 with mlflow.start_run(): 1179 model_info = mlflow.pytorch.log_model(sequential_model, name="pytorch") 1180 model_config = {"device": "cuda"} 1181 if ENABLE_LEGACY_DESERIALIZATION: 1182 model_config["weights_only"] = False 1183 1184 mlflow.pyfunc.load_model(model_uri=model_info.model_uri, model_config=model_config) 1185 1186 load_model_mock.assert_called_with(mock.ANY, **model_config) 1187 mlflow.pytorch.load_model(model_uri=model_info.model_uri, **model_config) 1188 load_model_mock.assert_called_with(path=mock.ANY, **model_config) 1189 1190 1191 def test_passing_params_to_model(data): 1192 class CustomModel(torch.nn.Module): 1193 def __init__(self): 1194 super().__init__() 1195 self.linear = torch.nn.Linear(4, 1) 1196 1197 def forward(self, x, y): 1198 if not torch.is_tensor(x): 1199 x = torch.from_numpy(x) 1200 y = torch.tensor(y) 1201 combined = x * y 1202 return self.linear(combined) 1203 1204 model = CustomModel() 1205 x = np.random.randn(8, 4).astype(np.float32) 1206 1207 signature = mlflow.models.infer_signature(x, None, {"y": 1}) 1208 with mlflow.start_run(): 1209 model_info = mlflow.pytorch.log_model(model, name="model", signature=signature) 1210 1211 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 1212 with torch.no_grad(): 1213 np.testing.assert_array_almost_equal(pyfunc_model.predict(x), model(x, 1), decimal=4) 1214 np.testing.assert_array_almost_equal( 1215 pyfunc_model.predict(x, {"y": 2}), model(x, 2), decimal=4 1216 ) 1217 1218 1219 def test_log_model_with_datetime_input(): 1220 df = pd.DataFrame({ 1221 "datetime": pd.date_range("2022-01-01", periods=5, freq="D"), 1222 "x": np.random.uniform(20, 30, 5), 1223 "y": np.random.uniform(2, 4, 5), 1224 "z": np.random.uniform(0, 10, 5), 1225 }) 1226 model = get_sequential_model() 1227 model_info = mlflow.pytorch.log_model(model, name="pytorch", input_example=df) 1228 assert model_info.signature.inputs.inputs[0].type == DataType.datetime 1229 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 1230 with torch.no_grad(): 1231 input_tensor = torch.from_numpy(df.to_numpy(dtype=np.float32)) 1232 expected_result = model(input_tensor) 1233 with torch.no_grad(): 1234 np.testing.assert_array_almost_equal(pyfunc_model.predict(df), expected_result, decimal=4) 1235 1236 1237 @pytest.mark.skipif( 1238 Version(torch.__version__) < Version("2.4"), reason="This test requires torch>=2.4" 1239 ) 1240 @pytest.mark.parametrize("scripted_model", [False]) 1241 def test_save_and_load_exported_model(sequential_model, model_path, data, sequential_predicted): 1242 input_example = data[0].to_numpy(dtype=np.float32) 1243 1244 mlflow.pytorch.save_model( 1245 sequential_model, 1246 model_path, 1247 serialization_format="pt2", 1248 input_example=input_example, 1249 ) 1250 1251 # Loading pytorch model 1252 sequential_model_loaded = mlflow.pytorch.load_model(model_path) 1253 np.testing.assert_array_equal(_predict(sequential_model_loaded, data), sequential_predicted) 1254 1255 # Loading pyfunc model 1256 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 1257 np.testing.assert_array_almost_equal( 1258 pyfunc_loaded.predict(input_example)[:, 0], sequential_predicted, decimal=4 1259 ) 1260 1261 1262 @pytest.mark.skipif( 1263 Version(torch.__version__) < Version("2.4"), reason="This test requires torch>=2.4" 1264 ) 1265 def test_exported_model_infer_dynamic_dim(tmp_path): 1266 class MyModule(torch.nn.Module): 1267 def forward(self, x: torch.Tensor) -> torch.Tensor: 1268 return torch.sin(x) 1269 1270 origin_model = MyModule() 1271 1272 input_example = torch.randn(3, 4, 5).numpy() 1273 1274 save_path1 = tmp_path / "model1" 1275 1276 # test exporting model with auto inferred signature, 1277 # which sets the first dim (batch dim) of input data as dynamic dim. 1278 mlflow.pytorch.save_model( 1279 origin_model, 1280 save_path1, 1281 serialization_format="pt2", 1282 input_example=input_example, 1283 ) 1284 1285 # Test the exported model works with test data that changes the first dim (batch dim) size. 1286 loaded_model1 = mlflow.pytorch.load_model(save_path1) 1287 1288 test_data1 = torch.randn(6, 4, 5) 1289 np.testing.assert_array_almost_equal( 1290 loaded_model1(test_data1), 1291 origin_model(test_data1), 1292 decimal=4, 1293 ) 1294 1295 save_path2 = tmp_path / "model2" 1296 # test exporting model with provided signature, 1297 # which sets the second dim of input data as dynamic dim. 1298 mlflow.pytorch.save_model( 1299 origin_model, 1300 save_path2, 1301 serialization_format="pt2", 1302 input_example=input_example, 1303 signature=ModelSignature( 1304 inputs=Schema([TensorSpec(np.dtype("float32"), (3, -1, 5))]), 1305 ), 1306 ) 1307 1308 # Test the exported model works with test data that changes the second dim (batch dim) size. 1309 loaded_model2 = mlflow.pytorch.load_model(save_path2) 1310 1311 test_data2 = torch.randn(3, 2, 5) 1312 np.testing.assert_array_almost_equal( 1313 loaded_model2(test_data2), 1314 origin_model(test_data2), 1315 decimal=4, 1316 ) 1317 1318 1319 @pytest.mark.skipif( 1320 Version(torch.__version__) < Version("2.4"), reason="This test requires torch>=2.4" 1321 ) 1322 @pytest.mark.parametrize("scripted_model", [False]) 1323 def test_load_exported_model_check_device_mismatch(sequential_model, model_path): 1324 mlflow.pytorch.save_model( 1325 sequential_model, 1326 model_path, 1327 serialization_format="pt2", 1328 input_example=torch.randn(3, 4).numpy(), 1329 ) 1330 1331 # test loading model to CPU works 1332 mlflow.pytorch.load_model(model_path, device="cpu") 1333 1334 with pytest.raises( 1335 MlflowException, 1336 match="it can't be loaded on 'cuda' device.", 1337 ): 1338 mlflow.pytorch.load_model(model_path, device="cuda") 1339 1340 1341 @pytest.mark.skipif( 1342 Version(torch.__version__) < Version("2.4"), reason="This test requires torch>=2.4" 1343 ) 1344 def test_save_and_load_exported_model_with_multi_inputs(model_path): 1345 1346 class CustomModel(torch.nn.Module): 1347 def __init__(self): 1348 super().__init__() 1349 self.linear = torch.nn.Linear(4, 1) 1350 1351 def forward(self, x, y): 1352 with torch.no_grad(): 1353 return self.linear(x + y) 1354 1355 model = CustomModel() 1356 input_example = (torch.randn(10, 4), torch.randn(10, 4)) 1357 1358 mlflow.pytorch.save_model( 1359 model, 1360 model_path, 1361 serialization_format="pt2", 1362 input_example=input_example, 1363 signature=ModelSignature( 1364 inputs=Schema([ 1365 TensorSpec(np.dtype("float32"), (-1, 4), "v1"), 1366 TensorSpec(np.dtype("float32"), (-1, 4), "v2"), 1367 ]), 1368 ), 1369 ) 1370 1371 model_loaded = mlflow.pytorch.load_model(model_path) 1372 1373 np.testing.assert_array_almost_equal( 1374 model(*input_example), 1375 model_loaded(*input_example), 1376 decimal=4, 1377 )