diff options
author | Edward Z. Yang <ezyang@mit.edu> | 2018-01-26 15:56:39 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-01-26 15:56:39 -0500 |
commit | b8ab7bee26c41ceb90a1f944dd27c135699adaf2 (patch) | |
tree | b721a9a06e064ee8195e7da188685c4338f1bd42 /setup.py | |
parent | 24177adc128fc76981a64c862872acc3da99d4bb (diff) | |
download | pytorch-b8ab7bee26c41ceb90a1f944dd27c135699adaf2.tar.gz pytorch-b8ab7bee26c41ceb90a1f944dd27c135699adaf2.tar.bz2 pytorch-b8ab7bee26c41ceb90a1f944dd27c135699adaf2.zip |
Use variadic templates instead of initializer lists and overloads. (#4772)
Suppose you are given a list of arguments, each of which may be Tensor or
TensorList. How can you write a function that can treat these arguments
uniformly as a list of tensors? This patch solves the problem using
variadic templates.
Why variadic templates? Use of variadic templates means anyone working
with this code has to understand universal references, perfect
forwarding, parameter packs and some idioms of C++ template design.
However, I argue that variadic templates are the *right* tool for
supporting the implementation of functions which must take an
arbitrarily heterogenous set of inputs. We were able to limp by
in old code because, for the most part, tensor inputs were homogenous,
but this is no longer the case for some non-primitively differentiable
functions; and with the upcoming cuDNN RNN in ATen PR, will no longer be
the case for primitively differentiable functions too.
There are two parts to the PR.
First, we add torch/csrc/utils/variadic.h, which defines a mix-in
IterArgs that takes any class which supports operator(), and augments
with a new variadic function apply() which calls operator() on each
argument passed to it. In an original draft of the patch, I wrote the
recursion for each parameter pack from scratch for each function;
however, it turns out there are no fewer than seven instances where we
need this idiom, and the mix-in reduces the lines of code, and also
helps centralize the most important (and easy to forget) boilerplate
for perfect forwarding.
To verify that IterArgs is compiled away into an unrolled form per
call site, I inspected the assembly on some synthetic examples.
Next, we modify the following functions to make use of IterArgs:
- compute_requires_grad
- Function::flags (Variable and Tensor variants)
- flatten
- isTracing
- count_tensors / count_variables
Finally, the tuple packer is rewritten to be variadic, although we
cannot make use of IterArgs (since we are given a tuple). It might
make sense to refactor the code into a generic piece which invokes
a function with the arguments specified by a tuple, and then an
appropriate IterArgs, but we leave this for future work.
One thing to note: we cannot write a function with overloads for both
Tensor and Variable, because both ArrayRef<Variable> and Tensor have
implicit conversions from Variable, making such an overload ambiguous.
It may be interesting to remove the implicit conversion from ArrayRef.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Diffstat (limited to 'setup.py')
-rw-r--r-- | setup.py | 1 |
1 files changed, 1 insertions, 0 deletions
@@ -455,6 +455,7 @@ main_sources = [ "torch/csrc/utils/tuple_parser.cpp", "torch/csrc/utils/tensor_apply.cpp", "torch/csrc/utils/tensor_flatten.cpp", + "torch/csrc/utils/variadic.cpp", "torch/csrc/allocators.cpp", "torch/csrc/serialization.cpp", "torch/csrc/jit/init.cpp", |