summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEli Uriegas <1700823+seemethere@users.noreply.github.com>2021-12-10 11:41:40 -0800
committerGitHub <noreply@github.com>2021-12-10 11:41:40 -0800
commit3e412cd6dff635c0a58d85749e42e9ee1866a84e (patch)
tree01a49dc0a242a53270b553eaf71e187c5ab695e8
parent302ee7bfb604ebef384602c56e3853efed262030 (diff)
downloadpytorch-3e412cd6dff635c0a58d85749e42e9ee1866a84e.tar.gz
pytorch-3e412cd6dff635c0a58d85749e42e9ee1866a84e.tar.bz2
pytorch-3e412cd6dff635c0a58d85749e42e9ee1866a84e.zip
[release/1.10] fix pybind issue for get_autocast_cpu_dtype and get_autocast_gpu_dtype (#66396) (#69620)
Co-authored-by: XiaobingSuper <xiaobing.zhang@intel.com>
-rw-r--r--test/test_autocast.py8
-rw-r--r--torch/csrc/autograd/init.cpp8
2 files changed, 14 insertions, 2 deletions
diff --git a/test/test_autocast.py b/test/test_autocast.py
index a722c1a04d..4e9597c6d7 100644
--- a/test/test_autocast.py
+++ b/test/test_autocast.py
@@ -119,5 +119,13 @@ class TestAutocastCPU(TestCase):
for op, args in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args, torch.float32)
+class TestTorchAutocast(TestCase):
+ def test_autocast_fast_dtype(self):
+ gpu_fast_dtype = torch.get_autocast_gpu_dtype()
+ cpu_fast_dtype = torch.get_autocast_cpu_dtype()
+ self.assertEqual(gpu_fast_dtype, torch.half)
+ self.assertEqual(cpu_fast_dtype, torch.bfloat16)
+
+
if __name__ == '__main__':
run_tests()
diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp
index 3551267459..9d8550f50e 100644
--- a/torch/csrc/autograd/init.cpp
+++ b/torch/csrc/autograd/init.cpp
@@ -390,14 +390,18 @@ static const char* scalarTypeName(const at::ScalarType type) {
static PyObject * get_autocast_gpu_dtype(PyObject* _unused, PyObject *arg){
HANDLE_TH_ERRORS
at::ScalarType current_dtype = at::autocast::get_autocast_gpu_dtype();
- return THPDtype_New(current_dtype, scalarTypeName(current_dtype));
+ auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
+ Py_INCREF(dtype);
+ return dtype;
END_HANDLE_TH_ERRORS
}
static PyObject * get_autocast_cpu_dtype(PyObject* _unused, PyObject *arg){
HANDLE_TH_ERRORS
at::ScalarType current_dtype = at::autocast::get_autocast_cpu_dtype();
- return THPDtype_New(current_dtype, scalarTypeName(current_dtype));
+ auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
+ Py_INCREF(dtype);
+ return dtype;
END_HANDLE_TH_ERRORS
}