summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkshitij12345 <kshitijkalambarkar@gmail.com>2021-09-21 07:21:15 -0700
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>2021-09-21 07:29:48 -0700
commit9c23f6eb7d87ec9fe7af7658e58f0b5f0eaf19df (patch)
tree19272378c016249cf457a7efc2b12ec061528811
parentd35ee431d88a3bf2186120b308d96c4ebb85f65c (diff)
downloadpytorch-9c23f6eb7d87ec9fe7af7658e58f0b5f0eaf19df.tar.gz
pytorch-9c23f6eb7d87ec9fe7af7658e58f0b5f0eaf19df.tar.bz2
pytorch-9c23f6eb7d87ec9fe7af7658e58f0b5f0eaf19df.zip
[nn] TripletMarginLoss and PairwiseDistance : no batch dim (#64882)
Summary: Reference: https://github.com/pytorch/pytorch/issues/60585 Pull Request resolved: https://github.com/pytorch/pytorch/pull/64882 Reviewed By: malfet Differential Revision: D31055577 Pulled By: jbschlosser fbshipit-source-id: 2f0a5a08619b672026b48a78bc7d83a6dccba0bf
-rw-r--r--aten/src/ATen/native/Distance.cpp7
-rw-r--r--aten/src/ATen/native/Loss.cpp12
-rw-r--r--test/test_nn.py15
-rw-r--r--torch/nn/modules/distance.py9
-rw-r--r--torch/nn/modules/loss.py6
-rw-r--r--torch/testing/_internal/common_nn.py39
6 files changed, 78 insertions, 10 deletions
diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp
index b79c3e91c6..7974840dd3 100644
--- a/aten/src/ATen/native/Distance.cpp
+++ b/aten/src/ATen/native/Distance.cpp
@@ -14,7 +14,12 @@ DEFINE_DISPATCH(cdist_stub);
DEFINE_DISPATCH(cdist_backward_stub);
Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double eps, bool keepdim) {
- return at::norm(x1 - x2 + eps, p, 1, keepdim);
+ // Since either x1 or x2 could be broadcasted
+ auto x1_dim = x1.dim();
+ auto x2_dim = x2.dim();
+ auto output_dim = x1_dim > x2_dim ? x1_dim : x2_dim;
+ auto innermost_dim = output_dim - 1;
+ return at::norm(x1 - x2 + eps, p, innermost_dim, keepdim);
}
// This is to guarantee that the contiguous memory is passed to the backward pass
diff --git a/aten/src/ATen/native/Loss.cpp b/aten/src/ATen/native/Loss.cpp
index 6c4c21bd1a..4054655599 100644
--- a/aten/src/ATen/native/Loss.cpp
+++ b/aten/src/ATen/native/Loss.cpp
@@ -79,6 +79,18 @@ Tensor hinge_embedding_loss(const Tensor& self, const Tensor& target, double mar
Tensor triplet_margin_loss(const Tensor& anchor, const Tensor& positive, const Tensor& negative, double margin,
double p, double eps, bool swap, int64_t reduction) {
+ auto a_dim = anchor.dim();
+ auto p_dim = positive.dim();
+ auto n_dim = negative.dim();
+ TORCH_CHECK(
+ a_dim == p_dim && p_dim == n_dim,
+ "All inputs should have same dimension but got ",
+ a_dim,
+ "D, ",
+ p_dim,
+ "D and ",
+ n_dim,
+ "D inputs.")
auto dist_pos = at::pairwise_distance(anchor, positive, p, eps);
auto dist_neg = at::pairwise_distance(anchor, negative, p, eps);
if (swap) {
diff --git a/test/test_nn.py b/test/test_nn.py
index f324350f9e..92357d9ce1 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -9594,6 +9594,21 @@ class TestNN(NNTestCase):
self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduction='none'),
loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduction='none'))
+ def test_triplet_margin_loss_invalid(self):
+ input1 = torch.randn(5, 10, requires_grad=True)
+ input2 = torch.randn(5, 10, requires_grad=True)
+ input3 = torch.randn(5, 10, requires_grad=True)
+ input_1d = torch.randn(10, requires_grad=True)
+
+ with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
+ F.triplet_margin_loss(input1, input2, input_1d)
+
+ with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
+ F.triplet_margin_loss(input1, input_1d, input3)
+
+ with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
+ F.triplet_margin_loss(input_1d, input2, input3)
+
def test_pointwise_loss_target_grad_none_reduction(self):
i = torch.randn(5, 10)
t = torch.randn(5, 10, requires_grad=True)
diff --git a/torch/nn/modules/distance.py b/torch/nn/modules/distance.py
index 29501e9010..00513ac2aa 100644
--- a/torch/nn/modules/distance.py
+++ b/torch/nn/modules/distance.py
@@ -6,7 +6,7 @@ from torch import Tensor
class PairwiseDistance(Module):
r"""
- Computes the batchwise pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm:
+ Computes the pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm:
.. math ::
\Vert x \Vert _p = \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}.
@@ -18,9 +18,10 @@ class PairwiseDistance(Module):
keepdim (bool, optional): Determines whether or not to keep the vector dimension.
Default: False
Shape:
- - Input1: :math:`(N, D)` where `D = vector dimension`
- - Input2: :math:`(N, D)`, same shape as the Input1
- - Output: :math:`(N)`. If :attr:`keepdim` is ``True``, then :math:`(N, 1)`.
+ - Input1: :math:`(N, D)` or :math:`(D)` where `N = batch dimension` and `D = vector dimension`
+ - Input2: :math:`(N, D)` or :math:`(D)`, same shape as the Input1
+ - Output: :math:`(N)` or :math:`()` based on input dimension.
+ If :attr:`keepdim` is ``True``, then :math:`(N, 1)` or :math:`(1)` based on input dimension.
Examples::
>>> pdist = nn.PairwiseDistance(p=2)
>>> input1 = torch.randn(100, 128)
diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py
index e0989e5b44..6d295aa733 100644
--- a/torch/nn/modules/loss.py
+++ b/torch/nn/modules/loss.py
@@ -1436,9 +1436,9 @@ class TripletMarginLoss(_Loss):
specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
Shape:
- - Input: :math:`(N, D)` where :math:`D` is the vector dimension.
- - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar
- otherwise.
+ - Input: :math:`(N, D)` or :math`(D)` where :math:`D` is the vector dimension.
+ - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'`` and
+ input shape is :math`(N, D)`; a scalar otherwise.
Examples::
diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py
index 58f30d8742..74895bb746 100644
--- a/torch/testing/_internal/common_nn.py
+++ b/torch/testing/_internal/common_nn.py
@@ -1305,9 +1305,15 @@ def single_batch_reference_fn(input, parameters, module):
The module is passed the input and target in batched form with a single item.
The output is squeezed to compare with the no-batch input.
"""
- single_batch_input = input.unsqueeze(0)
+ def unsqueeze_inp(inp):
+ if isinstance(inp, (list, tuple)):
+ return [t.unsqueeze(0) for t in inp]
+ return inp.unsqueeze(0)
+
+ single_batch_input = unsqueeze_inp(input)
+ single_batch_input = [single_batch_input] if isinstance(single_batch_input, torch.Tensor) else single_batch_input
with freeze_rng_state():
- return module(single_batch_input).squeeze(0)
+ return module(*single_batch_input).squeeze(0)
new_module_tests = [
@@ -3944,6 +3950,33 @@ new_module_tests = [
pickle=False,
),
dict(
+ module_name='PairwiseDistance',
+ input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)),
+ ),
+ dict(
+ module_name='PairwiseDistance',
+ input_fn=lambda: (torch.randn(10, 1), torch.randn(10, 8)),
+ desc='broadcast_lhs'
+ ),
+ dict(
+ module_name='PairwiseDistance',
+ input_fn=lambda: (torch.randn(10, 8), torch.randn(1, 8)),
+ desc='broadcast_rhs'
+ ),
+ dict(
+ module_name='PairwiseDistance',
+ constructor_args=(1.5, 1e-05, True),
+ cpp_constructor_args='torch::nn::PairwiseDistanceOptions().p(1.5).eps(1e-05).keepdim(true)',
+ input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)),
+ desc='with_non_default_args',
+ ),
+ dict(
+ module_name='PairwiseDistance',
+ input_fn=lambda: (torch.randn(8), torch.randn(8)),
+ reference_fn=single_batch_reference_fn,
+ desc='no_batch_dim',
+ ),
+ dict(
module_name='TransformerEncoderLayer',
constructor_args=(4, 2, 16, 0.0),
cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2)
@@ -5445,6 +5478,8 @@ classification_criterion_no_batch = [
('SoftMarginLoss', lambda: torch.randn(9), lambda: torch.tensor([-1, 1, 1] * 3)),
('NLLLoss', lambda: F.log_softmax(torch.randn(3), dim=0), lambda: torch.tensor(1)),
('CosineEmbeddingLoss', lambda: (torch.randn(9), torch.randn(9)), lambda: torch.tensor(1)),
+ # For TripletMarginLoss, input_fn : (anchor, positive) and target_fn : negative
+ ('TripletMarginLoss', lambda: (torch.randn(9), torch.randn(9)), lambda: torch.randn(9)),
]
classification_criterion_no_batch_extra_info: Dict[str, dict] = {
'MultiLabelMarginLoss': {'check_gradgrad': False},