summaryrefslogtreecommitdiff
path: root/torch/serialization.py
diff options
context:
space:
mode:
authorgreaber <grant.reaber@gmail.com>2017-10-14 19:54:53 +0300
committerAdam Paszke <adam.paszke@gmail.com>2017-10-14 18:54:53 +0200
commit490d5c2f13002e39670e657129b040ca3166f025 (patch)
treec4f96e34265abe4c085eb8e0fd05546bc959d28f /torch/serialization.py
parent75665ca6db84ee5302c807de4c86471983f66a91 (diff)
downloadpytorch-490d5c2f13002e39670e657129b040ca3166f025.tar.gz
pytorch-490d5c2f13002e39670e657129b040ca3166f025.tar.bz2
pytorch-490d5c2f13002e39670e657129b040ca3166f025.zip
improve torch.load documentation (#3118)
Diffstat (limited to 'torch/serialization.py')
-rw-r--r--torch/serialization.py39
1 files changed, 28 insertions, 11 deletions
diff --git a/torch/serialization.py b/torch/serialization.py
index b08fc05638..2e20232c98 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -203,17 +203,31 @@ def _save(obj, f, pickle_module, pickle_protocol):
def load(f, map_location=None, pickle_module=pickle):
"""Loads an object saved with :func:`torch.save` from a file.
- torch.load can dynamically remap storages to be loaded on a different device
- using the map_location argument. If it's a callable, it will be called with
- two arguments: storage and location tag. It's expected to either return a
- storage that's been moved to a different location, or None (and the location
- will be resolved using the default method). If this argument is a dict it's
- expected to be a mapping from location tags used in a file, to location
- tags of the current system.
-
- By default the location tags are 'cpu' for host tensors and 'cuda:device_id'
- (e.g. 'cuda:2') for cuda tensors. User extensions can register their own
- tagging and deserialization methods using register_package.
+ torch.load uses Python's unpickling facilities but treats storages,
+ which underlie tensors, specially. They are first deserialized on the
+ CPU and are then moved to the device they were saved from. If this fails
+ (e.g. because the run time system doesn't have certain devices), an exception
+ is raised. However, storages can be dynamically remapped to an alternative
+ set of devices using the map_location argument.
+
+ If map_location is a callable, it will be called once for each serialized
+ storage with two arguments: storage and location. The storage argument
+ will be the initial deserialization of the storage, residing on the CPU.
+ Each serialized storage has a location tag associated with it which
+ identifies the device it was saved from, and this tag is the second
+ argument passed to map_location. The builtin location tags are 'cpu' for
+ CPU tensors and 'cuda:device_id' (e.g. 'cuda:2') for CUDA tensors.
+ map_location should return either None or a storage. If map_location returns
+ a storage, it will be used as the final deserialized object, already moved to
+ the right device. Otherwise, torch.load will fall back to the default behavior,
+ as if map_location wasn't specified.
+
+ If map_location is a dict, it will be used to remap location tags
+ appearing in the file (keys), to ones that specify where to put the
+ storages (values).
+
+ User extensions can register their own location tags and tagging and
+ deserialization methods using register_package.
Args:
f: a file-like object (has to implement fileno that returns a file
@@ -228,8 +242,11 @@ def load(f, map_location=None, pickle_module=pickle):
>>> torch.load('tensors.pt')
# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
+ # Load all tensors onto GPU 1
+ >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
+
"""
new_fd = False
if isinstance(f, str) or (sys.version_info[0] == 2 and isinstance(f, unicode)):