summaryrefslogtreecommitdiff
path: root/torch/autograd
diff options
context:
space:
mode:
authorPriya Goyal <priy2201@gmail.com>2018-04-10 15:26:24 -0400
committerGitHub <noreply@github.com>2018-04-10 15:26:24 -0400
commite3196e0ea8e3d0a9830bbd2a87beadd12a8c450c (patch)
tree215d14289a9da6b8409c4e5dee1de81f838030b2 /torch/autograd
parent04c215b4454aa5816a15f60567800d65a8351d33 (diff)
downloadpytorch-e3196e0ea8e3d0a9830bbd2a87beadd12a8c450c.tar.gz
pytorch-e3196e0ea8e3d0a9830bbd2a87beadd12a8c450c.tar.bz2
pytorch-e3196e0ea8e3d0a9830bbd2a87beadd12a8c450c.zip
[Re-checkpointing] Autograd container for trading compute for memory (#6467)
* Autograd container for trading compute for memory * add a unit test for checkpoint * address comments * address review comments * adding some docs for the checkpoint api * more comments * more comments * repro bug * Fix a subtle bug/apply some review comments * Update checkpoint.py * Run everything in grad mode * fix flake and chunk=1 * use imperative backward as per discussion * remove Variable and also add models and test for models * Add a simple thread local variable to check for autograd grad mode * remove models and models test after debugging * address review comments * address more comments * address more comments
Diffstat (limited to 'torch/autograd')
-rw-r--r--torch/autograd/__init__.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py
index 67b782fcde..6d6f2f89e8 100644
--- a/torch/autograd/__init__.py
+++ b/torch/autograd/__init__.py
@@ -140,6 +140,24 @@ def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=Fal
inputs)
+# This function applies in case of gradient checkpointing for memory
+# optimization. Currently, for gradient checkpointing, we only support imperative
+# backwards call i.e. torch.autograd.backward() and the torch.autograd.grad() won't
+# work. The reason being that: torch.autograd.grad() only calculates the grads
+# for the inputs that are passed by user but it doesn't calculate grad for
+# anything else e.g. model parameters like weights, bias etc. However, for
+# torch.autograd.backward(), we would actually compute the grad for the weights as well.
+#
+# This function returns whether the checkpointing is valid i.e. torch.autograd.backward
+# or not i.e. torch.autograd.grad. The implementation works by maintaining a thread
+# local variable in torch/csrc/autograd/engine.cpp which looks at the FunctionTask
+# in the stack and before a FunctionTask is executed in evaluate_function, it
+# checks for whether reentrant backwards is imperative or not.
+# See https://github.com/pytorch/pytorch/pull/4594 for more discussion/context
+def _is_checkpoint_valid():
+ return Variable._execution_engine.is_checkpoint_valid()
+
+
def variable(*args, **kwargs):
warnings.warn("torch.autograd.variable(...) is deprecated, use torch.tensor(...) instead")
return torch.tensor(*args, **kwargs)