summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
Diffstat (limited to 'torch')
-rw-r--r--torch/serialization.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/torch/serialization.py b/torch/serialization.py
index ffae1be547..1bf7d3219b 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -161,6 +161,24 @@ def _should_read_directly(f):
return False
+def _check_seekable(f):
+
+ def raise_err_msg(patterns, e):
+ for p in patterns:
+ if p in str(e):
+ msg = (str(e) + ". You can only torch.load from a file that is seekable." +
+ " Please pre-load the data into a buffer like io.BytesIO and" +
+ " try to load from it instead.")
+ raise type(e)(msg)
+ raise e
+
+ try:
+ f.seek(f.tell())
+ return True
+ except (io.UnsupportedOperation, AttributeError) as e:
+ raise_err_msg(["seek", "tell"], e)
+
+
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
"""Saves an object to a disk file.
@@ -479,7 +497,9 @@ def _load(f, map_location, pickle_module):
else:
raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
+ _check_seekable(f)
f_should_read_directly = _should_read_directly(f)
+
if f_should_read_directly and f.tell() == 0:
# legacy_load requires that f has fileno()
# only if offset is zero we can attempt the legacy tar file loader