diff options
author | Eli Uriegas <1700823+seemethere@users.noreply.github.com> | 2021-12-10 11:41:40 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-12-10 11:41:40 -0800 |
commit | 3e412cd6dff635c0a58d85749e42e9ee1866a84e (patch) | |
tree | 01a49dc0a242a53270b553eaf71e187c5ab695e8 | |
parent | 302ee7bfb604ebef384602c56e3853efed262030 (diff) | |
download | pytorch-3e412cd6dff635c0a58d85749e42e9ee1866a84e.tar.gz pytorch-3e412cd6dff635c0a58d85749e42e9ee1866a84e.tar.bz2 pytorch-3e412cd6dff635c0a58d85749e42e9ee1866a84e.zip |
[release/1.10] fix pybind issue for get_autocast_cpu_dtype and get_autocast_gpu_dtype (#66396) (#69620)
Co-authored-by: XiaobingSuper <xiaobing.zhang@intel.com>
-rw-r--r-- | test/test_autocast.py | 8 | ||||
-rw-r--r-- | torch/csrc/autograd/init.cpp | 8 |
2 files changed, 14 insertions, 2 deletions
diff --git a/test/test_autocast.py b/test/test_autocast.py index a722c1a04d..4e9597c6d7 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -119,5 +119,13 @@ class TestAutocastCPU(TestCase): for op, args in self.autocast_lists.torch_need_autocast_promote: self._run_autocast_outofplace(op, args, torch.float32) +class TestTorchAutocast(TestCase): + def test_autocast_fast_dtype(self): + gpu_fast_dtype = torch.get_autocast_gpu_dtype() + cpu_fast_dtype = torch.get_autocast_cpu_dtype() + self.assertEqual(gpu_fast_dtype, torch.half) + self.assertEqual(cpu_fast_dtype, torch.bfloat16) + + if __name__ == '__main__': run_tests() diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 3551267459..9d8550f50e 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -390,14 +390,18 @@ static const char* scalarTypeName(const at::ScalarType type) { static PyObject * get_autocast_gpu_dtype(PyObject* _unused, PyObject *arg){ HANDLE_TH_ERRORS at::ScalarType current_dtype = at::autocast::get_autocast_gpu_dtype(); - return THPDtype_New(current_dtype, scalarTypeName(current_dtype)); + auto dtype = (PyObject*)torch::getTHPDtype(current_dtype); + Py_INCREF(dtype); + return dtype; END_HANDLE_TH_ERRORS } static PyObject * get_autocast_cpu_dtype(PyObject* _unused, PyObject *arg){ HANDLE_TH_ERRORS at::ScalarType current_dtype = at::autocast::get_autocast_cpu_dtype(); - return THPDtype_New(current_dtype, scalarTypeName(current_dtype)); + auto dtype = (PyObject*)torch::getTHPDtype(current_dtype); + Py_INCREF(dtype); + return dtype; END_HANDLE_TH_ERRORS } |