summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorChristian Puhrsch <cpuhrsch@fb.com>2019-01-11 13:28:52 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-11 13:32:35 -0800
commitd33159a4261fed36e147cc2351aff0741f0fa715 (patch)
treeea16b669843d77f8cc78b281f4a99d8720377d9a /test
parent926e718d5fe129f67a10eef7ef8ce754b25c1e1e (diff)
downloadpytorch-d33159a4261fed36e147cc2351aff0741f0fa715.tar.gz
pytorch-d33159a4261fed36e147cc2351aff0741f0fa715.tar.bz2
pytorch-d33159a4261fed36e147cc2351aff0741f0fa715.zip
Undo norm optimizations and add more documentation for parallel.h (#15885)
Summary: See https://github.com/pytorch/pytorch/issues/15602 Pull Request resolved: https://github.com/pytorch/pytorch/pull/15885 Differential Revision: D13614841 Pulled By: cpuhrsch fbshipit-source-id: 5d3e45f499d36ac287dbbc2e45798aa51eb5bfdf
Diffstat (limited to 'test')
-rw-r--r--test/test_torch.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/test/test_torch.py b/test/test_torch.py
index 6d48ecc353..ce889325b6 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -844,7 +844,7 @@ class _TestTorchMixin(object):
@staticmethod
def _test_norm(self, device):
# full reduction
- x = torch.randn(5, device=device)
+ x = torch.randn(25, device=device)
xn = x.cpu().numpy()
for p in [0, 1, 2, 3, 4, inf, -inf]:
res = x.norm(p).item()
@@ -852,7 +852,7 @@ class _TestTorchMixin(object):
self.assertEqual(res, expected, "full reduction failed for {}-norm".format(p))
# one dimension
- x = torch.randn(5, 5, device=device)
+ x = torch.randn(25, 25, device=device)
xn = x.cpu().numpy()
for p in [0, 1, 2, 3, 4, inf, -inf]:
res = x.norm(p, 1).cpu().numpy()
@@ -867,6 +867,9 @@ class _TestTorchMixin(object):
self.assertEqual(res.shape, expected.shape)
self.assertTrue(np.allclose(res, expected), "dim reduction failed for {}-norm".format(p))
+ # larger tensor sanity check
+ self.assertEqual(2 * torch.norm(torch.ones(10000)), torch.norm(torch.ones(40000)))
+
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
@skipIfNoLapack
def test_norm(self):