diff options
Diffstat (limited to 'torch/nn/modules/activation.py')
-rw-r--r-- | torch/nn/modules/activation.py | 105 |
1 files changed, 31 insertions, 74 deletions
diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index d4a2a46db4..18535cf67e 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -45,12 +45,11 @@ class Threshold(Module): def forward(self, input): return F.threshold(input, self.threshold, self.value, self.inplace) - def __repr__(self): + def extra_repr(self): inplace_str = ', inplace' if self.inplace else '' - return self.__class__.__name__ + ' (' \ - + str(self.threshold) \ - + ', ' + str(self.value) \ - + inplace_str + ')' + return 'threshold={}, value={}{}'.format( + self.threshold, self.value, inplace_str + ) class ReLU(Threshold): @@ -77,10 +76,9 @@ class ReLU(Threshold): def __init__(self, inplace=False): super(ReLU, self).__init__(0, 0, inplace) - def __repr__(self): + def extra_repr(self): inplace_str = 'inplace' if self.inplace else '' - return self.__class__.__name__ + '(' \ - + inplace_str + ')' + return inplace_str class RReLU(Module): @@ -129,12 +127,9 @@ class RReLU(Module): def forward(self, input): return F.rrelu(input, self.lower, self.upper, self.training, self.inplace) - def __repr__(self): + def extra_repr(self): inplace_str = ', inplace' if self.inplace else '' - return self.__class__.__name__ + '(' \ - + str(self.lower) \ - + ', ' + str(self.upper) \ - + inplace_str + ')' + return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str) class Hardtanh(Module): @@ -191,12 +186,11 @@ class Hardtanh(Module): def forward(self, input): return F.hardtanh(input, self.min_val, self.max_val, self.inplace) - def __repr__(self): + def extra_repr(self): inplace_str = ', inplace' if self.inplace else '' - return self.__class__.__name__ + '(' \ - + 'min_val=' + str(self.min_val) \ - + ', max_val=' + str(self.max_val) \ - + inplace_str + ')' + return 'min_val={}, max_val={}{}'.format( + self.min_val, self.max_val, inplace_str + ) class ReLU6(Hardtanh): @@ -222,10 +216,9 @@ class ReLU6(Hardtanh): def __init__(self, inplace=False): super(ReLU6, self).__init__(0, 6, inplace) - def __repr__(self): + def extra_repr(self): inplace_str = 'inplace' if self.inplace else '' - return self.__class__.__name__ + '(' \ - + inplace_str + ')' + return inplace_str class Sigmoid(Module): @@ -248,9 +241,6 @@ class Sigmoid(Module): def forward(self, input): return torch.sigmoid(input) - def __repr__(self): - return self.__class__.__name__ + '()' - class Tanh(Module): r"""Applies element-wise, @@ -273,9 +263,6 @@ class Tanh(Module): def forward(self, input): return torch.tanh(input) - def __repr__(self): - return self.__class__.__name__ + '()' - class ELU(Module): r"""Applies element-wise, @@ -307,11 +294,9 @@ class ELU(Module): def forward(self, input): return F.elu(input, self.alpha, self.inplace) - def __repr__(self): + def extra_repr(self): inplace_str = ', inplace' if self.inplace else '' - return self.__class__.__name__ + '(' \ - + 'alpha=' + str(self.alpha) \ - + inplace_str + ')' + return 'alpha={}{}'.format(self.alpha, inplace_str) class SELU(Module): @@ -348,9 +333,9 @@ class SELU(Module): def forward(self, input): return F.selu(input, self.inplace) - def __repr__(self): - inplace_str = '(inplace)' if self.inplace else '' - return self.__class__.__name__ + inplace_str + def extra_repr(self): + inplace_str = 'inplace' if self.inplace else '' + return inplace_str class GLU(Module): @@ -380,8 +365,8 @@ class GLU(Module): def forward(self, input): return F.glu(input, self.dim) - def __repr__(self): - return '{}(dim={})'.format(self.__class__.__name__, self.dim) + def extra_repr(self): + return 'dim={}'.format(self.dim) class Hardshrink(Module): @@ -420,9 +405,8 @@ class Hardshrink(Module): def forward(self, input): return F.hardshrink(input, self.lambd) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + str(self.lambd) + ')' + def extra_repr(self): + return '{}'.format(self.lambd) class LeakyReLU(Module): @@ -462,11 +446,9 @@ class LeakyReLU(Module): def forward(self, input): return F.leaky_relu(input, self.negative_slope, self.inplace) - def __repr__(self): + def extra_repr(self): inplace_str = ', inplace' if self.inplace else '' - return self.__class__.__name__ + '(' \ - + str(self.negative_slope) \ - + inplace_str + ')' + return 'negative_slope={}{}'.format(self.negative_slope, inplace_str) class LogSigmoid(Module): @@ -489,9 +471,6 @@ class LogSigmoid(Module): def forward(self, input): return F.logsigmoid(input) - def __repr__(self): - return self.__class__.__name__ + '()' - class Softplus(Module): r"""Applies element-wise :math:`\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))` @@ -528,10 +507,8 @@ class Softplus(Module): def forward(self, input): return F.softplus(input, self.beta, self.threshold) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'beta=' + str(self.beta) \ - + ', threshold=' + str(self.threshold) + ')' + def extra_repr(self): + return 'beta={}, threshold={}'.format(self.beta, self.threshold) class Softshrink(Module): @@ -571,9 +548,8 @@ class Softshrink(Module): def forward(self, input): return F.softshrink(input, self.lambd) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + str(self.lambd) + ')' + def extra_repr(self): + return str(self.lambd) class PReLU(Module): @@ -621,9 +597,8 @@ class PReLU(Module): def forward(self, input): return F.prelu(input, self.weight) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'num_parameters=' + str(self.num_parameters) + ')' + def extra_repr(self): + return 'num_parameters={}'.format(self.num_parameters) class Softsign(Module): @@ -646,9 +621,6 @@ class Softsign(Module): def forward(self, input): return F.softsign(input) - def __repr__(self): - return self.__class__.__name__ + '()' - class Tanhshrink(Module): r"""Applies element-wise, :math:`\text{Tanhshrink}(x) = x - \text{Tanh}(x)` @@ -670,9 +642,6 @@ class Tanhshrink(Module): def forward(self, input): return F.tanhshrink(input) - def __repr__(self): - return self.__class__.__name__ + '()' - class Softmin(Module): r"""Applies the Softmin function to an n-dimensional input Tensor @@ -706,9 +675,6 @@ class Softmin(Module): def forward(self, input): return F.softmin(input, self.dim, _stacklevel=5) - def __repr__(self): - return self.__class__.__name__ + '()' - class Softmax(Module): r"""Applies the Softmax function to an n-dimensional input Tensor @@ -754,9 +720,6 @@ class Softmax(Module): def forward(self, input): return F.softmax(input, self.dim, _stacklevel=5) - def __repr__(self): - return self.__class__.__name__ + '()' - class Softmax2d(Module): r"""Applies SoftMax over features to each spatial location. @@ -784,9 +747,6 @@ class Softmax2d(Module): assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input' return F.softmax(input, 1, _stacklevel=5) - def __repr__(self): - return self.__class__.__name__ + '()' - class LogSoftmax(Module): r"""Applies the `Log(Softmax(x))` function to an n-dimensional input Tensor. @@ -824,6 +784,3 @@ class LogSoftmax(Module): def forward(self, input): return F.log_softmax(input, self.dim, _stacklevel=5) - - def __repr__(self): - return self.__class__.__name__ + '()' |