summaryrefslogtreecommitdiff
path: root/torch/legacy/nn/RReLU.py
blob: e1c9c83a52168dca5aa4e04b3059dd0d7cdb8278 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from .Module import Module
from .utils import clear

class RReLU(Module):

    def __init__(self, lower=1./8, upper=1./3, inplace=False):
        super(RReLU, self).__init__()
        self.lower = lower
        self.upper = upper
        self.inplace = inplace

        assert self.lower <= self.upper and self.lower >= 0 and self.upper >= 0
        self.noise = torch.Tensor()
        self.train = True

    def updateOutput(self, input):
        self._backend.RReLU_updateOutput(
            self._backend.library_state,
            input,
            self.output,
            self.noise,
            self.lower,
            self.upper,
            self.train,
            self.inplace,
            torch.default_generator if not input.is_cuda else 0
        )
        return self.output

    def updateGradInput(self, input, gradOutput):
        self._backend.RReLU_updateGradInput(
            self._backend.library_state,
            input,
            gradOutput,
            self.gradInput,
            self.noise,
            self.lower,
            self.upper,
            self.train,
            self.inplace
        )
        return self.gradInput

    def __repr__(self):
        return super(RReLU, self).__repr__() + '({:.4f}, {:.4f})'.format(self.lower, self.upper)

    def clearState(self):
        clear(self, 'noise')
        return super(RReLU, self).clearState()