diff options
author | Wei Yang <38509346+weiyangfb@users.noreply.github.com> | 2018-06-12 12:57:28 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-12 12:57:28 -0700 |
commit | c3e4b3c88be23698f833e539606bc5567b7a1162 (patch) | |
tree | 792b96be5b63b5b9c4fcea6f3df360138e49c8c7 /torch | |
parent | c6db1bc952f565752380cf550cb3ee8d4ec974df (diff) | |
download | pytorch-c3e4b3c88be23698f833e539606bc5567b7a1162.tar.gz pytorch-c3e4b3c88be23698f833e539606bc5567b7a1162.tar.bz2 pytorch-c3e4b3c88be23698f833e539606bc5567b7a1162.zip |
raise more informative error msg for torch.load not support seek (#7754)
Raising more informative error msg for torch.load() when input file does not support seek() or tell()
Diffstat (limited to 'torch')
-rw-r--r-- | torch/serialization.py | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/torch/serialization.py b/torch/serialization.py index ffae1be547..1bf7d3219b 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -161,6 +161,24 @@ def _should_read_directly(f): return False +def _check_seekable(f): + + def raise_err_msg(patterns, e): + for p in patterns: + if p in str(e): + msg = (str(e) + ". You can only torch.load from a file that is seekable." + + " Please pre-load the data into a buffer like io.BytesIO and" + + " try to load from it instead.") + raise type(e)(msg) + raise e + + try: + f.seek(f.tell()) + return True + except (io.UnsupportedOperation, AttributeError) as e: + raise_err_msg(["seek", "tell"], e) + + def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL): """Saves an object to a disk file. @@ -479,7 +497,9 @@ def _load(f, map_location, pickle_module): else: raise RuntimeError("Unknown saved id type: %s" % saved_id[0]) + _check_seekable(f) f_should_read_directly = _should_read_directly(f) + if f_should_read_directly and f.tell() == 0: # legacy_load requires that f has fileno() # only if offset is zero we can attempt the legacy tar file loader |