diff options
-rw-r--r-- | tools/autograd/gen_variable_type.py | 2 | ||||
-rw-r--r-- | tools/autograd/templates/VariableType.cpp | 8 | ||||
-rw-r--r-- | torch/csrc/autograd/function.cpp | 6 | ||||
-rw-r--r-- | torch/csrc/autograd/function.h | 148 | ||||
-rw-r--r-- | torch/csrc/autograd/functions/basic_ops.cpp | 8 | ||||
-rw-r--r-- | torch/csrc/autograd/functions/basic_ops.h | 9 | ||||
-rw-r--r-- | torch/csrc/autograd/functions/utils.cpp | 11 | ||||
-rw-r--r-- | torch/csrc/autograd/functions/utils.h | 4 | ||||
-rw-r--r-- | torch/csrc/autograd/python_function.cpp | 24 | ||||
-rw-r--r-- | torch/csrc/autograd/variable.cpp | 2 |
10 files changed, 89 insertions, 133 deletions
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 8f86898283..eddde451c2 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -93,7 +93,7 @@ if (compute_requires_grad( ${args_with_derivatives} )) { ASSIGN_GRAD_FN = CodeTemplate("""\ grad_fn = std::make_shared<${op}>(${op_ctor}); -grad_fn->next_functions = compute_next_functions( ${args_with_derivatives} ); +grad_fn->next_functions = get_next_functions( ${args_with_derivatives} ); """) CALL_VIA_TYPE = CodeTemplate("""\ diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index 69cbf26823..68fda72e53 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -283,12 +283,6 @@ static void check_no_requires_grad(const Tensor& tensor, const char* name) { } } -// NB: This should be called with Tensor/TensorList arguments (not Variables) -template <typename... Args> -static function_list compute_next_functions(Args&&... args) { - return Function::tensor_flags(std::forward<Args>(args)...).next_functions; -} - static void check_inplace(const Tensor& tensor) { auto& var = static_cast<const Variable&>(tensor); if (var.requires_grad() && var.is_leaf() && GradMode::is_enabled()) { @@ -387,7 +381,7 @@ Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_block requires_grad &= isFloatingPoint(self.type().scalarType()); if (requires_grad) { grad_fn = std::make_shared<CopyBackwards>(); - grad_fn->next_functions = compute_next_functions( self, src ); + grad_fn->next_functions = get_next_functions(self, src); grad_fn->num_inputs = 1; grad_fn->src_type = &src.type(); grad_fn->src_device = src.is_cuda() ? src.get_device() : -1; diff --git a/torch/csrc/autograd/function.cpp b/torch/csrc/autograd/function.cpp index 7a2b548df1..ef63478692 100644 --- a/torch/csrc/autograd/function.cpp +++ b/torch/csrc/autograd/function.cpp @@ -1,13 +1,15 @@ #include "Python.h" #include "function.h" -#include <string> - #include "variable.h" #include "torch/csrc/jit/ir.h" #include "torch/csrc/autograd/grad_mode.h" #include "torch/csrc/autograd/functions/special.h" +#include <string> +#include <cstdint> +#include <vector> + namespace torch { namespace autograd { thread_local uint64_t Function::function_counter = 0; diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index 11e6b32ade..b54216d3e2 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -6,18 +6,22 @@ // Subclasses may represent "forward" or "backward" operations (i.e functions // and their derivatives). Some functions may be used as both. +#include "torch/csrc/assertions.h" +#include "torch/csrc/autograd/function_hook.h" +#include "torch/csrc/autograd/grad_mode.h" +#include "torch/csrc/autograd/profiler.h" #include "torch/csrc/autograd/saved_variable.h" +#include "torch/csrc/jit/tracer.h" #include "torch/csrc/utils/auto_unique_ptr.h" #include "torch/csrc/utils/python_stub.h" #include "torch/csrc/utils/variadic.h" -#include "torch/csrc/autograd/function_hook.h" -#include "torch/csrc/autograd/profiler.h" -#include "torch/csrc/jit/tracer.h" -#include "torch/csrc/autograd/grad_mode.h" #include <ATen/ATen.h> +#include <algorithm> +#include <cstdint> #include <memory> +#include <utility> #include <vector> namespace torch { namespace autograd { @@ -39,79 +43,50 @@ struct edge_hasher { } }; -// TODO: separate is_executable and next_functions -// State used to create "backward" functions -struct FunctionFlags { - // Roughly speaking, is_executable corresponds to requires_grad. - // It's true if any input requires grad and gradient calculation is enabled. - // See http://pytorch.org/docs/notes/autograd.html for more details. - bool is_executable = false; - // What functions take the output of this function as input. - // There is one function per output of this function. - function_list next_functions; -}; - namespace detail { - -// Why can't we just combine the set_variable and set_tensor variants -// into one set of overloads? The problem is Variable is convertible -// to both Tensor and ArrayRef<Variable>, making the overload ambiguous. - -// Invariant: this function unconditionally calls f.next_functions.emplace_back -inline void set_function_flags(FunctionFlags& f, const Variable& var) { - if (!var.defined()) { - f.next_functions.emplace_back(); - return; - } - f.is_executable |= var.requires_grad(); - if (var.grad_fn()) { - f.next_functions.emplace_back(var.grad_fn(), var.output_nr()); - } else if (var.requires_grad()) { - f.next_functions.emplace_back(var.grad_accumulator(), 0); - } else { - f.next_functions.emplace_back(); +inline edge_type make_edge(const Variable &variable) { + if (variable.defined()) { + if (variable.grad_fn() != nullptr) { + return {variable.grad_fn(), variable.output_nr()}; + } else if (variable.requires_grad()) { + return {variable.grad_accumulator(), 0}; + } } + return {nullptr, 0}; } -struct SetFunctionFlags : IterArgs<SetFunctionFlags> { - FunctionFlags& out; - SetFunctionFlags(FunctionFlags& out) : out(out) {} - using IterArgs<SetFunctionFlags>::operator(); - void operator()(const Variable& v) { set_function_flags(out, v); } -}; - -struct SetTensorFunctionFlags : IterArgs<SetTensorFunctionFlags> { - FunctionFlags& out; - SetTensorFunctionFlags(FunctionFlags& out) : out(out) {} - using IterArgs<SetTensorFunctionFlags>::operator(); - void operator()(const Tensor& t) { - set_function_flags(out, static_cast<const Variable&>(t)); +struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> { + function_list next_functions; + using IterArgs<MakeNextFunctionList>::operator(); + void operator()(const Variable& variable) { + next_functions.push_back(make_edge(variable)); } }; +} // namespace detail +// Returns true if any of the variables in the list require a gradient. +inline bool any_variable_requires_grad(const variable_list& variables) { + return std::any_of( + variables.begin(), variables.end(), [](const Variable& variable) { + return variable.requires_grad(); + }); +} -} // namespace detail +template <typename... Variables> +function_list get_next_functions(Variables&&... variables) { + if (!GradMode::is_enabled()) return {}; + detail::MakeNextFunctionList make; + make.apply(std::forward<Variables>(variables)...); + return std::move(make.next_functions); +} struct Function : std::enable_shared_from_this<Function> { static thread_local uint64_t function_counter; - Function() - : num_inputs(0) - , time(function_counter++) - , next_functions() - , pre_hooks() - , post_hooks() - , pyobj(nullptr) - {} - - Function(FunctionFlags&& flags) - : num_inputs(0) - , time(function_counter++) - , next_functions(std::move(flags.next_functions)) - , pre_hooks() - , post_hooks() - , pyobj(nullptr) - {} + Function() : time(function_counter++) {} + Function(function_list&& next_functions_) : Function() { + next_functions = std::move(next_functions_); + } Function(const Function& other) = delete; Function(Function&& other) = delete; @@ -136,26 +111,6 @@ struct Function : std::enable_shared_from_this<Function> { return shared_from_this(); }; - // Computes is_executable and next_functions from an arbitrary argument list - // of variables and lists of variables (but whose static type is Tensor) - template<typename... Args> inline static FunctionFlags tensor_flags(Args&&... args) { - FunctionFlags f; - if (!GradMode::is_enabled()) return f; - f.next_functions.reserve(count_tensors(std::forward<Args>(args)...)); - detail::SetTensorFunctionFlags(f).apply(std::forward<Args>(args)...); - return f; // RVO - } - - // Computes is_executable and next_functions from an arbitrary argument list - // of variables and lists of variables - template<typename... Args> inline static FunctionFlags flags(Args&&... args) { - FunctionFlags f; - if (!GradMode::is_enabled()) return f; - f.next_functions.reserve(count_variables(std::forward<Args>(args)...)); - detail::SetFunctionFlags(f).apply(std::forward<Args>(args)...); - return f; // RVO - } - // Releases saved variables if the operation won't be reused virtual inline void releaseVariables() {} // called before a an apply if will release variables is going to be called @@ -165,26 +120,27 @@ struct Function : std::enable_shared_from_this<Function> { // Function name for debugging virtual std::string name(); - inline bool should_compute_output(int i) const { - return bool(next_functions[i].first); + bool should_compute_output(size_t index) const { + TORCH_ASSERTM(index < next_functions.size(), "Index out of range"); + return next_functions[index].first != nullptr; } - inline bool should_compute_any_outputs() const { + bool should_compute_any_outputs() const { for (size_t i = 0; i < next_functions.size(); ++i) { - if (should_compute_output((int)i)) { + if (should_compute_output(i)) { return true; } } return false; } - inline bool should_compute_output(std::initializer_list<int> idxs) const { - return std::any_of(idxs.begin(), idxs.end(), [this](int i) { + bool should_compute_output(std::initializer_list<size_t> idxs) const { + return std::any_of(idxs.begin(), idxs.end(), [this](size_t i) { return should_compute_output(i); }); } - inline bool should_compute_output(std::initializer_list<std::pair<size_t, size_t>> idxs) const { + bool should_compute_output(std::initializer_list<std::pair<size_t, size_t>> idxs) const { return std::any_of(idxs.begin(), idxs.end(), [this](std::pair<size_t, size_t> range) { for (size_t i = range.first; i < range.second; i++) { if (should_compute_output(i)) return true; @@ -193,8 +149,8 @@ struct Function : std::enable_shared_from_this<Function> { }); } - inline void set_flags(FunctionFlags&& flags) { - next_functions = std::move(flags.next_functions); + void set_next_functions(function_list&& next_functions) { + this->next_functions = std::move(next_functions); } // An op is traceable if all operations happening within apply() are performed @@ -222,13 +178,13 @@ struct Function : std::enable_shared_from_this<Function> { static void setUpContextEdge(jit::Node* this_node, const variable_list& inputs, const variable_list& outputs); - int num_inputs; + int num_inputs = 0; uint64_t time; function_list next_functions; std::vector<std::shared_ptr<FunctionPreHook>> pre_hooks; std::vector<std::shared_ptr<FunctionPostHook>> post_hooks; - PyObject *pyobj; // weak reference + PyObject* pyobj = nullptr; // weak reference auto_unique_ptr<jit::tracer::FunctionTracingState> tracing_state; }; diff --git a/torch/csrc/autograd/functions/basic_ops.cpp b/torch/csrc/autograd/functions/basic_ops.cpp index 34dcf21bd0..9fc2207dc5 100644 --- a/torch/csrc/autograd/functions/basic_ops.cpp +++ b/torch/csrc/autograd/functions/basic_ops.cpp @@ -1,9 +1,13 @@ #include "basic_ops.h" +#include "torch/csrc/autograd/function.h" #include "torch/csrc/autograd/variable.h" #include "torch/csrc/autograd/functions/utils.h" #include "torch/csrc/utils/auto_gpu.h" +#include <memory> +#include <utility> + namespace torch { namespace autograd { auto Error::apply(const variable_list& grad_outputs) -> variable_list { @@ -17,8 +21,8 @@ auto DelayedError::apply(const variable_list& inputs) -> variable_list { // FIXME: share version counters outputs.emplace_back(var.defined() ? var.data() : Tensor()); } - return wrap_outputs(inputs, std::move(outputs), [&](FunctionFlags f) { - return std::make_shared<Error>(msg, std::move(f)); + return wrap_outputs(inputs, std::move(outputs), [&](function_list&& next_functions) { + return std::make_shared<Error>(msg, std::move(next_functions)); }); }; diff --git a/torch/csrc/autograd/functions/basic_ops.h b/torch/csrc/autograd/functions/basic_ops.h index 4af808447b..2e494dbecc 100644 --- a/torch/csrc/autograd/functions/basic_ops.h +++ b/torch/csrc/autograd/functions/basic_ops.h @@ -11,8 +11,8 @@ namespace torch { namespace autograd { struct Error : public Function { - Error(std::string msg, FunctionFlags&& flags) - : Function(std::move(flags)) + Error(std::string msg, function_list&& next_functions) + : Function(std::move(next_functions)) , msg(std::move(msg)) {} Error(std::string msg) @@ -35,9 +35,8 @@ struct DelayedError : public Function { struct GraphRoot : public Function { GraphRoot(function_list functions, variable_list inputs) - : outputs(std::move(inputs)) { - next_functions = std::move(functions); - }; + : Function(std::move(functions)), outputs(std::move(inputs)) { + } virtual variable_list apply(const variable_list& inputs) { return outputs; diff --git a/torch/csrc/autograd/functions/utils.cpp b/torch/csrc/autograd/functions/utils.cpp index ecec973f8a..b225d48455 100644 --- a/torch/csrc/autograd/functions/utils.cpp +++ b/torch/csrc/autograd/functions/utils.cpp @@ -1,19 +1,17 @@ #include "torch/csrc/autograd/functions/utils.h" -#include "torch/csrc/utils/functional.h" -#include "torch/csrc/jit/tracer.h" - +#include "torch/csrc/autograd/function.h" #include "torch/csrc/autograd/variable.h" #include <sstream> +#include <vector> namespace torch { namespace autograd { variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs, function_constructor ctr) { - auto flags = Function::flags(inputs); variable_list result; result.reserve(outputs.size()); - if (!flags.is_executable) { + if (!any_variable_requires_grad(inputs)) { for (auto& output : outputs) { if (output.defined()) { result.emplace_back(make_variable(output, false)); @@ -22,7 +20,7 @@ variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs, } } } else { - auto grad_fn = ctr(std::move(flags)); + auto grad_fn = ctr(get_next_functions(inputs)); for (auto& output : outputs) { if (output.defined()) { result.emplace_back(make_variable(output, grad_fn)); @@ -53,5 +51,4 @@ void check_input_variables(const char* name, const variable_list& inputs, int ar } } } - }} diff --git a/torch/csrc/autograd/functions/utils.h b/torch/csrc/autograd/functions/utils.h index cd9067793a..c4db7aa5e6 100644 --- a/torch/csrc/autograd/functions/utils.h +++ b/torch/csrc/autograd/functions/utils.h @@ -3,14 +3,13 @@ #include <Python.h> #include <functional> #include <memory> -#include <array> #include "torch/csrc/autograd/function.h" #include "torch/csrc/autograd/variable.h" namespace torch { namespace autograd { -using function_constructor = std::function<std::shared_ptr<Function>(FunctionFlags)>; +using function_constructor = std::function<std::shared_ptr<Function>(function_list&&)>; /** * Wraps the tensor outputs in variables and creates the grad_fn and sets the @@ -24,5 +23,4 @@ variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs, * items are not NULL. If not specified, `required_args` defaults to `args`. */ void check_input_variables(const char* name, const variable_list& inputs, int args, int required_args=-1); - }} diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 65b792bcc0..fe19d6d6a8 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -94,9 +94,13 @@ auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list { // leads to unexpected error messages ("no nodes require computing gradients"), // but I don't have a better idea. These functions would raise an error // in backward anyway. - return wrap_outputs(inputs, std::move(tensor_results), [this](FunctionFlags &&f) { - return std::make_shared<Error>(name() + " is not differentiable twice", std::move(f)); - }); + return wrap_outputs( + inputs, + std::move(tensor_results), + [this](function_list&& next_functions) { + return std::make_shared<Error>( + name() + " is not differentiable twice", std::move(next_functions)); + }); } // NOTE: this function is written in a way that assumes it's only called for backward; @@ -566,7 +570,8 @@ struct UnpackedInput { }; struct InputFlags { - FunctionFlags flags; + bool is_executable = false; + function_list next_functions; THPObjectPtr needs_input_grad; std::vector<bool> is_variable_input; }; @@ -606,7 +611,8 @@ std::pair<UnpackedInput, InputFlags> unpack_input(PyObject *args) { PyTuple_SET_ITEM(unpacked.tensor_input.get(), i, new_arg); } - flags.flags = Function::flags(unpacked.input_vars); + flags.is_executable = any_variable_requires_grad(unpacked.input_vars); + flags.next_functions = get_next_functions(unpacked.input_vars); return std::make_pair(std::move(unpacked), std::move(flags)); } @@ -779,8 +785,8 @@ PyObject *THPFunction_do_forward(THPFunction *self, PyObject *_inputs) auto info_pair = unpack_input<true>(_inputs); auto& unpacked_input = info_pair.first; auto& input_info = info_pair.second; - bool is_executable = input_info.flags.is_executable; - self->cdata.set_flags(std::move(input_info.flags)); + bool is_executable = input_info.is_executable; + self->cdata.set_next_functions(std::move(input_info.next_functions)); self->needs_input_grad = input_info.needs_input_grad.release(); // Now we're ready to call a forward (implemented in Python) @@ -811,8 +817,8 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs) InputFlags& input_info = info_pair.second; // Initialize backward function (and ctx) - bool is_executable = input_info.flags.is_executable; - ctx->cdata.set_flags(std::move(input_info.flags)); + bool is_executable = input_info.is_executable; + ctx->cdata.set_next_functions(std::move(input_info.next_functions)); ctx->needs_input_grad = input_info.needs_input_grad.release(); ctx->is_variable_input = std::move(input_info.is_variable_input); diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index 657fbcfa66..605c1d567b 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -117,7 +117,7 @@ std::shared_ptr<Function>& VariableViewImpl::get_grad_fn() { fn->size = sizes(); fn->stride = strides(); fn->storage_offset = data.storage_offset(); - fn->set_flags(Function::flags(base)); + fn->set_next_functions(get_next_functions(base)); fn->num_inputs = 1; _grad_fn = std::move(fn); attr_version = current_version; |