summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorZachary DeVito <zdevito@gmail.com>2018-06-15 14:56:19 -0700
committerGitHub <noreply@github.com>2018-06-15 14:56:19 -0700
commitd9686145021248ae77b643adec678e547358c48b (patch)
tree7605841e3c1d149f3e522f4ec724e628f388b83c /torch
parent711e5a6ceb46fcfe90dc8ca176c94c4f44dfbc17 (diff)
downloadpytorch-d9686145021248ae77b643adec678e547358c48b.tar.gz
pytorch-d9686145021248ae77b643adec678e547358c48b.tar.bz2
pytorch-d9686145021248ae77b643adec678e547358c48b.zip
Enable open registration of VariableType objects (#8540)
We have 2 use cases where we want to experiment with new base ATen tensor types: * BatchTensor for matchbox * Tensors that live on accelerators It is possible to subclass TensorImpl to implement these but VariableType does not work with them because it cannot find the equivalent variable type in the registry. This commit changes the way we implement type -> variable(type) lookup so that torch::register_variable_type_for can be called on any at::Type. Lookups are still done using arrays so there should be no perf impact from the change.
Diffstat (limited to 'torch')
-rw-r--r--torch/csrc/autograd/aten_variable_hooks.cpp3
1 files changed, 2 insertions, 1 deletions
diff --git a/torch/csrc/autograd/aten_variable_hooks.cpp b/torch/csrc/autograd/aten_variable_hooks.cpp
index b1ef0c948c..7a2c3974c2 100644
--- a/torch/csrc/autograd/aten_variable_hooks.cpp
+++ b/torch/csrc/autograd/aten_variable_hooks.cpp
@@ -16,7 +16,8 @@ REGISTER_VARIABLE_HOOKS(VariableHooks)
// Pre-condition: backend/scalar_type is a valid type in the type_registry
void VariableHooks::registerVariableTypeFor(at::Context* context, at::Backend backend, at::ScalarType scalar_type) const {
- register_variable_type_for(context, backend, scalar_type);
+ auto* baseType = context->getTypeRaw(backend, scalar_type);
+ register_variable_type_for(baseType);
}
}} // torch::autograd