summaryrefslogtreecommitdiff
path: root/test/common_nn.py
diff options
context:
space:
mode:
authorli-roy <8813817+li-roy@users.noreply.github.com>2018-03-17 08:10:48 -0700
committerSoumith Chintala <soumith@gmail.com>2018-03-17 11:10:48 -0400
commite876b5d9d00d11e99be70b69448545a948356e98 (patch)
treecc149c9da6b38f8b07bc10734c7a8ef5a040ebc0 /test/common_nn.py
parent32462e0ac460c29f58373e83aead482a06655ee2 (diff)
downloadpytorch-e876b5d9d00d11e99be70b69448545a948356e98.tar.gz
pytorch-e876b5d9d00d11e99be70b69448545a948356e98.tar.bz2
pytorch-e876b5d9d00d11e99be70b69448545a948356e98.zip
implement TripletMarginLoss as a native function (#5680)
* implement TripletMarginLoss as a native function * implement TripletMarginLoss as native function * fix compile error * address comments * address comments * Add keepdim arg to pairwise distance
Diffstat (limited to 'test/common_nn.py')
-rw-r--r--test/common_nn.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/test/common_nn.py b/test/common_nn.py
index 2e53339c04..31c2a8bffb 100644
--- a/test/common_nn.py
+++ b/test/common_nn.py
@@ -423,6 +423,22 @@ def cosineembeddingloss_reference(input1, input2, target, margin=0, size_average
return output
+def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False,
+ size_average=True, reduce=True):
+ d_p = torch.pairwise_distance(anchor, positive, p, eps)
+ d_n = torch.pairwise_distance(anchor, negative, p, eps)
+ if swap:
+ d_s = torch.pairwise_distance(positive, negative, p, eps)
+ d_n = torch.min(d_n, d_s)
+
+ output = torch.clamp(margin + d_p - d_n, min=0.0)
+ if reduce and size_average:
+ return output.mean()
+ elif reduce:
+ return output.sum()
+ return output
+
+
loss_reference_fns = {
'KLDivLoss': kldivloss_reference,
'NLLLoss': nllloss_reference,
@@ -433,6 +449,7 @@ loss_reference_fns = {
'SoftMarginLoss': softmarginloss_reference,
'MultiMarginLoss': multimarginloss_reference,
'CosineEmbeddingLoss': cosineembeddingloss_reference,
+ 'TripletMarginLoss': tripletmarginloss_reference,
}