diff options
Diffstat (limited to 'test/test_autocast.py')
-rw-r--r-- | test/test_autocast.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/test/test_autocast.py b/test/test_autocast.py index a722c1a04d..4e9597c6d7 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -119,5 +119,13 @@ class TestAutocastCPU(TestCase): for op, args in self.autocast_lists.torch_need_autocast_promote: self._run_autocast_outofplace(op, args, torch.float32) +class TestTorchAutocast(TestCase): + def test_autocast_fast_dtype(self): + gpu_fast_dtype = torch.get_autocast_gpu_dtype() + cpu_fast_dtype = torch.get_autocast_cpu_dtype() + self.assertEqual(gpu_fast_dtype, torch.half) + self.assertEqual(cpu_fast_dtype, torch.bfloat16) + + if __name__ == '__main__': run_tests() |