diff options
author | Edward Z. Yang <ezyang@fb.com> | 2017-08-31 08:46:30 -0700 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-09-05 17:48:55 -0400 |
commit | 57eb8bd28864713e63f82172329aa891c719a6c6 (patch) | |
tree | 268e378bcfd3443afebabf70188ae4bb1b5ed51a /torch/serialization.py | |
parent | 6ae77b32b993b5407046b44b9ebf0748f98690d3 (diff) | |
download | pytorch-57eb8bd28864713e63f82172329aa891c719a6c6.tar.gz pytorch-57eb8bd28864713e63f82172329aa891c719a6c6.tar.bz2 pytorch-57eb8bd28864713e63f82172329aa891c719a6c6.zip |
Frontend refactor, and some documentation.
- BC BREAKING: export now also takes a mandatory file-ish argument, specifying
the file to export the protobuf to. I rewrote the tests to use BytesIO to
get out the string so they could parse it again.
- BC BREAKING: export no longer returns the tensors that were computed. To
get these, use the internal _export function.
- Multiple inputs to models are now supported by passing a tuple to input.
(Old API of a single Variable still works.)
- Keyword arguments to models are now supported via kwargs keyword arg.
- Renamed embed_params to export_params, and it now defaults to True.
- Toffee tests now live in their own test_toffee.py file. I had to
rename a pile of expect files for this.
- Removed defunct torch.toffee imports from autograd to solve module import
cycle.
- Helper function _with_file_like to abstract over opening file-ish arguments,
taken from torch.save()
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Diffstat (limited to 'torch/serialization.py')
-rw-r--r-- | torch/serialization.py | 26 |
1 files changed, 17 insertions, 9 deletions
diff --git a/torch/serialization.py b/torch/serialization.py index da496894f6..b08fc05638 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -100,6 +100,22 @@ def storage_to_tensor_type(storage): return getattr(module, storage_type.__name__.replace('Storage', 'Tensor')) +def _with_file_like(f, mode, body): + """ + Executes a body function with a file object for f, opening + it in 'mode' if it is a string filename. + """ + new_fd = False + if isinstance(f, str) or (sys.version_info[0] == 2 and isinstance(f, unicode)): + new_fd = True + f = open(f, mode) + try: + return body(f) + finally: + if new_fd: + f.close() + + def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL): """Saves an object to a disk file. @@ -112,15 +128,7 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL): pickle_module: module used for pickling metadata and objects pickle_protocol: can be specified to override the default protocol """ - new_fd = False - if isinstance(f, str) or (sys.version_info[0] == 2 and isinstance(f, unicode)): - new_fd = True - f = open(f, "wb") - try: - return _save(obj, f, pickle_module, pickle_protocol) - finally: - if new_fd: - f.close() + return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol)) def _save(obj, f, pickle_module, pickle_protocol): |