diff options
author | Zachary DeVito <zdevito@gmail.com> | 2018-06-15 14:56:19 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-15 14:56:19 -0700 |
commit | d9686145021248ae77b643adec678e547358c48b (patch) | |
tree | 7605841e3c1d149f3e522f4ec724e628f388b83c /torch | |
parent | 711e5a6ceb46fcfe90dc8ca176c94c4f44dfbc17 (diff) | |
download | pytorch-d9686145021248ae77b643adec678e547358c48b.tar.gz pytorch-d9686145021248ae77b643adec678e547358c48b.tar.bz2 pytorch-d9686145021248ae77b643adec678e547358c48b.zip |
Enable open registration of VariableType objects (#8540)
We have 2 use cases where we want to experiment with new base ATen
tensor types:
* BatchTensor for matchbox
* Tensors that live on accelerators
It is possible to subclass TensorImpl to implement these but VariableType
does not work with them because it cannot find the equivalent variable type
in the registry.
This commit changes the way we implement type -> variable(type) lookup so that
torch::register_variable_type_for can be called on any at::Type.
Lookups are still done using arrays so there should be no perf impact from the change.
Diffstat (limited to 'torch')
-rw-r--r-- | torch/csrc/autograd/aten_variable_hooks.cpp | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/torch/csrc/autograd/aten_variable_hooks.cpp b/torch/csrc/autograd/aten_variable_hooks.cpp index b1ef0c948c..7a2c3974c2 100644 --- a/torch/csrc/autograd/aten_variable_hooks.cpp +++ b/torch/csrc/autograd/aten_variable_hooks.cpp @@ -16,7 +16,8 @@ REGISTER_VARIABLE_HOOKS(VariableHooks) // Pre-condition: backend/scalar_type is a valid type in the type_registry void VariableHooks::registerVariableTypeFor(at::Context* context, at::Backend backend, at::ScalarType scalar_type) const { - register_variable_type_for(context, backend, scalar_type); + auto* baseType = context->getTypeRaw(backend, scalar_type); + register_variable_type_for(baseType); } }} // torch::autograd |