diff options
Diffstat (limited to 'aten/src/ATen/Context.cpp')
-rw-r--r-- | aten/src/ATen/Context.cpp | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index b05475699c..45ebade245 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -101,10 +101,13 @@ TypeExtendedInterface& getType(TensorOptions options) { options.backend(), typeMetaToScalarType(options.dtype()), options.is_variable()); } +// NOTE: We also check `at::NonVariableTypeMode`, and if it's enabled we always +// return non-Variable type in this function. +// See NOTE [ Treating Variables as non-Variables in type dispatch ] TypeExtendedInterface& getType(const TensorImpl* impl) { Backend backend = tensorTypeIdToBackend(impl->type_id()); return globalContext().getType( - backend, typeMetaToScalarType(impl->dtype()), impl->is_variable()); + backend, typeMetaToScalarType(impl->dtype()), impl->is_variable() && !at::NonVariableTypeMode::is_enabled()); } TypeExtendedInterface& getType(const Tensor& t) { |