summaryrefslogtreecommitdiff
path: root/aten/src/ATen/Context.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'aten/src/ATen/Context.cpp')
-rw-r--r--aten/src/ATen/Context.cpp5
1 files changed, 4 insertions, 1 deletions
diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp
index b05475699c..45ebade245 100644
--- a/aten/src/ATen/Context.cpp
+++ b/aten/src/ATen/Context.cpp
@@ -101,10 +101,13 @@ TypeExtendedInterface& getType(TensorOptions options) {
options.backend(), typeMetaToScalarType(options.dtype()), options.is_variable());
}
+// NOTE: We also check `at::NonVariableTypeMode`, and if it's enabled we always
+// return non-Variable type in this function.
+// See NOTE [ Treating Variables as non-Variables in type dispatch ]
TypeExtendedInterface& getType(const TensorImpl* impl) {
Backend backend = tensorTypeIdToBackend(impl->type_id());
return globalContext().getType(
- backend, typeMetaToScalarType(impl->dtype()), impl->is_variable());
+ backend, typeMetaToScalarType(impl->dtype()), impl->is_variable() && !at::NonVariableTypeMode::is_enabled());
}
TypeExtendedInterface& getType(const Tensor& t) {