diff options
author | Richard Zou <zou3519@gmail.com> | 2019-04-16 08:51:01 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-16 08:57:20 -0700 |
commit | 3b29cbaf861d5daaff7c7506b62a252b9c7d5614 (patch) | |
tree | c2519ec366c4fb5ade80353944d2ffd3e64ad9f0 /test | |
parent | 35015762307a68654b8c72d690b7e3e58c3ad137 (diff) | |
download | pytorch-3b29cbaf861d5daaff7c7506b62a252b9c7d5614.tar.gz pytorch-3b29cbaf861d5daaff7c7506b62a252b9c7d5614.tar.bz2 pytorch-3b29cbaf861d5daaff7c7506b62a252b9c7d5614.zip |
Enable half for CUDA dense EmbeddingBag backward. (#19293)
Summary:
I audited the relevant kernel and saw it accumulates a good deal into float
so it should be fine.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19293
Differential Revision: D14942274
Pulled By: zou3519
fbshipit-source-id: 36996ba0fbb29fbfb12b27bfe9c0ad1eb012ba3c
Diffstat (limited to 'test')
-rw-r--r-- | test/test_nn.py | 41 |
1 files changed, 30 insertions, 11 deletions
diff --git a/test/test_nn.py b/test/test_nn.py index c0935024aa..d489e50a2d 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2381,7 +2381,8 @@ class TestNN(NNTestCase): self.assertEqual(es_weight_grad, e.weight.grad, needed_prec) if test_per_sample_weights and trainable_per_sample_weights: - self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad) + self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad, + dtype2prec[dtype]) def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double): # check a known test example @@ -2653,16 +2654,21 @@ class TestNN(NNTestCase): expected = self._embedding_bag_reference_impl( input, reference_weights, offsets, mode, ref_per_sample_weights) result = es(input, offsets, per_sample_weights) - self.assertEqual(result, expected) + self.assertEqual(result, expected, prec=dtype2prec[dtype]) grad = torch.randn_like(expected) result.backward(grad) expected.backward(grad) - self.assertEqual(es.weight.grad, reference_weights.grad) + self.assertEqual(es.weight.grad, reference_weights.grad, + dtype2prec[dtype]) if trainable_scale: - self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad) + self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad, + prec=dtype2prec[dtype]) - dtypes = (torch.float, torch.double) + if device == 'cuda': + dtypes = (torch.float, torch.double, torch.half) + else: + dtypes = (torch.float, torch.double) modes = ('sum',) trainable_scale = (True, False) for dtype, mode, trainable in itertools.product(dtypes, modes, trainable_scale): @@ -2677,12 +2683,7 @@ class TestNN(NNTestCase): @staticmethod def _test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device='cpu'): - dtypes = (torch.float, torch.double) - modes = ('sum',) - sparsity = (True, False) - trainable_scale = (True, False) - for dtype, mode, sparse, trainable_per_sample_weights in \ - itertools.product(dtypes, modes, sparsity, trainable_scale): + def run_tests(dtype, mode, sparse, trainable_per_sample_weights): kwargs = dict(test_per_sample_weights=True, device=device, mode=mode, dtype=dtype, sparse=sparse, trainable_per_sample_weights=trainable_per_sample_weights) @@ -2699,6 +2700,24 @@ class TestNN(NNTestCase): # Large embedding_dim self._test_EmbeddingBag_vs_Embedding(2, 101, 3, 7, **kwargs) + dtypes = (torch.float, torch.double) + modes = ('sum',) + sparsity = (True, False) + trainable_scale = (True, False) + for dtype, mode, sparse, trainable_per_sample_weights in \ + itertools.product(dtypes, modes, sparsity, trainable_scale): + run_tests(dtype, mode, sparse, trainable_per_sample_weights) + + # Test CUDA Dense on half precision + if device == 'cuda': + dtypes = (torch.half,) + modes = ('sum',) + sparsity = (False,) + trainable_scale = (True, False) + for dtype, mode, sparse, trainable_per_sample_weights in \ + itertools.product(dtypes, modes, sparsity, trainable_scale): + run_tests(dtype, mode, sparse, trainable_per_sample_weights) + def test_EmbeddingBag_per_sample_weights_and_no_offsets(self): self._test_EmbeddingBag_per_sample_weights_and_no_offsets(self) |