summaryrefslogtreecommitdiff
path: root/test/common_nn.py
diff options
context:
space:
mode:
authorTongzhou Wang <SsnL@users.noreply.github.com>2018-05-23 11:03:12 -0400
committerGitHub <noreply@github.com>2018-05-23 11:03:12 -0400
commite3e15b5d9534f1c10170e169f1423e9298648e86 (patch)
treed74cb50a982ca753d74e539a27d66731189130e2 /test/common_nn.py
parent6a604f16cc452424c58ce456da7f46e1619b9e18 (diff)
downloadpytorch-e3e15b5d9534f1c10170e169f1423e9298648e86.tar.gz
pytorch-e3e15b5d9534f1c10170e169f1423e9298648e86.tar.bz2
pytorch-e3e15b5d9534f1c10170e169f1423e9298648e86.zip
[PyTorch] [gradcheck] change backward() to grad() (#7710)
* Change backward calls to grad to avoid memory leak from #7343; Replace unnecesary create_graph=True with retain_graph=True * fix gradgradcheck use of make_non_contiguous * allow non-contguous target * remove unnecessray .grad.zero_() * remove contiguous_detach * fix PReLU double backward always returning ggW as a scalar * let noncontig gO require grad * move requires_grad to return
Diffstat (limited to 'test/common_nn.py')
-rw-r--r--test/common_nn.py9
1 files changed, 4 insertions, 5 deletions
diff --git a/test/common_nn.py b/test/common_nn.py
index 51b883ba9e..48a5211296 100644
--- a/test/common_nn.py
+++ b/test/common_nn.py
@@ -7,7 +7,7 @@ from itertools import product
import torch
import torch.cuda
from common import TestCase, to_gpu, freeze_rng_state, is_iterable
-from torch.autograd.gradcheck import get_numerical_jacobian, iter_tensors, contiguous
+from torch.autograd.gradcheck import get_numerical_jacobian, iter_tensors
import torch.backends.cudnn
# tarfile module tries to obtain a file object name in python 3.3
@@ -783,9 +783,8 @@ class NNTestCase(TestCase):
return self._forward(module, input).detach()
res = tuple()
- input = contiguous(input)
if jacobian_input:
- res += get_numerical_jacobian(fw, input, input, eps=1e-6),
+ res += get_numerical_jacobian(fw, input, eps=1e-6),
if jacobian_parameters:
param, _ = self._get_parameters(module)
res += torch.cat([get_numerical_jacobian(fw, input, p, eps=1e-6) for p in param], 0),
@@ -813,8 +812,8 @@ class NNTestCase(TestCase):
input_t = iter_tensors(input)
numerical_t = iter_tensors(numerical_d_x)
for x, d_x in zip(input_t, numerical_t):
- x = x.view(-1)
- d_x = d_x.view(-1)
+ x = x.view(-1).data
+ d_x = d_x.view(-1).data
for i in range(x.nelement()):
original = x[i].item()
x[i] = original + eps