/ mlflow / pytorch / pickle_module.py
pickle_module.py
 1  """
 2  This module imports contents from CloudPickle in a way that is compatible with the
 3  ``pickle_module`` parameter of PyTorch's model persistence function: ``torch.save``
 4  (see https://github.com/pytorch/pytorch/blob/692898fe379c9092f5e380797c32305145cd06e1/torch/
 5  serialization.py#L192). It is included as a distinct module from :mod:`mlflow.pytorch` to avoid
 6  polluting the namespace with wildcard imports.
 7  
 8  Calling ``torch.save(..., pickle_module=mlflow.pytorch.pickle_module)`` will persist PyTorch model
 9  definitions using CloudPickle, leveraging improved pickling functionality such as the ability
10  to capture class definitions in the "__main__" scope.
11  
12  TODO: Remove this module or make it an alias of CloudPickle when CloudPickle and PyTorch have
13  compatible pickling APIs.
14  """
15  
16  # Import all contents of the CloudPickle module in an attempt to include all functions required
17  # by ``torch.save``.
18  
19  # CloudPickle does not include `Unpickler` in its namespace, which is required by PyTorch for
20  # deserialization. Noting that CloudPickle's `load()` and `loads()` routines are aliases for
21  # `pickle.load()` and `pickle.loads()`, we therefore import Unpickler from the native
22  # Python pickle library.
23  from pickle import Unpickler  # noqa: F401
24  
25  from cloudpickle import *  # noqa: F403
26  
27  # PyTorch uses the ``Pickler`` class of the specified ``pickle_module``
28  # (https://github.com/pytorch/pytorch/blob/692898fe379c9092f5e380797c32305145cd06e1/torch/
29  # serialization.py#L290). Unfortunately, ``cloudpickle.Pickler`` is an alias for Python's native
30  # pickling class: ``pickle.Pickler``, instead of ``cloudpickle.CloudPickler``.
31  # https://github.com/cloudpipe/cloudpickle/pull/235 has been filed to correct the issue,
32  # but this import renaming is necessary until either the requested change has been incorporated
33  # into a CloudPickle release or the ``torch.save`` API has been updated to be compatible with
34  # the existing CloudPickle API.
35  from cloudpickle import CloudPickler as Pickler  # noqa: F401