summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorTongzhou Wang <SsnL@users.noreply.github.com>2018-05-31 09:42:56 -0400
committerSoumith Chintala <soumith@gmail.com>2018-05-31 09:42:56 -0400
commitf9926e4ce5c54e79df72a4284b6170083fcdaba4 (patch)
treedb05f9794c4d3d3c8bd5c7026eeb91cb9e6d3c6a /test
parent5596260b9e9b051400e6fcc8b0fad39ee918335e (diff)
downloadpytorch-f9926e4ce5c54e79df72a4284b6170083fcdaba4.tar.gz
pytorch-f9926e4ce5c54e79df72a4284b6170083fcdaba4.tar.bz2
pytorch-f9926e4ce5c54e79df72a4284b6170083fcdaba4.zip
Fix EmbeddingBag max_norm option (#7959)
* fix EmbeddingBag max_norm option * flake8 * add warning to the embedding bag arg change
Diffstat (limited to 'test')
-rw-r--r--test/test_nn.py19
1 files changed, 10 insertions, 9 deletions
diff --git a/test/test_nn.py b/test/test_nn.py
index 05177b720d..7ab47567c6 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -1626,9 +1626,9 @@ class TestNN(NNTestCase):
self.assertEqual(es_weight_grad, expected_grad_weight, dtype2prec[dtype])
# now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length
- def _test_vs_Embedding(N, D, B, L):
- es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse).to(device, dtype)
- e = nn.Embedding(N, D).to(device, dtype)
+ def _test_vs_Embedding(N, D, B, L, max_norm=None):
+ es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(device, dtype)
+ e = nn.Embedding(N, D, max_norm=max_norm).to(device, dtype)
e.weight.data.copy_(es.weight.data)
input = torch.randint(N, (B, L), device=device, dtype=torch.long)
offsets = torch.arange(0, B, device=device, dtype=torch.long).mul_(L)
@@ -1656,8 +1656,9 @@ class TestNN(NNTestCase):
N, D, B, L = random.randint(1, 100), random.randint(1, 100), random.randint(1, 50), random.randint(1, 50)
_test_vs_Embedding(N, D, B, L)
- for p in itertools.product([1, 2], repeat=4):
- _test_vs_Embedding(*p)
+ for max_norm in (None, 3):
+ for p in itertools.product([1, 2], repeat=4):
+ _test_vs_Embedding(*p, max_norm=max_norm)
# check that giving illegal input combos raises error
es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse)
@@ -6780,27 +6781,27 @@ new_module_tests = [
dict(
module_name='Embedding',
constructor_args=(4, 3),
- input_fn=lambda: Variable(torch.randperm(2).repeat(1, 2)),
+ input_fn=lambda: torch.randperm(2).repeat(1, 2),
jacobian_input=False,
check_gradgrad=False,
),
dict(
module_name='EmbeddingBag',
constructor_args=(4, 3),
- input_fn=lambda: Variable(torch.randperm(2).repeat(1, 2)),
+ input_fn=lambda:torch.randperm(2).repeat(1, 2),
jacobian_input=False,
check_gradgrad=False,
),
dict(
fullname='EmbeddingBag_sparse',
constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True),
- input_fn=lambda: Variable(torch.randperm(2).repeat(1, 2)),
+ input_fn=lambda: torch.randperm(2).repeat(1, 2),
jacobian_input=False,
check_gradgrad=False,
),
dict(
constructor=lambda: nn.Embedding(4, 3, sparse=True),
- input_fn=lambda: Variable(torch.randperm(2).repeat(1, 2)),
+ input_fn=lambda: torch.randperm(2).repeat(1, 2),
jacobian_input=False,
fullname='Embedding_sparse',
check_gradgrad=False,