summaryrefslogtreecommitdiff
path: root/torch/nn/modules/activation.py
diff options
context:
space:
mode:
Diffstat (limited to 'torch/nn/modules/activation.py')
-rw-r--r--torch/nn/modules/activation.py105
1 files changed, 31 insertions, 74 deletions
diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py
index d4a2a46db4..18535cf67e 100644
--- a/torch/nn/modules/activation.py
+++ b/torch/nn/modules/activation.py
@@ -45,12 +45,11 @@ class Threshold(Module):
def forward(self, input):
return F.threshold(input, self.threshold, self.value, self.inplace)
- def __repr__(self):
+ def extra_repr(self):
inplace_str = ', inplace' if self.inplace else ''
- return self.__class__.__name__ + ' (' \
- + str(self.threshold) \
- + ', ' + str(self.value) \
- + inplace_str + ')'
+ return 'threshold={}, value={}{}'.format(
+ self.threshold, self.value, inplace_str
+ )
class ReLU(Threshold):
@@ -77,10 +76,9 @@ class ReLU(Threshold):
def __init__(self, inplace=False):
super(ReLU, self).__init__(0, 0, inplace)
- def __repr__(self):
+ def extra_repr(self):
inplace_str = 'inplace' if self.inplace else ''
- return self.__class__.__name__ + '(' \
- + inplace_str + ')'
+ return inplace_str
class RReLU(Module):
@@ -129,12 +127,9 @@ class RReLU(Module):
def forward(self, input):
return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
- def __repr__(self):
+ def extra_repr(self):
inplace_str = ', inplace' if self.inplace else ''
- return self.__class__.__name__ + '(' \
- + str(self.lower) \
- + ', ' + str(self.upper) \
- + inplace_str + ')'
+ return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
class Hardtanh(Module):
@@ -191,12 +186,11 @@ class Hardtanh(Module):
def forward(self, input):
return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
- def __repr__(self):
+ def extra_repr(self):
inplace_str = ', inplace' if self.inplace else ''
- return self.__class__.__name__ + '(' \
- + 'min_val=' + str(self.min_val) \
- + ', max_val=' + str(self.max_val) \
- + inplace_str + ')'
+ return 'min_val={}, max_val={}{}'.format(
+ self.min_val, self.max_val, inplace_str
+ )
class ReLU6(Hardtanh):
@@ -222,10 +216,9 @@ class ReLU6(Hardtanh):
def __init__(self, inplace=False):
super(ReLU6, self).__init__(0, 6, inplace)
- def __repr__(self):
+ def extra_repr(self):
inplace_str = 'inplace' if self.inplace else ''
- return self.__class__.__name__ + '(' \
- + inplace_str + ')'
+ return inplace_str
class Sigmoid(Module):
@@ -248,9 +241,6 @@ class Sigmoid(Module):
def forward(self, input):
return torch.sigmoid(input)
- def __repr__(self):
- return self.__class__.__name__ + '()'
-
class Tanh(Module):
r"""Applies element-wise,
@@ -273,9 +263,6 @@ class Tanh(Module):
def forward(self, input):
return torch.tanh(input)
- def __repr__(self):
- return self.__class__.__name__ + '()'
-
class ELU(Module):
r"""Applies element-wise,
@@ -307,11 +294,9 @@ class ELU(Module):
def forward(self, input):
return F.elu(input, self.alpha, self.inplace)
- def __repr__(self):
+ def extra_repr(self):
inplace_str = ', inplace' if self.inplace else ''
- return self.__class__.__name__ + '(' \
- + 'alpha=' + str(self.alpha) \
- + inplace_str + ')'
+ return 'alpha={}{}'.format(self.alpha, inplace_str)
class SELU(Module):
@@ -348,9 +333,9 @@ class SELU(Module):
def forward(self, input):
return F.selu(input, self.inplace)
- def __repr__(self):
- inplace_str = '(inplace)' if self.inplace else ''
- return self.__class__.__name__ + inplace_str
+ def extra_repr(self):
+ inplace_str = 'inplace' if self.inplace else ''
+ return inplace_str
class GLU(Module):
@@ -380,8 +365,8 @@ class GLU(Module):
def forward(self, input):
return F.glu(input, self.dim)
- def __repr__(self):
- return '{}(dim={})'.format(self.__class__.__name__, self.dim)
+ def extra_repr(self):
+ return 'dim={}'.format(self.dim)
class Hardshrink(Module):
@@ -420,9 +405,8 @@ class Hardshrink(Module):
def forward(self, input):
return F.hardshrink(input, self.lambd)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + str(self.lambd) + ')'
+ def extra_repr(self):
+ return '{}'.format(self.lambd)
class LeakyReLU(Module):
@@ -462,11 +446,9 @@ class LeakyReLU(Module):
def forward(self, input):
return F.leaky_relu(input, self.negative_slope, self.inplace)
- def __repr__(self):
+ def extra_repr(self):
inplace_str = ', inplace' if self.inplace else ''
- return self.__class__.__name__ + '(' \
- + str(self.negative_slope) \
- + inplace_str + ')'
+ return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
class LogSigmoid(Module):
@@ -489,9 +471,6 @@ class LogSigmoid(Module):
def forward(self, input):
return F.logsigmoid(input)
- def __repr__(self):
- return self.__class__.__name__ + '()'
-
class Softplus(Module):
r"""Applies element-wise :math:`\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))`
@@ -528,10 +507,8 @@ class Softplus(Module):
def forward(self, input):
return F.softplus(input, self.beta, self.threshold)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'beta=' + str(self.beta) \
- + ', threshold=' + str(self.threshold) + ')'
+ def extra_repr(self):
+ return 'beta={}, threshold={}'.format(self.beta, self.threshold)
class Softshrink(Module):
@@ -571,9 +548,8 @@ class Softshrink(Module):
def forward(self, input):
return F.softshrink(input, self.lambd)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + str(self.lambd) + ')'
+ def extra_repr(self):
+ return str(self.lambd)
class PReLU(Module):
@@ -621,9 +597,8 @@ class PReLU(Module):
def forward(self, input):
return F.prelu(input, self.weight)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'num_parameters=' + str(self.num_parameters) + ')'
+ def extra_repr(self):
+ return 'num_parameters={}'.format(self.num_parameters)
class Softsign(Module):
@@ -646,9 +621,6 @@ class Softsign(Module):
def forward(self, input):
return F.softsign(input)
- def __repr__(self):
- return self.__class__.__name__ + '()'
-
class Tanhshrink(Module):
r"""Applies element-wise, :math:`\text{Tanhshrink}(x) = x - \text{Tanh}(x)`
@@ -670,9 +642,6 @@ class Tanhshrink(Module):
def forward(self, input):
return F.tanhshrink(input)
- def __repr__(self):
- return self.__class__.__name__ + '()'
-
class Softmin(Module):
r"""Applies the Softmin function to an n-dimensional input Tensor
@@ -706,9 +675,6 @@ class Softmin(Module):
def forward(self, input):
return F.softmin(input, self.dim, _stacklevel=5)
- def __repr__(self):
- return self.__class__.__name__ + '()'
-
class Softmax(Module):
r"""Applies the Softmax function to an n-dimensional input Tensor
@@ -754,9 +720,6 @@ class Softmax(Module):
def forward(self, input):
return F.softmax(input, self.dim, _stacklevel=5)
- def __repr__(self):
- return self.__class__.__name__ + '()'
-
class Softmax2d(Module):
r"""Applies SoftMax over features to each spatial location.
@@ -784,9 +747,6 @@ class Softmax2d(Module):
assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input'
return F.softmax(input, 1, _stacklevel=5)
- def __repr__(self):
- return self.__class__.__name__ + '()'
-
class LogSoftmax(Module):
r"""Applies the `Log(Softmax(x))` function to an n-dimensional input Tensor.
@@ -824,6 +784,3 @@ class LogSoftmax(Module):
def forward(self, input):
return F.log_softmax(input, self.dim, _stacklevel=5)
-
- def __repr__(self):
- return self.__class__.__name__ + '()'