summaryrefslogtreecommitdiff
path: root/docs
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2017-08-04 19:48:05 -0400
committerSoumith Chintala <soumith@gmail.com>2017-08-05 05:18:05 +0530
commit4599c0c7df92410e0a062e264330d62abd1a3e79 (patch)
tree85a5e11e3685dbe749b6e390d44265fa44905bdd /docs
parent8ce4401f09a516a006d9d4aebe685ab45a3d54d7 (diff)
downloadpytorch-4599c0c7df92410e0a062e264330d62abd1a3e79.tar.gz
pytorch-4599c0c7df92410e0a062e264330d62abd1a3e79.tar.bz2
pytorch-4599c0c7df92410e0a062e264330d62abd1a3e79.zip
Update autograd notes (#2295)
Diffstat (limited to 'docs')
-rw-r--r--docs/source/notes/extending.rst82
1 files changed, 47 insertions, 35 deletions
diff --git a/docs/source/notes/extending.rst b/docs/source/notes/extending.rst
index ad0ba402c2..e78794c187 100644
--- a/docs/source/notes/extending.rst
+++ b/docs/source/notes/extending.rst
@@ -13,31 +13,28 @@ Extending :mod:`torch.autograd`
Adding operations to :mod:`~torch.autograd` requires implementing a new
:class:`Function` subclass for each operation. Recall that :class:`Function` s
are what :mod:`~torch.autograd` uses to compute the results and gradients, and
-encode the operation history. Every new function requires you to implement 3
+encode the operation history. Every new function requires you to implement 2
methods:
-- ``__init__`` (*optional*) - if your operation is parametrized by/uses
- objects different than :class:`Variable` s, you should pass them as arguments
- to ``__init__``. For example, ``AddConstant`` function takes a scalar to add,
- while ``Transpose`` requires specifying which two dimensions to swap. If your
- function doesn't require any additional parameters, you can skip it.
- :meth:`~Function.forward` - the code that performs the operation. It can take
- as many arguments as you want, with some of them being
- optional, if you specify the default values. Keep in mind that only
- :class:`Variable` s will be passed in here. You can return either a single
- :class:`Variable` output, or a :class:`tuple` of :class:`Variable` s if there
- are multiple. Also, please refer to the docs of :class:`Function` to find
- descriptions of useful methods that can be called only from
- :meth:`~Function.forward`.
+ as many arguments as you want, with some of them being optional, if you
+ specify the default values. All kinds of Python objects are accepted here.
+ :class:`Variable` arguments will be converted to :class:`Tensor` s before the
+ call, and their use will be registered in the graph. Note that this logic won't
+ traverse lists/dicts/any other data structures and will only consider Variables
+ that are direct arguments to the call. You can return either a single
+ :class:`Tensor` output, or a :class:`tuple` of :class:`Tensor` s if there are
+ multiple outputs. Also, please refer to the docs of :class:`Function` to find
+ descriptions of useful methods that can be called only from :meth:`~Function.forward`.
- :meth:`~Function.backward` - gradient formula. It will be given
- as many arguments as there were outputs, with each of them representing
- gradient w.r.t. that output. It should return as many :class:`Tensor` s as
- there were inputs, with each of them containing the gradient w.r.t.
- corresponding input. If your inputs didn't require gradient (see
- :attr:`~Variable.needs_input_grad`), or it was non-differentiable, you
- can return :class:`None`. Also, if you have optional arguments to
- :meth:`~Variable.forward` you can return more gradients than there were
- inputs, as long as they're all :any:`python:None`.
+ as many :class:`Variable` arguments as there were outputs, with each of them
+ representing gradient w.r.t. that output. It should return as many
+ :class:`Variable` s as there were inputs, with each of them containing the
+ gradient w.r.t. its corresponding input. If your inputs didn't require
+ gradient (see :attr:`~Variable.needs_input_grad`), or were non-:class:`Variable`
+ objects, you can return :class:`python:None`. Also, if you have optional
+ arguments to :meth:`~Variable.forward` you can return more gradients than there
+ were inputs, as long as they're all :any:`python:None`.
Below you can find code for a ``Linear`` function from :mod:`torch.nn`, with
additional comments::
@@ -45,22 +42,25 @@ additional comments::
# Inherit from Function
class Linear(Function):
+ # Note that both forward and backward are @staticmethods
+ @staticmethod
# bias is an optional argument
- def forward(self, input, weight, bias=None):
- self.save_for_backward(input, weight, bias)
+ def forward(ctx, input, weight, bias=None):
+ ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
# This function has only a single output, so it gets only one gradient
- def backward(self, grad_output):
+ @staticmethod
+ def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
- input, weight, bias = self.saved_tensors
+ input, weight, bias = ctx.saved_variables
grad_input = grad_weight = grad_bias = None
# These needs_input_grad checks are optional and there only to
@@ -76,27 +76,39 @@ additional comments::
return grad_input, grad_weight, grad_bias
-Now, to make it easier to use these custom ops, we recommend wrapping them in
-small helper functions::
+Now, to make it easier to use these custom ops, we recommend aliasing their
+``apply`` method::
- def linear(input, weight, bias=None):
- # First braces create a Function object. Any arguments given here
- # will be passed to __init__. Second braces will invoke the __call__
- # operator, that will then use forward() to compute the result and
- # return it.
- return Linear()(input, weight, bias)
+ linear = Linear.aply
+
+Here, we give an additional example of a function that is parametrized by
+non-Variable arguments::
+
+ class MulConstant(Function):
+ @staticmethod
+ def forward(ctx, tensor, constant):
+ # ctx is a context object that can be used to stash information
+ for backward computation
+ ctx.constant = constant
+ return tensor * constant
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # We return as many input gradients as there were arguments.
+ # Gradients of non-Tensor arguments to forward must be None.
+ return grad_output * ctx.constant, None
You probably want to check if the backward method you implemented actually
computes the derivatives of your function. It is possible by comparing with
numerical approximations using small finite differences::
from torch.autograd import gradcheck
-
+
# gradchek takes a tuple of tensor as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (Variable(torch.randn(20,20).double(), requires_grad=True), Variable(torch.randn(30,20).double(), requires_grad=True),)
- test = gradcheck(Linear(), input, eps=1e-6, atol=1e-4)
+ test = gradcheck(Linear.apply, input, eps=1e-6, atol=1e-4)
print(test)
Extending :mod:`torch.nn`