diff options
author | greaber <grant.reaber@gmail.com> | 2017-10-14 19:54:53 +0300 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2017-10-14 18:54:53 +0200 |
commit | 490d5c2f13002e39670e657129b040ca3166f025 (patch) | |
tree | c4f96e34265abe4c085eb8e0fd05546bc959d28f /torch | |
parent | 75665ca6db84ee5302c807de4c86471983f66a91 (diff) | |
download | pytorch-490d5c2f13002e39670e657129b040ca3166f025.tar.gz pytorch-490d5c2f13002e39670e657129b040ca3166f025.tar.bz2 pytorch-490d5c2f13002e39670e657129b040ca3166f025.zip |
improve torch.load documentation (#3118)
Diffstat (limited to 'torch')
-rw-r--r-- | torch/serialization.py | 39 |
1 files changed, 28 insertions, 11 deletions
diff --git a/torch/serialization.py b/torch/serialization.py index b08fc05638..2e20232c98 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -203,17 +203,31 @@ def _save(obj, f, pickle_module, pickle_protocol): def load(f, map_location=None, pickle_module=pickle): """Loads an object saved with :func:`torch.save` from a file. - torch.load can dynamically remap storages to be loaded on a different device - using the map_location argument. If it's a callable, it will be called with - two arguments: storage and location tag. It's expected to either return a - storage that's been moved to a different location, or None (and the location - will be resolved using the default method). If this argument is a dict it's - expected to be a mapping from location tags used in a file, to location - tags of the current system. - - By default the location tags are 'cpu' for host tensors and 'cuda:device_id' - (e.g. 'cuda:2') for cuda tensors. User extensions can register their own - tagging and deserialization methods using register_package. + torch.load uses Python's unpickling facilities but treats storages, + which underlie tensors, specially. They are first deserialized on the + CPU and are then moved to the device they were saved from. If this fails + (e.g. because the run time system doesn't have certain devices), an exception + is raised. However, storages can be dynamically remapped to an alternative + set of devices using the map_location argument. + + If map_location is a callable, it will be called once for each serialized + storage with two arguments: storage and location. The storage argument + will be the initial deserialization of the storage, residing on the CPU. + Each serialized storage has a location tag associated with it which + identifies the device it was saved from, and this tag is the second + argument passed to map_location. The builtin location tags are 'cpu' for + CPU tensors and 'cuda:device_id' (e.g. 'cuda:2') for CUDA tensors. + map_location should return either None or a storage. If map_location returns + a storage, it will be used as the final deserialized object, already moved to + the right device. Otherwise, torch.load will fall back to the default behavior, + as if map_location wasn't specified. + + If map_location is a dict, it will be used to remap location tags + appearing in the file (keys), to ones that specify where to put the + storages (values). + + User extensions can register their own location tags and tagging and + deserialization methods using register_package. Args: f: a file-like object (has to implement fileno that returns a file @@ -228,8 +242,11 @@ def load(f, map_location=None, pickle_module=pickle): >>> torch.load('tensors.pt') # Load all tensors onto the CPU >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage) + # Load all tensors onto GPU 1 + >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) # Map tensors from GPU 1 to GPU 0 >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'}) + """ new_fd = False if isinstance(f, str) or (sys.version_info[0] == 2 and isinstance(f, unicode)): |