summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Reed <jamesreed@fb.com>2019-04-23 15:24:41 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-23 15:40:11 -0700
commit80020b3d2d310dd00835756b231371da64026acc (patch)
tree8f2ba303fe106637cb983e9a167e048cd1e03349
parentfb9fc42a0c00a04aaf5574ae4b932dc90221d147 (diff)
downloadpytorch-80020b3d2d310dd00835756b231371da64026acc.tar.gz
pytorch-80020b3d2d310dd00835756b231371da64026acc.tar.bz2
pytorch-80020b3d2d310dd00835756b231371da64026acc.zip
Guard {set,rebase}_history on grad_fn check (#19623)
Summary: We would previously have statements like ``` set_history(flatten_tensor_args( result ), grad_fn); ``` Internally, {set,rebase}_history would check grad_fn and short circuit if it is nullptr. However, this means that we are executing the expression `flatten_tensor_args( result )` and immediately throwing away the results. This was causing unnecessary allocations + overhead. My JIT overhead benchmark script (with custom benchmark method): ``` import torch, time torch.jit.script def add(x, y): return x + y a = torch.rand([]) b = torch.rand([]) niter = 1000000 with torch.no_grad(): s = time.time() add.__getattr__('forward').benchmark(niter, a, b) e = time.time() - s print('overhead per call (us)', e / niter * 1e6) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/19623 Differential Revision: D15053399 Pulled By: jamesr66a fbshipit-source-id: 8777e1a2b5c5a5bbd3a035b7247c8154c5fc4aa6
-rw-r--r--tools/autograd/gen_variable_type.py4
-rw-r--r--torch/csrc/autograd/functions/utils.h15
2 files changed, 10 insertions, 9 deletions
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 87f1318759..f7a68d8bcc 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -196,7 +196,9 @@ DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES = CodeTemplate("""\
""")
SET_HISTORY = CodeTemplate("""\
-${fn}_history(${differentiable_outputs}, grad_fn);
+if (grad_fn) {
+ ${fn}_history(${differentiable_outputs}, grad_fn);
+}
""")
CONDITIONAL = CodeTemplate("""\
diff --git a/torch/csrc/autograd/functions/utils.h b/torch/csrc/autograd/functions/utils.h
index c63252605e..b4ee46dcec 100644
--- a/torch/csrc/autograd/functions/utils.h
+++ b/torch/csrc/autograd/functions/utils.h
@@ -51,14 +51,13 @@ inline bool compute_requires_grad(Args&&... args) {
inline void set_history(
at::Tensor& variable,
const std::shared_ptr<Function>& grad_fn) {
- if (grad_fn) {
- if (variable.defined()) {
- auto output_nr =
- grad_fn->add_input_metadata(variable);
- as_variable_ref(variable).set_gradient_edge({grad_fn, output_nr});
- } else {
- grad_fn->add_input_metadata(Function::undefined_input());
- }
+ AT_ASSERT(grad_fn);
+ if (variable.defined()) {
+ auto output_nr =
+ grad_fn->add_input_metadata(variable);
+ as_variable_ref(variable).set_gradient_edge({grad_fn, output_nr});
+ } else {
+ grad_fn->add_input_metadata(Function::undefined_input());
}
}