summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--tools/autograd/gen_variable_type.py2
-rw-r--r--tools/autograd/templates/VariableType.cpp8
-rw-r--r--torch/csrc/autograd/function.cpp6
-rw-r--r--torch/csrc/autograd/function.h148
-rw-r--r--torch/csrc/autograd/functions/basic_ops.cpp8
-rw-r--r--torch/csrc/autograd/functions/basic_ops.h9
-rw-r--r--torch/csrc/autograd/functions/utils.cpp11
-rw-r--r--torch/csrc/autograd/functions/utils.h4
-rw-r--r--torch/csrc/autograd/python_function.cpp24
-rw-r--r--torch/csrc/autograd/variable.cpp2
10 files changed, 89 insertions, 133 deletions
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 8f86898283..eddde451c2 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -93,7 +93,7 @@ if (compute_requires_grad( ${args_with_derivatives} )) {
ASSIGN_GRAD_FN = CodeTemplate("""\
grad_fn = std::make_shared<${op}>(${op_ctor});
-grad_fn->next_functions = compute_next_functions( ${args_with_derivatives} );
+grad_fn->next_functions = get_next_functions( ${args_with_derivatives} );
""")
CALL_VIA_TYPE = CodeTemplate("""\
diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp
index 69cbf26823..68fda72e53 100644
--- a/tools/autograd/templates/VariableType.cpp
+++ b/tools/autograd/templates/VariableType.cpp
@@ -283,12 +283,6 @@ static void check_no_requires_grad(const Tensor& tensor, const char* name) {
}
}
-// NB: This should be called with Tensor/TensorList arguments (not Variables)
-template <typename... Args>
-static function_list compute_next_functions(Args&&... args) {
- return Function::tensor_flags(std::forward<Args>(args)...).next_functions;
-}
-
static void check_inplace(const Tensor& tensor) {
auto& var = static_cast<const Variable&>(tensor);
if (var.requires_grad() && var.is_leaf() && GradMode::is_enabled()) {
@@ -387,7 +381,7 @@ Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_block
requires_grad &= isFloatingPoint(self.type().scalarType());
if (requires_grad) {
grad_fn = std::make_shared<CopyBackwards>();
- grad_fn->next_functions = compute_next_functions( self, src );
+ grad_fn->next_functions = get_next_functions(self, src);
grad_fn->num_inputs = 1;
grad_fn->src_type = &src.type();
grad_fn->src_device = src.is_cuda() ? src.get_device() : -1;
diff --git a/torch/csrc/autograd/function.cpp b/torch/csrc/autograd/function.cpp
index 7a2b548df1..ef63478692 100644
--- a/torch/csrc/autograd/function.cpp
+++ b/torch/csrc/autograd/function.cpp
@@ -1,13 +1,15 @@
#include "Python.h"
#include "function.h"
-#include <string>
-
#include "variable.h"
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/autograd/grad_mode.h"
#include "torch/csrc/autograd/functions/special.h"
+#include <string>
+#include <cstdint>
+#include <vector>
+
namespace torch { namespace autograd {
thread_local uint64_t Function::function_counter = 0;
diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h
index 11e6b32ade..b54216d3e2 100644
--- a/torch/csrc/autograd/function.h
+++ b/torch/csrc/autograd/function.h
@@ -6,18 +6,22 @@
// Subclasses may represent "forward" or "backward" operations (i.e functions
// and their derivatives). Some functions may be used as both.
+#include "torch/csrc/assertions.h"
+#include "torch/csrc/autograd/function_hook.h"
+#include "torch/csrc/autograd/grad_mode.h"
+#include "torch/csrc/autograd/profiler.h"
#include "torch/csrc/autograd/saved_variable.h"
+#include "torch/csrc/jit/tracer.h"
#include "torch/csrc/utils/auto_unique_ptr.h"
#include "torch/csrc/utils/python_stub.h"
#include "torch/csrc/utils/variadic.h"
-#include "torch/csrc/autograd/function_hook.h"
-#include "torch/csrc/autograd/profiler.h"
-#include "torch/csrc/jit/tracer.h"
-#include "torch/csrc/autograd/grad_mode.h"
#include <ATen/ATen.h>
+#include <algorithm>
+#include <cstdint>
#include <memory>
+#include <utility>
#include <vector>
namespace torch { namespace autograd {
@@ -39,79 +43,50 @@ struct edge_hasher {
}
};
-// TODO: separate is_executable and next_functions
-// State used to create "backward" functions
-struct FunctionFlags {
- // Roughly speaking, is_executable corresponds to requires_grad.
- // It's true if any input requires grad and gradient calculation is enabled.
- // See http://pytorch.org/docs/notes/autograd.html for more details.
- bool is_executable = false;
- // What functions take the output of this function as input.
- // There is one function per output of this function.
- function_list next_functions;
-};
-
namespace detail {
-
-// Why can't we just combine the set_variable and set_tensor variants
-// into one set of overloads? The problem is Variable is convertible
-// to both Tensor and ArrayRef<Variable>, making the overload ambiguous.
-
-// Invariant: this function unconditionally calls f.next_functions.emplace_back
-inline void set_function_flags(FunctionFlags& f, const Variable& var) {
- if (!var.defined()) {
- f.next_functions.emplace_back();
- return;
- }
- f.is_executable |= var.requires_grad();
- if (var.grad_fn()) {
- f.next_functions.emplace_back(var.grad_fn(), var.output_nr());
- } else if (var.requires_grad()) {
- f.next_functions.emplace_back(var.grad_accumulator(), 0);
- } else {
- f.next_functions.emplace_back();
+inline edge_type make_edge(const Variable &variable) {
+ if (variable.defined()) {
+ if (variable.grad_fn() != nullptr) {
+ return {variable.grad_fn(), variable.output_nr()};
+ } else if (variable.requires_grad()) {
+ return {variable.grad_accumulator(), 0};
+ }
}
+ return {nullptr, 0};
}
-struct SetFunctionFlags : IterArgs<SetFunctionFlags> {
- FunctionFlags& out;
- SetFunctionFlags(FunctionFlags& out) : out(out) {}
- using IterArgs<SetFunctionFlags>::operator();
- void operator()(const Variable& v) { set_function_flags(out, v); }
-};
-
-struct SetTensorFunctionFlags : IterArgs<SetTensorFunctionFlags> {
- FunctionFlags& out;
- SetTensorFunctionFlags(FunctionFlags& out) : out(out) {}
- using IterArgs<SetTensorFunctionFlags>::operator();
- void operator()(const Tensor& t) {
- set_function_flags(out, static_cast<const Variable&>(t));
+struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> {
+ function_list next_functions;
+ using IterArgs<MakeNextFunctionList>::operator();
+ void operator()(const Variable& variable) {
+ next_functions.push_back(make_edge(variable));
}
};
+} // namespace detail
+// Returns true if any of the variables in the list require a gradient.
+inline bool any_variable_requires_grad(const variable_list& variables) {
+ return std::any_of(
+ variables.begin(), variables.end(), [](const Variable& variable) {
+ return variable.requires_grad();
+ });
+}
-} // namespace detail
+template <typename... Variables>
+function_list get_next_functions(Variables&&... variables) {
+ if (!GradMode::is_enabled()) return {};
+ detail::MakeNextFunctionList make;
+ make.apply(std::forward<Variables>(variables)...);
+ return std::move(make.next_functions);
+}
struct Function : std::enable_shared_from_this<Function> {
static thread_local uint64_t function_counter;
- Function()
- : num_inputs(0)
- , time(function_counter++)
- , next_functions()
- , pre_hooks()
- , post_hooks()
- , pyobj(nullptr)
- {}
-
- Function(FunctionFlags&& flags)
- : num_inputs(0)
- , time(function_counter++)
- , next_functions(std::move(flags.next_functions))
- , pre_hooks()
- , post_hooks()
- , pyobj(nullptr)
- {}
+ Function() : time(function_counter++) {}
+ Function(function_list&& next_functions_) : Function() {
+ next_functions = std::move(next_functions_);
+ }
Function(const Function& other) = delete;
Function(Function&& other) = delete;
@@ -136,26 +111,6 @@ struct Function : std::enable_shared_from_this<Function> {
return shared_from_this();
};
- // Computes is_executable and next_functions from an arbitrary argument list
- // of variables and lists of variables (but whose static type is Tensor)
- template<typename... Args> inline static FunctionFlags tensor_flags(Args&&... args) {
- FunctionFlags f;
- if (!GradMode::is_enabled()) return f;
- f.next_functions.reserve(count_tensors(std::forward<Args>(args)...));
- detail::SetTensorFunctionFlags(f).apply(std::forward<Args>(args)...);
- return f; // RVO
- }
-
- // Computes is_executable and next_functions from an arbitrary argument list
- // of variables and lists of variables
- template<typename... Args> inline static FunctionFlags flags(Args&&... args) {
- FunctionFlags f;
- if (!GradMode::is_enabled()) return f;
- f.next_functions.reserve(count_variables(std::forward<Args>(args)...));
- detail::SetFunctionFlags(f).apply(std::forward<Args>(args)...);
- return f; // RVO
- }
-
// Releases saved variables if the operation won't be reused
virtual inline void releaseVariables() {}
// called before a an apply if will release variables is going to be called
@@ -165,26 +120,27 @@ struct Function : std::enable_shared_from_this<Function> {
// Function name for debugging
virtual std::string name();
- inline bool should_compute_output(int i) const {
- return bool(next_functions[i].first);
+ bool should_compute_output(size_t index) const {
+ TORCH_ASSERTM(index < next_functions.size(), "Index out of range");
+ return next_functions[index].first != nullptr;
}
- inline bool should_compute_any_outputs() const {
+ bool should_compute_any_outputs() const {
for (size_t i = 0; i < next_functions.size(); ++i) {
- if (should_compute_output((int)i)) {
+ if (should_compute_output(i)) {
return true;
}
}
return false;
}
- inline bool should_compute_output(std::initializer_list<int> idxs) const {
- return std::any_of(idxs.begin(), idxs.end(), [this](int i) {
+ bool should_compute_output(std::initializer_list<size_t> idxs) const {
+ return std::any_of(idxs.begin(), idxs.end(), [this](size_t i) {
return should_compute_output(i);
});
}
- inline bool should_compute_output(std::initializer_list<std::pair<size_t, size_t>> idxs) const {
+ bool should_compute_output(std::initializer_list<std::pair<size_t, size_t>> idxs) const {
return std::any_of(idxs.begin(), idxs.end(), [this](std::pair<size_t, size_t> range) {
for (size_t i = range.first; i < range.second; i++) {
if (should_compute_output(i)) return true;
@@ -193,8 +149,8 @@ struct Function : std::enable_shared_from_this<Function> {
});
}
- inline void set_flags(FunctionFlags&& flags) {
- next_functions = std::move(flags.next_functions);
+ void set_next_functions(function_list&& next_functions) {
+ this->next_functions = std::move(next_functions);
}
// An op is traceable if all operations happening within apply() are performed
@@ -222,13 +178,13 @@ struct Function : std::enable_shared_from_this<Function> {
static void setUpContextEdge(jit::Node* this_node,
const variable_list& inputs, const variable_list& outputs);
- int num_inputs;
+ int num_inputs = 0;
uint64_t time;
function_list next_functions;
std::vector<std::shared_ptr<FunctionPreHook>> pre_hooks;
std::vector<std::shared_ptr<FunctionPostHook>> post_hooks;
- PyObject *pyobj; // weak reference
+ PyObject* pyobj = nullptr; // weak reference
auto_unique_ptr<jit::tracer::FunctionTracingState> tracing_state;
};
diff --git a/torch/csrc/autograd/functions/basic_ops.cpp b/torch/csrc/autograd/functions/basic_ops.cpp
index 34dcf21bd0..9fc2207dc5 100644
--- a/torch/csrc/autograd/functions/basic_ops.cpp
+++ b/torch/csrc/autograd/functions/basic_ops.cpp
@@ -1,9 +1,13 @@
#include "basic_ops.h"
+#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/functions/utils.h"
#include "torch/csrc/utils/auto_gpu.h"
+#include <memory>
+#include <utility>
+
namespace torch { namespace autograd {
auto Error::apply(const variable_list& grad_outputs) -> variable_list {
@@ -17,8 +21,8 @@ auto DelayedError::apply(const variable_list& inputs) -> variable_list {
// FIXME: share version counters
outputs.emplace_back(var.defined() ? var.data() : Tensor());
}
- return wrap_outputs(inputs, std::move(outputs), [&](FunctionFlags f) {
- return std::make_shared<Error>(msg, std::move(f));
+ return wrap_outputs(inputs, std::move(outputs), [&](function_list&& next_functions) {
+ return std::make_shared<Error>(msg, std::move(next_functions));
});
};
diff --git a/torch/csrc/autograd/functions/basic_ops.h b/torch/csrc/autograd/functions/basic_ops.h
index 4af808447b..2e494dbecc 100644
--- a/torch/csrc/autograd/functions/basic_ops.h
+++ b/torch/csrc/autograd/functions/basic_ops.h
@@ -11,8 +11,8 @@
namespace torch { namespace autograd {
struct Error : public Function {
- Error(std::string msg, FunctionFlags&& flags)
- : Function(std::move(flags))
+ Error(std::string msg, function_list&& next_functions)
+ : Function(std::move(next_functions))
, msg(std::move(msg)) {}
Error(std::string msg)
@@ -35,9 +35,8 @@ struct DelayedError : public Function {
struct GraphRoot : public Function {
GraphRoot(function_list functions, variable_list inputs)
- : outputs(std::move(inputs)) {
- next_functions = std::move(functions);
- };
+ : Function(std::move(functions)), outputs(std::move(inputs)) {
+ }
virtual variable_list apply(const variable_list& inputs) {
return outputs;
diff --git a/torch/csrc/autograd/functions/utils.cpp b/torch/csrc/autograd/functions/utils.cpp
index ecec973f8a..b225d48455 100644
--- a/torch/csrc/autograd/functions/utils.cpp
+++ b/torch/csrc/autograd/functions/utils.cpp
@@ -1,19 +1,17 @@
#include "torch/csrc/autograd/functions/utils.h"
-#include "torch/csrc/utils/functional.h"
-#include "torch/csrc/jit/tracer.h"
-
+#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
#include <sstream>
+#include <vector>
namespace torch { namespace autograd {
variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
function_constructor ctr) {
- auto flags = Function::flags(inputs);
variable_list result;
result.reserve(outputs.size());
- if (!flags.is_executable) {
+ if (!any_variable_requires_grad(inputs)) {
for (auto& output : outputs) {
if (output.defined()) {
result.emplace_back(make_variable(output, false));
@@ -22,7 +20,7 @@ variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
}
}
} else {
- auto grad_fn = ctr(std::move(flags));
+ auto grad_fn = ctr(get_next_functions(inputs));
for (auto& output : outputs) {
if (output.defined()) {
result.emplace_back(make_variable(output, grad_fn));
@@ -53,5 +51,4 @@ void check_input_variables(const char* name, const variable_list& inputs, int ar
}
}
}
-
}}
diff --git a/torch/csrc/autograd/functions/utils.h b/torch/csrc/autograd/functions/utils.h
index cd9067793a..c4db7aa5e6 100644
--- a/torch/csrc/autograd/functions/utils.h
+++ b/torch/csrc/autograd/functions/utils.h
@@ -3,14 +3,13 @@
#include <Python.h>
#include <functional>
#include <memory>
-#include <array>
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
namespace torch { namespace autograd {
-using function_constructor = std::function<std::shared_ptr<Function>(FunctionFlags)>;
+using function_constructor = std::function<std::shared_ptr<Function>(function_list&&)>;
/**
* Wraps the tensor outputs in variables and creates the grad_fn and sets the
@@ -24,5 +23,4 @@ variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
* items are not NULL. If not specified, `required_args` defaults to `args`.
*/
void check_input_variables(const char* name, const variable_list& inputs, int args, int required_args=-1);
-
}}
diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp
index 65b792bcc0..fe19d6d6a8 100644
--- a/torch/csrc/autograd/python_function.cpp
+++ b/torch/csrc/autograd/python_function.cpp
@@ -94,9 +94,13 @@ auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list {
// leads to unexpected error messages ("no nodes require computing gradients"),
// but I don't have a better idea. These functions would raise an error
// in backward anyway.
- return wrap_outputs(inputs, std::move(tensor_results), [this](FunctionFlags &&f) {
- return std::make_shared<Error>(name() + " is not differentiable twice", std::move(f));
- });
+ return wrap_outputs(
+ inputs,
+ std::move(tensor_results),
+ [this](function_list&& next_functions) {
+ return std::make_shared<Error>(
+ name() + " is not differentiable twice", std::move(next_functions));
+ });
}
// NOTE: this function is written in a way that assumes it's only called for backward;
@@ -566,7 +570,8 @@ struct UnpackedInput {
};
struct InputFlags {
- FunctionFlags flags;
+ bool is_executable = false;
+ function_list next_functions;
THPObjectPtr needs_input_grad;
std::vector<bool> is_variable_input;
};
@@ -606,7 +611,8 @@ std::pair<UnpackedInput, InputFlags> unpack_input(PyObject *args) {
PyTuple_SET_ITEM(unpacked.tensor_input.get(), i, new_arg);
}
- flags.flags = Function::flags(unpacked.input_vars);
+ flags.is_executable = any_variable_requires_grad(unpacked.input_vars);
+ flags.next_functions = get_next_functions(unpacked.input_vars);
return std::make_pair(std::move(unpacked), std::move(flags));
}
@@ -779,8 +785,8 @@ PyObject *THPFunction_do_forward(THPFunction *self, PyObject *_inputs)
auto info_pair = unpack_input<true>(_inputs);
auto& unpacked_input = info_pair.first;
auto& input_info = info_pair.second;
- bool is_executable = input_info.flags.is_executable;
- self->cdata.set_flags(std::move(input_info.flags));
+ bool is_executable = input_info.is_executable;
+ self->cdata.set_next_functions(std::move(input_info.next_functions));
self->needs_input_grad = input_info.needs_input_grad.release();
// Now we're ready to call a forward (implemented in Python)
@@ -811,8 +817,8 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
InputFlags& input_info = info_pair.second;
// Initialize backward function (and ctx)
- bool is_executable = input_info.flags.is_executable;
- ctx->cdata.set_flags(std::move(input_info.flags));
+ bool is_executable = input_info.is_executable;
+ ctx->cdata.set_next_functions(std::move(input_info.next_functions));
ctx->needs_input_grad = input_info.needs_input_grad.release();
ctx->is_variable_input = std::move(input_info.is_variable_input);
diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp
index 657fbcfa66..605c1d567b 100644
--- a/torch/csrc/autograd/variable.cpp
+++ b/torch/csrc/autograd/variable.cpp
@@ -117,7 +117,7 @@ std::shared_ptr<Function>& VariableViewImpl::get_grad_fn() {
fn->size = sizes();
fn->stride = strides();
fn->storage_offset = data.storage_offset();
- fn->set_flags(Function::flags(base));
+ fn->set_next_functions(get_next_functions(base));
fn->num_inputs = 1;
_grad_fn = std::move(fn);
attr_version = current_version;