diff options
author | Edward Z. Yang <ezyang@fb.com> | 2017-08-31 08:46:30 -0700 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-09-05 17:48:55 -0400 |
commit | 57eb8bd28864713e63f82172329aa891c719a6c6 (patch) | |
tree | 268e378bcfd3443afebabf70188ae4bb1b5ed51a /torch | |
parent | 6ae77b32b993b5407046b44b9ebf0748f98690d3 (diff) | |
download | pytorch-57eb8bd28864713e63f82172329aa891c719a6c6.tar.gz pytorch-57eb8bd28864713e63f82172329aa891c719a6c6.tar.bz2 pytorch-57eb8bd28864713e63f82172329aa891c719a6c6.zip |
Frontend refactor, and some documentation.
- BC BREAKING: export now also takes a mandatory file-ish argument, specifying
the file to export the protobuf to. I rewrote the tests to use BytesIO to
get out the string so they could parse it again.
- BC BREAKING: export no longer returns the tensors that were computed. To
get these, use the internal _export function.
- Multiple inputs to models are now supported by passing a tuple to input.
(Old API of a single Variable still works.)
- Keyword arguments to models are now supported via kwargs keyword arg.
- Renamed embed_params to export_params, and it now defaults to True.
- Toffee tests now live in their own test_toffee.py file. I had to
rename a pile of expect files for this.
- Removed defunct torch.toffee imports from autograd to solve module import
cycle.
- Helper function _with_file_like to abstract over opening file-ish arguments,
taken from torch.save()
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Diffstat (limited to 'torch')
-rw-r--r-- | torch/autograd/_functions/basic_ops.py | 1 | ||||
-rw-r--r-- | torch/autograd/_functions/blas.py | 1 | ||||
-rw-r--r-- | torch/autograd/_functions/tensor.py | 1 | ||||
-rw-r--r-- | torch/nn/_functions/thnn/auto_primspec.py | 3 | ||||
-rw-r--r-- | torch/nn/_functions/thnn/pooling.py | 1 | ||||
-rw-r--r-- | torch/serialization.py | 26 | ||||
-rw-r--r-- | torch/toffee.py | 58 |
7 files changed, 67 insertions, 24 deletions
diff --git a/torch/autograd/_functions/basic_ops.py b/torch/autograd/_functions/basic_ops.py index ca1d0b520b..b5ffa3abdd 100644 --- a/torch/autograd/_functions/basic_ops.py +++ b/torch/autograd/_functions/basic_ops.py @@ -1,5 +1,4 @@ import torch -import torch.toffee from ..function import Function, InplaceFunction from .utils import maybe_unexpand, maybe_unexpand_or_view import math diff --git a/torch/autograd/_functions/blas.py b/torch/autograd/_functions/blas.py index ac769b9b26..4903b13b1e 100644 --- a/torch/autograd/_functions/blas.py +++ b/torch/autograd/_functions/blas.py @@ -1,5 +1,4 @@ import torch -import torch.toffee from ..function import Function, InplaceFunction from .utils import maybe_unexpand diff --git a/torch/autograd/_functions/tensor.py b/torch/autograd/_functions/tensor.py index e0f871d4af..caaee08161 100644 --- a/torch/autograd/_functions/tensor.py +++ b/torch/autograd/_functions/tensor.py @@ -1,6 +1,5 @@ from functools import reduce import torch -import torch.toffee from torch._utils import _accumulate from ..function import Function, InplaceFunction, once_differentiable diff --git a/torch/nn/_functions/thnn/auto_primspec.py b/torch/nn/_functions/thnn/auto_primspec.py index 5ed0090d54..64e17611a9 100644 --- a/torch/nn/_functions/thnn/auto_primspec.py +++ b/torch/nn/_functions/thnn/auto_primspec.py @@ -1,6 +1,3 @@ -import torch.toffee - - def threshold_primspec(g, input, threshold=0, value=0, inplace=False): if inplace or threshold != 0 or value != 0: return None diff --git a/torch/nn/_functions/thnn/pooling.py b/torch/nn/_functions/thnn/pooling.py index a12a60f4eb..2ce62f2779 100644 --- a/torch/nn/_functions/thnn/pooling.py +++ b/torch/nn/_functions/thnn/pooling.py @@ -1,7 +1,6 @@ from torch.autograd import Variable from torch.autograd.function import Function, once_differentiable from torch._thnn import type2backend -import torch.toffee from . import _all_functions from torch.nn.modules.utils import _single, _pair, _triple diff --git a/torch/serialization.py b/torch/serialization.py index da496894f6..b08fc05638 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -100,6 +100,22 @@ def storage_to_tensor_type(storage): return getattr(module, storage_type.__name__.replace('Storage', 'Tensor')) +def _with_file_like(f, mode, body): + """ + Executes a body function with a file object for f, opening + it in 'mode' if it is a string filename. + """ + new_fd = False + if isinstance(f, str) or (sys.version_info[0] == 2 and isinstance(f, unicode)): + new_fd = True + f = open(f, mode) + try: + return body(f) + finally: + if new_fd: + f.close() + + def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL): """Saves an object to a disk file. @@ -112,15 +128,7 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL): pickle_module: module used for pickling metadata and objects pickle_protocol: can be specified to override the default protocol """ - new_fd = False - if isinstance(f, str) or (sys.version_info[0] == 2 and isinstance(f, unicode)): - new_fd = True - f = open(f, "wb") - try: - return _save(obj, f, pickle_module, pickle_protocol) - finally: - if new_fd: - f.close() + return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol)) def _save(obj, f, pickle_module, pickle_protocol): diff --git a/torch/toffee.py b/torch/toffee.py index 6f5f82186b..67e566e4c7 100644 --- a/torch/toffee.py +++ b/torch/toffee.py @@ -1,13 +1,55 @@ +""" +The torch.toffee module contains functions to export models into the Toffee +IR format. These models can be loaded with the ToffeeIR library and then +converted to models which run on other deep learning frameworks. +""" + import torch +import torch.jit +import torch.autograd +import torch.serialization -def export(model, input, embed_params): +def export(model, args, f, export_params=True, kwargs=None): + """ + Export a model into Toffee format. This exporter runs your model + once in order to get a trace of its execution to be exported; at the + moment, it does not support dynamic models (e.g., RNNs.) - # Enable tracing on the model - trace, torch_out = torch.jit.record_trace(model, input) - if embed_params is False: - proto = trace.export() - else: + See also: :ref:`toffee-export` + + Arguments: + model (torch.nn.Module): the model to be exported. + args (torch.autograd.Variable or tuple of variables): the inputs to + the model, e.g., such that ``model(*args)`` is a valid invocation + of the model. + f: a file-like object (has to implement fileno that returns a file descriptor) + or a string containing a file name. A binary Protobuf will be written + to this file. + export_params (bool, default True): if specified, all parameters will + be exported. Set this to False if you are exporting an + untrained model. + kwargs (dict, optional): keyword inputs to the model. + """ + _export(model, args, f, export_params=export_params, kwargs=None) + + +# Internal helper function which also returns the computed tensors, which +# can be useful for comparing PyTorch's execution of the model with the +# eventual runner. +def _export(model, args, f, export_params=True, kwargs=None): + # Special case for common case of passing a single Variable + if isinstance(args, torch.autograd.Variable): + args = (args, ) + if not kwargs: + kwargs = {} + trace, torch_out = torch.jit.record_trace(model, *args, **kwargs) + # TODO: Don't allocate a in-memory string for the protobuf + if export_params: proto = trace.export(model.state_dict().values()) - # TODO: a way to print the proto - return proto, torch_out + else: + proto = trace.export() + torch.serialization._with_file_like(f, "wb", lambda f: f.write(proto)) + return torch_out + # NB: It's very important that trace dies at the end of this function; + # otherwise you can't retrace the model. |