diff options
author | li-roy <8813817+li-roy@users.noreply.github.com> | 2018-03-17 08:10:48 -0700 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2018-03-17 11:10:48 -0400 |
commit | e876b5d9d00d11e99be70b69448545a948356e98 (patch) | |
tree | cc149c9da6b38f8b07bc10734c7a8ef5a040ebc0 /test/common_nn.py | |
parent | 32462e0ac460c29f58373e83aead482a06655ee2 (diff) | |
download | pytorch-e876b5d9d00d11e99be70b69448545a948356e98.tar.gz pytorch-e876b5d9d00d11e99be70b69448545a948356e98.tar.bz2 pytorch-e876b5d9d00d11e99be70b69448545a948356e98.zip |
implement TripletMarginLoss as a native function (#5680)
* implement TripletMarginLoss as a native function
* implement TripletMarginLoss as native function
* fix compile error
* address comments
* address comments
* Add keepdim arg to pairwise distance
Diffstat (limited to 'test/common_nn.py')
-rw-r--r-- | test/common_nn.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/test/common_nn.py b/test/common_nn.py index 2e53339c04..31c2a8bffb 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -423,6 +423,22 @@ def cosineembeddingloss_reference(input1, input2, target, margin=0, size_average return output +def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, + size_average=True, reduce=True): + d_p = torch.pairwise_distance(anchor, positive, p, eps) + d_n = torch.pairwise_distance(anchor, negative, p, eps) + if swap: + d_s = torch.pairwise_distance(positive, negative, p, eps) + d_n = torch.min(d_n, d_s) + + output = torch.clamp(margin + d_p - d_n, min=0.0) + if reduce and size_average: + return output.mean() + elif reduce: + return output.sum() + return output + + loss_reference_fns = { 'KLDivLoss': kldivloss_reference, 'NLLLoss': nllloss_reference, @@ -433,6 +449,7 @@ loss_reference_fns = { 'SoftMarginLoss': softmarginloss_reference, 'MultiMarginLoss': multimarginloss_reference, 'CosineEmbeddingLoss': cosineembeddingloss_reference, + 'TripletMarginLoss': tripletmarginloss_reference, } |