summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--aten/src/ATen/ATen.h1
-rw-r--r--aten/src/ATen/DimVector.h11
-rw-r--r--aten/src/ATen/SmallVector.h6
-rw-r--r--test/cpp/api/misc.cpp2
-rw-r--r--test/test_autograd.py26
-rw-r--r--test/test_jit.py2
-rw-r--r--tools/autograd/gen_autograd_functions.py2
-rw-r--r--tools/autograd/templates/VariableType.cpp15
-rw-r--r--torch/csrc/autograd/engine.cpp60
-rw-r--r--torch/csrc/autograd/function.cpp4
-rw-r--r--torch/csrc/autograd/function.h52
-rw-r--r--torch/csrc/autograd/functions/accumulate_grad.cpp9
-rw-r--r--torch/csrc/autograd/functions/basic_ops.h4
-rw-r--r--torch/csrc/autograd/functions/special.cpp10
-rw-r--r--torch/csrc/autograd/functions/special.h4
-rw-r--r--torch/csrc/autograd/functions/tensor.cpp3
-rw-r--r--torch/csrc/autograd/functions/utils.cpp2
-rw-r--r--torch/csrc/autograd/python_engine.cpp2
-rw-r--r--torch/csrc/autograd/python_function.cpp11
-rw-r--r--torch/csrc/autograd/python_function.h2
-rw-r--r--torch/csrc/autograd/python_legacy_variable.cpp2
-rw-r--r--torch/csrc/autograd/type_and_shape.h34
-rw-r--r--torch/csrc/autograd/variable.cpp10
23 files changed, 208 insertions, 66 deletions
diff --git a/aten/src/ATen/ATen.h b/aten/src/ATen/ATen.h
index e41c2d9565..84568880fc 100644
--- a/aten/src/ATen/ATen.h
+++ b/aten/src/ATen/ATen.h
@@ -15,3 +15,4 @@
#include "ATen/TensorOperators.h"
#include "ATen/TensorMethods.h"
#include "ATen/Dispatch.h"
+#include "ATen/DimVector.h"
diff --git a/aten/src/ATen/DimVector.h b/aten/src/ATen/DimVector.h
new file mode 100644
index 0000000000..aaa4dc9c07
--- /dev/null
+++ b/aten/src/ATen/DimVector.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include "SmallVector.h"
+#include <stdint.h>
+
+namespace at {
+
+/// A container for sizes or strides
+using DimVector = SmallVector<int64_t, 5>;
+
+}
diff --git a/aten/src/ATen/SmallVector.h b/aten/src/ATen/SmallVector.h
index 521e46082b..3a5926a06d 100644
--- a/aten/src/ATen/SmallVector.h
+++ b/aten/src/ATen/SmallVector.h
@@ -921,6 +921,12 @@ public:
SmallVectorImpl<T>::operator=(::std::move(RHS));
}
+ template<typename Container>
+ const SmallVector &operator=(const Container &RHS) {
+ this->assign(RHS.begin(), RHS.end());
+ return *this;
+ }
+
SmallVector(SmallVectorImpl<T> &&RHS) : SmallVectorImpl<T>(N) {
if (!RHS.empty())
SmallVectorImpl<T>::operator=(::std::move(RHS));
diff --git a/test/cpp/api/misc.cpp b/test/cpp/api/misc.cpp
index a494e3389e..0f4fa33ddf 100644
--- a/test/cpp/api/misc.cpp
+++ b/test/cpp/api/misc.cpp
@@ -71,7 +71,7 @@ TEST_CASE("autograd") {
}
SECTION("custom gradient inputs") {
z.sum().backward(
- autograd::make_variable(at::ones(at::CPU(at::kFloat), {1}) * 2));
+ autograd::make_variable(at::ones(at::CPU(at::kFloat), {}) * 2));
REQUIRE(x.grad().allclose(y * 2));
}
// Assume everything else is safe from PyTorch tests.
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 0c19e24208..4c26e58aa6 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -75,7 +75,7 @@ class TestAutograd(TestCase):
x = torch.randn(5, 5, requires_grad=True)
y = torch.randn(5, 5, requires_grad=True)
result = cls.apply(x, 2, y)
- go = torch.ones(1, requires_grad=True)
+ go = torch.ones((), requires_grad=True)
result.sum().backward(go, create_graph=True)
self.assertEqual(x.grad.data, y.data + torch.ones(5, 5))
@@ -173,6 +173,23 @@ class TestAutograd(TestCase):
MyFunction()(y).sum().backward()
self.assertEqual(v.grad.data, torch.zeros(shape))
+ def test_invalid_gradients(self):
+ class MyFunction(Function):
+ @staticmethod
+ def forward(ctx, x):
+ return x * 2
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return torch.randn(10, dtype=torch.float)
+
+ with self.assertRaisesRegex(RuntimeError, 'expected shape'):
+ input = torch.randn(5, 5, dtype=torch.float, requires_grad=True)
+ MyFunction.apply(input).sum().backward()
+ with self.assertRaisesRegex(RuntimeError, 'expected type'):
+ input = torch.randn(10, dtype=torch.double, requires_grad=True)
+ MyFunction.apply(input).sum().backward()
+
def test_accumulate_grad(self):
grad_output = torch.ones(5, 5)
@@ -495,7 +512,6 @@ class TestAutograd(TestCase):
def test_sparse_backward(self):
class FixedGradientFunction(Function):
-
def __init__(self, grad):
self.grad = grad
@@ -524,15 +540,15 @@ class TestAutograd(TestCase):
dense_fn = FixedGradientFunction(dense_grad)
# sparse first
- x = torch.randn(5, 5, requires_grad=True)
+ x = torch.randn(size, requires_grad=True)
(sparse_fn1(x) + dense_fn(x) + sparse_fn2(x)).sum().backward()
self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
# dense first
- x = torch.randn(5, 5, requires_grad=True)
+ x = torch.randn(size, requires_grad=True)
(dense_fn(x) + sparse_fn1(x) + sparse_fn2(x)).sum().backward()
self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
# sparse only
- x = torch.randn(5, 5, requires_grad=True)
+ x = torch.randn(size, requires_grad=True)
(sparse_fn1(x) + sparse_fn2(x)).sum().backward()
self.assertEqual(x.grad, sparse_grad1 + sparse_grad2)
diff --git a/test/test_jit.py b/test/test_jit.py
index d56f5ad5f8..4b6ee940b5 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1152,7 +1152,7 @@ class TestScript(JitTestCase):
out = func(x, y)
self.assertEqual(func(x, y), x + y)
- grad = torch.randn(2, 3)
+ grad = torch.randn(2, 3, dtype=torch.float)
out.backward(grad)
self.assertEqual(x.grad, grad)
self.assertEqual(y.grad, grad.sum(dim=0))
diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py
index a343050eb3..94ed1f8578 100644
--- a/tools/autograd/gen_autograd_functions.py
+++ b/tools/autograd/gen_autograd_functions.py
@@ -18,7 +18,7 @@ FUNCTION_DECLARATION = CodeTemplate("""\
struct ${op} : public ${superclass} {
using ${superclass}::${superclass};
variable_list apply(const variable_list& grads) override;
- std::string name() override { return "${op}"; }
+ std::string name() const override { return "${op}"; }
void release_variables() override {
${release_variables}
}
diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp
index 38fa9c2ed6..3dfb63313e 100644
--- a/tools/autograd/templates/VariableType.cpp
+++ b/tools/autograd/templates/VariableType.cpp
@@ -359,35 +359,35 @@ static void throw_error_out_requires_grad(const char* name) {
static void rebase_history(Variable& var, std::shared_ptr<Function> grad_fn) {
if (grad_fn && var.defined()) {
- grad_fn->set_num_inputs(1);
+ grad_fn->add_input_metadata(var.type(), var.sizes());
var.rebase_history({std::move(grad_fn), 0});
}
}
static void rebase_history(ArrayRef<Variable> vars, std::shared_ptr<Function> grad_fn) {
if (grad_fn) {
- grad_fn->set_num_inputs(vars.size());
- uint32_t output_nr = 0;
for (auto& var : vars) {
if (var.defined()) {
// TODO: eliminate const_cast
+ auto output_nr = grad_fn->add_input_metadata(var.type(), var.sizes());
const_cast<Variable&>(var).rebase_history({grad_fn, output_nr});
+ } else {
+ grad_fn->add_input_metadata(Function::undefined_input());
}
- output_nr++;
}
}
}
static void set_history(ArrayRef<Variable> vars, std::shared_ptr<Function> grad_fn) {
if (grad_fn) {
- grad_fn->set_num_inputs(vars.size());
- uint32_t output_nr = 0;
for (auto& var : vars) {
if (var.defined()) {
// TODO: eliminate const_cast
+ auto output_nr = grad_fn->add_input_metadata(var.type(), var.sizes());
const_cast<Variable&>(var).set_gradient_edge({grad_fn, output_nr});
+ } else {
+ grad_fn->add_input_metadata(Function::undefined_input());
}
- output_nr++;
}
}
}
@@ -428,7 +428,6 @@ Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_block
if (requires_grad) {
grad_fn = std::make_shared<CopyBackwards>();
grad_fn->set_next_edges(collect_next_edges(self, src));
- grad_fn->set_num_inputs(1);
grad_fn->src_type = &src.type();
grad_fn->src_device = src.is_cuda() ? src.get_device() : -1;
}
diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp
index 8f4d72398a..88aec1573e 100644
--- a/torch/csrc/autograd/engine.cpp
+++ b/torch/csrc/autograd/engine.cpp
@@ -288,6 +288,49 @@ static variable_list call_post_hooks(Function& fn, variable_list outputs, variab
return outputs;
}
+static bool is_compatible_type(const at::Type& expected, const at::Type& actual) {
+ // Types are compatible if they exactly match or if the gradient is a sparse
+ // version of the expected type.
+ return expected == actual || (actual.is_sparse() &&
+ expected == actual.toBackend(toDense(actual.backend())));
+}
+
+template<typename F>
+static void validate_outputs(const edge_list& edges, const variable_list& grads, const F& format_error) {
+ if (grads.size() != edges.size()) {
+ std::stringstream ss;
+ ss << "invalid number of gradients - expected ";
+ ss << edges.size() << ", but got " << grads.size();
+ throw std::runtime_error(format_error(ss.str()));
+ }
+ for (size_t i = 0; i < grads.size(); i++) {
+ const auto& edge = edges[i];
+ if (!edge.is_valid()) continue;
+
+ const auto& metadata = edge.function->input_metadata(edge.input_nr);
+ const auto& output = grads[i];
+ if (!output.defined()) {
+ // FIXME: TestJit.test_ge_optimized fails this assertion.
+ // std::stringstream ss;
+ // ss << "undefined gradient at index " << i;
+ // throw std::runtime_error(format_error(ss.str()));
+ continue;
+ }
+ if (!grads[i].sizes().equals(metadata.shape())) {
+ std::stringstream ss;
+ ss << "invalid gradient at index " << i << " - expected shape ";
+ ss << metadata.shape() << " but got " << grads[i].sizes();
+ throw std::runtime_error(format_error(ss.str()));
+ }
+ if (!is_compatible_type(metadata.type(), grads[i].type())) {
+ std::stringstream ss;
+ ss << "invalid gradient at index " << i << " - expected type ";
+ ss << metadata.type() << " but got " << grads[i].type();
+ throw std::runtime_error(format_error(ss.str()));
+ }
+ }
+}
+
static variable_list call_function(FunctionTask& task) {
bool prev_checkpoint_valid_state = checkpoint_valid;
checkpoint_valid = task.base->can_checkpoint() && prev_checkpoint_valid_state;
@@ -298,6 +341,11 @@ static variable_list call_function(FunctionTask& task) {
fn.will_release_variables();
}
auto outputs = fn(inputs);
+ validate_outputs(fn.next_edges(), outputs, [&](const std::string& msg) {
+ std::ostringstream ss;
+ ss << "Function " << fn.name() << " returned an " << msg;
+ return ss.str();
+ });
checkpoint_valid = prev_checkpoint_valid_state;
return call_post_hooks(fn, std::move(outputs), std::move(inputs));
}
@@ -323,13 +371,6 @@ auto Engine::evaluate_function(FunctionTask& task) -> void {
fn.release_variables();
}
- if (outputs.size() != fn.num_outputs()) {
- std::stringstream ss;
- ss << "Function '" << fn.name() << "' returned an invalid number of outputs - expected ";
- ss << fn.num_outputs() << ", but got " << outputs.size();
- throw std::runtime_error(ss.str());
- }
-
int num_outputs = outputs.size();
if (num_outputs == 0) return; // Don't even acquire the mutex
std::lock_guard<std::mutex> lock(task.base->mutex);
@@ -426,6 +467,11 @@ auto Engine::execute(const edge_list& input_roots,
bool create_graph,
const edge_list& outputs) -> variable_list {
std::call_once(start_threads_flag, &Engine::start_threads, this);
+
+ validate_outputs(input_roots, inputs, [](const std::string& msg) {
+ return msg;
+ });
+
// Callbacks are only valid for the duration of this run and should always be cleared
ClearCallbacks _cb_guard(final_callbacks, post_callbacks_lock);
diff --git a/torch/csrc/autograd/function.cpp b/torch/csrc/autograd/function.cpp
index 47b4e0620e..d116069149 100644
--- a/torch/csrc/autograd/function.cpp
+++ b/torch/csrc/autograd/function.cpp
@@ -19,8 +19,8 @@ namespace torch { namespace autograd {
thread_local uint64_t Function::next_sequence_nr_ = 0;
-auto Function::name() -> std::string {
- return std::string(typeid(*this).name());
+auto Function::name() const -> std::string {
+ return at::demangle(typeid(*this).name());
}
// This function is analogous to make_trace which operates on PythonOp, but this
diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h
index 77e9b21846..d76296a10a 100644
--- a/torch/csrc/autograd/function.h
+++ b/torch/csrc/autograd/function.h
@@ -5,6 +5,7 @@
#include "torch/csrc/autograd/grad_mode.h"
#include "torch/csrc/autograd/profiler.h"
#include "torch/csrc/autograd/saved_variable.h"
+#include "torch/csrc/autograd/type_and_shape.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/jit/tracer.h"
#include "torch/csrc/utils/auto_unique_ptr.h"
@@ -91,18 +92,15 @@ struct Function : std::enable_shared_from_this<Function> {
/// in the backward() pass, with higher sequence numbers prioritized
/// before lower sequence numbers.
explicit Function(
- uint32_t num_inputs,
uint64_t sequence_nr,
edge_list&& next_edges = edge_list())
: sequence_nr_(sequence_nr),
- num_inputs_(num_inputs),
next_edges_(std::move(next_edges)) {}
explicit Function(
- uint32_t num_inputs = 0,
edge_list&& next_edges = edge_list())
- : Function(num_inputs, next_sequence_nr_++, std::move(next_edges)) {}
-
+ : Function(next_sequence_nr_++, std::move(next_edges)) {}
+
/// Functions are neither copyable nor moveable.
Function(const Function& other) = delete;
Function(Function&& other) = delete;
@@ -123,20 +121,37 @@ struct Function : std::enable_shared_from_this<Function> {
// Graph Connectivity API
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- // Inputs
+ // Inputs. NOTE: inputs of the grad_fn correspond to Tensor outputs of the
+ // forward function.
+
+ // Marker for expected undefined input
+ struct undefined_input {};
- /// Increments the number of inputs of the function and returns the previous
- /// value.
- uint32_t bump_inputs() noexcept {
- return num_inputs_++;
+ /// Adds the type and shape metadata for a new input. Returns the index of
+ /// of the new input.
+ uint32_t add_input_metadata(const at::Type& type, at::IntList shape) noexcept {
+ uint32_t input_nr = input_metadata_.size();
+ input_metadata_.emplace_back(type, shape);
+ return input_nr;
}
- void set_num_inputs(uint32_t num_inputs) noexcept {
- num_inputs_ = num_inputs;
+ /// Adds a placeholder for an input that will not be used.
+ uint32_t add_input_metadata(undefined_input u) noexcept {
+ uint32_t input_nr = input_metadata_.size();
+ input_metadata_.emplace_back();
+ return input_nr;
}
uint32_t num_inputs() const noexcept {
- return num_inputs_;
+ return input_metadata_.size();
+ }
+
+ const TypeAndShape& input_metadata(size_t index) const {
+ return input_metadata_[index];
+ }
+
+ void clear_input_metadata() {
+ input_metadata_.clear();
}
// Outputs ("Next Edges")
@@ -185,7 +200,7 @@ struct Function : std::enable_shared_from_this<Function> {
}
/// Returns the name of the dynamic type of the function, for debugging.
- virtual std::string name();
+ virtual std::string name() const;
/// Returns true if the particular output edge is active, and that particular
/// output of this function should be computed.
@@ -312,12 +327,12 @@ struct Function : std::enable_shared_from_this<Function> {
// fields.
const uint64_t sequence_nr_;
- uint32_t num_inputs_;
edge_list next_edges_;
PyObject* pyobj_ = nullptr; // weak reference
std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_;
std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
auto_unique_ptr<jit::tracer::FunctionTracingState> tracing_state_;
+ at::SmallVector<TypeAndShape, 2> input_metadata_;
};
/// See Function::is_traceable() for definition.
@@ -355,13 +370,14 @@ struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> {
/// `input_nr` thus equal to `function->num_inputs()`. Additionally, it
/// increments the `Function`'s number of inputs by one. Approximately
/// equivalent to `variable.set_gradient_edge(function,
-/// function->bump_inputs())`. If you don't want the `Function`'s `num_inputs`
-/// to be incremented, use `set_gradient_edge` directly.
+/// function->add_input_metadata(variable.type(), variable.sizes()))`.
+/// If you don't want the `Function`'s `num_inputs` to be incremented, use
+/// `set_gradient_edge` directly.
inline void create_gradient_edge(
Variable& variable,
std::shared_ptr<Function> function) {
// Copy before move.
- const auto input_nr = function->bump_inputs();
+ const auto input_nr = function->add_input_metadata(variable.type(), variable.sizes());
variable.set_gradient_edge({std::move(function), input_nr});
}
diff --git a/torch/csrc/autograd/functions/accumulate_grad.cpp b/torch/csrc/autograd/functions/accumulate_grad.cpp
index 31fc2434ea..fdbe54d3fc 100644
--- a/torch/csrc/autograd/functions/accumulate_grad.cpp
+++ b/torch/csrc/autograd/functions/accumulate_grad.cpp
@@ -13,12 +13,13 @@ using at::Tensor;
namespace torch { namespace autograd {
-// AccumulateGrad sets sequence_nr to the max value so it's always called
+// AccumulateGrad sets sequence_nr to the max value so it's always called
// ASAP during backwards.
AccumulateGrad::AccumulateGrad(Variable variable_)
- : Function(/*num_inputs=*/1
- , /*sequence_nr=*/UINT64_MAX)
- , variable(std::move(variable_)) {}
+ : Function(/*sequence_nr=*/UINT64_MAX)
+ , variable(std::move(variable_)) {
+ add_input_metadata(variable.type(), variable.sizes());
+}
auto AccumulateGrad::apply(const variable_list& grads) -> variable_list {
// XXX: this method is not thread-safe!
diff --git a/torch/csrc/autograd/functions/basic_ops.h b/torch/csrc/autograd/functions/basic_ops.h
index a630c3c79e..7ac4d2a3fa 100644
--- a/torch/csrc/autograd/functions/basic_ops.h
+++ b/torch/csrc/autograd/functions/basic_ops.h
@@ -12,7 +12,7 @@ namespace torch { namespace autograd {
struct Error : public Function {
Error(std::string msg, edge_list&& next_edges)
- : Function(/*num_inputs=*/0, std::move(next_edges))
+ : Function(std::move(next_edges))
, msg(std::move(msg)) {}
Error(std::string msg)
@@ -35,7 +35,7 @@ struct DelayedError : public Function {
struct GraphRoot : public Function {
GraphRoot(edge_list functions, variable_list inputs)
- : Function(/*num_inputs=*/0, std::move(functions)),
+ : Function(std::move(functions)),
outputs(std::move(inputs)) {}
virtual variable_list apply(const variable_list& inputs) {
diff --git a/torch/csrc/autograd/functions/special.cpp b/torch/csrc/autograd/functions/special.cpp
index a5c5a6228e..88ac969122 100644
--- a/torch/csrc/autograd/functions/special.cpp
+++ b/torch/csrc/autograd/functions/special.cpp
@@ -17,7 +17,9 @@ namespace torch { namespace autograd {
// Used when an output has multiple uses (there's only one entry
// in next_edges per output).
struct Replicate : public Function {
- Replicate() : Function(/*num_inputs=*/1) {}
+ Replicate(const at::Type& type, at::IntList shape) : Function() {
+ add_input_metadata(type, shape);
+ }
virtual variable_list apply(const variable_list& inputs) {
TORCH_ASSERT(inputs.size() == 1);
@@ -236,6 +238,7 @@ bool Eval::replaceSubgraph(const variable_list& inputs, const variable_list& _ou
// This detaches the subgraph from the full backward graph.
for (auto& begin : subgraph.boundary.begins) {
const auto& edge = begin.function->next_edge(begin.input_nr);
+
begin.function->set_next_edge(
begin.input_nr, Edge(ends_to_outputs.at(edge), 0));
}
@@ -265,7 +268,7 @@ bool Eval::replaceSubgraph(const variable_list& inputs, const variable_list& _ou
// the same Variable has been returned multiple times, and
// is repeated in this list.
if (output.grad_fn_unsafe() == this) {
- auto replicate = std::make_shared<Replicate>();
+ auto replicate = std::make_shared<Replicate>(output.type(), output.sizes());
replicate->add_next_edge({this_shared, output.output_nr()});
output.set_gradient_edge({std::move(replicate), 0});
repeated_outputs.emplace(&output);
@@ -274,7 +277,8 @@ bool Eval::replaceSubgraph(const variable_list& inputs, const variable_list& _ou
// perform any allocations until we actually see repeated outputs.
if (repeated_outputs.count(&output) > 0) {
auto & replicate = output.grad_fn();
- replicate->add_next_edge({this_shared, num_inputs_++});
+ auto input_nr = add_input_metadata(output.type(), output.sizes());
+ replicate->add_next_edge({this_shared, input_nr});
} else {
autograd::create_gradient_edge(output, this_shared);
}
diff --git a/torch/csrc/autograd/functions/special.h b/torch/csrc/autograd/functions/special.h
index 076a4faa98..273b139e23 100644
--- a/torch/csrc/autograd/functions/special.h
+++ b/torch/csrc/autograd/functions/special.h
@@ -15,7 +15,9 @@ namespace torch { namespace autograd {
struct EvalOutput : Function {
explicit EvalOutput(const Edge& next_edge_)
- : Function(/*num_inputs=*/1), next_edge(next_edge_) {}
+ : Function(), next_edge(next_edge_) {
+ add_input_metadata(undefined_input());
+ }
virtual variable_list apply(const variable_list& inputs) override {
throw std::logic_error("EvalOutput::apply() called");
diff --git a/torch/csrc/autograd/functions/tensor.cpp b/torch/csrc/autograd/functions/tensor.cpp
index 75ec8180c9..5df72d5263 100644
--- a/torch/csrc/autograd/functions/tensor.cpp
+++ b/torch/csrc/autograd/functions/tensor.cpp
@@ -36,12 +36,13 @@ CopySlices::CopySlices(
const Variable& base_var,
at::TensorGeometry view_,
std::shared_ptr<Function> fn_)
- : Function(/*num_inputs=*/1),
+ : Function(),
base(base_var),
view(std::move(view_)),
fn(std::move(fn_)) {
// Take the next_edges of fn as our own, except for index 0 which goes
// to base instead of the view.
+ add_input_metadata(base_var.type(), base_var.sizes());
const auto num_outputs = fn->num_outputs();
next_edges_.reserve(num_outputs);
add_next_edge(base_var.gradient_edge());
diff --git a/torch/csrc/autograd/functions/utils.cpp b/torch/csrc/autograd/functions/utils.cpp
index 09939a109a..485572d9ee 100644
--- a/torch/csrc/autograd/functions/utils.cpp
+++ b/torch/csrc/autograd/functions/utils.cpp
@@ -29,7 +29,7 @@ variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
autograd::create_gradient_edge(variable, grad_fn);
result.push_back(std::move(variable));
} else {
- grad_fn->bump_inputs();
+ grad_fn->add_input_metadata(Function::undefined_input());
result.emplace_back();
}
}
diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp
index e07d88feee..e52e06a341 100644
--- a/torch/csrc/autograd/python_engine.cpp
+++ b/torch/csrc/autograd/python_engine.cpp
@@ -134,7 +134,7 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
}
}
- edge_list output_edges;
+ std::vector<Edge> output_edges;
if (inputs != nullptr) {
int num_inputs = PyTuple_GET_SIZE(inputs);
output_edges.reserve(num_inputs);
diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp
index 3fb53c9a6d..4672d535ef 100644
--- a/torch/csrc/autograd/python_function.cpp
+++ b/torch/csrc/autograd/python_function.cpp
@@ -207,7 +207,7 @@ auto PyFunction::release_variables() -> void {
f->has_freed_buffers = 1;
}
-auto PyFunction::name() -> std::string {
+auto PyFunction::name() const -> std::string {
AutoGIL gil;
auto f = (THPFunction*) obj;
auto name = std::string(Py_TYPE(f)->tp_name);
@@ -245,7 +245,7 @@ static int THPFunction_traverse(THPFunction *self, visitproc visit, void *arg)
static int THPFunction_clear(THPFunction *self)
{
- self->cdata.set_num_inputs(0);
+ self->cdata.clear_input_metadata();
Py_CLEAR(self->needs_input_grad);
@@ -293,7 +293,6 @@ PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
new (&self->input_info) std::vector<VariableInfo>();
new (&self->saved_variables) std::vector<SavedVariable>();
new (&self->is_variable_input) std::vector<bool>();
- self->cdata.set_num_inputs(0);
return obj;
}
@@ -425,6 +424,10 @@ static void _wrap_outputs(THPFunction *self,
// Note that output Variables may be repeated. In that case, the last call
// to set_history wins.
auto var = as_variable(obj, i);
+ if (cdata) {
+ auto output_nr = cdata->add_input_metadata(var.type(), var.sizes());
+ TORCH_ASSERT(i == (int)output_nr);
+ }
set_history(var, i, is_input, is_modified, is_differentiable);
if (is_executable) {
@@ -616,7 +619,7 @@ PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const Unpacked
THPObjectPtr outputs(PyTuple_New(num_outputs));
if (!outputs) throw python_error();
- grad_fn->cdata.set_num_inputs(num_outputs);
+ grad_fn->cdata.clear_input_metadata();
// Record type, device, and size information about inputs
if (is_executable) {
diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h
index 562ab79a28..529bcaf454 100644
--- a/torch/csrc/autograd/python_function.h
+++ b/torch/csrc/autograd/python_function.h
@@ -35,7 +35,7 @@ struct PyFunction : public Function {
variable_list legacy_apply(const variable_list& inputs);
virtual void release_variables() override;
- virtual std::string name() override;
+ virtual std::string name() const override;
virtual std::shared_ptr<Function> get_shared_ptr() override;
virtual bool is_traceable() override;
diff --git a/torch/csrc/autograd/python_legacy_variable.cpp b/torch/csrc/autograd/python_legacy_variable.cpp
index 179c54a1e3..33dd281ab0 100644
--- a/torch/csrc/autograd/python_legacy_variable.cpp
+++ b/torch/csrc/autograd/python_legacy_variable.cpp
@@ -57,7 +57,7 @@ static PyObject *THPVariable_pynew(PyTypeObject* type, PyObject *args, PyObject
Variable var;
if (grad_fn) {
auto grad_fn_ = THPFunction_asFunction((THPFunction*)grad_fn);
- Edge edge(grad_fn_, grad_fn_->bump_inputs());
+ Edge edge(grad_fn_, grad_fn_->add_input_metadata(tensor.type(), tensor.sizes()));
var = make_variable(std::move(tensor), std::move(edge));
} else {
var = make_variable(std::move(tensor), requires_grad);
diff --git a/torch/csrc/autograd/type_and_shape.h b/torch/csrc/autograd/type_and_shape.h
new file mode 100644
index 0000000000..01a62fa2cf
--- /dev/null
+++ b/torch/csrc/autograd/type_and_shape.h
@@ -0,0 +1,34 @@
+#pragma once
+
+#include <ATen/ATen.h>
+#include "torch/csrc/assertions.h"
+
+namespace torch { namespace autograd {
+
+/// A tensor's type and shape. Each Function records the required type and
+/// shape of its inputs. If is_valid() is false, then the corresponding input
+/// is not used and may be an undefined tensor.
+struct TypeAndShape {
+ TypeAndShape() : type_(nullptr) {}
+
+ TypeAndShape(const at::Type& type, at::IntList shape)
+ : type_(&type) , shape_(shape) {}
+
+ bool is_valid() const {
+ return type_ != nullptr;
+ }
+
+ const at::Type& type() const {
+ TORCH_ASSERT(type_);
+ return *type_;
+ }
+
+ at::IntList shape() const {
+ return shape_;
+ }
+
+ const at::Type* type_;
+ at::DimVector shape_;
+};
+
+}}
diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp
index f0b4c1df9d..48a35f4f88 100644
--- a/torch/csrc/autograd/variable.cpp
+++ b/torch/csrc/autograd/variable.cpp
@@ -126,10 +126,12 @@ void Variable::Impl::backward(
}
void Variable::Impl::set_data(Tensor new_data) {
- data_ = std::move(new_data);
- if (data_.type() != *type_) {
- type_ = VariableType::getType(data_);
+ if (new_data.type() != data_.type()) {
+ type_ = VariableType::getType(new_data.type());
+ // Clear grad_accumulator if it exists, since it stores the old type info.
+ grad_accumulator_.reset();
}
+ data_ = std::move(new_data);
}
Variable::ViewImpl::ViewImpl(Variable base, at::Tensor data, Edge gradient_edge)
@@ -158,7 +160,7 @@ std::shared_ptr<Function>& Variable::ViewImpl::get_grad_fn() {
fn->stride = strides();
fn->storage_offset = data_.storage_offset();
fn->set_next_edges(collect_next_edges(base_));
- fn->set_num_inputs(1);
+ fn->add_input_metadata(base_.type(), sizes());
grad_fn_ = std::move(fn);
attr_version = current_version;
}