diff options
author | vishwakftw <cs15btech11043@iith.ac.in> | 2019-02-25 10:32:48 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-25 10:35:33 -0800 |
commit | 724c7e76c69ac26140ceb61576de8b96bb8417e8 (patch) | |
tree | a1b95a7d2bdf1f94356f0f09f6c3b4eef64ded90 /test/test_nn.py | |
parent | f9ba3831ef3ae8501612d51353907717231871f2 (diff) | |
download | pytorch-724c7e76c69ac26140ceb61576de8b96bb8417e8.tar.gz pytorch-724c7e76c69ac26140ceb61576de8b96bb8417e8.tar.bz2 pytorch-724c7e76c69ac26140ceb61576de8b96bb8417e8.zip |
Fix reduction='none' in poisson_nll_loss (#17358)
Summary:
Changelog:
- Modify `if` to `elif` in reduction mode comparison
- Add error checking for reduction mode
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17358
Differential Revision: D14190523
Pulled By: zou3519
fbshipit-source-id: 2b734d284dc4c40679923606a1aa148e6a0abeb8
Diffstat (limited to 'test/test_nn.py')
-rw-r--r-- | test/test_nn.py | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/test/test_nn.py b/test/test_nn.py index 022a5a555b..f2a9beb06d 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4301,6 +4301,19 @@ class TestNN(NNTestCase): with self.assertRaisesRegex(ValueError, 'Expected.*batch_size'): F.nll_loss(x, t) + def test_poisson_nll_loss_reduction_modes(self): + input = torch.tensor([0.5, 1.5, 2.5]) + target = torch.tensor([1., 2., 3.]) + component_wise_loss = torch.exp(input) - target * input + self.assertEqual(component_wise_loss, + F.poisson_nll_loss(input, target, reduction='none')) + self.assertEqual(torch.sum(component_wise_loss), + F.poisson_nll_loss(input, target, reduction='sum')) + self.assertEqual(torch.mean(component_wise_loss), + F.poisson_nll_loss(input, target, reduction='mean')) + with self.assertRaisesRegex(ValueError, 'is not valid'): + F.poisson_nll_loss(input, target, reduction='total') + def test_KLDivLoss_batch_mean(self): input_shape = (2, 5) log_prob1 = F.log_softmax(torch.randn(input_shape), 1) |