1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
|
import torch
from .Module import Module
from .utils import clear
class RReLU(Module):
def __init__(self, lower=1./8, upper=1./3, inplace=False):
super(RReLU, self).__init__()
self.lower = lower
self.upper = upper
self.inplace = inplace
assert self.lower <= self.upper and self.lower >= 0 and self.upper >= 0
self.noise = torch.Tensor()
self.train = True
def updateOutput(self, input):
self._backend.RReLU_updateOutput(
self._backend.library_state,
input,
self.output,
self.noise,
self.lower,
self.upper,
self.train,
self.inplace,
torch.default_generator if not input.is_cuda else 0
)
return self.output
def updateGradInput(self, input, gradOutput):
self._backend.RReLU_updateGradInput(
self._backend.library_state,
input,
gradOutput,
self.gradInput,
self.noise,
self.lower,
self.upper,
self.train,
self.inplace
)
return self.gradInput
def __repr__(self):
return super(RReLU, self).__repr__() + '({:.4f}, {:.4f})'.format(self.lower, self.upper)
def clearState(self):
clear(self, 'noise')
return super(RReLU, self).clearState()
|