summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorAnders Papitto <anderspapitto@gmail.com>2018-10-12 13:08:03 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-10-12 13:14:44 -0700
commit49256ddb4a80d5385779c9d60c49004a39535d6a (patch)
tree0135d9f1fcc65a469a699c8325719e88edd9b70e /torch
parent3f52a0aad7bcc4147243c3bda2022413b2842de8 (diff)
downloadpytorch-49256ddb4a80d5385779c9d60c49004a39535d6a.tar.gz
pytorch-49256ddb4a80d5385779c9d60c49004a39535d6a.tar.bz2
pytorch-49256ddb4a80d5385779c9d60c49004a39535d6a.zip
split generated VariableType.cpp (#12493)
Summary: On my devgpu, this brings the time taken for `touch torch/csrc/jit/type.h && time python setup.py rebuild develop` (debug mode, multicore build) down from 75 seconds to 62 seconds. For the `ninja install` of libtorch portion, which this affects, the reduction is from 52 seconds to 35. Pull Request resolved: https://github.com/pytorch/pytorch/pull/12493 Reviewed By: zdevito Differential Revision: D10315988 Pulled By: anderspapitto fbshipit-source-id: 316dc4ab81134aaa17a568cfc07408b7ced08c2e
Diffstat (limited to 'torch')
-rw-r--r--torch/CMakeLists.txt13
-rw-r--r--torch/csrc/autograd/VariableTypeManual.cpp336
-rw-r--r--torch/csrc/autograd/VariableTypeUtils.h182
3 files changed, 529 insertions, 2 deletions
diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt
index 70a970025b..5e2e1d920b 100644
--- a/torch/CMakeLists.txt
+++ b/torch/CMakeLists.txt
@@ -70,7 +70,11 @@ add_custom_command(
"${TORCH_SRC_DIR}/csrc/nn/THNN.cpp"
"${TORCH_SRC_DIR}/csrc/nn/THCUNN.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType.h"
- "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType.cpp"
+ "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType-0.cpp"
+ "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType-1.cpp"
+ "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType-2.cpp"
+ "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType-3.cpp"
+ "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType-4.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/Functions.h"
"${TORCH_SRC_DIR}/csrc/autograd/generated/Functions.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_functions.h"
@@ -131,12 +135,17 @@ set(TORCH_SRCS
${TORCH_SRC_DIR}/csrc/autograd/functions/tensor.cpp
${TORCH_SRC_DIR}/csrc/autograd/functions/utils.cpp
${TORCH_SRC_DIR}/csrc/autograd/generated/Functions.cpp
- ${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType.cpp
+ ${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType-0.cpp
+ ${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType-1.cpp
+ ${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType-2.cpp
+ ${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType-3.cpp
+ ${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType-4.cpp
${TORCH_SRC_DIR}/csrc/autograd/grad_mode.cpp
${TORCH_SRC_DIR}/csrc/autograd/input_buffer.cpp
${TORCH_SRC_DIR}/csrc/autograd/profiler.cpp
${TORCH_SRC_DIR}/csrc/autograd/saved_variable.cpp
${TORCH_SRC_DIR}/csrc/autograd/variable.cpp
+ ${TORCH_SRC_DIR}/csrc/autograd/VariableTypeManual.cpp
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
${TORCH_SRC_DIR}/csrc/jit/autodiff.cpp
${TORCH_SRC_DIR}/csrc/jit/export.cpp
diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp
new file mode 100644
index 0000000000..dbe10cf68a
--- /dev/null
+++ b/torch/csrc/autograd/VariableTypeManual.cpp
@@ -0,0 +1,336 @@
+#include "torch/csrc/autograd/VariableTypeUtils.h"
+
+using namespace at;
+using namespace torch::autograd::generated;
+
+namespace torch { namespace autograd {
+
+VariableType::VariableType(Context* context, TypeExtendedInterface* baseType)
+ : TypeDefault(baseType->type_id(), /*is_variable=*/true, /*is_undefined=*/false)
+ , baseType(baseType)
+ , id_(context->freshTypeID()) {
+ str = std::string("Variable[") + baseType->toString() + "]";
+}
+
+ScalarType VariableType::scalarType() const {
+ return baseType->scalarType();
+}
+caffe2::TypeMeta VariableType::typeMeta() const {
+ return baseType->typeMeta();
+}
+Backend VariableType::backend() const {
+ return baseType->backend();
+}
+Allocator* VariableType::allocator() const {
+ return baseType->allocator();
+}
+Device VariableType::getDeviceFromPtr(void * data) const {
+ return baseType->getDeviceFromPtr(data);
+}
+Storage VariableType::storage(bool resizable) const {
+ return baseType->storage();
+}
+Storage VariableType::storage(size_t size, bool resizable) const {
+ return baseType->storage(size);
+}
+Storage VariableType::storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter) const {
+ return baseType->storageFromBlob(data, size, deleter);
+}
+Storage VariableType::unsafeStorageFromTH(void * th_pointer, bool retain) const {
+ return baseType->unsafeStorageFromTH(th_pointer, retain);
+}
+Storage VariableType::storageWithAllocator(int64_t size, Allocator* allocator) const {
+ return baseType->storageWithAllocator(size, allocator);
+}
+Tensor VariableType::unsafeTensorFromTH(void * th_pointer, bool retain) const {
+ return make_variable(baseType->unsafeTensorFromTH(th_pointer, retain), /*requires_grad=*/false);
+}
+std::unique_ptr<Generator> VariableType::generator() const {
+ return baseType->generator();
+}
+
+const char * VariableType::toString() const {
+ return str.c_str();
+}
+size_t VariableType::elementSizeInBytes() const {
+ return baseType->elementSizeInBytes();
+}
+Type & VariableType::toBackend(Backend b) const {
+ return *getVariableTypeFromBaseType(baseType->toBackend(b));
+}
+Type & VariableType::toScalarType(ScalarType s) const {
+ return *getVariableTypeFromBaseType(baseType->toScalarType(s));
+}
+TypeID VariableType::ID() const {
+ return static_cast<TypeID>(id_);
+}
+
+std::vector<std::unique_ptr<Type>> type_to_variable_type;
+
+// XXX - this is not threadsafe with uses of Variables
+void register_variable_type_for(TypeExtendedInterface* baseType) {
+ AT_ASSERT(baseType);
+ size_t base_id = static_cast<size_t>(baseType->ID());
+ if(type_to_variable_type.size() <= base_id) {
+ type_to_variable_type.resize(base_id + 1);
+ }
+ type_to_variable_type[base_id].reset(new VariableType(&at::globalContext(), baseType));
+}
+
+struct VariableTypeRegistry {
+ VariableTypeRegistry() {
+ auto& context = at::globalContext();
+ for (int p = 0; p < static_cast<int>(Backend::NumOptions); ++p) {
+ for (int s = 0; s < static_cast<int>(ScalarType::NumOptions); ++s) {
+ auto baseType = context.getNonVariableTypeRaw(static_cast<Backend>(p), static_cast<ScalarType>(s));
+ if (baseType && baseType->backend() != Backend::Undefined) {
+ register_variable_type_for(baseType);
+ }
+ }
+ }
+ }
+};
+
+struct VariableHooks : public at::VariableHooksInterface {
+ VariableHooks(at::VariableHooksArgs) {}
+ void registerVariableTypeFor(at::LegacyTypeDispatch*, at::Backend, at::ScalarType) const override;
+ at::Type& getVariableTypeFromBaseType(const at::Type&) const override;
+};
+
+// Sigh, the registry doesn't support namespaces :(
+using at::RegistererVariableHooksRegistry;
+using at::VariableHooksRegistry;
+
+// WARNING: YOU MUST DO THE NEXT TWO STATIC INITIALIZERS IN THIS ORDER.
+//
+// If you do it in the other order, this is what can happen if
+// these static initializers are called before Context is
+// initialized:
+//
+// - VariableHooks::registerVariableTypeFor will be activated
+// to register a variable type
+//
+// - We run the constructor of VariableTypeRegistry, which
+// calls at::globalContext()
+//
+// - Context is not initialized yet, so we call the constructor
+// of Context
+//
+// - We register CPU types, calling VariableHooks::registerVariableTypeFor
+//
+// - We register the CPU type as a variable type
+//
+// - In VariableTypeRegistry, we try to register the Variable type AGAIN!!
+// Disaster.
+//
+static VariableTypeRegistry registry;
+REGISTER_VARIABLE_HOOKS(VariableHooks)
+
+// Pre-condition: backend/scalar_type is a valid type in the type_registry
+void VariableHooks::registerVariableTypeFor(at::LegacyTypeDispatch* context, at::Backend backend, at::ScalarType scalar_type) const {
+ auto* baseType = context->getNonVariableTypeRaw(backend, scalar_type);
+ register_variable_type_for(static_cast<at::TypeExtendedInterface*>(baseType));
+}
+
+at::Type& VariableHooks::getVariableTypeFromBaseType(const at::Type& baseType) const {
+ return *VariableType::getVariableTypeFromBaseType(baseType);
+}
+
+bool VariableType::isVariableType(const at::Type& type) {
+ return type.is_variable();
+}
+
+at::TypeExtendedInterface* VariableType::getVariableTypeFromBaseType(const at::Type& baseType) {
+ auto id = static_cast<size_t>(baseType.ID());
+ if(id >= type_to_variable_type.size())
+ return nullptr;
+ return static_cast<at::TypeExtendedInterface*>(type_to_variable_type[id].get());
+}
+
+namespace {
+std::vector<at::Type*> allTypesForBackends(at::ArrayRef<at::Backend> backends) {
+ auto& context = at::globalContext();
+ std::vector<Type*> res;
+ res.reserve(backends.size() * static_cast<int>(ScalarType::NumOptions));
+ for (auto p : backends) {
+ for (int s = 0; s < static_cast<int>(ScalarType::NumOptions); s++) {
+ auto baseType = context.getNonVariableTypeRaw(static_cast<Backend>(p), static_cast<ScalarType>(s));
+ if (baseType) {
+ res.emplace_back(VariableType::getVariableTypeFromBaseType(*baseType));
+ }
+ }
+ }
+ return res;
+}
+}
+
+std::vector<at::Type*> VariableType::allCPUTypes() {
+ return allTypesForBackends({ Backend::CPU, Backend::SparseCPU });
+}
+
+std::vector<at::Type*> VariableType::allCUDATypes() {
+ at::globalContext().lazyInitCUDA();
+ return allTypesForBackends({ Backend::CUDA, Backend::SparseCUDA });
+}
+
+Variable & VariableType::checked_cast_variable(const Tensor & t, const char * name, int pos) {
+ if (!t.defined()) {
+ AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor for argument #", pos, " '", name, "'");
+ }
+ if (!isVariableType(t.type())) {
+ AT_ERROR("Expected object of type Variable but found type ", t.type().toString(), " for argument #", pos, " '", name, "'");
+ }
+ return as_variable_ref(const_cast<Tensor&>(t));
+}
+
+Tensor & VariableType::unpack(const Tensor & t, const char * name, int pos) {
+ return checked_cast_variable(t, name, pos).data();
+}
+
+SparseTensorRef VariableType::unpack(SparseTensorRef t, const char * name, int pos) {
+ return SparseTensorRef(checked_cast_variable(t.tref, name, pos).data());
+}
+
+Tensor VariableType::unpack_opt(const Tensor & t, const char * name, int pos) {
+ if (!t.defined()) {
+ return Tensor();
+ }
+ return unpack(t, name, pos);
+}
+
+std::vector<at::Tensor> VariableType::unpack(at::TensorList tl, const char *name, int pos) {
+ std::vector<at::Tensor> ret(tl.size());
+ for (size_t i = 0; i < tl.size(); ++i) {
+ const auto &t = tl[i];
+ if (!t.defined()) {
+ AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor at position #", i, " "
+ "for iterable argument #", pos, " '", name, "'");
+ }
+ if (!isVariableType(t.type())) {
+ AT_ERROR("Expected object of type Variable but found type ", t.type().toString(), " at position #", i, " "
+ "for iterable argument #", pos, " '", name, "'");
+ }
+ ret[i] = static_cast<const Variable&>(t).data();
+ }
+ return ret;
+}
+
+void VariableType::backward(Tensor & self, at::optional<Tensor> gradient, bool keep_graph, bool create_graph) const {
+ as_variable_ref(self).backward(gradient, keep_graph, create_graph);
+}
+
+void VariableType::set_data(Tensor & self, Tensor new_data) const {
+ as_variable_ref(self).set_data(new_data);
+}
+Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const {
+ jit::Node* node = nullptr;
+ if(torch::jit::tracer::isTracing()) {
+ auto& graph = jit::tracer::getTracingState()->graph;
+ // if you have no views of self, then an in place copy is equivalent to
+ // making sure we expand src to the same size as self
+ node = graph->create(jit::aten::expand_as, /*num_outputs=*/0);
+ jit::tracer::addInputs(node, "src", src);
+ jit::tracer::addInputs(node, "self", self);
+ graph->appendNode(node);
+ jit::tracer::ensureUnique("copy_ (possibly due to an assignment)", self);
+ }
+ // TODO: once copy is exposed in Declarations.yaml we may be able to bind
+ // it automatically
+ auto& self_ = unpack(self, "self", 0);
+ auto& src_ = unpack(src, "src", 1);
+ check_inplace(self);
+ std::shared_ptr<CopyBackwards> grad_fn;
+ auto requires_grad = compute_requires_grad(self, src);
+ requires_grad &= isFloatingPoint(self.type().scalarType());
+ if (requires_grad) {
+ grad_fn = std::make_shared<CopyBackwards>();
+ grad_fn->set_next_edges(collect_next_edges(self, src));
+ grad_fn->src_type = &src.type();
+ if (src.is_cuda()) {
+ grad_fn->src_device = src.get_device();
+ }
+ }
+ if (self.is_sparse() && src.is_sparse()) baseType->copy_sparse_to_sparse_(self_, src_, non_blocking);
+ else if (!self.is_sparse() && !src.is_sparse()) baseType->s_copy_(self_, src_, non_blocking);
+ else AT_ERROR("copy_() between dense and sparse Tensors is not implemented! Found self type = ", self.type(), " and src type = ", src.type());
+ increment_version(self);
+ rebase_history(as_variable_ref( self ), std::move(grad_fn));
+ if(torch::jit::tracer::isTracing()) {
+ jit::tracer::addOutput(node, self);
+ }
+ return self;
+}
+
+Tensor & VariableType::_s_copy_from(const Tensor & self, Tensor & dst, bool non_blocking) const {
+ AT_ERROR("copy_from does not support automatic differentiation; use copy_ instead");
+}
+
+Tensor & VariableType::resize_(Tensor & self, IntList size) const {
+ auto& self_ = unpack(self, "self", 0);
+ if (as_variable_ref(self).requires_grad()) {
+ AT_ERROR("cannot resize variables that require grad");
+ }
+ if (torch::jit::tracer::isTracing()) {
+ jit::tracer::ArgumentStash::popIntList("size");
+ jit::tracer::warn("resize_", jit::tracer::WARN_RESIZE);
+ jit::tracer::delValueTrace(self);
+ }
+ baseType->resize_(self_, size);
+ return self;
+}
+
+Tensor & VariableType::resize_as_(Tensor & self, const Tensor & the_template) const {
+ auto& self_ = unpack(self, "self", 0);
+ auto& the_template_ = unpack(the_template, "the_template", 1);
+ if (as_variable_ref(self).requires_grad()) {
+ AT_ERROR("cannot resize variables that require grad");
+ }
+ if (torch::jit::tracer::isTracing()) {
+ jit::tracer::warn("resize_as_", jit::tracer::WARN_RESIZE);
+ jit::tracer::delValueTrace(self);
+ }
+ baseType->resize_as_(self_, the_template_);
+ return self;
+}
+
+Tensor VariableType::detach(const Tensor & self) const {
+ profiler::RecordFunction profiler("detach");
+ torch::jit::Node* node = nullptr;
+ if (jit::tracer::isTracing()) {
+ auto& graph = jit::tracer::getTracingState()->graph;
+ node = graph->create(jit::aten::detach, /*num_outputs=*/0);
+ jit::tracer::recordSourceLocation(node);
+ jit::tracer::addInputs(node, "self", self);
+ graph->appendNode(node);
+
+ }
+ // <NON_GENERATED_CODE>
+ auto result = as_variable_ref(const_cast<Tensor&>(self)).detach();
+ // </NON_GENERATED_CODE>
+ if (jit::tracer::isTracing()) {
+ jit::tracer::addOutput(node, result);
+ }
+ return result;
+}
+
+Tensor & VariableType::detach_(Tensor & self) const {
+ profiler::RecordFunction profiler("detach_");
+ torch::jit::Node* node = nullptr;
+ if (jit::tracer::isTracing()) {
+ auto& graph = jit::tracer::getTracingState()->graph;
+ node = graph->create(jit::aten::detach, /*num_outputs=*/0);
+ jit::tracer::recordSourceLocation(node);
+ jit::tracer::addInputs(node, "self", self);
+ graph->appendNode(node);
+ jit::tracer::ensureUnique("detach_", self);
+ }
+ // <NON_GENERATED_CODE>
+ as_variable_ref(self).detach_();
+ // </NON_GENERATED_CODE>
+ if (jit::tracer::isTracing()) {
+ jit::tracer::addOutput(node, self);
+ }
+ return self;
+}
+
+}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h
new file mode 100644
index 0000000000..d28149e128
--- /dev/null
+++ b/torch/csrc/autograd/VariableTypeUtils.h
@@ -0,0 +1,182 @@
+#include "torch/csrc/autograd/generated/VariableType.h"
+
+#include "torch/csrc/autograd/variable.h"
+#include "torch/csrc/autograd/function.h"
+#include "torch/csrc/autograd/edge.h"
+#include "torch/csrc/autograd/grad_mode.h"
+#include "torch/csrc/autograd/saved_variable.h"
+#include "torch/csrc/autograd/generated/Functions.h"
+#include "torch/csrc/autograd/functions/tensor.h"
+#include "torch/csrc/autograd/functions/basic_ops.h"
+#include "torch/csrc/jit/tracer.h"
+#include "torch/csrc/jit/constants.h"
+#include "torch/csrc/jit/symbolic_variable.h"
+#include "torch/csrc/jit/ir.h"
+
+#include "torch/csrc/utils/variadic.h"
+#include "torch/csrc/autograd/functions/utils.h"
+
+#include <ATen/core/VariableHooksInterface.h>
+
+#include <array>
+#include <cstddef>
+#include <functional>
+#include <initializer_list>
+#include <memory>
+#include <stdexcept>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#ifdef _MSC_VER
+#ifdef Type
+#undef Type
+#endif
+#endif
+
+using namespace at;
+using namespace torch::autograd::generated;
+
+namespace torch { namespace autograd {
+
+extern std::vector<std::unique_ptr<Type>> type_to_variable_type;
+
+inline void check_inplace(const Tensor& tensor) {
+ auto& var = static_cast<const Variable&>(tensor);
+ if (var.requires_grad() && var.is_leaf() && GradMode::is_enabled()) {
+ AT_ERROR(
+ "a leaf Variable that requires grad has been used in an in-place operation.");
+ }
+}
+
+inline void throw_error_out_requires_grad(const char* name) {
+ AT_ERROR(
+ name, "(): functions with out=... arguments don't support automatic differentiation, "
+ "but one of the arguments requires grad.");
+}
+
+// TODO: Blegh, bare references
+
+inline void rebase_history(Variable& var, std::shared_ptr<Function> grad_fn) {
+ if (grad_fn && var.defined()) {
+ grad_fn->add_input_metadata(var);
+ var.rebase_history({std::move(grad_fn), 0});
+ }
+}
+
+inline void rebase_history(ArrayRef<Variable> vars, std::shared_ptr<Function> grad_fn) {
+ if (grad_fn) {
+ for (auto& var : vars) {
+ if (var.defined()) {
+ // TODO: eliminate const_cast
+ auto output_nr = grad_fn->add_input_metadata(var);
+ const_cast<Variable&>(var).rebase_history({grad_fn, output_nr});
+ } else {
+ grad_fn->add_input_metadata(Function::undefined_input());
+ }
+ }
+ }
+}
+
+inline void increment_version(Tensor & t) {
+ as_variable_ref(t).bump_version();
+}
+
+inline bool isFloatingPoint(ScalarType s) {
+ return s == kFloat || s == kDouble || s == kHalf;
+}
+
+struct Flatten : IterArgs<Flatten> {
+ Flatten(variable_list& out) : out(out) {}
+ variable_list& out;
+ void operator()(const at::Tensor& x) { out.emplace_back(x); }
+ void operator()(at::ArrayRef<at::Tensor> xs) {
+ out.insert(out.end(), xs.begin(), xs.end());
+ }
+};
+
+template<typename... Args> inline variable_list flatten_tensor_args(Args&&... args) {
+ variable_list out;
+ out.reserve(count_tensors(std::forward<Args>(args)...));
+ Flatten(out).apply(std::forward<Args>(args)...);
+ return out; // RVO
+}
+
+inline Tensor as_view(const Tensor & base, Tensor tensor) {
+ auto base_var = Variable(base);
+ if (base_var.is_view()) {
+ base_var = base_var.base();
+ }
+ return make_variable_view(std::move(base_var), std::move(tensor));
+}
+
+inline std::vector<Tensor> as_view(const Tensor & base, std::vector<Tensor> tensors) {
+ auto base_var = Variable(base);
+ if (base_var.is_view()) {
+ base_var = base_var.base();
+ }
+ for(Tensor &tensor : tensors) {
+ tensor = make_variable_view(base_var, std::move(tensor));
+ }
+ return tensors;
+}
+
+inline void check_no_requires_grad(const Tensor& tensor, const char* name) {
+ auto& var = static_cast<const Variable&>(tensor);
+ if (var.defined() && var.requires_grad()) {
+ std::string msg = "the derivative for '";
+ msg += name;
+ msg += "' is not implemented";
+ throw std::runtime_error(msg);
+ }
+}
+
+// Assumed that saved tensor lists are never inplace outputs
+inline std::vector<SavedVariable> make_saved_variable_list(TensorList tensors) {
+ return fmap(tensors, [](const Tensor& tensor) -> SavedVariable {
+ return SavedVariable{tensor, false /* is output */}; });
+}
+
+inline Tensor as_variable(Tensor tensor) {
+ return make_variable(std::move(tensor), /*requires_grad=*/false);
+}
+
+inline std::vector<Tensor> as_variable(TensorList tl) {
+ return fmap(tl, [](const Tensor& t) -> Tensor {
+ return make_variable(std::move(t), /*requires_grad=*/false);
+ });
+}
+
+template <typename... Tensors, size_t... Is>
+std::tuple<Tensors...> as_variable_impl(
+ std::tuple<Tensors...> tensors,
+ Indices<Is...>) {
+ // Expand the integer parameter pack into a sequence of Variable
+ // constructions. This turns into (boolean omitted):
+ // Variable(std::get<0>(tensors)), Variable(std::get<1>(tensors)), ...
+ return std::tuple<Tensors...>(
+ as_variable(std::get<Is>(tensors))...);
+}
+
+// NB: Because this was not forward declared, recursive std::tuple won't work.
+// You can probably rejigger this to make it supported if you really need it.
+template <typename... Tensors>
+std::tuple<Tensors...> as_variable(std::tuple<Tensors...> tensors) {
+ // `sizeof...(Tensors)` gets us the size of the `Tensors` parameter pack at
+ // compile time. We use it to parameterize a `MakeIndices` class, which will
+ // expand into an Indices object containing the numbers 0 to
+ // sizeof...(Tensors) - 1.
+ return as_variable_impl(
+ tensors, typename MakeIndices<sizeof...(Tensors)>::indices());
+}
+
+inline std::vector<std::vector<int64_t>> to_args_sizes(TensorList tensors) {
+ std::vector<std::vector<int64_t>> args_sizes(tensors.size());
+ for (size_t i = 0; i < tensors.size(); ++i) {
+ args_sizes[i] = tensors[i].sizes().vec();
+ }
+ return args_sizes;
+}
+
+}} // namespace torch::autograd