diff options
author | Priya Goyal <priy2201@gmail.com> | 2018-04-10 15:26:24 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-04-10 15:26:24 -0400 |
commit | e3196e0ea8e3d0a9830bbd2a87beadd12a8c450c (patch) | |
tree | 215d14289a9da6b8409c4e5dee1de81f838030b2 /torch/autograd | |
parent | 04c215b4454aa5816a15f60567800d65a8351d33 (diff) | |
download | pytorch-e3196e0ea8e3d0a9830bbd2a87beadd12a8c450c.tar.gz pytorch-e3196e0ea8e3d0a9830bbd2a87beadd12a8c450c.tar.bz2 pytorch-e3196e0ea8e3d0a9830bbd2a87beadd12a8c450c.zip |
[Re-checkpointing] Autograd container for trading compute for memory (#6467)
* Autograd container for trading compute for memory
* add a unit test for checkpoint
* address comments
* address review comments
* adding some docs for the checkpoint api
* more comments
* more comments
* repro bug
* Fix a subtle bug/apply some review comments
* Update checkpoint.py
* Run everything in grad mode
* fix flake and chunk=1
* use imperative backward as per discussion
* remove Variable and also add models and test for models
* Add a simple thread local variable to check for autograd grad mode
* remove models and models test after debugging
* address review comments
* address more comments
* address more comments
Diffstat (limited to 'torch/autograd')
-rw-r--r-- | torch/autograd/__init__.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index 67b782fcde..6d6f2f89e8 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -140,6 +140,24 @@ def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=Fal inputs) +# This function applies in case of gradient checkpointing for memory +# optimization. Currently, for gradient checkpointing, we only support imperative +# backwards call i.e. torch.autograd.backward() and the torch.autograd.grad() won't +# work. The reason being that: torch.autograd.grad() only calculates the grads +# for the inputs that are passed by user but it doesn't calculate grad for +# anything else e.g. model parameters like weights, bias etc. However, for +# torch.autograd.backward(), we would actually compute the grad for the weights as well. +# +# This function returns whether the checkpointing is valid i.e. torch.autograd.backward +# or not i.e. torch.autograd.grad. The implementation works by maintaining a thread +# local variable in torch/csrc/autograd/engine.cpp which looks at the FunctionTask +# in the stack and before a FunctionTask is executed in evaluate_function, it +# checks for whether reentrant backwards is imperative or not. +# See https://github.com/pytorch/pytorch/pull/4594 for more discussion/context +def _is_checkpoint_valid(): + return Variable._execution_engine.is_checkpoint_valid() + + def variable(*args, **kwargs): warnings.warn("torch.autograd.variable(...) is deprecated, use torch.tensor(...) instead") return torch.tensor(*args, **kwargs) |