diff options
author | Thomas Viehmann <tv.github@beamnet.de> | 2018-06-24 23:15:13 +0200 |
---|---|---|
committer | Tongzhou Wang <SsnL@users.noreply.github.com> | 2018-06-24 17:15:13 -0400 |
commit | fc22bf3e82723178015708ae1265ec14710c2dec (patch) | |
tree | 6a37495f32c7b7fae1b0bfee16baeea2213c2720 /test/test_nn.py | |
parent | 3598356420d6adc77aa70893205138d499451fc1 (diff) | |
download | pytorch-fc22bf3e82723178015708ae1265ec14710c2dec.tar.gz pytorch-fc22bf3e82723178015708ae1265ec14710c2dec.tar.bz2 pytorch-fc22bf3e82723178015708ae1265ec14710c2dec.zip |
Spectral norm improvements (#8590)
* Spectral norm improvements
- Don't do iterations on weight in eval mode
To facilitate this, register weight as buffer in order to be able
to use module with spectral norm in eval mode after immediately
after loading state dict (#8208)
- Use weight instead of weight_orig as weight when removing
spectral norm
- Add dim parameter in case the normalization should occur w.r.t.
a dimension other than 0 (#7865)
* add and update spectral norm tests
* More spectral norm tests
Thank you, Simon, for the suggestions.
Diffstat (limited to 'test/test_nn.py')
-rw-r--r-- | test/test_nn.py | 40 |
1 files changed, 38 insertions, 2 deletions
diff --git a/test/test_nn.py b/test/test_nn.py index 6283d86c3a..bbec52945f 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1438,9 +1438,9 @@ class TestNN(NNTestCase): # weight_u should be just a reused buffer self.assertTrue(hasattr(m, 'weight_u')) self.assertTrue('weight_u' in m._buffers) + self.assertTrue('weight' in m._buffers) # weight should be a plain attribute, not counted as a buffer or a param - self.assertTrue(hasattr(m, 'weight')) - self.assertFalse('weight' in m._buffers or 'weight' in m._parameters) + self.assertFalse('weight' in m._parameters) # it should also be sharing storage as `weight_orig` self.assertEqual(m.weight_orig.storage(), m.weight.storage()) self.assertEqual(m.weight_orig.size(), m.weight.size()) @@ -1453,6 +1453,42 @@ class TestNN(NNTestCase): self.assertTrue(hasattr(m, 'weight')) self.assertTrue('weight' in m._parameters) + def test_spectral_norm_eval_remove(self): + inp = torch.randn(3, 5) + m = nn.Linear(5, 7) + m = torch.nn.utils.spectral_norm(m) + x0 = m(inp) + m.eval() + # test that eval mode and removing / adding+removing doesn't change weight and output + x1 = m(inp) + x2 = m(inp) + self.assertEqual(x0, x1) + self.assertEqual(x0, x2) + m = torch.nn.utils.remove_spectral_norm(m) + x3 = m(inp) + self.assertEqual(x0, x3) + m = torch.nn.utils.spectral_norm(m) + m = torch.nn.utils.remove_spectral_norm(m) + x4 = m(inp) + self.assertEqual(x0, x4) + # check that removing after train doesn't change output + m.train() + m = torch.nn.utils.spectral_norm(m) + for i in range(5): + x0 = m(inp) + m = torch.nn.utils.remove_spectral_norm(m) + x1 = m(inp) + self.assertEqual(x0, x1) + + def test_spectral_norm_dim(self): + inp = torch.randn(2, 3, 10, 12) + m = nn.ConvTranspose2d(3, 4, (5, 6)) + m = torch.nn.utils.spectral_norm(m) + # this should not run into incompatible shapes + x = m(inp) + # check that u refers to the same dimension + self.assertEqual(m.weight_u.shape, m.weight_orig[0, :, 0, 0].shape) + def test_spectral_norm_forward(self): input = torch.randn(3, 5) m = nn.Linear(5, 7) |