summaryrefslogtreecommitdiff
path: root/torch/legacy/nn/MultiLabelMarginCriterion.py
blob: 42d6f7ac9111934ae6733059d36d67900cdd60ef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from .Criterion import Criterion


class MultiLabelMarginCriterion(Criterion):

    def __init__(self, sizeAverage=True):
        super(MultiLabelMarginCriterion, self).__init__()
        self.sizeAverage = sizeAverage
        self.isTarget = torch.Tensor()
        self.output_tensor = None

    def updateOutput(self, input, target):
        if self.output_tensor is None:
            self.output_tensor = input.new(1)
        target = target.long()
        self._backend.MultiLabelMarginCriterion_updateOutput(
            self._backend.library_state,
            input,
            target,
            self.output_tensor,
            self.isTarget,
            self.sizeAverage
        )
        self.output = self.output_tensor[0]
        return self.output

    def updateGradInput(self, input, target):
        target = target.long()
        self._backend.MultiLabelMarginCriterion_updateGradInput(
            self._backend.library_state,
            input,
            target,
            self.gradInput,
            self.isTarget,
            self.sizeAverage
        )
        return self.gradInput