diff options
Diffstat (limited to 'torch/serialization.py')
-rw-r--r-- | torch/serialization.py | 62 |
1 files changed, 50 insertions, 12 deletions
diff --git a/torch/serialization.py b/torch/serialization.py index 1473202a59..934e4d35f7 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1,6 +1,7 @@ import difflib import inspect import os +import io import shutil import struct import sys @@ -120,6 +121,16 @@ def _with_file_like(f, mode, body): f.close() +def _is_real_file(f): + """Checks if f is backed by a real file (has a fileno)""" + try: + return f.fileno() >= 0 + except io.UnsupportedOperation: + return False + except AttributeError: + return False + + def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL): """Saves an object to a disk file. @@ -127,15 +138,39 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL): Args: obj: saved object - f: a file-like object (has to implement fileno that returns a file descriptor) - or a string containing a file name + f: a file-like object (has to implement write and flush) or a string + containing a file name pickle_module: module used for pickling metadata and objects pickle_protocol: can be specified to override the default protocol + + .. warning:: + If you are using Python 2, torch.save does NOT support StringIO.StringIO + as a valid file-like object. This is because the write method should return + the number of bytes written; StringIO.write() does not do this. + + Please use something like io.BytesIO instead. + + Example: + # Save to file + >>> x = torch.Tensor([0, 1, 2, 3, 4]) + >>> torch.save(x, 'tensor.pt') + + # Save to io.BytesIO buffer + >>> buffer = io.BytesIO() + >>> torch.save(x, buffer) """ return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol)) def _save(obj, f, pickle_module, pickle_protocol): + if sys.version_info[0] == 2: + import StringIO + if isinstance(f, StringIO.StringIO): + msg = ('torch.save received unsupported StringIO.StringIO file object, whose ' + 'write method does not return the number of bytes written. ' + 'Please use something like io.BytesIO for torch.save instead.') + raise RuntimeError(msg) + import torch.nn as nn serialized_container_types = {} serialized_storages = {} @@ -201,7 +236,7 @@ def _save(obj, f, pickle_module, pickle_protocol): pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol) f.flush() for key in serialized_storage_keys: - serialized_storages[key]._write_file(f) + serialized_storages[key]._write_file(f, _is_real_file(f)) def load(f, map_location=None, pickle_module=pickle): @@ -237,9 +272,8 @@ def load(f, map_location=None, pickle_module=pickle): deserialization methods using `register_package`. Args: - f: a file-like object (has to implement fileno that returns a file - descriptor, and must implement seek), or a string containing a file - name + f: a file-like object (has to implement read, readline, tell, and seek), + or a string containing a file name map_location: a function, string or a dict specifying how to remap storage locations pickle_module: module used for unpickling metadata and objects (has to @@ -255,7 +289,10 @@ def load(f, map_location=None, pickle_module=pickle): >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) # Map tensors from GPU 1 to GPU 0 >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'}) - + # Load tensor from io.BytesIO object + >>> with open('tensor.pt') as f: + buffer = io.BytesIO(f.read()) + >>> torch.load(buffer) """ new_fd = False if isinstance(f, str) or \ @@ -410,14 +447,15 @@ def _load(f, map_location, pickle_module): else: raise RuntimeError("Unknown saved id type: %s" % saved_id[0]) - foffset = f.tell() - if foffset == 0: + f_is_real_file = _is_real_file(f) + if f_is_real_file and f.tell() == 0: + # legacy_load requires that f has fileno() # only if offset is zero we can attempt the legacy tar file loader try: return legacy_load(f) except tarfile.TarError: # if not a tarfile, reset file offset and proceed - f.seek(foffset) + f.seek(0) magic_number = pickle_module.load(f) if magic_number != MAGIC_NUMBER: @@ -433,10 +471,10 @@ def _load(f, map_location, pickle_module): deserialized_storage_keys = pickle_module.load(f) - offset = f.tell() + offset = f.tell() if f_is_real_file else None for key in deserialized_storage_keys: assert key in deserialized_objects - deserialized_objects[key]._set_from_file(f, offset) + deserialized_objects[key]._set_from_file(f, offset, f_is_real_file) offset = None return result |