summaryrefslogtreecommitdiff
path: root/torch/serialization.py
diff options
context:
space:
mode:
authorSam Gross <colesbury@gmail.com>2016-10-31 12:12:22 -0400
committerGitHub <noreply@github.com>2016-10-31 12:12:22 -0400
commitad5fdef6acb11862b74bb711734d00f46035896f (patch)
treefdeca371242c63c99bb09d4cef3bf1a9e1bba06d /torch/serialization.py
parent0cb5943be8e1581f0ee2b2a76d0b8eec81757654 (diff)
downloadpytorch-ad5fdef6acb11862b74bb711734d00f46035896f.tar.gz
pytorch-ad5fdef6acb11862b74bb711734d00f46035896f.tar.bz2
pytorch-ad5fdef6acb11862b74bb711734d00f46035896f.zip
Make every user-visible Tensor have a Storage (#179)
Diffstat (limited to 'torch/serialization.py')
-rw-r--r--torch/serialization.py10
1 files changed, 3 insertions, 7 deletions
diff --git a/torch/serialization.py b/torch/serialization.py
index 6358ab7e46..9fe609013c 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -248,12 +248,9 @@ def load(f, map_location=None, pickle_module=pickle):
for i in range(num_tensors):
args = pickle_module.load(f)
key, storage_id, original_tensor_type = args
- storage = deserialized_objects.get(storage_id, None)
- if storage:
- tensor_type = storage_to_tensor_type(storage)
- tensor = tensor_type._new_with_metadata_file(f, storage)
- else:
- tensor = original_tensor_type._new_with_metadata_file(f, storage)
+ storage = deserialized_objects[storage_id]
+ tensor_type = storage_to_tensor_type(storage)
+ tensor = tensor_type._new_with_metadata_file(f, storage)
deserialized_objects[key] = tensor
pickle_file = tar.extractfile('pickle')
@@ -261,4 +258,3 @@ def load(f, map_location=None, pickle_module=pickle):
unpickler.persistent_load = persistent_load
result = unpickler.load()
return result
-