summaryrefslogtreecommitdiff
path: root/test/test_nn.py
diff options
context:
space:
mode:
authorThomas Viehmann <tv.github@beamnet.de>2018-06-24 23:15:13 +0200
committerTongzhou Wang <SsnL@users.noreply.github.com>2018-06-24 17:15:13 -0400
commitfc22bf3e82723178015708ae1265ec14710c2dec (patch)
tree6a37495f32c7b7fae1b0bfee16baeea2213c2720 /test/test_nn.py
parent3598356420d6adc77aa70893205138d499451fc1 (diff)
downloadpytorch-fc22bf3e82723178015708ae1265ec14710c2dec.tar.gz
pytorch-fc22bf3e82723178015708ae1265ec14710c2dec.tar.bz2
pytorch-fc22bf3e82723178015708ae1265ec14710c2dec.zip
Spectral norm improvements (#8590)
* Spectral norm improvements - Don't do iterations on weight in eval mode To facilitate this, register weight as buffer in order to be able to use module with spectral norm in eval mode after immediately after loading state dict (#8208) - Use weight instead of weight_orig as weight when removing spectral norm - Add dim parameter in case the normalization should occur w.r.t. a dimension other than 0 (#7865) * add and update spectral norm tests * More spectral norm tests Thank you, Simon, for the suggestions.
Diffstat (limited to 'test/test_nn.py')
-rw-r--r--test/test_nn.py40
1 files changed, 38 insertions, 2 deletions
diff --git a/test/test_nn.py b/test/test_nn.py
index 6283d86c3a..bbec52945f 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -1438,9 +1438,9 @@ class TestNN(NNTestCase):
# weight_u should be just a reused buffer
self.assertTrue(hasattr(m, 'weight_u'))
self.assertTrue('weight_u' in m._buffers)
+ self.assertTrue('weight' in m._buffers)
# weight should be a plain attribute, not counted as a buffer or a param
- self.assertTrue(hasattr(m, 'weight'))
- self.assertFalse('weight' in m._buffers or 'weight' in m._parameters)
+ self.assertFalse('weight' in m._parameters)
# it should also be sharing storage as `weight_orig`
self.assertEqual(m.weight_orig.storage(), m.weight.storage())
self.assertEqual(m.weight_orig.size(), m.weight.size())
@@ -1453,6 +1453,42 @@ class TestNN(NNTestCase):
self.assertTrue(hasattr(m, 'weight'))
self.assertTrue('weight' in m._parameters)
+ def test_spectral_norm_eval_remove(self):
+ inp = torch.randn(3, 5)
+ m = nn.Linear(5, 7)
+ m = torch.nn.utils.spectral_norm(m)
+ x0 = m(inp)
+ m.eval()
+ # test that eval mode and removing / adding+removing doesn't change weight and output
+ x1 = m(inp)
+ x2 = m(inp)
+ self.assertEqual(x0, x1)
+ self.assertEqual(x0, x2)
+ m = torch.nn.utils.remove_spectral_norm(m)
+ x3 = m(inp)
+ self.assertEqual(x0, x3)
+ m = torch.nn.utils.spectral_norm(m)
+ m = torch.nn.utils.remove_spectral_norm(m)
+ x4 = m(inp)
+ self.assertEqual(x0, x4)
+ # check that removing after train doesn't change output
+ m.train()
+ m = torch.nn.utils.spectral_norm(m)
+ for i in range(5):
+ x0 = m(inp)
+ m = torch.nn.utils.remove_spectral_norm(m)
+ x1 = m(inp)
+ self.assertEqual(x0, x1)
+
+ def test_spectral_norm_dim(self):
+ inp = torch.randn(2, 3, 10, 12)
+ m = nn.ConvTranspose2d(3, 4, (5, 6))
+ m = torch.nn.utils.spectral_norm(m)
+ # this should not run into incompatible shapes
+ x = m(inp)
+ # check that u refers to the same dimension
+ self.assertEqual(m.weight_u.shape, m.weight_orig[0, :, 0, 0].shape)
+
def test_spectral_norm_forward(self):
input = torch.randn(3, 5)
m = nn.Linear(5, 7)