summaryrefslogtreecommitdiff
path: root/test/test_cuda.py
diff options
context:
space:
mode:
authorAiling Zhang <ailzhang@fb.com>2018-08-29 10:48:04 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-08-29 10:55:22 -0700
commita9469c9c8ab046a7961c1c357d84f60063507c4b (patch)
tree1640671a3fe2c51f38a838a08f1f886ce5c1ef8c /test/test_cuda.py
parentb41988c71ed7d40af7a314b2049a4b0d5909fed2 (diff)
downloadpytorch-a9469c9c8ab046a7961c1c357d84f60063507c4b.tar.gz
pytorch-a9469c9c8ab046a7961c1c357d84f60063507c4b.tar.bz2
pytorch-a9469c9c8ab046a7961c1c357d84f60063507c4b.zip
Fill eigenvector with zeros if not required (#10645)
Summary: Fix #10345, which only happens in CUDA case. * Instead of returning some random buffer, we fill it with zeros. * update torch.symeig doc. Pull Request resolved: https://github.com/pytorch/pytorch/pull/10645 Reviewed By: soumith Differential Revision: D9395762 Pulled By: ailzhang fbshipit-source-id: 0f3ed9bb6a919a9c1a4b8eb45188f65a68bfa9ba
Diffstat (limited to 'test/test_cuda.py')
-rw-r--r--test/test_cuda.py12
1 files changed, 1 insertions, 11 deletions
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 73ba388069..088919ad59 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -1734,17 +1734,7 @@ class TestCuda(TestCase):
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
def test_symeig(self):
- # Small case
- tensor = torch.randn(3, 3).cuda()
- tensor = torch.mm(tensor, tensor.t())
- eigval, eigvec = torch.symeig(tensor, eigenvectors=True)
- self.assertEqual(tensor, torch.mm(torch.mm(eigvec, eigval.diag()), eigvec.t()))
-
- # Large case
- tensor = torch.randn(257, 257).cuda()
- tensor = torch.mm(tensor, tensor.t())
- eigval, eigvec = torch.symeig(tensor, eigenvectors=True)
- self.assertEqual(tensor, torch.mm(torch.mm(eigvec, eigval.diag()), eigvec.t()))
+ TestTorch._test_symeig(self, lambda t: t.cuda())
def test_arange(self):
for t in ['IntTensor', 'LongTensor', 'FloatTensor', 'DoubleTensor']: