summaryrefslogtreecommitdiff
path: root/test/test_nn.py
diff options
context:
space:
mode:
authormc-robinson <matthew.robinson@yale.edu>2019-03-24 19:17:00 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-24 19:22:14 -0700
commit8bc5b867093f1b44f5168f3e41b1bdf0ba3e05a7 (patch)
treea16ec7e849c73aeb589793b7b481925f9e88bcbf /test/test_nn.py
parentca962f0f95185e76b1f43a4423d28986ee0da191 (diff)
downloadpytorch-8bc5b867093f1b44f5168f3e41b1bdf0ba3e05a7.tar.gz
pytorch-8bc5b867093f1b44f5168f3e41b1bdf0ba3e05a7.tar.bz2
pytorch-8bc5b867093f1b44f5168f3e41b1bdf0ba3e05a7.zip
Added tensor size warning to F.mse_loss() (#18349)
Summary: To address the issue of broadcasting giving the wrong result in `nn.MSELoss()` as mentioned here https://github.com/pytorch/pytorch/issues/16045 . In particular, the issue often arises when computing the loss between tensors with shapes (n, 1) and (n,) Pull Request resolved: https://github.com/pytorch/pytorch/pull/18349 Differential Revision: D14594176 Pulled By: soumith fbshipit-source-id: f23ae68a4bf42f3554ad7678a314ba2c7532a6db
Diffstat (limited to 'test/test_nn.py')
-rw-r--r--test/test_nn.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/test/test_nn.py b/test/test_nn.py
index 897255c215..b93f850021 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -4363,6 +4363,18 @@ class TestNN(NNTestCase):
def test_loss_equal_input_target_shape(self):
self._test_loss_equal_input_target_shape(lambda x: x)
+ def test_mse_loss_size_warning(self):
+ i = torch.randn((10, 1), requires_grad=True)
+ t = torch.randn((10,))
+ with warnings.catch_warnings(record=True) as w:
+ # Ensure warnings are being shown
+ warnings.simplefilter("always")
+ # Trigger Warning
+ F.mse_loss(i, t)
+ # Check warning occurs
+ self.assertEqual(len(w), 1)
+ self.assertIn('Please ensure they have the same size.', str(w[0]))
+
def test_nll_loss_mismatched_batch(self):
x = torch.randn((10, 3), requires_grad=True)
# t should have size (10,)