summaryrefslogtreecommitdiff
path: root/test/test_autocast.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_autocast.py')
-rw-r--r--test/test_autocast.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/test/test_autocast.py b/test/test_autocast.py
index a722c1a04d..4e9597c6d7 100644
--- a/test/test_autocast.py
+++ b/test/test_autocast.py
@@ -119,5 +119,13 @@ class TestAutocastCPU(TestCase):
for op, args in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args, torch.float32)
+class TestTorchAutocast(TestCase):
+ def test_autocast_fast_dtype(self):
+ gpu_fast_dtype = torch.get_autocast_gpu_dtype()
+ cpu_fast_dtype = torch.get_autocast_cpu_dtype()
+ self.assertEqual(gpu_fast_dtype, torch.half)
+ self.assertEqual(cpu_fast_dtype, torch.bfloat16)
+
+
if __name__ == '__main__':
run_tests()