summaryrefslogtreecommitdiff
path: root/test/common_nn.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/common_nn.py')
-rw-r--r--test/common_nn.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/test/common_nn.py b/test/common_nn.py
index 2e53339c04..31c2a8bffb 100644
--- a/test/common_nn.py
+++ b/test/common_nn.py
@@ -423,6 +423,22 @@ def cosineembeddingloss_reference(input1, input2, target, margin=0, size_average
return output
+def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False,
+ size_average=True, reduce=True):
+ d_p = torch.pairwise_distance(anchor, positive, p, eps)
+ d_n = torch.pairwise_distance(anchor, negative, p, eps)
+ if swap:
+ d_s = torch.pairwise_distance(positive, negative, p, eps)
+ d_n = torch.min(d_n, d_s)
+
+ output = torch.clamp(margin + d_p - d_n, min=0.0)
+ if reduce and size_average:
+ return output.mean()
+ elif reduce:
+ return output.sum()
+ return output
+
+
loss_reference_fns = {
'KLDivLoss': kldivloss_reference,
'NLLLoss': nllloss_reference,
@@ -433,6 +449,7 @@ loss_reference_fns = {
'SoftMarginLoss': softmarginloss_reference,
'MultiMarginLoss': multimarginloss_reference,
'CosineEmbeddingLoss': cosineembeddingloss_reference,
+ 'TripletMarginLoss': tripletmarginloss_reference,
}