diff options
author | Peter Goldsborough <peter@goldsborough.me> | 2018-02-12 20:26:26 -0800 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2018-02-12 23:26:26 -0500 |
commit | 2d5fbe6e0de1e5dfa292afedec34b277c14d7c10 (patch) | |
tree | 44003eb36bec371c3fddf3e6965a714caf47b61f /tools | |
parent | 0ef10385b2d82b87d89124a6dc3e0d95dfb97a51 (diff) | |
download | pytorch-2d5fbe6e0de1e5dfa292afedec34b277c14d7c10.tar.gz pytorch-2d5fbe6e0de1e5dfa292afedec34b277c14d7c10.tar.bz2 pytorch-2d5fbe6e0de1e5dfa292afedec34b277c14d7c10.zip |
Improve Variable interface (#5127)
* Improve Variable interface
* Address comments from @apaszke and @colesbury
* string ::operator= is not noexcept
* Remove ir.h from tracer_state.h to improve build times
* Make Variable a struct and pack SavedVariable fields
* Implement as_variable_ref
* grad_fn_ptr() -> grad_fn_unsafe()
* Reduce hackiness of set_type hack
* Include variable.h and edge.h in tracer_state.h because it uses them
* class Variable -> struct Variable because Windows cant even
* Make Variable::output_nr uint32_t instead of int
* Add comment about tracing state
* Replaced more static_cast<Variable&> and improve docs
* Remove SavedVariable destructor and construct members in init list
* Clarify docs for Variable
* Variable::set_version -> set_version_counter
Diffstat (limited to 'tools')
-rw-r--r-- | tools/autograd/gen_autograd_functions.py | 2 | ||||
-rw-r--r-- | tools/autograd/templates/Functions.cpp | 6 | ||||
-rw-r--r-- | tools/autograd/templates/VariableType.cpp | 91 | ||||
-rw-r--r-- | tools/autograd/templates/VariableType.h | 5 | ||||
-rw-r--r-- | tools/autograd/templates/python_torch_functions.cpp | 4 |
5 files changed, 47 insertions, 61 deletions
diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index 8707ab8010..0a50b5c4b3 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -129,7 +129,7 @@ def process_function(func): name = arg['name'] if arg['type'] == 'Tensor' or (arg['type'] == 'Scalar' and is_output): saved_variables.append('SavedVariable {}_;'.format(name)) - release_variables.append('{}_.data.reset();'.format(name)) + release_variables.append('{}_.reset_data();'.format(name)) ptr = 'shared_from_this()' if is_output else '' unpack.append('auto {} = {}_.unpack({});'.format(name, name, ptr)) elif arg['type'] == 'TensorList': diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index 1aae50dcfd..90ad2ebbf8 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -478,7 +478,7 @@ Tensor select_backward_scalar(Tensor grad, const Tensor & input, const Tensor & #ifdef WITH_SCALARS grad_input.masked_fill_(input == value, grad); #else - auto grad_data = static_cast<Variable&>(grad).data(); + auto grad_data = as_variable_ref(grad).data(); grad_input.masked_fill_(input == value, Scalar(grad_data[0])); #endif return grad_input; @@ -1088,9 +1088,9 @@ std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward( for (auto s : input.sizes().slice(2)) { M *= s; } - auto mu = unsqueeze_dim1(make_variable(training ? save_mean : running_mean), input); + auto mu = unsqueeze_dim1(make_variable(training ? save_mean : running_mean, /*requires_grad=*/false), input); auto input_sub_mu = input - mu; - auto sigma2_eps_neg_1_2 = unsqueeze_dim1(make_variable(training ? save_std : running_var.add(Scalar(eps)).pow(-0.5)), input); + auto sigma2_eps_neg_1_2 = unsqueeze_dim1(make_variable(training ? save_std : running_var.add(Scalar(eps)).pow(-0.5), /*requires_grad=*/false), input); auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2); auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3); diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index 96fbfaa2ce..42495ebea1 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -5,6 +5,7 @@ #include "torch/csrc/autograd/variable.h" #include "torch/csrc/autograd/function.h" +#include "torch/csrc/autograd/edge.h" #include "torch/csrc/autograd/grad_mode.h" #include "torch/csrc/autograd/saved_variable.h" #include "torch/csrc/autograd/generated/Functions.h" @@ -28,7 +29,6 @@ using namespace at; using namespace torch::autograd::generated; namespace torch { namespace autograd { - // Helper methods for working with Attributes (torch/csrc/jit/attributes.h) // The overloaded accessors are convenient for the generated code (since we @@ -74,7 +74,7 @@ std::unique_ptr<Storage> VariableType::storageWithAllocator(int64_t size, std::u return baseType->storageWithAllocator(size, std::move(allocator)); } Tensor VariableType::unsafeTensorFromTH(void * th_pointer, bool retain) const { - return make_variable(baseType->unsafeTensorFromTH(th_pointer, retain), false); + return make_variable(baseType->unsafeTensorFromTH(th_pointer, retain), /*requires_grad=*/false); } std::unique_ptr<Generator> VariableType::generator() const { return baseType->generator(); @@ -164,7 +164,7 @@ Variable & VariableType::checked_cast_variable(const Tensor & t, const char * na runtime_error("Expected object of type Variable but found type %s for argument #%d '%s'", t.type().toString(), pos, name); } - return static_cast<Variable&>(const_cast<Tensor&>(t)); + return as_variable_ref(const_cast<Tensor&>(t)); } Tensor & VariableType::unpack(const Tensor & t, const char * name, int pos) { @@ -207,49 +207,35 @@ static std::vector<SavedVariable> make_saved_variable_list(TensorList tensors) { return SavedVariable{tensor, false /* is output */}; }); } -static Tensor as_variable(Tensor tensor) { - return make_variable(std::move(tensor)); -} - -static std::tuple<Tensor, Tensor> -as_variable(std::tuple<Tensor, Tensor> tensors) { - return std::make_tuple<>( - make_variable(std::move(std::get<0>(tensors))), - make_variable(std::move(std::get<1>(tensors)))); +template <typename... Tensors, size_t... Is> +std::tuple<Tensors...> as_variable_impl( + std::tuple<Tensors...> tensors, + Indices<Is...>) { + // Expand the integer parameter pack into a sequence of Variable + // constructions. This turns into (boolean omitted): + // Variable(std::get<0>(tensors)), Variable(std::get<1>(tensors)), ... + return std::tuple<Tensors...>( + make_variable(std::get<Is>(tensors), /*requires_grad=*/false)...); } -static std::tuple<Tensor, Tensor, Tensor> -as_variable(std::tuple<Tensor, Tensor, Tensor> tensors) { - return std::make_tuple<>( - make_variable(std::move(std::get<0>(tensors))), - make_variable(std::move(std::get<1>(tensors))), - make_variable(std::move(std::get<2>(tensors)))); +template <typename... Tensors> +std::tuple<Tensors...> as_variable(std::tuple<Tensors...> tensors) { + // `sizeof...(Tensors)` gets us the size of the `Tensors` parameter pack at + // compile time. We use it to parameterize a `MakeIndices` class, which will + // expand into an Indices object containing the numbers 0 to + // sizeof...(Tensors) - 1. + return as_variable_impl( + tensors, typename MakeIndices<sizeof...(Tensors)>::indices()); } -static std::tuple<Tensor, Tensor, Tensor, Tensor> -as_variable(std::tuple<Tensor, Tensor, Tensor, Tensor> tensors) { - return std::make_tuple<>( - make_variable(std::move(std::get<0>(tensors))), - make_variable(std::move(std::get<1>(tensors))), - make_variable(std::move(std::get<2>(tensors))), - make_variable(std::move(std::get<3>(tensors)))); -} - -static std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> -as_variable(std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> tensors) { - return std::make_tuple<>( - make_variable(std::move(std::get<0>(tensors))), - make_variable(std::move(std::get<1>(tensors))), - make_variable(std::move(std::get<2>(tensors))), - make_variable(std::move(std::get<3>(tensors))), - make_variable(std::move(std::get<4>(tensors))) - ); +static Tensor as_variable(Tensor tensor) { + return make_variable(std::move(tensor), /*requires_grad=*/false); } static std::vector<Tensor> as_variable(TensorList tl) { std::vector<Tensor> variables; for (auto& t : tl) { - variables.emplace_back(make_variable(std::move(t))); + variables.emplace_back(make_variable(std::move(t), /*requires_grad=*/false)); } return variables; } @@ -316,20 +302,20 @@ static void throw_error_out_requires_grad(const char* name) { static void rebase_history(Tensor& tensor, std::shared_ptr<Function> grad_fn) { if (grad_fn && tensor.defined()) { - auto& var = static_cast<Variable&>(tensor); + auto& var = as_variable_ref(tensor); grad_fn->num_inputs = 1; - var.rebase_history(0, std::move(grad_fn)); + var.rebase_history({std::move(grad_fn), 0}); } } static void rebase_history(TensorList tensors, std::shared_ptr<Function> grad_fn) { if (grad_fn) { grad_fn->num_inputs = tensors.size(); - int output_nr = 0; + uint32_t output_nr = 0; for (auto& tensor : tensors) { if (tensor.defined()) { - auto& var = static_cast<Variable&>(const_cast<Tensor&>(tensor)); - var.rebase_history(output_nr, grad_fn); + auto& var = as_variable_ref(const_cast<Tensor&>(tensor)); + var.rebase_history({grad_fn, output_nr}); } output_nr++; } @@ -340,22 +326,20 @@ static void rebase_history(TensorList tensors, std::shared_ptr<Function> grad_fn // overload for functions with multiple differentiable outputs. static void set_history(Tensor& tensor, std::shared_ptr<Function> grad_fn) { if (grad_fn && tensor.defined()) { - auto& var = static_cast<Variable&>(tensor); + auto& var = as_variable_ref(tensor); grad_fn->num_inputs = 1; - var.get()->output_nr = 0; - var.get()->_grad_fn = std::move(grad_fn); + var.set_gradient_edge({std::move(grad_fn), 0}); } } static void set_history(TensorList tensors, std::shared_ptr<Function> grad_fn) { if (grad_fn) { grad_fn->num_inputs = tensors.size(); - int64_t output_nr = 0; + uint32_t output_nr = 0; for (auto& tensor : tensors) { if (tensor.defined()) { - auto& var = static_cast<Variable&>(const_cast<Tensor&>(tensor)); - var.get()->output_nr = output_nr; - var.get()->_grad_fn = grad_fn; + auto& var = as_variable_ref(const_cast<Tensor&>(tensor)); + var.set_gradient_edge({grad_fn, output_nr}); } output_nr++; } @@ -378,9 +362,8 @@ template<typename... Args> inline variable_list flatten(Args&&... args) { return out; // RVO } -static void increment_version(const Tensor & t) { - auto& var = static_cast<const Variable&>(t); - var.version_counter().increment(); +static void increment_version(Tensor & t) { + as_variable_ref(t).bump_version(); } static bool isFloatingPoint(ScalarType s) { @@ -411,7 +394,7 @@ Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_block Tensor & VariableType::resize_(Tensor & self, IntList size) const { auto& self_ = unpack(self, "self", 0); - if (static_cast<Variable&>(self).requires_grad()) { + if (as_variable_ref(self).requires_grad()) { at::runtime_error("cannot resize variables that require grad"); } baseType->resize_(self_, size); @@ -421,7 +404,7 @@ Tensor & VariableType::resize_(Tensor & self, IntList size) const { Tensor & VariableType::resize_as_(Tensor & self, const Tensor & the_template) const { auto& self_ = unpack(self, "self", 0); auto& the_template_ = unpack(the_template, "the_template", 1); - if (static_cast<Variable&>(self).requires_grad()) { + if (as_variable_ref(self).requires_grad()) { at::runtime_error("cannot resize variables that require grad"); } baseType->resize_as_(self_, the_template_); diff --git a/tools/autograd/templates/VariableType.h b/tools/autograd/templates/VariableType.h index fd59d54a29..938a425088 100644 --- a/tools/autograd/templates/VariableType.h +++ b/tools/autograd/templates/VariableType.h @@ -3,6 +3,10 @@ // ${generated_comment} #include <ATen/ATen.h> + +#include <cstdint> // for size_t +#include <functional> // for function +#include <memory> // for unique_ptr #include <string> #include <vector> @@ -56,7 +60,6 @@ private: static at::Tensor unpack_opt(const Tensor & t, const char * name, int pos); static std::vector<at::Tensor> unpack(at::TensorList tl, const char *name, int pos); -private: at::Type* baseType; std::string str; }; diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp index 0490af5ab3..00a6178f84 100644 --- a/tools/autograd/templates/python_torch_functions.cpp +++ b/tools/autograd/templates/python_torch_functions.cpp @@ -25,7 +25,7 @@ using namespace torch::autograd::utils; namespace torch { namespace autograd { static Tensor set_requires_grad(Tensor self, bool requires_grad) { - static_cast<Variable&>(self).get()->_requires_grad = requires_grad; + as_variable_ref(self).set_requires_grad(requires_grad); return self; } @@ -70,7 +70,7 @@ static PyObject * THPVariable_from_numpy(PyObject* module, PyObject* arg) { HANDLE_TH_ERRORS auto data = torch::utils::tensor_from_numpy(arg); - return THPVariable_Wrap(make_variable(std::move(data))); + return THPVariable_Wrap(make_variable(std::move(data), /*requires_grad=*/false)); END_HANDLE_TH_ERRORS } |