pickle_module.py
1 """ 2 This module imports contents from CloudPickle in a way that is compatible with the 3 ``pickle_module`` parameter of PyTorch's model persistence function: ``torch.save`` 4 (see https://github.com/pytorch/pytorch/blob/692898fe379c9092f5e380797c32305145cd06e1/torch/ 5 serialization.py#L192). It is included as a distinct module from :mod:`mlflow.pytorch` to avoid 6 polluting the namespace with wildcard imports. 7 8 Calling ``torch.save(..., pickle_module=mlflow.pytorch.pickle_module)`` will persist PyTorch model 9 definitions using CloudPickle, leveraging improved pickling functionality such as the ability 10 to capture class definitions in the "__main__" scope. 11 12 TODO: Remove this module or make it an alias of CloudPickle when CloudPickle and PyTorch have 13 compatible pickling APIs. 14 """ 15 16 # Import all contents of the CloudPickle module in an attempt to include all functions required 17 # by ``torch.save``. 18 19 # CloudPickle does not include `Unpickler` in its namespace, which is required by PyTorch for 20 # deserialization. Noting that CloudPickle's `load()` and `loads()` routines are aliases for 21 # `pickle.load()` and `pickle.loads()`, we therefore import Unpickler from the native 22 # Python pickle library. 23 from pickle import Unpickler # noqa: F401 24 25 from cloudpickle import * # noqa: F403 26 27 # PyTorch uses the ``Pickler`` class of the specified ``pickle_module`` 28 # (https://github.com/pytorch/pytorch/blob/692898fe379c9092f5e380797c32305145cd06e1/torch/ 29 # serialization.py#L290). Unfortunately, ``cloudpickle.Pickler`` is an alias for Python's native 30 # pickling class: ``pickle.Pickler``, instead of ``cloudpickle.CloudPickler``. 31 # https://github.com/cloudpipe/cloudpickle/pull/235 has been filed to correct the issue, 32 # but this import renaming is necessary until either the requested change has been incorporated 33 # into a CloudPickle release or the ``torch.save`` API has been updated to be compatible with 34 # the existing CloudPickle API. 35 from cloudpickle import CloudPickler as Pickler # noqa: F401