diff options
author | kshitij12345 <kshitijkalambarkar@gmail.com> | 2021-09-21 07:21:15 -0700 |
---|---|---|
committer | Facebook GitHub Bot <facebook-github-bot@users.noreply.github.com> | 2021-09-21 07:29:48 -0700 |
commit | 9c23f6eb7d87ec9fe7af7658e58f0b5f0eaf19df (patch) | |
tree | 19272378c016249cf457a7efc2b12ec061528811 | |
parent | d35ee431d88a3bf2186120b308d96c4ebb85f65c (diff) | |
download | pytorch-9c23f6eb7d87ec9fe7af7658e58f0b5f0eaf19df.tar.gz pytorch-9c23f6eb7d87ec9fe7af7658e58f0b5f0eaf19df.tar.bz2 pytorch-9c23f6eb7d87ec9fe7af7658e58f0b5f0eaf19df.zip |
[nn] TripletMarginLoss and PairwiseDistance : no batch dim (#64882)
Summary:
Reference: https://github.com/pytorch/pytorch/issues/60585
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64882
Reviewed By: malfet
Differential Revision: D31055577
Pulled By: jbschlosser
fbshipit-source-id: 2f0a5a08619b672026b48a78bc7d83a6dccba0bf
-rw-r--r-- | aten/src/ATen/native/Distance.cpp | 7 | ||||
-rw-r--r-- | aten/src/ATen/native/Loss.cpp | 12 | ||||
-rw-r--r-- | test/test_nn.py | 15 | ||||
-rw-r--r-- | torch/nn/modules/distance.py | 9 | ||||
-rw-r--r-- | torch/nn/modules/loss.py | 6 | ||||
-rw-r--r-- | torch/testing/_internal/common_nn.py | 39 |
6 files changed, 78 insertions, 10 deletions
diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp index b79c3e91c6..7974840dd3 100644 --- a/aten/src/ATen/native/Distance.cpp +++ b/aten/src/ATen/native/Distance.cpp @@ -14,7 +14,12 @@ DEFINE_DISPATCH(cdist_stub); DEFINE_DISPATCH(cdist_backward_stub); Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double eps, bool keepdim) { - return at::norm(x1 - x2 + eps, p, 1, keepdim); + // Since either x1 or x2 could be broadcasted + auto x1_dim = x1.dim(); + auto x2_dim = x2.dim(); + auto output_dim = x1_dim > x2_dim ? x1_dim : x2_dim; + auto innermost_dim = output_dim - 1; + return at::norm(x1 - x2 + eps, p, innermost_dim, keepdim); } // This is to guarantee that the contiguous memory is passed to the backward pass diff --git a/aten/src/ATen/native/Loss.cpp b/aten/src/ATen/native/Loss.cpp index 6c4c21bd1a..4054655599 100644 --- a/aten/src/ATen/native/Loss.cpp +++ b/aten/src/ATen/native/Loss.cpp @@ -79,6 +79,18 @@ Tensor hinge_embedding_loss(const Tensor& self, const Tensor& target, double mar Tensor triplet_margin_loss(const Tensor& anchor, const Tensor& positive, const Tensor& negative, double margin, double p, double eps, bool swap, int64_t reduction) { + auto a_dim = anchor.dim(); + auto p_dim = positive.dim(); + auto n_dim = negative.dim(); + TORCH_CHECK( + a_dim == p_dim && p_dim == n_dim, + "All inputs should have same dimension but got ", + a_dim, + "D, ", + p_dim, + "D and ", + n_dim, + "D inputs.") auto dist_pos = at::pairwise_distance(anchor, positive, p, eps); auto dist_neg = at::pairwise_distance(anchor, negative, p, eps); if (swap) { diff --git a/test/test_nn.py b/test/test_nn.py index f324350f9e..92357d9ce1 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -9594,6 +9594,21 @@ class TestNN(NNTestCase): self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduction='none'), loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduction='none')) + def test_triplet_margin_loss_invalid(self): + input1 = torch.randn(5, 10, requires_grad=True) + input2 = torch.randn(5, 10, requires_grad=True) + input3 = torch.randn(5, 10, requires_grad=True) + input_1d = torch.randn(10, requires_grad=True) + + with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"): + F.triplet_margin_loss(input1, input2, input_1d) + + with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"): + F.triplet_margin_loss(input1, input_1d, input3) + + with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"): + F.triplet_margin_loss(input_1d, input2, input3) + def test_pointwise_loss_target_grad_none_reduction(self): i = torch.randn(5, 10) t = torch.randn(5, 10, requires_grad=True) diff --git a/torch/nn/modules/distance.py b/torch/nn/modules/distance.py index 29501e9010..00513ac2aa 100644 --- a/torch/nn/modules/distance.py +++ b/torch/nn/modules/distance.py @@ -6,7 +6,7 @@ from torch import Tensor class PairwiseDistance(Module): r""" - Computes the batchwise pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm: + Computes the pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm: .. math :: \Vert x \Vert _p = \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}. @@ -18,9 +18,10 @@ class PairwiseDistance(Module): keepdim (bool, optional): Determines whether or not to keep the vector dimension. Default: False Shape: - - Input1: :math:`(N, D)` where `D = vector dimension` - - Input2: :math:`(N, D)`, same shape as the Input1 - - Output: :math:`(N)`. If :attr:`keepdim` is ``True``, then :math:`(N, 1)`. + - Input1: :math:`(N, D)` or :math:`(D)` where `N = batch dimension` and `D = vector dimension` + - Input2: :math:`(N, D)` or :math:`(D)`, same shape as the Input1 + - Output: :math:`(N)` or :math:`()` based on input dimension. + If :attr:`keepdim` is ``True``, then :math:`(N, 1)` or :math:`(1)` based on input dimension. Examples:: >>> pdist = nn.PairwiseDistance(p=2) >>> input1 = torch.randn(100, 128) diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index e0989e5b44..6d295aa733 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -1436,9 +1436,9 @@ class TripletMarginLoss(_Loss): specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` Shape: - - Input: :math:`(N, D)` where :math:`D` is the vector dimension. - - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar - otherwise. + - Input: :math:`(N, D)` or :math`(D)` where :math:`D` is the vector dimension. + - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'`` and + input shape is :math`(N, D)`; a scalar otherwise. Examples:: diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index 58f30d8742..74895bb746 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -1305,9 +1305,15 @@ def single_batch_reference_fn(input, parameters, module): The module is passed the input and target in batched form with a single item. The output is squeezed to compare with the no-batch input. """ - single_batch_input = input.unsqueeze(0) + def unsqueeze_inp(inp): + if isinstance(inp, (list, tuple)): + return [t.unsqueeze(0) for t in inp] + return inp.unsqueeze(0) + + single_batch_input = unsqueeze_inp(input) + single_batch_input = [single_batch_input] if isinstance(single_batch_input, torch.Tensor) else single_batch_input with freeze_rng_state(): - return module(single_batch_input).squeeze(0) + return module(*single_batch_input).squeeze(0) new_module_tests = [ @@ -3944,6 +3950,33 @@ new_module_tests = [ pickle=False, ), dict( + module_name='PairwiseDistance', + input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)), + ), + dict( + module_name='PairwiseDistance', + input_fn=lambda: (torch.randn(10, 1), torch.randn(10, 8)), + desc='broadcast_lhs' + ), + dict( + module_name='PairwiseDistance', + input_fn=lambda: (torch.randn(10, 8), torch.randn(1, 8)), + desc='broadcast_rhs' + ), + dict( + module_name='PairwiseDistance', + constructor_args=(1.5, 1e-05, True), + cpp_constructor_args='torch::nn::PairwiseDistanceOptions().p(1.5).eps(1e-05).keepdim(true)', + input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)), + desc='with_non_default_args', + ), + dict( + module_name='PairwiseDistance', + input_fn=lambda: (torch.randn(8), torch.randn(8)), + reference_fn=single_batch_reference_fn, + desc='no_batch_dim', + ), + dict( module_name='TransformerEncoderLayer', constructor_args=(4, 2, 16, 0.0), cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2) @@ -5445,6 +5478,8 @@ classification_criterion_no_batch = [ ('SoftMarginLoss', lambda: torch.randn(9), lambda: torch.tensor([-1, 1, 1] * 3)), ('NLLLoss', lambda: F.log_softmax(torch.randn(3), dim=0), lambda: torch.tensor(1)), ('CosineEmbeddingLoss', lambda: (torch.randn(9), torch.randn(9)), lambda: torch.tensor(1)), + # For TripletMarginLoss, input_fn : (anchor, positive) and target_fn : negative + ('TripletMarginLoss', lambda: (torch.randn(9), torch.randn(9)), lambda: torch.randn(9)), ] classification_criterion_no_batch_extra_info: Dict[str, dict] = { 'MultiLabelMarginLoss': {'check_gradgrad': False}, |