diff options
Diffstat (limited to 'test/common_nn.py')
-rw-r--r-- | test/common_nn.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/test/common_nn.py b/test/common_nn.py index 2e53339c04..31c2a8bffb 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -423,6 +423,22 @@ def cosineembeddingloss_reference(input1, input2, target, margin=0, size_average return output +def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, + size_average=True, reduce=True): + d_p = torch.pairwise_distance(anchor, positive, p, eps) + d_n = torch.pairwise_distance(anchor, negative, p, eps) + if swap: + d_s = torch.pairwise_distance(positive, negative, p, eps) + d_n = torch.min(d_n, d_s) + + output = torch.clamp(margin + d_p - d_n, min=0.0) + if reduce and size_average: + return output.mean() + elif reduce: + return output.sum() + return output + + loss_reference_fns = { 'KLDivLoss': kldivloss_reference, 'NLLLoss': nllloss_reference, @@ -433,6 +449,7 @@ loss_reference_fns = { 'SoftMarginLoss': softmarginloss_reference, 'MultiMarginLoss': multimarginloss_reference, 'CosineEmbeddingLoss': cosineembeddingloss_reference, + 'TripletMarginLoss': tripletmarginloss_reference, } |