diff options
author | Christian Puhrsch <cpuhrsch@fb.com> | 2019-01-11 13:28:52 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-11 13:32:35 -0800 |
commit | d33159a4261fed36e147cc2351aff0741f0fa715 (patch) | |
tree | ea16b669843d77f8cc78b281f4a99d8720377d9a /test | |
parent | 926e718d5fe129f67a10eef7ef8ce754b25c1e1e (diff) | |
download | pytorch-d33159a4261fed36e147cc2351aff0741f0fa715.tar.gz pytorch-d33159a4261fed36e147cc2351aff0741f0fa715.tar.bz2 pytorch-d33159a4261fed36e147cc2351aff0741f0fa715.zip |
Undo norm optimizations and add more documentation for parallel.h (#15885)
Summary:
See https://github.com/pytorch/pytorch/issues/15602
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15885
Differential Revision: D13614841
Pulled By: cpuhrsch
fbshipit-source-id: 5d3e45f499d36ac287dbbc2e45798aa51eb5bfdf
Diffstat (limited to 'test')
-rw-r--r-- | test/test_torch.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/test/test_torch.py b/test/test_torch.py index 6d48ecc353..ce889325b6 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -844,7 +844,7 @@ class _TestTorchMixin(object): @staticmethod def _test_norm(self, device): # full reduction - x = torch.randn(5, device=device) + x = torch.randn(25, device=device) xn = x.cpu().numpy() for p in [0, 1, 2, 3, 4, inf, -inf]: res = x.norm(p).item() @@ -852,7 +852,7 @@ class _TestTorchMixin(object): self.assertEqual(res, expected, "full reduction failed for {}-norm".format(p)) # one dimension - x = torch.randn(5, 5, device=device) + x = torch.randn(25, 25, device=device) xn = x.cpu().numpy() for p in [0, 1, 2, 3, 4, inf, -inf]: res = x.norm(p, 1).cpu().numpy() @@ -867,6 +867,9 @@ class _TestTorchMixin(object): self.assertEqual(res.shape, expected.shape) self.assertTrue(np.allclose(res, expected), "dim reduction failed for {}-norm".format(p)) + # larger tensor sanity check + self.assertEqual(2 * torch.norm(torch.ones(10000)), torch.norm(torch.ones(40000))) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") @skipIfNoLapack def test_norm(self): |