diff options
author | Sam Gross <colesbury@gmail.com> | 2018-06-06 18:09:53 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-06 18:09:53 -0400 |
commit | 12229afd0054be5ced0ae36fccf7eae3dffa30a1 (patch) | |
tree | 6d05c881d136feee4f7148abe572d435b95e8812 /test | |
parent | 36b8cc54836474b07b96cd3b5eefdd7d9a503c2f (diff) | |
download | pytorch-12229afd0054be5ced0ae36fccf7eae3dffa30a1.tar.gz pytorch-12229afd0054be5ced0ae36fccf7eae3dffa30a1.tar.bz2 pytorch-12229afd0054be5ced0ae36fccf7eae3dffa30a1.zip |
Record shape and type in autograd to validate gradients (#8168)
The check that the gradient is defined is currently disabled because
TestJit.test_ge_optimized will trigger the error.
Diffstat (limited to 'test')
-rw-r--r-- | test/cpp/api/misc.cpp | 2 | ||||
-rw-r--r-- | test/test_autograd.py | 26 | ||||
-rw-r--r-- | test/test_jit.py | 2 |
3 files changed, 23 insertions, 7 deletions
diff --git a/test/cpp/api/misc.cpp b/test/cpp/api/misc.cpp index a494e3389e..0f4fa33ddf 100644 --- a/test/cpp/api/misc.cpp +++ b/test/cpp/api/misc.cpp @@ -71,7 +71,7 @@ TEST_CASE("autograd") { } SECTION("custom gradient inputs") { z.sum().backward( - autograd::make_variable(at::ones(at::CPU(at::kFloat), {1}) * 2)); + autograd::make_variable(at::ones(at::CPU(at::kFloat), {}) * 2)); REQUIRE(x.grad().allclose(y * 2)); } // Assume everything else is safe from PyTorch tests. diff --git a/test/test_autograd.py b/test/test_autograd.py index 0c19e24208..4c26e58aa6 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -75,7 +75,7 @@ class TestAutograd(TestCase): x = torch.randn(5, 5, requires_grad=True) y = torch.randn(5, 5, requires_grad=True) result = cls.apply(x, 2, y) - go = torch.ones(1, requires_grad=True) + go = torch.ones((), requires_grad=True) result.sum().backward(go, create_graph=True) self.assertEqual(x.grad.data, y.data + torch.ones(5, 5)) @@ -173,6 +173,23 @@ class TestAutograd(TestCase): MyFunction()(y).sum().backward() self.assertEqual(v.grad.data, torch.zeros(shape)) + def test_invalid_gradients(self): + class MyFunction(Function): + @staticmethod + def forward(ctx, x): + return x * 2 + + @staticmethod + def backward(ctx, grad_output): + return torch.randn(10, dtype=torch.float) + + with self.assertRaisesRegex(RuntimeError, 'expected shape'): + input = torch.randn(5, 5, dtype=torch.float, requires_grad=True) + MyFunction.apply(input).sum().backward() + with self.assertRaisesRegex(RuntimeError, 'expected type'): + input = torch.randn(10, dtype=torch.double, requires_grad=True) + MyFunction.apply(input).sum().backward() + def test_accumulate_grad(self): grad_output = torch.ones(5, 5) @@ -495,7 +512,6 @@ class TestAutograd(TestCase): def test_sparse_backward(self): class FixedGradientFunction(Function): - def __init__(self, grad): self.grad = grad @@ -524,15 +540,15 @@ class TestAutograd(TestCase): dense_fn = FixedGradientFunction(dense_grad) # sparse first - x = torch.randn(5, 5, requires_grad=True) + x = torch.randn(size, requires_grad=True) (sparse_fn1(x) + dense_fn(x) + sparse_fn2(x)).sum().backward() self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2) # dense first - x = torch.randn(5, 5, requires_grad=True) + x = torch.randn(size, requires_grad=True) (dense_fn(x) + sparse_fn1(x) + sparse_fn2(x)).sum().backward() self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2) # sparse only - x = torch.randn(5, 5, requires_grad=True) + x = torch.randn(size, requires_grad=True) (sparse_fn1(x) + sparse_fn2(x)).sum().backward() self.assertEqual(x.grad, sparse_grad1 + sparse_grad2) diff --git a/test/test_jit.py b/test/test_jit.py index d56f5ad5f8..4b6ee940b5 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1152,7 +1152,7 @@ class TestScript(JitTestCase): out = func(x, y) self.assertEqual(func(x, y), x + y) - grad = torch.randn(2, 3) + grad = torch.randn(2, 3, dtype=torch.float) out.backward(grad) self.assertEqual(x.grad, grad) self.assertEqual(y.grad, grad.sum(dim=0)) |