diff options
author | Tongzhou Wang <SsnL@users.noreply.github.com> | 2018-05-31 09:42:56 -0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2018-05-31 09:42:56 -0400 |
commit | f9926e4ce5c54e79df72a4284b6170083fcdaba4 (patch) | |
tree | db05f9794c4d3d3c8bd5c7026eeb91cb9e6d3c6a /test | |
parent | 5596260b9e9b051400e6fcc8b0fad39ee918335e (diff) | |
download | pytorch-f9926e4ce5c54e79df72a4284b6170083fcdaba4.tar.gz pytorch-f9926e4ce5c54e79df72a4284b6170083fcdaba4.tar.bz2 pytorch-f9926e4ce5c54e79df72a4284b6170083fcdaba4.zip |
Fix EmbeddingBag max_norm option (#7959)
* fix EmbeddingBag max_norm option
* flake8
* add warning to the embedding bag arg change
Diffstat (limited to 'test')
-rw-r--r-- | test/test_nn.py | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/test/test_nn.py b/test/test_nn.py index 05177b720d..7ab47567c6 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1626,9 +1626,9 @@ class TestNN(NNTestCase): self.assertEqual(es_weight_grad, expected_grad_weight, dtype2prec[dtype]) # now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length - def _test_vs_Embedding(N, D, B, L): - es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse).to(device, dtype) - e = nn.Embedding(N, D).to(device, dtype) + def _test_vs_Embedding(N, D, B, L, max_norm=None): + es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(device, dtype) + e = nn.Embedding(N, D, max_norm=max_norm).to(device, dtype) e.weight.data.copy_(es.weight.data) input = torch.randint(N, (B, L), device=device, dtype=torch.long) offsets = torch.arange(0, B, device=device, dtype=torch.long).mul_(L) @@ -1656,8 +1656,9 @@ class TestNN(NNTestCase): N, D, B, L = random.randint(1, 100), random.randint(1, 100), random.randint(1, 50), random.randint(1, 50) _test_vs_Embedding(N, D, B, L) - for p in itertools.product([1, 2], repeat=4): - _test_vs_Embedding(*p) + for max_norm in (None, 3): + for p in itertools.product([1, 2], repeat=4): + _test_vs_Embedding(*p, max_norm=max_norm) # check that giving illegal input combos raises error es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse) @@ -6780,27 +6781,27 @@ new_module_tests = [ dict( module_name='Embedding', constructor_args=(4, 3), - input_fn=lambda: Variable(torch.randperm(2).repeat(1, 2)), + input_fn=lambda: torch.randperm(2).repeat(1, 2), jacobian_input=False, check_gradgrad=False, ), dict( module_name='EmbeddingBag', constructor_args=(4, 3), - input_fn=lambda: Variable(torch.randperm(2).repeat(1, 2)), + input_fn=lambda:torch.randperm(2).repeat(1, 2), jacobian_input=False, check_gradgrad=False, ), dict( fullname='EmbeddingBag_sparse', constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True), - input_fn=lambda: Variable(torch.randperm(2).repeat(1, 2)), + input_fn=lambda: torch.randperm(2).repeat(1, 2), jacobian_input=False, check_gradgrad=False, ), dict( constructor=lambda: nn.Embedding(4, 3, sparse=True), - input_fn=lambda: Variable(torch.randperm(2).repeat(1, 2)), + input_fn=lambda: torch.randperm(2).repeat(1, 2), jacobian_input=False, fullname='Embedding_sparse', check_gradgrad=False, |