summaryrefslogtreecommitdiff
path: root/test/test_nn.py
diff options
context:
space:
mode:
authorvishwakftw <cs15btech11043@iith.ac.in>2019-02-25 10:32:48 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-25 10:35:33 -0800
commit724c7e76c69ac26140ceb61576de8b96bb8417e8 (patch)
treea1b95a7d2bdf1f94356f0f09f6c3b4eef64ded90 /test/test_nn.py
parentf9ba3831ef3ae8501612d51353907717231871f2 (diff)
downloadpytorch-724c7e76c69ac26140ceb61576de8b96bb8417e8.tar.gz
pytorch-724c7e76c69ac26140ceb61576de8b96bb8417e8.tar.bz2
pytorch-724c7e76c69ac26140ceb61576de8b96bb8417e8.zip
Fix reduction='none' in poisson_nll_loss (#17358)
Summary: Changelog: - Modify `if` to `elif` in reduction mode comparison - Add error checking for reduction mode Pull Request resolved: https://github.com/pytorch/pytorch/pull/17358 Differential Revision: D14190523 Pulled By: zou3519 fbshipit-source-id: 2b734d284dc4c40679923606a1aa148e6a0abeb8
Diffstat (limited to 'test/test_nn.py')
-rw-r--r--test/test_nn.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/test/test_nn.py b/test/test_nn.py
index 022a5a555b..f2a9beb06d 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -4301,6 +4301,19 @@ class TestNN(NNTestCase):
with self.assertRaisesRegex(ValueError, 'Expected.*batch_size'):
F.nll_loss(x, t)
+ def test_poisson_nll_loss_reduction_modes(self):
+ input = torch.tensor([0.5, 1.5, 2.5])
+ target = torch.tensor([1., 2., 3.])
+ component_wise_loss = torch.exp(input) - target * input
+ self.assertEqual(component_wise_loss,
+ F.poisson_nll_loss(input, target, reduction='none'))
+ self.assertEqual(torch.sum(component_wise_loss),
+ F.poisson_nll_loss(input, target, reduction='sum'))
+ self.assertEqual(torch.mean(component_wise_loss),
+ F.poisson_nll_loss(input, target, reduction='mean'))
+ with self.assertRaisesRegex(ValueError, 'is not valid'):
+ F.poisson_nll_loss(input, target, reduction='total')
+
def test_KLDivLoss_batch_mean(self):
input_shape = (2, 5)
log_prob1 = F.log_softmax(torch.randn(input_shape), 1)