summaryrefslogtreecommitdiff
path: root/setup.py
diff options
context:
space:
mode:
authorEdward Z. Yang <ezyang@mit.edu>2018-01-26 15:56:39 -0500
committerGitHub <noreply@github.com>2018-01-26 15:56:39 -0500
commitb8ab7bee26c41ceb90a1f944dd27c135699adaf2 (patch)
treeb721a9a06e064ee8195e7da188685c4338f1bd42 /setup.py
parent24177adc128fc76981a64c862872acc3da99d4bb (diff)
downloadpytorch-b8ab7bee26c41ceb90a1f944dd27c135699adaf2.tar.gz
pytorch-b8ab7bee26c41ceb90a1f944dd27c135699adaf2.tar.bz2
pytorch-b8ab7bee26c41ceb90a1f944dd27c135699adaf2.zip
Use variadic templates instead of initializer lists and overloads. (#4772)
Suppose you are given a list of arguments, each of which may be Tensor or TensorList. How can you write a function that can treat these arguments uniformly as a list of tensors? This patch solves the problem using variadic templates. Why variadic templates? Use of variadic templates means anyone working with this code has to understand universal references, perfect forwarding, parameter packs and some idioms of C++ template design. However, I argue that variadic templates are the *right* tool for supporting the implementation of functions which must take an arbitrarily heterogenous set of inputs. We were able to limp by in old code because, for the most part, tensor inputs were homogenous, but this is no longer the case for some non-primitively differentiable functions; and with the upcoming cuDNN RNN in ATen PR, will no longer be the case for primitively differentiable functions too. There are two parts to the PR. First, we add torch/csrc/utils/variadic.h, which defines a mix-in IterArgs that takes any class which supports operator(), and augments with a new variadic function apply() which calls operator() on each argument passed to it. In an original draft of the patch, I wrote the recursion for each parameter pack from scratch for each function; however, it turns out there are no fewer than seven instances where we need this idiom, and the mix-in reduces the lines of code, and also helps centralize the most important (and easy to forget) boilerplate for perfect forwarding. To verify that IterArgs is compiled away into an unrolled form per call site, I inspected the assembly on some synthetic examples. Next, we modify the following functions to make use of IterArgs: - compute_requires_grad - Function::flags (Variable and Tensor variants) - flatten - isTracing - count_tensors / count_variables Finally, the tuple packer is rewritten to be variadic, although we cannot make use of IterArgs (since we are given a tuple). It might make sense to refactor the code into a generic piece which invokes a function with the arguments specified by a tuple, and then an appropriate IterArgs, but we leave this for future work. One thing to note: we cannot write a function with overloads for both Tensor and Variable, because both ArrayRef<Variable> and Tensor have implicit conversions from Variable, making such an overload ambiguous. It may be interesting to remove the implicit conversion from ArrayRef. Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Diffstat (limited to 'setup.py')
-rw-r--r--setup.py1
1 files changed, 1 insertions, 0 deletions
diff --git a/setup.py b/setup.py
index ca171db0dc..b15801966c 100644
--- a/setup.py
+++ b/setup.py
@@ -455,6 +455,7 @@ main_sources = [
"torch/csrc/utils/tuple_parser.cpp",
"torch/csrc/utils/tensor_apply.cpp",
"torch/csrc/utils/tensor_flatten.cpp",
+ "torch/csrc/utils/variadic.cpp",
"torch/csrc/allocators.cpp",
"torch/csrc/serialization.cpp",
"torch/csrc/jit/init.cpp",