summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorEdward Z. Yang <ezyang@fb.com>2017-08-31 08:46:30 -0700
committerSoumith Chintala <soumith@gmail.com>2017-09-05 17:48:55 -0400
commit57eb8bd28864713e63f82172329aa891c719a6c6 (patch)
tree268e378bcfd3443afebabf70188ae4bb1b5ed51a /torch
parent6ae77b32b993b5407046b44b9ebf0748f98690d3 (diff)
downloadpytorch-57eb8bd28864713e63f82172329aa891c719a6c6.tar.gz
pytorch-57eb8bd28864713e63f82172329aa891c719a6c6.tar.bz2
pytorch-57eb8bd28864713e63f82172329aa891c719a6c6.zip
Frontend refactor, and some documentation.
- BC BREAKING: export now also takes a mandatory file-ish argument, specifying the file to export the protobuf to. I rewrote the tests to use BytesIO to get out the string so they could parse it again. - BC BREAKING: export no longer returns the tensors that were computed. To get these, use the internal _export function. - Multiple inputs to models are now supported by passing a tuple to input. (Old API of a single Variable still works.) - Keyword arguments to models are now supported via kwargs keyword arg. - Renamed embed_params to export_params, and it now defaults to True. - Toffee tests now live in their own test_toffee.py file. I had to rename a pile of expect files for this. - Removed defunct torch.toffee imports from autograd to solve module import cycle. - Helper function _with_file_like to abstract over opening file-ish arguments, taken from torch.save() Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Diffstat (limited to 'torch')
-rw-r--r--torch/autograd/_functions/basic_ops.py1
-rw-r--r--torch/autograd/_functions/blas.py1
-rw-r--r--torch/autograd/_functions/tensor.py1
-rw-r--r--torch/nn/_functions/thnn/auto_primspec.py3
-rw-r--r--torch/nn/_functions/thnn/pooling.py1
-rw-r--r--torch/serialization.py26
-rw-r--r--torch/toffee.py58
7 files changed, 67 insertions, 24 deletions
diff --git a/torch/autograd/_functions/basic_ops.py b/torch/autograd/_functions/basic_ops.py
index ca1d0b520b..b5ffa3abdd 100644
--- a/torch/autograd/_functions/basic_ops.py
+++ b/torch/autograd/_functions/basic_ops.py
@@ -1,5 +1,4 @@
import torch
-import torch.toffee
from ..function import Function, InplaceFunction
from .utils import maybe_unexpand, maybe_unexpand_or_view
import math
diff --git a/torch/autograd/_functions/blas.py b/torch/autograd/_functions/blas.py
index ac769b9b26..4903b13b1e 100644
--- a/torch/autograd/_functions/blas.py
+++ b/torch/autograd/_functions/blas.py
@@ -1,5 +1,4 @@
import torch
-import torch.toffee
from ..function import Function, InplaceFunction
from .utils import maybe_unexpand
diff --git a/torch/autograd/_functions/tensor.py b/torch/autograd/_functions/tensor.py
index e0f871d4af..caaee08161 100644
--- a/torch/autograd/_functions/tensor.py
+++ b/torch/autograd/_functions/tensor.py
@@ -1,6 +1,5 @@
from functools import reduce
import torch
-import torch.toffee
from torch._utils import _accumulate
from ..function import Function, InplaceFunction, once_differentiable
diff --git a/torch/nn/_functions/thnn/auto_primspec.py b/torch/nn/_functions/thnn/auto_primspec.py
index 5ed0090d54..64e17611a9 100644
--- a/torch/nn/_functions/thnn/auto_primspec.py
+++ b/torch/nn/_functions/thnn/auto_primspec.py
@@ -1,6 +1,3 @@
-import torch.toffee
-
-
def threshold_primspec(g, input, threshold=0, value=0, inplace=False):
if inplace or threshold != 0 or value != 0:
return None
diff --git a/torch/nn/_functions/thnn/pooling.py b/torch/nn/_functions/thnn/pooling.py
index a12a60f4eb..2ce62f2779 100644
--- a/torch/nn/_functions/thnn/pooling.py
+++ b/torch/nn/_functions/thnn/pooling.py
@@ -1,7 +1,6 @@
from torch.autograd import Variable
from torch.autograd.function import Function, once_differentiable
from torch._thnn import type2backend
-import torch.toffee
from . import _all_functions
from torch.nn.modules.utils import _single, _pair, _triple
diff --git a/torch/serialization.py b/torch/serialization.py
index da496894f6..b08fc05638 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -100,6 +100,22 @@ def storage_to_tensor_type(storage):
return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
+def _with_file_like(f, mode, body):
+ """
+ Executes a body function with a file object for f, opening
+ it in 'mode' if it is a string filename.
+ """
+ new_fd = False
+ if isinstance(f, str) or (sys.version_info[0] == 2 and isinstance(f, unicode)):
+ new_fd = True
+ f = open(f, mode)
+ try:
+ return body(f)
+ finally:
+ if new_fd:
+ f.close()
+
+
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
"""Saves an object to a disk file.
@@ -112,15 +128,7 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
"""
- new_fd = False
- if isinstance(f, str) or (sys.version_info[0] == 2 and isinstance(f, unicode)):
- new_fd = True
- f = open(f, "wb")
- try:
- return _save(obj, f, pickle_module, pickle_protocol)
- finally:
- if new_fd:
- f.close()
+ return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
def _save(obj, f, pickle_module, pickle_protocol):
diff --git a/torch/toffee.py b/torch/toffee.py
index 6f5f82186b..67e566e4c7 100644
--- a/torch/toffee.py
+++ b/torch/toffee.py
@@ -1,13 +1,55 @@
+"""
+The torch.toffee module contains functions to export models into the Toffee
+IR format. These models can be loaded with the ToffeeIR library and then
+converted to models which run on other deep learning frameworks.
+"""
+
import torch
+import torch.jit
+import torch.autograd
+import torch.serialization
-def export(model, input, embed_params):
+def export(model, args, f, export_params=True, kwargs=None):
+ """
+ Export a model into Toffee format. This exporter runs your model
+ once in order to get a trace of its execution to be exported; at the
+ moment, it does not support dynamic models (e.g., RNNs.)
- # Enable tracing on the model
- trace, torch_out = torch.jit.record_trace(model, input)
- if embed_params is False:
- proto = trace.export()
- else:
+ See also: :ref:`toffee-export`
+
+ Arguments:
+ model (torch.nn.Module): the model to be exported.
+ args (torch.autograd.Variable or tuple of variables): the inputs to
+ the model, e.g., such that ``model(*args)`` is a valid invocation
+ of the model.
+ f: a file-like object (has to implement fileno that returns a file descriptor)
+ or a string containing a file name. A binary Protobuf will be written
+ to this file.
+ export_params (bool, default True): if specified, all parameters will
+ be exported. Set this to False if you are exporting an
+ untrained model.
+ kwargs (dict, optional): keyword inputs to the model.
+ """
+ _export(model, args, f, export_params=export_params, kwargs=None)
+
+
+# Internal helper function which also returns the computed tensors, which
+# can be useful for comparing PyTorch's execution of the model with the
+# eventual runner.
+def _export(model, args, f, export_params=True, kwargs=None):
+ # Special case for common case of passing a single Variable
+ if isinstance(args, torch.autograd.Variable):
+ args = (args, )
+ if not kwargs:
+ kwargs = {}
+ trace, torch_out = torch.jit.record_trace(model, *args, **kwargs)
+ # TODO: Don't allocate a in-memory string for the protobuf
+ if export_params:
proto = trace.export(model.state_dict().values())
- # TODO: a way to print the proto
- return proto, torch_out
+ else:
+ proto = trace.export()
+ torch.serialization._with_file_like(f, "wb", lambda f: f.write(proto))
+ return torch_out
+ # NB: It's very important that trace dies at the end of this function;
+ # otherwise you can't retrace the model.