summaryrefslogtreecommitdiff
path: root/torch/csrc/jit/autodiff.h
blob: ea2b7a1170efebfe1d9d791f414c16cae279a3e4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#pragma once

#include "torch/csrc/WindowsTorchApiMacro.h"
#include "torch/csrc/jit/ir.h"

#include <ATen/ATen.h>

#include <vector>
#include <memory>

namespace torch { namespace jit {

using value_list = std::vector<Value*>;
// Example showcasing how Gradient is constructed:
//
// Let's assume we have a function f, `m` and `n` do not require grad
// (`n` can depend only on `m`):
//   y, n = f(x, m)
//
// Now, let's assume that the reverse of f (called f') needs to use values of `x`, `t` and `y`.
// `t` is an intermediate value produced in the body of f, and let's assume that it requires
// grad too.
//
// In this case differentiate(f) will return this:
//   y, n, t = f(x, m)        // `t` is appended to the output list
//   dx = f'(dy, dt, x, t, y) // No `dm` or `dn` because they do not require gradient
//                            // All needed values from f are prepended to the input list
//
//   f_real_outputs = 2       // Only first two outputs were present in f originally
//   df_input_vjps = {0, 2}   // i.e. connect grad_fn of y and t variables produced by f,
//                    y  t    // with y's output_nr = 0 and t's output_nr = 1
//   df_input_captures = {I0, O2, O0} // Order matches the prefix of inputs to df
//                        x   t   y
//   df_output_vjps = {0}     // i.e. connect next_edge[0] of grad_fn to x's (grad_fn, output_nr).
//
// Terminology: vjp = vector-jacobian product

struct Gradient {
  explicit operator bool() const {
    return df != nullptr;
  }
  std::shared_ptr<Graph> f;
  std::shared_ptr<Graph> df;

  // Describes how to construct outputs of f from what its graph will return.
  // This is necessary because some trailing outputs are intermediates produced
  // only to be saved for df (and should be ignored).
  size_t f_real_outputs;

  // df inputs are split into two sections: vjps (aka grad_outputs) and captures.
  // VJPs are "seeds" for the gradient computation given for each input capture
  // of an Output kind.
  // Captures are values the need to be saved when f is run. We handle inputs
  // specially, because this allows us to avoid adding extra vjps as df inputs.

  std::vector<size_t> df_input_vjps; // Offsets into f's outputs.
  // capture can come from inputs or outputs
  std::vector<size_t> df_input_captured_inputs; // Offsets into f's inputs
  std::vector<size_t> df_input_captured_outputs; // Offsets into f's outputs


  // df will produce vjps for a subset of inputs of f that required grad.
  // df_output_vjps[idx] == inp_idx means that idx-th output of df produces a vjp
  // for inp_idx-th input of f.
  std::vector<size_t> df_output_vjps; // Offsets into f's inputs.

  // How to use gradient to implement a differentiable autograd function:
  // When running f:
  //   - Unwrap input Variables
  //   - Run f's graph
  //   - Create grad_fn
  //   - Wrap outputs in Variables (assume we have a tensor_outputs array):
  //       outputs = map(Variable, tensor_output)
  //       for i, offset in enumerate(df_input_vjps):
  //         outputs[offset].set_grad_fn(grad_fn, output_nr=i)
  //   - Use df_output_vjps to connect next_edges of grad_fn:
  //       for idx in df_output_vjps:
  //         grad_fn.add_next_edge(inputs[idx].gradient_edge())
  //   - Save captures for df (care needs to be taken to use SavedVariables for inputs and
  //                           outputs that we will actually return)
  //   - Return outputs[:f_real_outputs]
  //
  // When running df:
  //   - Concatenate received vjps and captured Variables
  //   - Interpret df
  //   - Wrap outputs of df into Variables (that don't require grad)
};
// XXX: When calling this function, graph should have complete type information.
// Use the shape analysis pass to fill in the gaps if it doesn't.
TORCH_API Gradient differentiate(std::shared_ptr<Graph>& graph, const std::vector<bool>& requires_grad);

// can we take a derivative of this node symbolically?
TORCH_API bool isDifferentiable(Node * n);
TORCH_API bool isDifferentiable(Graph & g);
TORCH_API bool isZero(Value * v);

}}