diff options
23 files changed, 208 insertions, 66 deletions
diff --git a/aten/src/ATen/ATen.h b/aten/src/ATen/ATen.h index e41c2d9565..84568880fc 100644 --- a/aten/src/ATen/ATen.h +++ b/aten/src/ATen/ATen.h @@ -15,3 +15,4 @@ #include "ATen/TensorOperators.h" #include "ATen/TensorMethods.h" #include "ATen/Dispatch.h" +#include "ATen/DimVector.h" diff --git a/aten/src/ATen/DimVector.h b/aten/src/ATen/DimVector.h new file mode 100644 index 0000000000..aaa4dc9c07 --- /dev/null +++ b/aten/src/ATen/DimVector.h @@ -0,0 +1,11 @@ +#pragma once + +#include "SmallVector.h" +#include <stdint.h> + +namespace at { + +/// A container for sizes or strides +using DimVector = SmallVector<int64_t, 5>; + +} diff --git a/aten/src/ATen/SmallVector.h b/aten/src/ATen/SmallVector.h index 521e46082b..3a5926a06d 100644 --- a/aten/src/ATen/SmallVector.h +++ b/aten/src/ATen/SmallVector.h @@ -921,6 +921,12 @@ public: SmallVectorImpl<T>::operator=(::std::move(RHS)); } + template<typename Container> + const SmallVector &operator=(const Container &RHS) { + this->assign(RHS.begin(), RHS.end()); + return *this; + } + SmallVector(SmallVectorImpl<T> &&RHS) : SmallVectorImpl<T>(N) { if (!RHS.empty()) SmallVectorImpl<T>::operator=(::std::move(RHS)); diff --git a/test/cpp/api/misc.cpp b/test/cpp/api/misc.cpp index a494e3389e..0f4fa33ddf 100644 --- a/test/cpp/api/misc.cpp +++ b/test/cpp/api/misc.cpp @@ -71,7 +71,7 @@ TEST_CASE("autograd") { } SECTION("custom gradient inputs") { z.sum().backward( - autograd::make_variable(at::ones(at::CPU(at::kFloat), {1}) * 2)); + autograd::make_variable(at::ones(at::CPU(at::kFloat), {}) * 2)); REQUIRE(x.grad().allclose(y * 2)); } // Assume everything else is safe from PyTorch tests. diff --git a/test/test_autograd.py b/test/test_autograd.py index 0c19e24208..4c26e58aa6 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -75,7 +75,7 @@ class TestAutograd(TestCase): x = torch.randn(5, 5, requires_grad=True) y = torch.randn(5, 5, requires_grad=True) result = cls.apply(x, 2, y) - go = torch.ones(1, requires_grad=True) + go = torch.ones((), requires_grad=True) result.sum().backward(go, create_graph=True) self.assertEqual(x.grad.data, y.data + torch.ones(5, 5)) @@ -173,6 +173,23 @@ class TestAutograd(TestCase): MyFunction()(y).sum().backward() self.assertEqual(v.grad.data, torch.zeros(shape)) + def test_invalid_gradients(self): + class MyFunction(Function): + @staticmethod + def forward(ctx, x): + return x * 2 + + @staticmethod + def backward(ctx, grad_output): + return torch.randn(10, dtype=torch.float) + + with self.assertRaisesRegex(RuntimeError, 'expected shape'): + input = torch.randn(5, 5, dtype=torch.float, requires_grad=True) + MyFunction.apply(input).sum().backward() + with self.assertRaisesRegex(RuntimeError, 'expected type'): + input = torch.randn(10, dtype=torch.double, requires_grad=True) + MyFunction.apply(input).sum().backward() + def test_accumulate_grad(self): grad_output = torch.ones(5, 5) @@ -495,7 +512,6 @@ class TestAutograd(TestCase): def test_sparse_backward(self): class FixedGradientFunction(Function): - def __init__(self, grad): self.grad = grad @@ -524,15 +540,15 @@ class TestAutograd(TestCase): dense_fn = FixedGradientFunction(dense_grad) # sparse first - x = torch.randn(5, 5, requires_grad=True) + x = torch.randn(size, requires_grad=True) (sparse_fn1(x) + dense_fn(x) + sparse_fn2(x)).sum().backward() self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2) # dense first - x = torch.randn(5, 5, requires_grad=True) + x = torch.randn(size, requires_grad=True) (dense_fn(x) + sparse_fn1(x) + sparse_fn2(x)).sum().backward() self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2) # sparse only - x = torch.randn(5, 5, requires_grad=True) + x = torch.randn(size, requires_grad=True) (sparse_fn1(x) + sparse_fn2(x)).sum().backward() self.assertEqual(x.grad, sparse_grad1 + sparse_grad2) diff --git a/test/test_jit.py b/test/test_jit.py index d56f5ad5f8..4b6ee940b5 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1152,7 +1152,7 @@ class TestScript(JitTestCase): out = func(x, y) self.assertEqual(func(x, y), x + y) - grad = torch.randn(2, 3) + grad = torch.randn(2, 3, dtype=torch.float) out.backward(grad) self.assertEqual(x.grad, grad) self.assertEqual(y.grad, grad.sum(dim=0)) diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index a343050eb3..94ed1f8578 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -18,7 +18,7 @@ FUNCTION_DECLARATION = CodeTemplate("""\ struct ${op} : public ${superclass} { using ${superclass}::${superclass}; variable_list apply(const variable_list& grads) override; - std::string name() override { return "${op}"; } + std::string name() const override { return "${op}"; } void release_variables() override { ${release_variables} } diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index 38fa9c2ed6..3dfb63313e 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -359,35 +359,35 @@ static void throw_error_out_requires_grad(const char* name) { static void rebase_history(Variable& var, std::shared_ptr<Function> grad_fn) { if (grad_fn && var.defined()) { - grad_fn->set_num_inputs(1); + grad_fn->add_input_metadata(var.type(), var.sizes()); var.rebase_history({std::move(grad_fn), 0}); } } static void rebase_history(ArrayRef<Variable> vars, std::shared_ptr<Function> grad_fn) { if (grad_fn) { - grad_fn->set_num_inputs(vars.size()); - uint32_t output_nr = 0; for (auto& var : vars) { if (var.defined()) { // TODO: eliminate const_cast + auto output_nr = grad_fn->add_input_metadata(var.type(), var.sizes()); const_cast<Variable&>(var).rebase_history({grad_fn, output_nr}); + } else { + grad_fn->add_input_metadata(Function::undefined_input()); } - output_nr++; } } } static void set_history(ArrayRef<Variable> vars, std::shared_ptr<Function> grad_fn) { if (grad_fn) { - grad_fn->set_num_inputs(vars.size()); - uint32_t output_nr = 0; for (auto& var : vars) { if (var.defined()) { // TODO: eliminate const_cast + auto output_nr = grad_fn->add_input_metadata(var.type(), var.sizes()); const_cast<Variable&>(var).set_gradient_edge({grad_fn, output_nr}); + } else { + grad_fn->add_input_metadata(Function::undefined_input()); } - output_nr++; } } } @@ -428,7 +428,6 @@ Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_block if (requires_grad) { grad_fn = std::make_shared<CopyBackwards>(); grad_fn->set_next_edges(collect_next_edges(self, src)); - grad_fn->set_num_inputs(1); grad_fn->src_type = &src.type(); grad_fn->src_device = src.is_cuda() ? src.get_device() : -1; } diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 8f4d72398a..88aec1573e 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -288,6 +288,49 @@ static variable_list call_post_hooks(Function& fn, variable_list outputs, variab return outputs; } +static bool is_compatible_type(const at::Type& expected, const at::Type& actual) { + // Types are compatible if they exactly match or if the gradient is a sparse + // version of the expected type. + return expected == actual || (actual.is_sparse() && + expected == actual.toBackend(toDense(actual.backend()))); +} + +template<typename F> +static void validate_outputs(const edge_list& edges, const variable_list& grads, const F& format_error) { + if (grads.size() != edges.size()) { + std::stringstream ss; + ss << "invalid number of gradients - expected "; + ss << edges.size() << ", but got " << grads.size(); + throw std::runtime_error(format_error(ss.str())); + } + for (size_t i = 0; i < grads.size(); i++) { + const auto& edge = edges[i]; + if (!edge.is_valid()) continue; + + const auto& metadata = edge.function->input_metadata(edge.input_nr); + const auto& output = grads[i]; + if (!output.defined()) { + // FIXME: TestJit.test_ge_optimized fails this assertion. + // std::stringstream ss; + // ss << "undefined gradient at index " << i; + // throw std::runtime_error(format_error(ss.str())); + continue; + } + if (!grads[i].sizes().equals(metadata.shape())) { + std::stringstream ss; + ss << "invalid gradient at index " << i << " - expected shape "; + ss << metadata.shape() << " but got " << grads[i].sizes(); + throw std::runtime_error(format_error(ss.str())); + } + if (!is_compatible_type(metadata.type(), grads[i].type())) { + std::stringstream ss; + ss << "invalid gradient at index " << i << " - expected type "; + ss << metadata.type() << " but got " << grads[i].type(); + throw std::runtime_error(format_error(ss.str())); + } + } +} + static variable_list call_function(FunctionTask& task) { bool prev_checkpoint_valid_state = checkpoint_valid; checkpoint_valid = task.base->can_checkpoint() && prev_checkpoint_valid_state; @@ -298,6 +341,11 @@ static variable_list call_function(FunctionTask& task) { fn.will_release_variables(); } auto outputs = fn(inputs); + validate_outputs(fn.next_edges(), outputs, [&](const std::string& msg) { + std::ostringstream ss; + ss << "Function " << fn.name() << " returned an " << msg; + return ss.str(); + }); checkpoint_valid = prev_checkpoint_valid_state; return call_post_hooks(fn, std::move(outputs), std::move(inputs)); } @@ -323,13 +371,6 @@ auto Engine::evaluate_function(FunctionTask& task) -> void { fn.release_variables(); } - if (outputs.size() != fn.num_outputs()) { - std::stringstream ss; - ss << "Function '" << fn.name() << "' returned an invalid number of outputs - expected "; - ss << fn.num_outputs() << ", but got " << outputs.size(); - throw std::runtime_error(ss.str()); - } - int num_outputs = outputs.size(); if (num_outputs == 0) return; // Don't even acquire the mutex std::lock_guard<std::mutex> lock(task.base->mutex); @@ -426,6 +467,11 @@ auto Engine::execute(const edge_list& input_roots, bool create_graph, const edge_list& outputs) -> variable_list { std::call_once(start_threads_flag, &Engine::start_threads, this); + + validate_outputs(input_roots, inputs, [](const std::string& msg) { + return msg; + }); + // Callbacks are only valid for the duration of this run and should always be cleared ClearCallbacks _cb_guard(final_callbacks, post_callbacks_lock); diff --git a/torch/csrc/autograd/function.cpp b/torch/csrc/autograd/function.cpp index 47b4e0620e..d116069149 100644 --- a/torch/csrc/autograd/function.cpp +++ b/torch/csrc/autograd/function.cpp @@ -19,8 +19,8 @@ namespace torch { namespace autograd { thread_local uint64_t Function::next_sequence_nr_ = 0; -auto Function::name() -> std::string { - return std::string(typeid(*this).name()); +auto Function::name() const -> std::string { + return at::demangle(typeid(*this).name()); } // This function is analogous to make_trace which operates on PythonOp, but this diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index 77e9b21846..d76296a10a 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -5,6 +5,7 @@ #include "torch/csrc/autograd/grad_mode.h" #include "torch/csrc/autograd/profiler.h" #include "torch/csrc/autograd/saved_variable.h" +#include "torch/csrc/autograd/type_and_shape.h" #include "torch/csrc/autograd/variable.h" #include "torch/csrc/jit/tracer.h" #include "torch/csrc/utils/auto_unique_ptr.h" @@ -91,18 +92,15 @@ struct Function : std::enable_shared_from_this<Function> { /// in the backward() pass, with higher sequence numbers prioritized /// before lower sequence numbers. explicit Function( - uint32_t num_inputs, uint64_t sequence_nr, edge_list&& next_edges = edge_list()) : sequence_nr_(sequence_nr), - num_inputs_(num_inputs), next_edges_(std::move(next_edges)) {} explicit Function( - uint32_t num_inputs = 0, edge_list&& next_edges = edge_list()) - : Function(num_inputs, next_sequence_nr_++, std::move(next_edges)) {} - + : Function(next_sequence_nr_++, std::move(next_edges)) {} + /// Functions are neither copyable nor moveable. Function(const Function& other) = delete; Function(Function&& other) = delete; @@ -123,20 +121,37 @@ struct Function : std::enable_shared_from_this<Function> { // Graph Connectivity API //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // Inputs + // Inputs. NOTE: inputs of the grad_fn correspond to Tensor outputs of the + // forward function. + + // Marker for expected undefined input + struct undefined_input {}; - /// Increments the number of inputs of the function and returns the previous - /// value. - uint32_t bump_inputs() noexcept { - return num_inputs_++; + /// Adds the type and shape metadata for a new input. Returns the index of + /// of the new input. + uint32_t add_input_metadata(const at::Type& type, at::IntList shape) noexcept { + uint32_t input_nr = input_metadata_.size(); + input_metadata_.emplace_back(type, shape); + return input_nr; } - void set_num_inputs(uint32_t num_inputs) noexcept { - num_inputs_ = num_inputs; + /// Adds a placeholder for an input that will not be used. + uint32_t add_input_metadata(undefined_input u) noexcept { + uint32_t input_nr = input_metadata_.size(); + input_metadata_.emplace_back(); + return input_nr; } uint32_t num_inputs() const noexcept { - return num_inputs_; + return input_metadata_.size(); + } + + const TypeAndShape& input_metadata(size_t index) const { + return input_metadata_[index]; + } + + void clear_input_metadata() { + input_metadata_.clear(); } // Outputs ("Next Edges") @@ -185,7 +200,7 @@ struct Function : std::enable_shared_from_this<Function> { } /// Returns the name of the dynamic type of the function, for debugging. - virtual std::string name(); + virtual std::string name() const; /// Returns true if the particular output edge is active, and that particular /// output of this function should be computed. @@ -312,12 +327,12 @@ struct Function : std::enable_shared_from_this<Function> { // fields. const uint64_t sequence_nr_; - uint32_t num_inputs_; edge_list next_edges_; PyObject* pyobj_ = nullptr; // weak reference std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_; std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_; auto_unique_ptr<jit::tracer::FunctionTracingState> tracing_state_; + at::SmallVector<TypeAndShape, 2> input_metadata_; }; /// See Function::is_traceable() for definition. @@ -355,13 +370,14 @@ struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> { /// `input_nr` thus equal to `function->num_inputs()`. Additionally, it /// increments the `Function`'s number of inputs by one. Approximately /// equivalent to `variable.set_gradient_edge(function, -/// function->bump_inputs())`. If you don't want the `Function`'s `num_inputs` -/// to be incremented, use `set_gradient_edge` directly. +/// function->add_input_metadata(variable.type(), variable.sizes()))`. +/// If you don't want the `Function`'s `num_inputs` to be incremented, use +/// `set_gradient_edge` directly. inline void create_gradient_edge( Variable& variable, std::shared_ptr<Function> function) { // Copy before move. - const auto input_nr = function->bump_inputs(); + const auto input_nr = function->add_input_metadata(variable.type(), variable.sizes()); variable.set_gradient_edge({std::move(function), input_nr}); } diff --git a/torch/csrc/autograd/functions/accumulate_grad.cpp b/torch/csrc/autograd/functions/accumulate_grad.cpp index 31fc2434ea..fdbe54d3fc 100644 --- a/torch/csrc/autograd/functions/accumulate_grad.cpp +++ b/torch/csrc/autograd/functions/accumulate_grad.cpp @@ -13,12 +13,13 @@ using at::Tensor; namespace torch { namespace autograd { -// AccumulateGrad sets sequence_nr to the max value so it's always called +// AccumulateGrad sets sequence_nr to the max value so it's always called // ASAP during backwards. AccumulateGrad::AccumulateGrad(Variable variable_) - : Function(/*num_inputs=*/1 - , /*sequence_nr=*/UINT64_MAX) - , variable(std::move(variable_)) {} + : Function(/*sequence_nr=*/UINT64_MAX) + , variable(std::move(variable_)) { + add_input_metadata(variable.type(), variable.sizes()); +} auto AccumulateGrad::apply(const variable_list& grads) -> variable_list { // XXX: this method is not thread-safe! diff --git a/torch/csrc/autograd/functions/basic_ops.h b/torch/csrc/autograd/functions/basic_ops.h index a630c3c79e..7ac4d2a3fa 100644 --- a/torch/csrc/autograd/functions/basic_ops.h +++ b/torch/csrc/autograd/functions/basic_ops.h @@ -12,7 +12,7 @@ namespace torch { namespace autograd { struct Error : public Function { Error(std::string msg, edge_list&& next_edges) - : Function(/*num_inputs=*/0, std::move(next_edges)) + : Function(std::move(next_edges)) , msg(std::move(msg)) {} Error(std::string msg) @@ -35,7 +35,7 @@ struct DelayedError : public Function { struct GraphRoot : public Function { GraphRoot(edge_list functions, variable_list inputs) - : Function(/*num_inputs=*/0, std::move(functions)), + : Function(std::move(functions)), outputs(std::move(inputs)) {} virtual variable_list apply(const variable_list& inputs) { diff --git a/torch/csrc/autograd/functions/special.cpp b/torch/csrc/autograd/functions/special.cpp index a5c5a6228e..88ac969122 100644 --- a/torch/csrc/autograd/functions/special.cpp +++ b/torch/csrc/autograd/functions/special.cpp @@ -17,7 +17,9 @@ namespace torch { namespace autograd { // Used when an output has multiple uses (there's only one entry // in next_edges per output). struct Replicate : public Function { - Replicate() : Function(/*num_inputs=*/1) {} + Replicate(const at::Type& type, at::IntList shape) : Function() { + add_input_metadata(type, shape); + } virtual variable_list apply(const variable_list& inputs) { TORCH_ASSERT(inputs.size() == 1); @@ -236,6 +238,7 @@ bool Eval::replaceSubgraph(const variable_list& inputs, const variable_list& _ou // This detaches the subgraph from the full backward graph. for (auto& begin : subgraph.boundary.begins) { const auto& edge = begin.function->next_edge(begin.input_nr); + begin.function->set_next_edge( begin.input_nr, Edge(ends_to_outputs.at(edge), 0)); } @@ -265,7 +268,7 @@ bool Eval::replaceSubgraph(const variable_list& inputs, const variable_list& _ou // the same Variable has been returned multiple times, and // is repeated in this list. if (output.grad_fn_unsafe() == this) { - auto replicate = std::make_shared<Replicate>(); + auto replicate = std::make_shared<Replicate>(output.type(), output.sizes()); replicate->add_next_edge({this_shared, output.output_nr()}); output.set_gradient_edge({std::move(replicate), 0}); repeated_outputs.emplace(&output); @@ -274,7 +277,8 @@ bool Eval::replaceSubgraph(const variable_list& inputs, const variable_list& _ou // perform any allocations until we actually see repeated outputs. if (repeated_outputs.count(&output) > 0) { auto & replicate = output.grad_fn(); - replicate->add_next_edge({this_shared, num_inputs_++}); + auto input_nr = add_input_metadata(output.type(), output.sizes()); + replicate->add_next_edge({this_shared, input_nr}); } else { autograd::create_gradient_edge(output, this_shared); } diff --git a/torch/csrc/autograd/functions/special.h b/torch/csrc/autograd/functions/special.h index 076a4faa98..273b139e23 100644 --- a/torch/csrc/autograd/functions/special.h +++ b/torch/csrc/autograd/functions/special.h @@ -15,7 +15,9 @@ namespace torch { namespace autograd { struct EvalOutput : Function { explicit EvalOutput(const Edge& next_edge_) - : Function(/*num_inputs=*/1), next_edge(next_edge_) {} + : Function(), next_edge(next_edge_) { + add_input_metadata(undefined_input()); + } virtual variable_list apply(const variable_list& inputs) override { throw std::logic_error("EvalOutput::apply() called"); diff --git a/torch/csrc/autograd/functions/tensor.cpp b/torch/csrc/autograd/functions/tensor.cpp index 75ec8180c9..5df72d5263 100644 --- a/torch/csrc/autograd/functions/tensor.cpp +++ b/torch/csrc/autograd/functions/tensor.cpp @@ -36,12 +36,13 @@ CopySlices::CopySlices( const Variable& base_var, at::TensorGeometry view_, std::shared_ptr<Function> fn_) - : Function(/*num_inputs=*/1), + : Function(), base(base_var), view(std::move(view_)), fn(std::move(fn_)) { // Take the next_edges of fn as our own, except for index 0 which goes // to base instead of the view. + add_input_metadata(base_var.type(), base_var.sizes()); const auto num_outputs = fn->num_outputs(); next_edges_.reserve(num_outputs); add_next_edge(base_var.gradient_edge()); diff --git a/torch/csrc/autograd/functions/utils.cpp b/torch/csrc/autograd/functions/utils.cpp index 09939a109a..485572d9ee 100644 --- a/torch/csrc/autograd/functions/utils.cpp +++ b/torch/csrc/autograd/functions/utils.cpp @@ -29,7 +29,7 @@ variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs, autograd::create_gradient_edge(variable, grad_fn); result.push_back(std::move(variable)); } else { - grad_fn->bump_inputs(); + grad_fn->add_input_metadata(Function::undefined_input()); result.emplace_back(); } } diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp index e07d88feee..e52e06a341 100644 --- a/torch/csrc/autograd/python_engine.cpp +++ b/torch/csrc/autograd/python_engine.cpp @@ -134,7 +134,7 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar } } - edge_list output_edges; + std::vector<Edge> output_edges; if (inputs != nullptr) { int num_inputs = PyTuple_GET_SIZE(inputs); output_edges.reserve(num_inputs); diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 3fb53c9a6d..4672d535ef 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -207,7 +207,7 @@ auto PyFunction::release_variables() -> void { f->has_freed_buffers = 1; } -auto PyFunction::name() -> std::string { +auto PyFunction::name() const -> std::string { AutoGIL gil; auto f = (THPFunction*) obj; auto name = std::string(Py_TYPE(f)->tp_name); @@ -245,7 +245,7 @@ static int THPFunction_traverse(THPFunction *self, visitproc visit, void *arg) static int THPFunction_clear(THPFunction *self) { - self->cdata.set_num_inputs(0); + self->cdata.clear_input_metadata(); Py_CLEAR(self->needs_input_grad); @@ -293,7 +293,6 @@ PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) new (&self->input_info) std::vector<VariableInfo>(); new (&self->saved_variables) std::vector<SavedVariable>(); new (&self->is_variable_input) std::vector<bool>(); - self->cdata.set_num_inputs(0); return obj; } @@ -425,6 +424,10 @@ static void _wrap_outputs(THPFunction *self, // Note that output Variables may be repeated. In that case, the last call // to set_history wins. auto var = as_variable(obj, i); + if (cdata) { + auto output_nr = cdata->add_input_metadata(var.type(), var.sizes()); + TORCH_ASSERT(i == (int)output_nr); + } set_history(var, i, is_input, is_modified, is_differentiable); if (is_executable) { @@ -616,7 +619,7 @@ PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const Unpacked THPObjectPtr outputs(PyTuple_New(num_outputs)); if (!outputs) throw python_error(); - grad_fn->cdata.set_num_inputs(num_outputs); + grad_fn->cdata.clear_input_metadata(); // Record type, device, and size information about inputs if (is_executable) { diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index 562ab79a28..529bcaf454 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -35,7 +35,7 @@ struct PyFunction : public Function { variable_list legacy_apply(const variable_list& inputs); virtual void release_variables() override; - virtual std::string name() override; + virtual std::string name() const override; virtual std::shared_ptr<Function> get_shared_ptr() override; virtual bool is_traceable() override; diff --git a/torch/csrc/autograd/python_legacy_variable.cpp b/torch/csrc/autograd/python_legacy_variable.cpp index 179c54a1e3..33dd281ab0 100644 --- a/torch/csrc/autograd/python_legacy_variable.cpp +++ b/torch/csrc/autograd/python_legacy_variable.cpp @@ -57,7 +57,7 @@ static PyObject *THPVariable_pynew(PyTypeObject* type, PyObject *args, PyObject Variable var; if (grad_fn) { auto grad_fn_ = THPFunction_asFunction((THPFunction*)grad_fn); - Edge edge(grad_fn_, grad_fn_->bump_inputs()); + Edge edge(grad_fn_, grad_fn_->add_input_metadata(tensor.type(), tensor.sizes())); var = make_variable(std::move(tensor), std::move(edge)); } else { var = make_variable(std::move(tensor), requires_grad); diff --git a/torch/csrc/autograd/type_and_shape.h b/torch/csrc/autograd/type_and_shape.h new file mode 100644 index 0000000000..01a62fa2cf --- /dev/null +++ b/torch/csrc/autograd/type_and_shape.h @@ -0,0 +1,34 @@ +#pragma once + +#include <ATen/ATen.h> +#include "torch/csrc/assertions.h" + +namespace torch { namespace autograd { + +/// A tensor's type and shape. Each Function records the required type and +/// shape of its inputs. If is_valid() is false, then the corresponding input +/// is not used and may be an undefined tensor. +struct TypeAndShape { + TypeAndShape() : type_(nullptr) {} + + TypeAndShape(const at::Type& type, at::IntList shape) + : type_(&type) , shape_(shape) {} + + bool is_valid() const { + return type_ != nullptr; + } + + const at::Type& type() const { + TORCH_ASSERT(type_); + return *type_; + } + + at::IntList shape() const { + return shape_; + } + + const at::Type* type_; + at::DimVector shape_; +}; + +}} diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index f0b4c1df9d..48a35f4f88 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -126,10 +126,12 @@ void Variable::Impl::backward( } void Variable::Impl::set_data(Tensor new_data) { - data_ = std::move(new_data); - if (data_.type() != *type_) { - type_ = VariableType::getType(data_); + if (new_data.type() != data_.type()) { + type_ = VariableType::getType(new_data.type()); + // Clear grad_accumulator if it exists, since it stores the old type info. + grad_accumulator_.reset(); } + data_ = std::move(new_data); } Variable::ViewImpl::ViewImpl(Variable base, at::Tensor data, Edge gradient_edge) @@ -158,7 +160,7 @@ std::shared_ptr<Function>& Variable::ViewImpl::get_grad_fn() { fn->stride = strides(); fn->storage_offset = data_.storage_offset(); fn->set_next_edges(collect_next_edges(base_)); - fn->set_num_inputs(1); + fn->add_input_metadata(base_.type(), sizes()); grad_fn_ = std::move(fn); attr_version = current_version; } |