diff options
author | SsnL <tongzhou.wang.1994@gmail.com> | 2019-01-27 12:08:09 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-27 12:11:14 -0800 |
commit | c863a759a0caa13f10f639b66de43e0b3b0f9dd4 (patch) | |
tree | c99e07c82752203de9e58d43809c5ce2accc110e /test | |
parent | 6944461a76b3fe9df02679497191f79fddb1b69a (diff) | |
download | pytorch-c863a759a0caa13f10f639b66de43e0b3b0f9dd4.tar.gz pytorch-c863a759a0caa13f10f639b66de43e0b3b0f9dd4.tar.bz2 pytorch-c863a759a0caa13f10f639b66de43e0b3b0f9dd4.zip |
Fix slogdet sign requiring grad when input requires grad (#16337)
Summary:
The real fix for https://github.com/pytorch/pytorch/issues/15605.
This is sort of BC breaking because now
```py
In [1]: import torch
In [2]: a = torch.randn(3, 3, requires_grad=True)
In [3]: a.slogdet()
Out[3]: (tensor(1.), tensor(0.1356, grad_fn=<SlogdetBackward>))
In [4]: a.slogdet()[0].requires_grad
Out[4]: False
```
while before this patch ` a.slogdet()[0]` requires grad with `grad_fn=<SlogdetBackward>`. But any use of backproping through this value will meet the error in #15605 so I don't think this is a problem.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16337
Differential Revision: D13832644
Pulled By: soumith
fbshipit-source-id: f96c477e99edcbdbd966888e5c5ea7fd058429a8
Diffstat (limited to 'test')
-rw-r--r-- | test/test_autograd.py | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/test/test_autograd.py b/test/test_autograd.py index 6525870fd9..31c138a798 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -204,6 +204,26 @@ class TestAutograd(TestCase): x_grad, x_grad_clone = compute_grad(create_graph=True) self.assertEqual(x_grad, x_grad_clone) + def test_slogdet_sign(self): + a = torch.randn(3, 3, requires_grad=True) + s, logdet = a.slogdet() + + # test that sign should not require grad + self.assertFalse(s.requires_grad) + + # test that backward through computation involving sign works + def sign_mul_logdet(mat): + s, logdet = mat.slogdet() + return s * logdet + + u, s, v = a.detach().svd() + s.abs_().clamp_(0.0001) + for sign in (-1, 1): + s[-1] = sign + mat = torch.chain_matmul(u, s.diag(), v.t()).requires_grad_() + gradcheck(sign_mul_logdet, mat) + gradgradcheck(sign_mul_logdet, mat) + def test_sum_to_with_empty_dim_grad(self): a = torch.rand(4, 0, requires_grad=True) b = torch.rand(4, 1, requires_grad=True) |