summaryrefslogtreecommitdiff
path: root/torch/serialization.py
diff options
context:
space:
mode:
Diffstat (limited to 'torch/serialization.py')
-rw-r--r--torch/serialization.py62
1 files changed, 50 insertions, 12 deletions
diff --git a/torch/serialization.py b/torch/serialization.py
index 1473202a59..934e4d35f7 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -1,6 +1,7 @@
import difflib
import inspect
import os
+import io
import shutil
import struct
import sys
@@ -120,6 +121,16 @@ def _with_file_like(f, mode, body):
f.close()
+def _is_real_file(f):
+ """Checks if f is backed by a real file (has a fileno)"""
+ try:
+ return f.fileno() >= 0
+ except io.UnsupportedOperation:
+ return False
+ except AttributeError:
+ return False
+
+
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
"""Saves an object to a disk file.
@@ -127,15 +138,39 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
Args:
obj: saved object
- f: a file-like object (has to implement fileno that returns a file descriptor)
- or a string containing a file name
+ f: a file-like object (has to implement write and flush) or a string
+ containing a file name
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
+
+ .. warning::
+ If you are using Python 2, torch.save does NOT support StringIO.StringIO
+ as a valid file-like object. This is because the write method should return
+ the number of bytes written; StringIO.write() does not do this.
+
+ Please use something like io.BytesIO instead.
+
+ Example:
+ # Save to file
+ >>> x = torch.Tensor([0, 1, 2, 3, 4])
+ >>> torch.save(x, 'tensor.pt')
+
+ # Save to io.BytesIO buffer
+ >>> buffer = io.BytesIO()
+ >>> torch.save(x, buffer)
"""
return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
def _save(obj, f, pickle_module, pickle_protocol):
+ if sys.version_info[0] == 2:
+ import StringIO
+ if isinstance(f, StringIO.StringIO):
+ msg = ('torch.save received unsupported StringIO.StringIO file object, whose '
+ 'write method does not return the number of bytes written. '
+ 'Please use something like io.BytesIO for torch.save instead.')
+ raise RuntimeError(msg)
+
import torch.nn as nn
serialized_container_types = {}
serialized_storages = {}
@@ -201,7 +236,7 @@ def _save(obj, f, pickle_module, pickle_protocol):
pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
f.flush()
for key in serialized_storage_keys:
- serialized_storages[key]._write_file(f)
+ serialized_storages[key]._write_file(f, _is_real_file(f))
def load(f, map_location=None, pickle_module=pickle):
@@ -237,9 +272,8 @@ def load(f, map_location=None, pickle_module=pickle):
deserialization methods using `register_package`.
Args:
- f: a file-like object (has to implement fileno that returns a file
- descriptor, and must implement seek), or a string containing a file
- name
+ f: a file-like object (has to implement read, readline, tell, and seek),
+ or a string containing a file name
map_location: a function, string or a dict specifying how to remap storage
locations
pickle_module: module used for unpickling metadata and objects (has to
@@ -255,7 +289,10 @@ def load(f, map_location=None, pickle_module=pickle):
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
-
+ # Load tensor from io.BytesIO object
+ >>> with open('tensor.pt') as f:
+ buffer = io.BytesIO(f.read())
+ >>> torch.load(buffer)
"""
new_fd = False
if isinstance(f, str) or \
@@ -410,14 +447,15 @@ def _load(f, map_location, pickle_module):
else:
raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
- foffset = f.tell()
- if foffset == 0:
+ f_is_real_file = _is_real_file(f)
+ if f_is_real_file and f.tell() == 0:
+ # legacy_load requires that f has fileno()
# only if offset is zero we can attempt the legacy tar file loader
try:
return legacy_load(f)
except tarfile.TarError:
# if not a tarfile, reset file offset and proceed
- f.seek(foffset)
+ f.seek(0)
magic_number = pickle_module.load(f)
if magic_number != MAGIC_NUMBER:
@@ -433,10 +471,10 @@ def _load(f, map_location, pickle_module):
deserialized_storage_keys = pickle_module.load(f)
- offset = f.tell()
+ offset = f.tell() if f_is_real_file else None
for key in deserialized_storage_keys:
assert key in deserialized_objects
- deserialized_objects[key]._set_from_file(f, offset)
+ deserialized_objects[key]._set_from_file(f, offset, f_is_real_file)
offset = None
return result