summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRichard Zou <zou3519@gmail.com>2019-04-16 08:51:01 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-16 08:57:20 -0700
commit3b29cbaf861d5daaff7c7506b62a252b9c7d5614 (patch)
treec2519ec366c4fb5ade80353944d2ffd3e64ad9f0 /test
parent35015762307a68654b8c72d690b7e3e58c3ad137 (diff)
downloadpytorch-3b29cbaf861d5daaff7c7506b62a252b9c7d5614.tar.gz
pytorch-3b29cbaf861d5daaff7c7506b62a252b9c7d5614.tar.bz2
pytorch-3b29cbaf861d5daaff7c7506b62a252b9c7d5614.zip
Enable half for CUDA dense EmbeddingBag backward. (#19293)
Summary: I audited the relevant kernel and saw it accumulates a good deal into float so it should be fine. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19293 Differential Revision: D14942274 Pulled By: zou3519 fbshipit-source-id: 36996ba0fbb29fbfb12b27bfe9c0ad1eb012ba3c
Diffstat (limited to 'test')
-rw-r--r--test/test_nn.py41
1 files changed, 30 insertions, 11 deletions
diff --git a/test/test_nn.py b/test/test_nn.py
index c0935024aa..d489e50a2d 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -2381,7 +2381,8 @@ class TestNN(NNTestCase):
self.assertEqual(es_weight_grad, e.weight.grad, needed_prec)
if test_per_sample_weights and trainable_per_sample_weights:
- self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad)
+ self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad,
+ dtype2prec[dtype])
def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double):
# check a known test example
@@ -2653,16 +2654,21 @@ class TestNN(NNTestCase):
expected = self._embedding_bag_reference_impl(
input, reference_weights, offsets, mode, ref_per_sample_weights)
result = es(input, offsets, per_sample_weights)
- self.assertEqual(result, expected)
+ self.assertEqual(result, expected, prec=dtype2prec[dtype])
grad = torch.randn_like(expected)
result.backward(grad)
expected.backward(grad)
- self.assertEqual(es.weight.grad, reference_weights.grad)
+ self.assertEqual(es.weight.grad, reference_weights.grad,
+ dtype2prec[dtype])
if trainable_scale:
- self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad)
+ self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad,
+ prec=dtype2prec[dtype])
- dtypes = (torch.float, torch.double)
+ if device == 'cuda':
+ dtypes = (torch.float, torch.double, torch.half)
+ else:
+ dtypes = (torch.float, torch.double)
modes = ('sum',)
trainable_scale = (True, False)
for dtype, mode, trainable in itertools.product(dtypes, modes, trainable_scale):
@@ -2677,12 +2683,7 @@ class TestNN(NNTestCase):
@staticmethod
def _test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device='cpu'):
- dtypes = (torch.float, torch.double)
- modes = ('sum',)
- sparsity = (True, False)
- trainable_scale = (True, False)
- for dtype, mode, sparse, trainable_per_sample_weights in \
- itertools.product(dtypes, modes, sparsity, trainable_scale):
+ def run_tests(dtype, mode, sparse, trainable_per_sample_weights):
kwargs = dict(test_per_sample_weights=True, device=device,
mode=mode, dtype=dtype, sparse=sparse,
trainable_per_sample_weights=trainable_per_sample_weights)
@@ -2699,6 +2700,24 @@ class TestNN(NNTestCase):
# Large embedding_dim
self._test_EmbeddingBag_vs_Embedding(2, 101, 3, 7, **kwargs)
+ dtypes = (torch.float, torch.double)
+ modes = ('sum',)
+ sparsity = (True, False)
+ trainable_scale = (True, False)
+ for dtype, mode, sparse, trainable_per_sample_weights in \
+ itertools.product(dtypes, modes, sparsity, trainable_scale):
+ run_tests(dtype, mode, sparse, trainable_per_sample_weights)
+
+ # Test CUDA Dense on half precision
+ if device == 'cuda':
+ dtypes = (torch.half,)
+ modes = ('sum',)
+ sparsity = (False,)
+ trainable_scale = (True, False)
+ for dtype, mode, sparse, trainable_per_sample_weights in \
+ itertools.product(dtypes, modes, sparsity, trainable_scale):
+ run_tests(dtype, mode, sparse, trainable_per_sample_weights)
+
def test_EmbeddingBag_per_sample_weights_and_no_offsets(self):
self._test_EmbeddingBag_per_sample_weights_and_no_offsets(self)