summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSsnL <tongzhou.wang.1994@gmail.com>2019-01-27 12:08:09 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-27 12:11:14 -0800
commitc863a759a0caa13f10f639b66de43e0b3b0f9dd4 (patch)
treec99e07c82752203de9e58d43809c5ce2accc110e /test
parent6944461a76b3fe9df02679497191f79fddb1b69a (diff)
downloadpytorch-c863a759a0caa13f10f639b66de43e0b3b0f9dd4.tar.gz
pytorch-c863a759a0caa13f10f639b66de43e0b3b0f9dd4.tar.bz2
pytorch-c863a759a0caa13f10f639b66de43e0b3b0f9dd4.zip
Fix slogdet sign requiring grad when input requires grad (#16337)
Summary: The real fix for https://github.com/pytorch/pytorch/issues/15605. This is sort of BC breaking because now ```py In [1]: import torch In [2]: a = torch.randn(3, 3, requires_grad=True) In [3]: a.slogdet() Out[3]: (tensor(1.), tensor(0.1356, grad_fn=<SlogdetBackward>)) In [4]: a.slogdet()[0].requires_grad Out[4]: False ``` while before this patch ` a.slogdet()[0]` requires grad with `grad_fn=<SlogdetBackward>`. But any use of backproping through this value will meet the error in #15605 so I don't think this is a problem. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16337 Differential Revision: D13832644 Pulled By: soumith fbshipit-source-id: f96c477e99edcbdbd966888e5c5ea7fd058429a8
Diffstat (limited to 'test')
-rw-r--r--test/test_autograd.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 6525870fd9..31c138a798 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -204,6 +204,26 @@ class TestAutograd(TestCase):
x_grad, x_grad_clone = compute_grad(create_graph=True)
self.assertEqual(x_grad, x_grad_clone)
+ def test_slogdet_sign(self):
+ a = torch.randn(3, 3, requires_grad=True)
+ s, logdet = a.slogdet()
+
+ # test that sign should not require grad
+ self.assertFalse(s.requires_grad)
+
+ # test that backward through computation involving sign works
+ def sign_mul_logdet(mat):
+ s, logdet = mat.slogdet()
+ return s * logdet
+
+ u, s, v = a.detach().svd()
+ s.abs_().clamp_(0.0001)
+ for sign in (-1, 1):
+ s[-1] = sign
+ mat = torch.chain_matmul(u, s.diag(), v.t()).requires_grad_()
+ gradcheck(sign_mul_logdet, mat)
+ gradgradcheck(sign_mul_logdet, mat)
+
def test_sum_to_with_empty_dim_grad(self):
a = torch.rand(4, 0, requires_grad=True)
b = torch.rand(4, 1, requires_grad=True)