summaryrefslogtreecommitdiff
path: root/torch/serialization.py
diff options
context:
space:
mode:
authorEdward Z. Yang <ezyang@fb.com>2017-08-31 08:46:30 -0700
committerSoumith Chintala <soumith@gmail.com>2017-09-05 17:48:55 -0400
commit57eb8bd28864713e63f82172329aa891c719a6c6 (patch)
tree268e378bcfd3443afebabf70188ae4bb1b5ed51a /torch/serialization.py
parent6ae77b32b993b5407046b44b9ebf0748f98690d3 (diff)
downloadpytorch-57eb8bd28864713e63f82172329aa891c719a6c6.tar.gz
pytorch-57eb8bd28864713e63f82172329aa891c719a6c6.tar.bz2
pytorch-57eb8bd28864713e63f82172329aa891c719a6c6.zip
Frontend refactor, and some documentation.
- BC BREAKING: export now also takes a mandatory file-ish argument, specifying the file to export the protobuf to. I rewrote the tests to use BytesIO to get out the string so they could parse it again. - BC BREAKING: export no longer returns the tensors that were computed. To get these, use the internal _export function. - Multiple inputs to models are now supported by passing a tuple to input. (Old API of a single Variable still works.) - Keyword arguments to models are now supported via kwargs keyword arg. - Renamed embed_params to export_params, and it now defaults to True. - Toffee tests now live in their own test_toffee.py file. I had to rename a pile of expect files for this. - Removed defunct torch.toffee imports from autograd to solve module import cycle. - Helper function _with_file_like to abstract over opening file-ish arguments, taken from torch.save() Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Diffstat (limited to 'torch/serialization.py')
-rw-r--r--torch/serialization.py26
1 files changed, 17 insertions, 9 deletions
diff --git a/torch/serialization.py b/torch/serialization.py
index da496894f6..b08fc05638 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -100,6 +100,22 @@ def storage_to_tensor_type(storage):
return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
+def _with_file_like(f, mode, body):
+ """
+ Executes a body function with a file object for f, opening
+ it in 'mode' if it is a string filename.
+ """
+ new_fd = False
+ if isinstance(f, str) or (sys.version_info[0] == 2 and isinstance(f, unicode)):
+ new_fd = True
+ f = open(f, mode)
+ try:
+ return body(f)
+ finally:
+ if new_fd:
+ f.close()
+
+
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
"""Saves an object to a disk file.
@@ -112,15 +128,7 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
"""
- new_fd = False
- if isinstance(f, str) or (sys.version_info[0] == 2 and isinstance(f, unicode)):
- new_fd = True
- f = open(f, "wb")
- try:
- return _save(obj, f, pickle_module, pickle_protocol)
- finally:
- if new_fd:
- f.close()
+ return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
def _save(obj, f, pickle_module, pickle_protocol):