/ tests / pytorch / test_pytorch_model_export.py
test_pytorch_model_export.py
   1  import importlib
   2  import json
   3  import logging
   4  import os
   5  import pickle
   6  import re
   7  from pathlib import Path
   8  from unittest import mock
   9  
  10  import numpy as np
  11  import pandas as pd
  12  import pytest
  13  import torch
  14  import yaml
  15  from packaging.version import Version
  16  from sklearn import datasets
  17  from torch import nn
  18  from torch.utils.data import DataLoader
  19  
  20  import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
  21  import mlflow.pytorch
  22  from mlflow import pyfunc
  23  from mlflow.exceptions import MlflowException
  24  from mlflow.models import Model, ModelSignature
  25  from mlflow.models.utils import _read_example, load_serving_example
  26  from mlflow.pytorch import pickle_module as mlflow_pytorch_pickle_module
  27  from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
  28  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
  29  from mlflow.types.schema import DataType, Schema, TensorSpec
  30  from mlflow.utils.environment import _mlflow_conda_env
  31  from mlflow.utils.file_utils import TempDir
  32  from mlflow.utils.model_utils import _get_flavor_configuration
  33  
  34  from tests.helper_functions import (
  35      _assert_pip_requirements,
  36      _compare_conda_env_requirements,
  37      _compare_logged_code_paths,
  38      _is_available_on_pypi,
  39      _is_importable,
  40      _mlflow_major_version_string,
  41      assert_array_almost_equal,
  42      assert_register_model_called_with_local_model_path,
  43  )
  44  
  45  _logger = logging.getLogger(__name__)
  46  
  47  # This test suite is included as a code dependency when testing PyTorch model scoring in new
  48  # processes and docker containers. In these environments, the `tests` module is not available.
  49  # Therefore, we attempt to import from `tests` and gracefully emit a warning if it's unavailable.
  50  try:
  51      from tests.helper_functions import pyfunc_serve_and_score_model
  52  except ImportError:
  53      _logger.warning(
  54          "Failed to import test helper functions. Tests depending on these functions may fail!"
  55      )
  56  
  57  EXTRA_PYFUNC_SERVING_TEST_ARGS = (
  58      [] if _is_available_on_pypi("torch") else ["--env-manager", "local"]
  59  )
  60  
  61  # in pytorch >= 2.6.0, the `weights_only` kwarg default has been changed from
  62  # `False` to `True`. this can cause pickle deserialization errors when loading
  63  # models, unless the model classes have been explicitly marked as safe using
  64  # `torch.serialization.add_safe_globals()`
  65  ENABLE_LEGACY_DESERIALIZATION = Version(torch.__version__) >= Version("2.6.0")
  66  
  67  
  68  @pytest.fixture(scope="module")
  69  def data():
  70      iris = datasets.load_iris()
  71      data = pd.DataFrame(
  72          data=np.c_[iris["data"], iris["target"]], columns=iris["feature_names"] + ["target"]
  73      )
  74      y = data["target"]
  75      x = data.drop("target", axis=1)
  76      return x, y
  77  
  78  
  79  @pytest.fixture(scope="module")
  80  def iris_tensor_spec():
  81      return ModelSignature(
  82          inputs=Schema([TensorSpec(np.dtype("float32"), (-1, 4))]),
  83          outputs=Schema([TensorSpec(np.dtype("float32"), (-1, 1))]),
  84      )
  85  
  86  
  87  def get_dataset(data):
  88      x, y = data
  89      return [(xi.astype(np.float32), yi.astype(np.float32)) for xi, yi in zip(x.values, y.values)]
  90  
  91  
  92  def train_model(model, data):
  93      dataset = get_dataset(data)
  94      criterion = nn.MSELoss()
  95      optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
  96      batch_size = 16
  97      num_workers = 4
  98      dataloader = DataLoader(
  99          dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, drop_last=False
 100      )
 101  
 102      model.train()
 103      for _ in range(5):
 104          for batch in dataloader:
 105              optimizer.zero_grad()
 106              batch_size = batch[0].shape[0]
 107              y_pred = model(batch[0]).squeeze(dim=1)
 108              loss = criterion(y_pred, batch[1])
 109              loss.backward()
 110              optimizer.step()
 111  
 112  
 113  def get_sequential_model():
 114      return nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1))
 115  
 116  
 117  @pytest.fixture
 118  def sequential_model(data, scripted_model):
 119      model = get_sequential_model()
 120      if scripted_model:
 121          model = torch.jit.script(model)
 122  
 123      train_model(model=model, data=data)
 124      return model
 125  
 126  
 127  def get_subclassed_model_definition():
 128      """
 129      Defines a PyTorch model class that inherits from ``torch.nn.Module``. This method can be invoked
 130      within a pytest fixture to define the model class in the ``__main__`` scope. Alternatively, it
 131      can be invoked within a module to define the class in the module's scope.
 132      """
 133  
 134      class SubclassedModel(torch.nn.Module):
 135          def __init__(self):
 136              super().__init__()
 137              self.linear = torch.nn.Linear(4, 1)
 138  
 139          def forward(self, x):
 140              return self.linear(x)
 141  
 142      return SubclassedModel
 143  
 144  
 145  @pytest.fixture(scope="module")
 146  def main_scoped_subclassed_model(data):
 147      """
 148      A custom PyTorch model inheriting from ``torch.nn.Module`` whose class is defined in the
 149      "__main__" scope.
 150      """
 151      model_class = get_subclassed_model_definition()
 152      model = model_class()
 153      train_model(model=model, data=data)
 154      return model
 155  
 156  
 157  class ModuleScopedSubclassedModel(get_subclassed_model_definition()):
 158      """
 159      A custom PyTorch model class defined in the test module scope. This is a subclass of
 160      ``torch.nn.Module``.
 161      """
 162  
 163  
 164  @pytest.fixture(scope="module")
 165  def module_scoped_subclassed_model(data):
 166      """
 167      A custom PyTorch model inheriting from ``torch.nn.Module`` whose class is defined in the test
 168      module scope.
 169      """
 170      model = ModuleScopedSubclassedModel()
 171      train_model(model=model, data=data)
 172      return model
 173  
 174  
 175  @pytest.fixture
 176  def model_path(tmp_path):
 177      return os.path.join(tmp_path, "model")
 178  
 179  
 180  @pytest.fixture
 181  def pytorch_custom_env(tmp_path):
 182      conda_env = os.path.join(tmp_path, "conda_env.yml")
 183      _mlflow_conda_env(conda_env, additional_pip_deps=["pytorch", "torchvision", "pytest"])
 184      return conda_env
 185  
 186  
 187  def _predict(model, data):
 188      from torch.fx import GraphModule
 189  
 190      dataset = get_dataset(data)
 191      batch_size = 16
 192      num_workers = 4
 193      dataloader = DataLoader(
 194          dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, drop_last=False
 195      )
 196      predictions = np.zeros((len(dataloader.sampler),))
 197  
 198      if not isinstance(model, GraphModule):
 199          model.eval()
 200      with torch.no_grad():
 201          for i, batch in enumerate(dataloader):
 202              y_preds = model(batch[0]).squeeze(dim=1).numpy()
 203              predictions[i * batch_size : (i + 1) * batch_size] = y_preds
 204      return predictions
 205  
 206  
 207  @pytest.fixture
 208  def sequential_predicted(sequential_model, data):
 209      return _predict(sequential_model, data)
 210  
 211  
 212  @pytest.mark.parametrize("scripted_model", [True, False])
 213  def test_signature_and_examples_are_saved_correctly(sequential_model, data, iris_tensor_spec):
 214      model = sequential_model
 215      example_ = data[0].head(3).values.astype(np.float32)
 216      for signature in (None, iris_tensor_spec):
 217          for example in (None, example_):
 218              with TempDir() as tmp:
 219                  path = tmp.path("model")
 220                  mlflow.pytorch.save_model(
 221                      model, path=path, signature=signature, input_example=example
 222                  )
 223                  mlflow_model = Model.load(path)
 224                  if signature is None and example is None:
 225                      assert mlflow_model.signature is None
 226                  else:
 227                      assert mlflow_model.signature == iris_tensor_spec
 228                  if example is None:
 229                      assert mlflow_model.saved_input_example_info is None
 230                  else:
 231                      np.testing.assert_allclose(_read_example(mlflow_model, path), example)
 232  
 233  
 234  @pytest.mark.parametrize("scripted_model", [True, False])
 235  def test_log_model(sequential_model, data, sequential_predicted):
 236      try:
 237          artifact_path = "pytorch"
 238          model_info = mlflow.pytorch.log_model(sequential_model, name=artifact_path)
 239  
 240          sequential_model_loaded = mlflow.pytorch.load_model(model_uri=model_info.model_uri)
 241          test_predictions = _predict(sequential_model_loaded, data)
 242          np.testing.assert_array_equal(test_predictions, sequential_predicted)
 243      finally:
 244          mlflow.end_run()
 245  
 246  
 247  def test_log_model_calls_register_model(module_scoped_subclassed_model):
 248      custom_pickle_module = pickle
 249      artifact_path = "model"
 250      register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model")
 251      with mlflow.start_run(), register_model_patch:
 252          model_info = mlflow.pytorch.log_model(
 253              module_scoped_subclassed_model,
 254              name=artifact_path,
 255              pickle_module=custom_pickle_module,
 256              registered_model_name="AdsModel1",
 257          )
 258          assert_register_model_called_with_local_model_path(
 259              register_model_mock=mlflow.tracking._model_registry.fluent._register_model,
 260              model_uri=model_info.model_uri,
 261              registered_model_name="AdsModel1",
 262          )
 263  
 264  
 265  def test_log_model_no_registered_model_name(module_scoped_subclassed_model):
 266      custom_pickle_module = pickle
 267      artifact_path = "model"
 268      register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model")
 269      with mlflow.start_run(), register_model_patch:
 270          mlflow.pytorch.log_model(
 271              module_scoped_subclassed_model,
 272              name=artifact_path,
 273              pickle_module=custom_pickle_module,
 274          )
 275          mlflow.tracking._model_registry.fluent._register_model.assert_not_called()
 276  
 277  
 278  @pytest.mark.parametrize("scripted_model", [True, False])
 279  def test_raise_exception(sequential_model):
 280      with TempDir(chdr=True, remove_on_exit=True) as tmp:
 281          path = tmp.path("model")
 282          with pytest.raises(MlflowException, match="No such artifact"):
 283              mlflow.pytorch.load_model(path)
 284  
 285          with pytest.raises(TypeError, match="Argument 'pytorch_model' should be a torch.nn.Module"):
 286              mlflow.pytorch.save_model([1, 2, 3], path)
 287  
 288          mlflow.pytorch.save_model(sequential_model, path)
 289          with pytest.raises(MlflowException, match=f"Path '{os.path.abspath(path)}' already exists"):
 290              mlflow.pytorch.save_model(sequential_model, path)
 291  
 292          import sklearn.neighbors as knn
 293  
 294          from mlflow import sklearn
 295  
 296          path = tmp.path("knn.pkl")
 297          knn = knn.KNeighborsClassifier()
 298          with open(path, "wb") as f:
 299              pickle.dump(knn, f)
 300          path = tmp.path("knn")
 301          sklearn.save_model(knn, path=path)
 302          with pytest.raises(MlflowException, match='Model does not have the "pytorch" flavor'):
 303              mlflow.pytorch.load_model(path)
 304  
 305  
 306  @pytest.mark.parametrize("scripted_model", [True, False])
 307  def test_save_and_load_model(sequential_model, model_path, data, sequential_predicted):
 308      mlflow.pytorch.save_model(sequential_model, model_path)
 309  
 310      # Loading pytorch model
 311      sequential_model_loaded = mlflow.pytorch.load_model(model_path)
 312      np.testing.assert_array_equal(_predict(sequential_model_loaded, data), sequential_predicted)
 313  
 314      # Loading pyfunc model
 315      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
 316      np.testing.assert_array_almost_equal(
 317          pyfunc_loaded.predict(data[0]).values[:, 0], sequential_predicted, decimal=4
 318      )
 319  
 320  
 321  @pytest.mark.parametrize("scripted_model", [True, False])
 322  def test_pyfunc_model_works_with_np_input_type(
 323      sequential_model, model_path, data, sequential_predicted
 324  ):
 325      mlflow.pytorch.save_model(sequential_model, model_path)
 326  
 327      # Loading pyfunc model
 328      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
 329  
 330      # predict works with dataframes
 331      df_result = pyfunc_loaded.predict(data[0])
 332      assert type(df_result) == pd.DataFrame
 333      np.testing.assert_array_almost_equal(df_result.values[:, 0], sequential_predicted, decimal=4)
 334  
 335      # predict works with numpy ndarray
 336      np_result = pyfunc_loaded.predict(data[0].values.astype(np.float32))
 337      assert type(np_result) == np.ndarray
 338      np.testing.assert_array_almost_equal(np_result[:, 0], sequential_predicted, decimal=4)
 339  
 340      # predict does not work with lists
 341      with pytest.raises(
 342          TypeError, match="The PyTorch flavor does not support List or Dict input types"
 343      ):
 344          pyfunc_loaded.predict([1, 2, 3, 4])
 345  
 346      # predict does not work with scalars
 347      with pytest.raises(TypeError, match="Input data should be pandas.DataFrame or numpy.ndarray"):
 348          pyfunc_loaded.predict(4)
 349  
 350  
 351  @pytest.mark.parametrize("scripted_model", [True, False])
 352  def test_load_model_from_remote_uri_succeeds(
 353      sequential_model, model_path, mock_s3_bucket, data, sequential_predicted
 354  ):
 355      mlflow.pytorch.save_model(sequential_model, model_path)
 356  
 357      artifact_root = f"s3://{mock_s3_bucket}"
 358      artifact_path = "model"
 359      artifact_repo = S3ArtifactRepository(artifact_root)
 360      artifact_repo.log_artifacts(model_path, artifact_path=artifact_path)
 361  
 362      model_uri = artifact_root + "/" + artifact_path
 363      sequential_model_loaded = mlflow.pytorch.load_model(model_uri=model_uri)
 364      np.testing.assert_array_equal(_predict(sequential_model_loaded, data), sequential_predicted)
 365  
 366  
 367  @pytest.mark.parametrize("scripted_model", [True, False])
 368  def test_model_save_persists_specified_conda_env_in_mlflow_model_directory(
 369      sequential_model, model_path, pytorch_custom_env
 370  ):
 371      mlflow.pytorch.save_model(
 372          pytorch_model=sequential_model, path=model_path, conda_env=pytorch_custom_env
 373      )
 374  
 375      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
 376      saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
 377      assert os.path.exists(saved_conda_env_path)
 378      assert saved_conda_env_path != pytorch_custom_env
 379  
 380      with open(pytorch_custom_env) as f:
 381          pytorch_custom_env_text = f.read()
 382      with open(saved_conda_env_path) as f:
 383          saved_conda_env_text = f.read()
 384      assert saved_conda_env_text == pytorch_custom_env_text
 385  
 386  
 387  @pytest.mark.parametrize("scripted_model", [True, False])
 388  def test_model_save_persists_requirements_in_mlflow_model_directory(
 389      sequential_model, model_path, pytorch_custom_env
 390  ):
 391      mlflow.pytorch.save_model(
 392          pytorch_model=sequential_model, path=model_path, conda_env=pytorch_custom_env
 393      )
 394  
 395      saved_pip_req_path = os.path.join(model_path, "requirements.txt")
 396      _compare_conda_env_requirements(pytorch_custom_env, saved_pip_req_path)
 397  
 398  
 399  @pytest.mark.parametrize("scripted_model", [False])
 400  def test_save_model_with_pip_requirements(sequential_model, tmp_path):
 401      expected_mlflow_version = _mlflow_major_version_string()
 402      # Path to a requirements file
 403      tmpdir1 = tmp_path.joinpath("1")
 404      req_file = tmp_path.joinpath("requirements.txt")
 405      req_file.write_text("a")
 406      mlflow.pytorch.save_model(sequential_model, tmpdir1, pip_requirements=str(req_file))
 407      _assert_pip_requirements(tmpdir1, [expected_mlflow_version, "a"], strict=True)
 408  
 409      # List of requirements
 410      tmpdir2 = tmp_path.joinpath("2")
 411      mlflow.pytorch.save_model(sequential_model, tmpdir2, pip_requirements=[f"-r {req_file}", "b"])
 412      _assert_pip_requirements(tmpdir2, [expected_mlflow_version, "a", "b"], strict=True)
 413  
 414      # Constraints file
 415      tmpdir3 = tmp_path.joinpath("3")
 416      mlflow.pytorch.save_model(sequential_model, tmpdir3, pip_requirements=[f"-c {req_file}", "b"])
 417      _assert_pip_requirements(
 418          tmpdir3, [expected_mlflow_version, "b", "-c constraints.txt"], ["a"], strict=True
 419      )
 420  
 421  
 422  @pytest.mark.parametrize("scripted_model", [False])
 423  def test_save_model_with_extra_pip_requirements(sequential_model, tmp_path):
 424      expected_mlflow_version = _mlflow_major_version_string()
 425      default_reqs = mlflow.pytorch.get_default_pip_requirements()
 426  
 427      # Path to a requirements file
 428      tmpdir1 = tmp_path.joinpath("1")
 429      req_file = tmp_path.joinpath("requirements.txt")
 430      req_file.write_text("a")
 431      mlflow.pytorch.save_model(sequential_model, tmpdir1, extra_pip_requirements=str(req_file))
 432      _assert_pip_requirements(tmpdir1, [expected_mlflow_version, *default_reqs, "a"])
 433  
 434      # List of requirements
 435      tmpdir2 = tmp_path.joinpath("2")
 436      mlflow.pytorch.save_model(
 437          sequential_model, tmpdir2, extra_pip_requirements=[f"-r {req_file}", "b"]
 438      )
 439      _assert_pip_requirements(tmpdir2, [expected_mlflow_version, *default_reqs, "a", "b"])
 440  
 441      # Constraints file
 442      tmpdir3 = tmp_path.joinpath("3")
 443      mlflow.pytorch.save_model(
 444          sequential_model, tmpdir3, extra_pip_requirements=[f"-c {req_file}", "b"]
 445      )
 446      _assert_pip_requirements(
 447          tmpdir3, [expected_mlflow_version, *default_reqs, "b", "-c constraints.txt"], ["a"]
 448      )
 449  
 450  
 451  @pytest.mark.parametrize("scripted_model", [True, False])
 452  def test_model_save_accepts_conda_env_as_dict(sequential_model, model_path):
 453      conda_env = dict(mlflow.pytorch.get_default_conda_env())
 454      conda_env["dependencies"].append("pytest")
 455      mlflow.pytorch.save_model(pytorch_model=sequential_model, path=model_path, conda_env=conda_env)
 456  
 457      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
 458      saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
 459      assert os.path.exists(saved_conda_env_path)
 460  
 461      with open(saved_conda_env_path) as f:
 462          saved_conda_env_parsed = yaml.safe_load(f)
 463      assert saved_conda_env_parsed == conda_env
 464  
 465  
 466  @pytest.mark.parametrize("scripted_model", [True, False])
 467  def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(
 468      sequential_model, pytorch_custom_env
 469  ):
 470      artifact_path = "model"
 471      with mlflow.start_run():
 472          model_info = mlflow.pytorch.log_model(
 473              sequential_model,
 474              name=artifact_path,
 475              conda_env=pytorch_custom_env,
 476          )
 477          model_path = _download_artifact_from_uri(model_info.model_uri)
 478  
 479      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
 480      saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
 481      assert os.path.exists(saved_conda_env_path)
 482      assert saved_conda_env_path != pytorch_custom_env
 483  
 484      with open(pytorch_custom_env) as f:
 485          pytorch_custom_env_text = f.read()
 486      with open(saved_conda_env_path) as f:
 487          saved_conda_env_text = f.read()
 488      assert saved_conda_env_text == pytorch_custom_env_text
 489  
 490  
 491  @pytest.mark.parametrize("scripted_model", [True, False])
 492  def test_model_log_persists_requirements_in_mlflow_model_directory(
 493      sequential_model, pytorch_custom_env
 494  ):
 495      artifact_path = "model"
 496      with mlflow.start_run():
 497          model_info = mlflow.pytorch.log_model(
 498              sequential_model,
 499              name=artifact_path,
 500              conda_env=pytorch_custom_env,
 501          )
 502          model_path = _download_artifact_from_uri(model_info.model_uri)
 503  
 504      saved_pip_req_path = os.path.join(model_path, "requirements.txt")
 505      _compare_conda_env_requirements(pytorch_custom_env, saved_pip_req_path)
 506  
 507  
 508  @pytest.mark.parametrize("scripted_model", [True, False])
 509  def test_model_save_without_specified_conda_env_uses_default_env_with_expected_dependencies(
 510      sequential_model, model_path
 511  ):
 512      mlflow.pytorch.save_model(pytorch_model=sequential_model, path=model_path)
 513      _assert_pip_requirements(model_path, mlflow.pytorch.get_default_pip_requirements())
 514  
 515  
 516  @pytest.mark.parametrize("scripted_model", [True, False])
 517  def test_model_log_without_specified_conda_env_uses_default_env_with_expected_dependencies(
 518      sequential_model,
 519  ):
 520      with mlflow.start_run():
 521          model_info = mlflow.pytorch.log_model(sequential_model, name="model")
 522  
 523      _assert_pip_requirements(model_info.model_uri, mlflow.pytorch.get_default_pip_requirements())
 524  
 525  
 526  @pytest.mark.parametrize("scripted_model", [True, False])
 527  def test_load_model_with_differing_pytorch_version_logs_warning(sequential_model, model_path):
 528      mlflow.pytorch.save_model(pytorch_model=sequential_model, path=model_path)
 529      saver_pytorch_version = "1.0"
 530      model_config_path = os.path.join(model_path, "MLmodel")
 531      model_config = Model.load(model_config_path)
 532      model_config.flavors[mlflow.pytorch.FLAVOR_NAME]["pytorch_version"] = saver_pytorch_version
 533      model_config.save(model_config_path)
 534  
 535      log_messages = []
 536  
 537      def custom_warn(message_text, *args, **kwargs):
 538          log_messages.append(message_text % args % kwargs)
 539  
 540      loader_pytorch_version = "0.8.2"
 541      with (
 542          mock.patch("mlflow.pytorch._logger.warning") as warn_mock,
 543          mock.patch("torch.__version__", loader_pytorch_version),
 544      ):
 545          warn_mock.side_effect = custom_warn
 546          mlflow.pytorch.load_model(model_uri=model_path)
 547  
 548      assert any(
 549          "does not match installed PyTorch version" in log_message
 550          and saver_pytorch_version in log_message
 551          and loader_pytorch_version in log_message
 552          for log_message in log_messages
 553      )
 554  
 555  
 556  def test_pyfunc_model_serving_with_module_scoped_subclassed_model_and_default_conda_env(
 557      module_scoped_subclassed_model, data
 558  ):
 559      with mlflow.start_run():
 560          model_info = mlflow.pytorch.log_model(
 561              module_scoped_subclassed_model,
 562              name="pytorch_model",
 563              code_paths=[__file__],
 564              input_example=data[0],
 565          )
 566  
 567      inference_payload = load_serving_example(model_info.model_uri)
 568      scoring_response = pyfunc_serve_and_score_model(
 569          model_uri=model_info.model_uri,
 570          data=inference_payload,
 571          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
 572          extra_args=["--env-manager", "local"],
 573      )
 574      assert scoring_response.status_code == 200
 575  
 576      deployed_model_preds = pd.DataFrame(json.loads(scoring_response.content)["predictions"])
 577      np.testing.assert_array_almost_equal(
 578          deployed_model_preds.values[:, 0],
 579          _predict(model=module_scoped_subclassed_model, data=data),
 580          decimal=4,
 581      )
 582  
 583  
 584  def test_save_model_with_wrong_codepaths_fails_correctly(
 585      module_scoped_subclassed_model, model_path, data
 586  ):
 587      with pytest.raises(TypeError, match="Argument code_paths should be a list, not <class 'str'>"):
 588          mlflow.pytorch.save_model(
 589              path=model_path, pytorch_model=module_scoped_subclassed_model, code_paths="some string"
 590          )
 591  
 592  
 593  def test_pyfunc_model_serving_with_main_scoped_subclassed_model_and_custom_pickle_module(
 594      main_scoped_subclassed_model, data
 595  ):
 596      with mlflow.start_run():
 597          model_info = mlflow.pytorch.log_model(
 598              main_scoped_subclassed_model,
 599              name="pytorch_model",
 600              pickle_module=mlflow_pytorch_pickle_module,
 601              input_example=data[0],
 602          )
 603  
 604      inference_payload = load_serving_example(model_info.model_uri)
 605      scoring_response = pyfunc_serve_and_score_model(
 606          model_uri=model_info.model_uri,
 607          data=inference_payload,
 608          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
 609          extra_args=["--env-manager", "local"],
 610      )
 611      assert scoring_response.status_code == 200
 612  
 613      deployed_model_preds = pd.DataFrame(json.loads(scoring_response.content)["predictions"])
 614      np.testing.assert_array_almost_equal(
 615          deployed_model_preds.values[:, 0],
 616          _predict(model=main_scoped_subclassed_model, data=data),
 617          decimal=4,
 618      )
 619  
 620  
 621  def test_load_model_succeeds_with_dependencies_specified_via_code_paths(
 622      module_scoped_subclassed_model, model_path, data
 623  ):
 624      # Save a PyTorch model whose class is defined in the current test suite. Because the
 625      # `tests` module is not available when the model is deployed for local scoring, we include
 626      # the test suite file as a code dependency
 627      mlflow.pytorch.save_model(
 628          path=model_path,
 629          pytorch_model=module_scoped_subclassed_model,
 630          code_paths=[__file__],
 631      )
 632  
 633      # Define a custom pyfunc model that loads a PyTorch model artifact using
 634      # `mlflow.pytorch.load_model`
 635      class TorchValidatorModel(pyfunc.PythonModel):
 636          def load_context(self, context):
 637              self.pytorch_model = mlflow.pytorch.load_model(context.artifacts["pytorch_model"])
 638  
 639          def predict(self, context, model_input, params=None):
 640              with torch.no_grad():
 641                  input_tensor = torch.from_numpy(model_input.values.astype(np.float32))
 642                  output_tensor = self.pytorch_model(input_tensor)
 643                  return pd.DataFrame(output_tensor.numpy())
 644  
 645      pyfunc_artifact_path = "pyfunc_model"
 646      with mlflow.start_run():
 647          model_info = pyfunc.log_model(
 648              pyfunc_artifact_path,
 649              python_model=TorchValidatorModel(),
 650              artifacts={"pytorch_model": model_path},
 651              input_example=data[0],
 652              # save file into code_paths, otherwise after first model loading (happens when
 653              # validating input_example) then we can not load the model again
 654              code_paths=[__file__],
 655          )
 656  
 657      # Deploy the custom pyfunc model and ensure that it is able to successfully load its
 658      # constituent PyTorch model via `mlflow.pytorch.load_model`
 659      inference_payload = load_serving_example(model_info.model_uri)
 660      scoring_response = pyfunc_serve_and_score_model(
 661          model_uri=model_info.model_uri,
 662          data=inference_payload,
 663          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
 664          extra_args=["--env-manager", "local"],
 665      )
 666      assert scoring_response.status_code == 200
 667  
 668      deployed_model_preds = pd.DataFrame(json.loads(scoring_response.content)["predictions"])
 669      np.testing.assert_array_almost_equal(
 670          deployed_model_preds.values[:, 0],
 671          _predict(model=module_scoped_subclassed_model, data=data),
 672          decimal=4,
 673      )
 674  
 675  
 676  def test_load_pyfunc_loads_torch_model_using_pickle_module_specified_at_save_time(
 677      module_scoped_subclassed_model, model_path
 678  ):
 679      custom_pickle_module = pickle
 680  
 681      mlflow.pytorch.save_model(
 682          path=model_path,
 683          pytorch_model=module_scoped_subclassed_model,
 684          pickle_module=custom_pickle_module,
 685      )
 686  
 687      import_module_fn = importlib.import_module
 688      imported_modules = []
 689  
 690      def track_module_imports(module_name):
 691          imported_modules.append(module_name)
 692          return import_module_fn(module_name)
 693  
 694      with (
 695          mock.patch("importlib.import_module") as import_mock,
 696          mock.patch("torch.load") as torch_load_mock,
 697      ):
 698          import_mock.side_effect = track_module_imports
 699          pyfunc.load_model(model_path)
 700  
 701      expected_kwargs = {"pickle_module": custom_pickle_module}
 702      if ENABLE_LEGACY_DESERIALIZATION:
 703          expected_kwargs["weights_only"] = False
 704  
 705      torch_load_mock.assert_called_with(mock.ANY, **expected_kwargs)
 706      assert custom_pickle_module.__name__ in imported_modules
 707  
 708  
 709  def test_load_model_loads_torch_model_using_pickle_module_specified_at_save_time(
 710      module_scoped_subclassed_model,
 711  ):
 712      custom_pickle_module = pickle
 713  
 714      artifact_path = "pytorch_model"
 715      with mlflow.start_run():
 716          model_info = mlflow.pytorch.log_model(
 717              module_scoped_subclassed_model,
 718              name=artifact_path,
 719              pickle_module=custom_pickle_module,
 720          )
 721          model_uri = model_info.model_uri
 722  
 723      import_module_fn = importlib.import_module
 724      imported_modules = []
 725  
 726      def track_module_imports(module_name):
 727          imported_modules.append(module_name)
 728          return import_module_fn(module_name)
 729  
 730      with (
 731          mock.patch("importlib.import_module") as import_mock,
 732          mock.patch("torch.load") as torch_load_mock,
 733      ):
 734          import_mock.side_effect = track_module_imports
 735          pyfunc.load_model(model_uri=model_uri)
 736  
 737      expected_kwargs = {"pickle_module": custom_pickle_module}
 738      if ENABLE_LEGACY_DESERIALIZATION:
 739          expected_kwargs["weights_only"] = False
 740  
 741      torch_load_mock.assert_called_with(mock.ANY, **expected_kwargs)
 742      assert custom_pickle_module.__name__ in imported_modules
 743  
 744  
 745  def test_load_pyfunc_succeeds_when_data_is_model_file_instead_of_directory(
 746      module_scoped_subclassed_model, model_path, data
 747  ):
 748      """
 749      This test verifies that PyTorch models saved in older versions of MLflow are loaded successfully
 750      by ``mlflow.pytorch.load_model``. The ``data`` path associated with these older models is
 751      serialized PyTorch model file, as opposed to the current format: a directory containing a
 752      serialized model file and pickle module information.
 753      """
 754      mlflow.pytorch.save_model(path=model_path, pytorch_model=module_scoped_subclassed_model)
 755  
 756      model_conf_path = os.path.join(model_path, "MLmodel")
 757      model_conf = Model.load(model_conf_path)
 758      pyfunc_conf = model_conf.flavors.get(pyfunc.FLAVOR_NAME)
 759      assert pyfunc_conf is not None
 760      model_data_path = os.path.join(model_path, pyfunc_conf[pyfunc.DATA])
 761      assert os.path.exists(model_data_path)
 762      assert mlflow.pytorch._SERIALIZED_TORCH_MODEL_FILE_NAME in os.listdir(model_data_path)
 763      pyfunc_conf[pyfunc.DATA] = os.path.join(
 764          model_data_path, mlflow.pytorch._SERIALIZED_TORCH_MODEL_FILE_NAME
 765      )
 766      model_conf.save(model_conf_path)
 767  
 768      loaded_pyfunc = pyfunc.load_model(model_path)
 769  
 770      np.testing.assert_array_almost_equal(
 771          loaded_pyfunc.predict(data[0]),
 772          pd.DataFrame(_predict(model=module_scoped_subclassed_model, data=data)),
 773          decimal=4,
 774      )
 775  
 776  
 777  def test_load_model_succeeds_when_data_is_model_file_instead_of_directory(
 778      module_scoped_subclassed_model, model_path, data
 779  ):
 780      """
 781      This test verifies that PyTorch models saved in older versions of MLflow are loaded successfully
 782      by ``mlflow.pytorch.load_model``. The ``data`` path associated with these older models is
 783      serialized PyTorch model file, as opposed to the current format: a directory containing a
 784      serialized model file and pickle module information.
 785      """
 786      artifact_path = "pytorch_model"
 787      with mlflow.start_run():
 788          model_info = mlflow.pytorch.log_model(module_scoped_subclassed_model, name=artifact_path)
 789          model_path = _download_artifact_from_uri(model_info.model_uri)
 790  
 791      model_conf_path = os.path.join(model_path, "MLmodel")
 792      model_conf = Model.load(model_conf_path)
 793      pyfunc_conf = model_conf.flavors.get(pyfunc.FLAVOR_NAME)
 794      assert pyfunc_conf is not None
 795      model_data_path = os.path.join(model_path, pyfunc_conf[pyfunc.DATA])
 796      assert os.path.exists(model_data_path)
 797      assert mlflow.pytorch._SERIALIZED_TORCH_MODEL_FILE_NAME in os.listdir(model_data_path)
 798      pyfunc_conf[pyfunc.DATA] = os.path.join(
 799          model_data_path, mlflow.pytorch._SERIALIZED_TORCH_MODEL_FILE_NAME
 800      )
 801      model_conf.save(model_conf_path)
 802  
 803      loaded_pyfunc = pyfunc.load_model(model_path)
 804  
 805      np.testing.assert_array_almost_equal(
 806          loaded_pyfunc.predict(data[0]),
 807          pd.DataFrame(_predict(model=module_scoped_subclassed_model, data=data)),
 808          decimal=4,
 809      )
 810  
 811  
 812  def test_load_model_allows_user_to_override_pickle_module_via_keyword_argument(
 813      module_scoped_subclassed_model, model_path
 814  ):
 815      mlflow.pytorch.save_model(
 816          path=model_path, pytorch_model=module_scoped_subclassed_model, pickle_module=pickle
 817      )
 818  
 819      with (
 820          mock.patch("torch.load") as torch_load_mock,
 821          mock.patch("mlflow.pytorch._logger.warning") as warn_mock,
 822      ):
 823          mlflow.pytorch.load_model(model_uri=model_path, pickle_module=mlflow_pytorch_pickle_module)
 824          torch_load_mock.assert_called_with(mock.ANY, pickle_module=mlflow_pytorch_pickle_module)
 825          warn_mock.assert_any_call(mock.ANY, mlflow_pytorch_pickle_module.__name__, pickle.__name__)
 826  
 827  
 828  def test_load_model_raises_exception_when_pickle_module_cannot_be_imported(
 829      main_scoped_subclassed_model, model_path
 830  ):
 831      mlflow.pytorch.save_model(path=model_path, pytorch_model=main_scoped_subclassed_model)
 832  
 833      bad_pickle_module_name = "not.a.real.module"
 834  
 835      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
 836      model_data_path = os.path.join(model_path, pyfunc_conf[pyfunc.DATA])
 837      assert os.path.exists(model_data_path)
 838      assert mlflow.pytorch._PICKLE_MODULE_INFO_FILE_NAME in os.listdir(model_data_path)
 839      with open(
 840          os.path.join(model_data_path, mlflow.pytorch._PICKLE_MODULE_INFO_FILE_NAME), "w"
 841      ) as f:
 842          f.write(bad_pickle_module_name)
 843  
 844      with pytest.raises(
 845          MlflowException,
 846          match=r"Failed to import the pickle module.+" + re.escape(bad_pickle_module_name),
 847      ):
 848          mlflow.pytorch.load_model(model_uri=model_path)
 849  
 850  
 851  def test_pyfunc_serve_and_score(data):
 852      model = torch.nn.Linear(4, 1)
 853      train_model(model=model, data=data)
 854  
 855      with mlflow.start_run():
 856          model_info = mlflow.pytorch.log_model(model, name="model", input_example=data[0])
 857  
 858      inference_payload = load_serving_example(model_info.model_uri)
 859      resp = pyfunc_serve_and_score_model(
 860          model_info.model_uri,
 861          inference_payload,
 862          pyfunc_scoring_server.CONTENT_TYPE_JSON,
 863          extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS,
 864      )
 865      from mlflow.deployments import PredictionsResponse
 866  
 867      scores = PredictionsResponse.from_json(resp.content).get_predictions()
 868      np.testing.assert_array_almost_equal(scores.values[:, 0], _predict(model=model, data=data))
 869  
 870  
 871  @pytest.mark.skipif(not _is_importable("transformers"), reason="This test requires transformers")
 872  def test_pyfunc_serve_and_score_transformers():
 873      from transformers import BertConfig, BertModel
 874  
 875      from mlflow.deployments import PredictionsResponse
 876  
 877      class MyBertModel(BertModel):
 878          def forward(self, *args, **kwargs):
 879              return super().forward(*args, **kwargs).last_hidden_state
 880  
 881      model = MyBertModel(
 882          BertConfig(
 883              vocab_size=16,
 884              hidden_size=2,
 885              num_hidden_layers=2,
 886              num_attention_heads=2,
 887              intermediate_size=2,
 888          )
 889      )
 890      model.eval()
 891  
 892      input_ids = model.dummy_inputs["input_ids"]
 893  
 894      with mlflow.start_run():
 895          model_info = mlflow.pytorch.log_model(
 896              model, name="model", input_example=np.array(input_ids.tolist())
 897          )
 898  
 899      inference_payload = load_serving_example(model_info.model_uri)
 900      resp = pyfunc_serve_and_score_model(
 901          model_info.model_uri,
 902          inference_payload,
 903          pyfunc_scoring_server.CONTENT_TYPE_JSON,
 904          extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS,
 905      )
 906  
 907      scores = PredictionsResponse.from_json(resp.content.decode("utf-8")).get_predictions(
 908          predictions_format="ndarray"
 909      )
 910      assert_array_almost_equal(scores, model(input_ids).detach().numpy(), rtol=1e-6)
 911  
 912  
 913  @pytest.fixture
 914  def create_requirements_file(tmp_path):
 915      requirement_file_name = "requirements.txt"
 916      fp = tmp_path.joinpath(requirement_file_name)
 917      test_string = "mlflow"
 918      fp.write_text(test_string)
 919      return str(fp), test_string
 920  
 921  
 922  @pytest.fixture
 923  def create_extra_files(tmp_path):
 924      fp1 = tmp_path.joinpath("extra1.txt")
 925      fp2 = tmp_path.joinpath("extra2.txt")
 926      fp1.write_text("1")
 927      fp2.write_text("2")
 928      return [str(fp1), str(fp2)], ["1", "2"]
 929  
 930  
 931  @pytest.mark.parametrize("scripted_model", [True, False])
 932  def test_extra_files_log_model(create_extra_files, sequential_model):
 933      extra_files, contents_expected = create_extra_files
 934      with mlflow.start_run():
 935          mlflow.pytorch.log_model(sequential_model, name="models", extra_files=extra_files)
 936  
 937          model_uri = "runs:/{run_id}/{model_path}".format(
 938              run_id=mlflow.active_run().info.run_id, model_path="models"
 939          )
 940          with TempDir(remove_on_exit=True) as tmp:
 941              model_path = _download_artifact_from_uri(model_uri, tmp.path())
 942              model_config_path = os.path.join(model_path, "MLmodel")
 943              model_config = Model.load(model_config_path)
 944              flavor_config = model_config.flavors["pytorch"]
 945  
 946              assert "extra_files" in flavor_config
 947              loaded_extra_files = flavor_config["extra_files"]
 948  
 949              for loaded_extra_file, content_expected in zip(loaded_extra_files, contents_expected):
 950                  assert "path" in loaded_extra_file
 951                  extra_file_path = os.path.join(model_path, loaded_extra_file["path"])
 952                  with open(extra_file_path) as fp:
 953                      assert fp.read() == content_expected
 954  
 955  
 956  @pytest.mark.parametrize("scripted_model", [True, False])
 957  def test_extra_files_save_model(create_extra_files, sequential_model):
 958      extra_files, contents_expected = create_extra_files
 959      with TempDir(remove_on_exit=True) as tmp:
 960          model_path = os.path.join(tmp.path(), "models")
 961          mlflow.pytorch.save_model(
 962              pytorch_model=sequential_model, path=model_path, extra_files=extra_files
 963          )
 964          model_config_path = os.path.join(model_path, "MLmodel")
 965          model_config = Model.load(model_config_path)
 966          flavor_config = model_config.flavors["pytorch"]
 967  
 968          assert "extra_files" in flavor_config
 969          loaded_extra_files = flavor_config["extra_files"]
 970  
 971          for loaded_extra_file, content_expected in zip(loaded_extra_files, contents_expected):
 972              assert "path" in loaded_extra_file
 973              extra_file_path = os.path.join(model_path, loaded_extra_file["path"])
 974              with open(extra_file_path) as fp:
 975                  assert fp.read() == content_expected
 976  
 977  
 978  @pytest.mark.parametrize("scripted_model", [True, False])
 979  def test_log_model_invalid_extra_file_path(sequential_model):
 980      with (
 981          mlflow.start_run(),
 982          pytest.raises(MlflowException, match="No such artifact: 'non_existing_file.txt'"),
 983      ):
 984          mlflow.pytorch.log_model(
 985              sequential_model,
 986              name="models",
 987              extra_files=["non_existing_file.txt"],
 988          )
 989  
 990  
 991  @pytest.mark.parametrize("scripted_model", [True, False])
 992  def test_log_model_invalid_extra_file_type(sequential_model):
 993      with (
 994          mlflow.start_run(),
 995          pytest.raises(TypeError, match="Extra files argument should be a list"),
 996      ):
 997          mlflow.pytorch.log_model(
 998              sequential_model,
 999              name="models",
1000              extra_files="non_existing_file.txt",
1001          )
1002  
1003  
1004  def state_dict_equal(state_dict1, state_dict2):
1005      for key1 in state_dict1:
1006          if key1 not in state_dict2:
1007              return False
1008  
1009          value1 = state_dict1[key1]
1010          value2 = state_dict2[key1]
1011  
1012          if type(value1) != type(value2):
1013              return False
1014          elif isinstance(value1, dict):
1015              if not state_dict_equal(value1, value2):
1016                  return False
1017          elif isinstance(value1, torch.Tensor):
1018              if not torch.equal(value1, value2):
1019                  return False
1020          elif value1 != value2:
1021              return False
1022          else:
1023              continue
1024  
1025      return True
1026  
1027  
1028  @pytest.mark.parametrize("scripted_model", [True, False])
1029  def test_save_state_dict(sequential_model, model_path, data):
1030      state_dict = sequential_model.state_dict()
1031      mlflow.pytorch.save_state_dict(state_dict, model_path)
1032  
1033      loaded_state_dict = mlflow.pytorch.load_state_dict(model_path)
1034      assert state_dict_equal(loaded_state_dict, state_dict)
1035      model = get_sequential_model()
1036      model.load_state_dict(loaded_state_dict)
1037      np.testing.assert_array_almost_equal(
1038          _predict(model, data),
1039          _predict(sequential_model, data),
1040          decimal=4,
1041      )
1042  
1043  
1044  def test_save_state_dict_can_save_nested_state_dict(model_path):
1045      """
1046      This test ensures that `save_state_dict` supports a use case described in the page below
1047      where a user bundles multiple objects (e.g., model, optimizer, learning-rate scheduler)
1048      into a single nested state_dict and loads it back later for inference or re-training:
1049      https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html
1050      """
1051      model = get_sequential_model()
1052      optim = torch.optim.Adam(model.parameters())
1053      state_dict = {"model": model.state_dict(), "optim": optim.state_dict()}
1054      mlflow.pytorch.save_state_dict(state_dict, model_path)
1055  
1056      loaded_state_dict = mlflow.pytorch.load_state_dict(model_path)
1057      assert state_dict_equal(loaded_state_dict, state_dict)
1058      model.load_state_dict(loaded_state_dict["model"])
1059      optim.load_state_dict(loaded_state_dict["optim"])
1060  
1061  
1062  def test_load_state_dict_disallows_pickle_deserialization(model_path, monkeypatch):
1063      model = get_sequential_model()
1064      mlflow.pytorch.save_state_dict(model.state_dict(), model_path)
1065  
1066      monkeypatch.setenv("MLFLOW_ALLOW_PICKLE_DESERIALIZATION", "false")
1067      with pytest.raises(MlflowException, match="MLFLOW_ALLOW_PICKLE_DESERIALIZATION"):
1068          mlflow.pytorch.load_state_dict(model_path)
1069  
1070  
1071  @pytest.mark.parametrize("not_state_dict", [0, "", get_sequential_model()])
1072  def test_save_state_dict_throws_for_invalid_object_type(not_state_dict, model_path):
1073      with pytest.raises(TypeError, match="Invalid object type for `state_dict`"):
1074          mlflow.pytorch.save_state_dict(not_state_dict, model_path)
1075  
1076  
1077  @pytest.mark.parametrize("scripted_model", [True, False])
1078  def test_log_state_dict(sequential_model, data):
1079      artifact_path = "model"
1080      state_dict = sequential_model.state_dict()
1081      with mlflow.start_run():
1082          mlflow.pytorch.log_state_dict(state_dict, artifact_path)
1083          state_dict_uri = mlflow.get_artifact_uri(artifact_path)
1084  
1085      loaded_state_dict = mlflow.pytorch.load_state_dict(state_dict_uri)
1086      assert state_dict_equal(loaded_state_dict, state_dict)
1087      model = get_sequential_model()
1088      model.load_state_dict(loaded_state_dict)
1089      np.testing.assert_array_almost_equal(
1090          _predict(model, data),
1091          _predict(sequential_model, data),
1092          decimal=4,
1093      )
1094  
1095  
1096  @pytest.mark.parametrize("scripted_model", [True, False])
1097  def test_log_model_with_code_paths(sequential_model):
1098      artifact_path = "model"
1099      with (
1100          mlflow.start_run(),
1101          mock.patch("mlflow.pytorch._add_code_from_conf_to_system_path") as add_mock,
1102      ):
1103          model_info = mlflow.pytorch.log_model(
1104              sequential_model, name=artifact_path, code_paths=[__file__]
1105          )
1106          _compare_logged_code_paths(__file__, model_info.model_uri, mlflow.pytorch.FLAVOR_NAME)
1107          mlflow.pytorch.load_model(model_info.model_uri)
1108          add_mock.assert_called()
1109  
1110  
1111  def test_virtualenv_subfield_points_to_correct_path(model_path):
1112      model = get_sequential_model()
1113      mlflow.pytorch.save_model(model, path=model_path)
1114      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
1115      python_env_path = Path(model_path, pyfunc_conf[pyfunc.ENV]["virtualenv"])
1116      assert python_env_path.exists()
1117      assert python_env_path.is_file()
1118  
1119  
1120  @pytest.mark.parametrize("scripted_model", [True, False])
1121  def test_model_save_load_with_metadata(sequential_model, model_path):
1122      mlflow.pytorch.save_model(
1123          sequential_model, path=model_path, metadata={"metadata_key": "metadata_value"}
1124      )
1125  
1126      reloaded_model = mlflow.pyfunc.load_model(model_uri=model_path)
1127      assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value"
1128  
1129  
1130  @pytest.mark.parametrize("scripted_model", [True, False])
1131  def test_model_log_with_metadata(sequential_model):
1132      artifact_path = "model"
1133  
1134      with mlflow.start_run():
1135          model_info = mlflow.pytorch.log_model(
1136              sequential_model,
1137              name=artifact_path,
1138              metadata={"metadata_key": "metadata_value"},
1139          )
1140  
1141      reloaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
1142      assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value"
1143  
1144  
1145  @pytest.mark.parametrize("scripted_model", [True, False])
1146  def test_model_log_with_signature_inference(sequential_model, data):
1147      artifact_path = "model"
1148      example_ = data[0].head(3).values.astype(np.float32)
1149  
1150      with mlflow.start_run():
1151          model_info = mlflow.pytorch.log_model(
1152              sequential_model, name=artifact_path, input_example=example_
1153          )
1154  
1155      assert model_info.signature == ModelSignature(
1156          inputs=Schema([TensorSpec(np.dtype("float32"), (-1, 4))]),
1157          outputs=Schema([TensorSpec(np.dtype("float32"), (-1, 1))]),
1158      )
1159      inference_payload = load_serving_example(model_info.model_uri)
1160      response = pyfunc_serve_and_score_model(
1161          model_info.model_uri,
1162          inference_payload,
1163          pyfunc_scoring_server.CONTENT_TYPE_JSON,
1164          extra_args=["--env-manager", "local"],
1165      )
1166      assert response.status_code == 200
1167      deployed_model_preds = pd.DataFrame(json.loads(response.content)["predictions"])
1168      np.testing.assert_array_almost_equal(
1169          deployed_model_preds.values[:, 0],
1170          _predict(model=sequential_model, data=(data[0].head(3), data[1].head(3))),
1171          decimal=4,
1172      )
1173  
1174  
1175  @pytest.mark.parametrize("scripted_model", [False])
1176  def test_load_model_to_device(sequential_model):
1177      with mock.patch("mlflow.pytorch._load_model") as load_model_mock:
1178          with mlflow.start_run():
1179              model_info = mlflow.pytorch.log_model(sequential_model, name="pytorch")
1180              model_config = {"device": "cuda"}
1181              if ENABLE_LEGACY_DESERIALIZATION:
1182                  model_config["weights_only"] = False
1183  
1184              mlflow.pyfunc.load_model(model_uri=model_info.model_uri, model_config=model_config)
1185  
1186              load_model_mock.assert_called_with(mock.ANY, **model_config)
1187              mlflow.pytorch.load_model(model_uri=model_info.model_uri, **model_config)
1188              load_model_mock.assert_called_with(path=mock.ANY, **model_config)
1189  
1190  
1191  def test_passing_params_to_model(data):
1192      class CustomModel(torch.nn.Module):
1193          def __init__(self):
1194              super().__init__()
1195              self.linear = torch.nn.Linear(4, 1)
1196  
1197          def forward(self, x, y):
1198              if not torch.is_tensor(x):
1199                  x = torch.from_numpy(x)
1200              y = torch.tensor(y)
1201              combined = x * y
1202              return self.linear(combined)
1203  
1204      model = CustomModel()
1205      x = np.random.randn(8, 4).astype(np.float32)
1206  
1207      signature = mlflow.models.infer_signature(x, None, {"y": 1})
1208      with mlflow.start_run():
1209          model_info = mlflow.pytorch.log_model(model, name="model", signature=signature)
1210  
1211      pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
1212      with torch.no_grad():
1213          np.testing.assert_array_almost_equal(pyfunc_model.predict(x), model(x, 1), decimal=4)
1214          np.testing.assert_array_almost_equal(
1215              pyfunc_model.predict(x, {"y": 2}), model(x, 2), decimal=4
1216          )
1217  
1218  
1219  def test_log_model_with_datetime_input():
1220      df = pd.DataFrame({
1221          "datetime": pd.date_range("2022-01-01", periods=5, freq="D"),
1222          "x": np.random.uniform(20, 30, 5),
1223          "y": np.random.uniform(2, 4, 5),
1224          "z": np.random.uniform(0, 10, 5),
1225      })
1226      model = get_sequential_model()
1227      model_info = mlflow.pytorch.log_model(model, name="pytorch", input_example=df)
1228      assert model_info.signature.inputs.inputs[0].type == DataType.datetime
1229      pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
1230      with torch.no_grad():
1231          input_tensor = torch.from_numpy(df.to_numpy(dtype=np.float32))
1232          expected_result = model(input_tensor)
1233      with torch.no_grad():
1234          np.testing.assert_array_almost_equal(pyfunc_model.predict(df), expected_result, decimal=4)
1235  
1236  
1237  @pytest.mark.skipif(
1238      Version(torch.__version__) < Version("2.4"), reason="This test requires torch>=2.4"
1239  )
1240  @pytest.mark.parametrize("scripted_model", [False])
1241  def test_save_and_load_exported_model(sequential_model, model_path, data, sequential_predicted):
1242      input_example = data[0].to_numpy(dtype=np.float32)
1243  
1244      mlflow.pytorch.save_model(
1245          sequential_model,
1246          model_path,
1247          serialization_format="pt2",
1248          input_example=input_example,
1249      )
1250  
1251      # Loading pytorch model
1252      sequential_model_loaded = mlflow.pytorch.load_model(model_path)
1253      np.testing.assert_array_equal(_predict(sequential_model_loaded, data), sequential_predicted)
1254  
1255      # Loading pyfunc model
1256      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
1257      np.testing.assert_array_almost_equal(
1258          pyfunc_loaded.predict(input_example)[:, 0], sequential_predicted, decimal=4
1259      )
1260  
1261  
1262  @pytest.mark.skipif(
1263      Version(torch.__version__) < Version("2.4"), reason="This test requires torch>=2.4"
1264  )
1265  def test_exported_model_infer_dynamic_dim(tmp_path):
1266      class MyModule(torch.nn.Module):
1267          def forward(self, x: torch.Tensor) -> torch.Tensor:
1268              return torch.sin(x)
1269  
1270      origin_model = MyModule()
1271  
1272      input_example = torch.randn(3, 4, 5).numpy()
1273  
1274      save_path1 = tmp_path / "model1"
1275  
1276      # test exporting model with auto inferred signature,
1277      # which sets the first dim (batch dim) of input data as dynamic dim.
1278      mlflow.pytorch.save_model(
1279          origin_model,
1280          save_path1,
1281          serialization_format="pt2",
1282          input_example=input_example,
1283      )
1284  
1285      # Test the exported model works with test data that changes the first dim (batch dim) size.
1286      loaded_model1 = mlflow.pytorch.load_model(save_path1)
1287  
1288      test_data1 = torch.randn(6, 4, 5)
1289      np.testing.assert_array_almost_equal(
1290          loaded_model1(test_data1),
1291          origin_model(test_data1),
1292          decimal=4,
1293      )
1294  
1295      save_path2 = tmp_path / "model2"
1296      # test exporting model with provided signature,
1297      # which sets the second dim of input data as dynamic dim.
1298      mlflow.pytorch.save_model(
1299          origin_model,
1300          save_path2,
1301          serialization_format="pt2",
1302          input_example=input_example,
1303          signature=ModelSignature(
1304              inputs=Schema([TensorSpec(np.dtype("float32"), (3, -1, 5))]),
1305          ),
1306      )
1307  
1308      # Test the exported model works with test data that changes the second dim (batch dim) size.
1309      loaded_model2 = mlflow.pytorch.load_model(save_path2)
1310  
1311      test_data2 = torch.randn(3, 2, 5)
1312      np.testing.assert_array_almost_equal(
1313          loaded_model2(test_data2),
1314          origin_model(test_data2),
1315          decimal=4,
1316      )
1317  
1318  
1319  @pytest.mark.skipif(
1320      Version(torch.__version__) < Version("2.4"), reason="This test requires torch>=2.4"
1321  )
1322  @pytest.mark.parametrize("scripted_model", [False])
1323  def test_load_exported_model_check_device_mismatch(sequential_model, model_path):
1324      mlflow.pytorch.save_model(
1325          sequential_model,
1326          model_path,
1327          serialization_format="pt2",
1328          input_example=torch.randn(3, 4).numpy(),
1329      )
1330  
1331      # test loading model to CPU works
1332      mlflow.pytorch.load_model(model_path, device="cpu")
1333  
1334      with pytest.raises(
1335          MlflowException,
1336          match="it can't be loaded on 'cuda' device.",
1337      ):
1338          mlflow.pytorch.load_model(model_path, device="cuda")
1339  
1340  
1341  @pytest.mark.skipif(
1342      Version(torch.__version__) < Version("2.4"), reason="This test requires torch>=2.4"
1343  )
1344  def test_save_and_load_exported_model_with_multi_inputs(model_path):
1345  
1346      class CustomModel(torch.nn.Module):
1347          def __init__(self):
1348              super().__init__()
1349              self.linear = torch.nn.Linear(4, 1)
1350  
1351          def forward(self, x, y):
1352              with torch.no_grad():
1353                  return self.linear(x + y)
1354  
1355      model = CustomModel()
1356      input_example = (torch.randn(10, 4), torch.randn(10, 4))
1357  
1358      mlflow.pytorch.save_model(
1359          model,
1360          model_path,
1361          serialization_format="pt2",
1362          input_example=input_example,
1363          signature=ModelSignature(
1364              inputs=Schema([
1365                  TensorSpec(np.dtype("float32"), (-1, 4), "v1"),
1366                  TensorSpec(np.dtype("float32"), (-1, 4), "v2"),
1367              ]),
1368          ),
1369      )
1370  
1371      model_loaded = mlflow.pytorch.load_model(model_path)
1372  
1373      np.testing.assert_array_almost_equal(
1374          model(*input_example),
1375          model_loaded(*input_example),
1376          decimal=4,
1377      )