diff options
Diffstat (limited to 'torch')
-rw-r--r-- | torch/serialization.py | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/torch/serialization.py b/torch/serialization.py index ffae1be547..1bf7d3219b 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -161,6 +161,24 @@ def _should_read_directly(f): return False +def _check_seekable(f): + + def raise_err_msg(patterns, e): + for p in patterns: + if p in str(e): + msg = (str(e) + ". You can only torch.load from a file that is seekable." + + " Please pre-load the data into a buffer like io.BytesIO and" + + " try to load from it instead.") + raise type(e)(msg) + raise e + + try: + f.seek(f.tell()) + return True + except (io.UnsupportedOperation, AttributeError) as e: + raise_err_msg(["seek", "tell"], e) + + def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL): """Saves an object to a disk file. @@ -479,7 +497,9 @@ def _load(f, map_location, pickle_module): else: raise RuntimeError("Unknown saved id type: %s" % saved_id[0]) + _check_seekable(f) f_should_read_directly = _should_read_directly(f) + if f_should_read_directly and f.tell() == 0: # legacy_load requires that f has fileno() # only if offset is zero we can attempt the legacy tar file loader |