diff options
author | James Reed <jamesreed@fb.com> | 2019-04-23 15:24:41 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-23 15:40:11 -0700 |
commit | 80020b3d2d310dd00835756b231371da64026acc (patch) | |
tree | 8f2ba303fe106637cb983e9a167e048cd1e03349 | |
parent | fb9fc42a0c00a04aaf5574ae4b932dc90221d147 (diff) | |
download | pytorch-80020b3d2d310dd00835756b231371da64026acc.tar.gz pytorch-80020b3d2d310dd00835756b231371da64026acc.tar.bz2 pytorch-80020b3d2d310dd00835756b231371da64026acc.zip |
Guard {set,rebase}_history on grad_fn check (#19623)
Summary:
We would previously have statements like
```
set_history(flatten_tensor_args( result ), grad_fn);
```
Internally, {set,rebase}_history would check grad_fn and short circuit if it is nullptr. However, this means that we are executing the expression `flatten_tensor_args( result )` and immediately throwing away the results. This was causing unnecessary allocations + overhead.
My JIT overhead benchmark script (with custom benchmark method):
```
import torch, time
torch.jit.script
def add(x, y):
return x + y
a = torch.rand([])
b = torch.rand([])
niter = 1000000
with torch.no_grad():
s = time.time()
add.__getattr__('forward').benchmark(niter, a, b)
e = time.time() - s
print('overhead per call (us)', e / niter * 1e6)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19623
Differential Revision: D15053399
Pulled By: jamesr66a
fbshipit-source-id: 8777e1a2b5c5a5bbd3a035b7247c8154c5fc4aa6
-rw-r--r-- | tools/autograd/gen_variable_type.py | 4 | ||||
-rw-r--r-- | torch/csrc/autograd/functions/utils.h | 15 |
2 files changed, 10 insertions, 9 deletions
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 87f1318759..f7a68d8bcc 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -196,7 +196,9 @@ DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES = CodeTemplate("""\ """) SET_HISTORY = CodeTemplate("""\ -${fn}_history(${differentiable_outputs}, grad_fn); +if (grad_fn) { + ${fn}_history(${differentiable_outputs}, grad_fn); +} """) CONDITIONAL = CodeTemplate("""\ diff --git a/torch/csrc/autograd/functions/utils.h b/torch/csrc/autograd/functions/utils.h index c63252605e..b4ee46dcec 100644 --- a/torch/csrc/autograd/functions/utils.h +++ b/torch/csrc/autograd/functions/utils.h @@ -51,14 +51,13 @@ inline bool compute_requires_grad(Args&&... args) { inline void set_history( at::Tensor& variable, const std::shared_ptr<Function>& grad_fn) { - if (grad_fn) { - if (variable.defined()) { - auto output_nr = - grad_fn->add_input_metadata(variable); - as_variable_ref(variable).set_gradient_edge({grad_fn, output_nr}); - } else { - grad_fn->add_input_metadata(Function::undefined_input()); - } + AT_ASSERT(grad_fn); + if (variable.defined()) { + auto output_nr = + grad_fn->add_input_metadata(variable); + as_variable_ref(variable).set_gradient_edge({grad_fn, output_nr}); + } else { + grad_fn->add_input_metadata(Function::undefined_input()); } } |