summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKaiyu Shi <skyisno.1@gmail.com>2018-04-03 01:52:33 +0800
committerEdward Z. Yang <ezyang@mit.edu>2018-04-02 13:52:33 -0400
commit605307f8f3c249d9279030502d2aac98d4170b83 (patch)
treeb1327aea596e5a62965087be7b311281885f540f
parent7355f5cd8dc52a048d8c367cabfed9e888acd586 (diff)
downloadpytorch-605307f8f3c249d9279030502d2aac98d4170b83.tar.gz
pytorch-605307f8f3c249d9279030502d2aac98d4170b83.tar.bz2
pytorch-605307f8f3c249d9279030502d2aac98d4170b83.zip
Add support for printing extra information in Module and refactor redundant codes (#5936)
This PR enables users to print extra information of their subclassed nn.Module. Now I simply insert the user-defined string at the ending of module name, which should be discussed in this PR. Before this PR, users should redefine the __repr__ and copy&paste the source code from Module. * Add support for extra information on Module * Rewrite the repr method of Module * Fix flake8 * Change the __repr__ to get_extra_repr in Linear * Fix extra new-line for empty line * Add test for __repr__ method * Fix bug of block string indent * Add indent for multi-line repr test. * Address review comments * Update tutorial for creating nn.Module * Fix flake8, add extra_repr of bilinear * Refactor DropoutNd * Change to extra_repr in some Modules * Fix flake8 * Refactor padding modules * Refactor pooling module * Fix typo * Change to extra_repr * Fix bug for GroupNorm * Fix bug for LayerNorm
-rw-r--r--docs/source/notes/extending.rst7
-rw-r--r--test/test_nn.py18
-rw-r--r--torch/nn/modules/activation.py105
-rw-r--r--torch/nn/modules/batchnorm.py7
-rw-r--r--torch/nn/modules/container.py5
-rw-r--r--torch/nn/modules/conv.py7
-rw-r--r--torch/nn/modules/dropout.py63
-rw-r--r--torch/nn/modules/fold.py21
-rw-r--r--torch/nn/modules/linear.py19
-rw-r--r--torch/nn/modules/module.py37
-rw-r--r--torch/nn/modules/normalization.py30
-rw-r--r--torch/nn/modules/padding.py100
-rw-r--r--torch/nn/modules/pixelshuffle.py4
-rw-r--r--torch/nn/modules/pooling.py536
-rw-r--r--torch/nn/modules/rnn.py14
-rw-r--r--torch/nn/modules/sparse.py14
-rw-r--r--torch/nn/modules/upsampling.py4
17 files changed, 402 insertions, 589 deletions
diff --git a/docs/source/notes/extending.rst b/docs/source/notes/extending.rst
index e9b7ef26b2..216da7795c 100644
--- a/docs/source/notes/extending.rst
+++ b/docs/source/notes/extending.rst
@@ -173,6 +173,13 @@ This is how a ``Linear`` module can be implemented::
# See the autograd section for explanation of what happens here.
return LinearFunction.apply(input, self.weight, self.bias)
+ def extra_repr(self):
+ # (Optional)Set the extra information about this module. You can test
+ # it by printing an object of this class.
+ return 'in_features={}, out_features={}, bias={}'.format(
+ self.in_features, self.out_features, self.bias is not None
+ )
+
Writing custom C extensions
---------------------------
diff --git a/test/test_nn.py b/test/test_nn.py
index d8ada0ddbd..31c575f52a 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -768,6 +768,24 @@ class TestNN(NNTestCase):
for key in keys:
self.assertTrue(hasattr(linear, key))
+ def test_repr(self):
+ # no extra information or sub-modules
+ empty_sequential = nn.Sequential()
+ expected_repr_empty = 'Sequential()'
+ self.assertEqual(repr(empty_sequential), expected_repr_empty)
+
+ # one liner extra information
+ linear = nn.Linear(1, 1)
+ expected_repr_linear = 'Linear(in_features=1, out_features=1, bias=True)'
+ self.assertEqual(repr(linear), expected_repr_linear)
+
+ # sub-modules repr
+ sequential = nn.Sequential(linear)
+ expected_repr_sequential = 'Sequential(\n' \
+ ' (0): Linear(in_features=1, out_features=1, bias=True)\n' \
+ ')'
+ self.assertEqual(repr(sequential), expected_repr_sequential)
+
def test_dir_digit(self):
model = nn.Sequential(nn.Linear(2, 2))
keys = dir(model)
diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py
index d4a2a46db4..18535cf67e 100644
--- a/torch/nn/modules/activation.py
+++ b/torch/nn/modules/activation.py
@@ -45,12 +45,11 @@ class Threshold(Module):
def forward(self, input):
return F.threshold(input, self.threshold, self.value, self.inplace)
- def __repr__(self):
+ def extra_repr(self):
inplace_str = ', inplace' if self.inplace else ''
- return self.__class__.__name__ + ' (' \
- + str(self.threshold) \
- + ', ' + str(self.value) \
- + inplace_str + ')'
+ return 'threshold={}, value={}{}'.format(
+ self.threshold, self.value, inplace_str
+ )
class ReLU(Threshold):
@@ -77,10 +76,9 @@ class ReLU(Threshold):
def __init__(self, inplace=False):
super(ReLU, self).__init__(0, 0, inplace)
- def __repr__(self):
+ def extra_repr(self):
inplace_str = 'inplace' if self.inplace else ''
- return self.__class__.__name__ + '(' \
- + inplace_str + ')'
+ return inplace_str
class RReLU(Module):
@@ -129,12 +127,9 @@ class RReLU(Module):
def forward(self, input):
return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
- def __repr__(self):
+ def extra_repr(self):
inplace_str = ', inplace' if self.inplace else ''
- return self.__class__.__name__ + '(' \
- + str(self.lower) \
- + ', ' + str(self.upper) \
- + inplace_str + ')'
+ return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
class Hardtanh(Module):
@@ -191,12 +186,11 @@ class Hardtanh(Module):
def forward(self, input):
return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
- def __repr__(self):
+ def extra_repr(self):
inplace_str = ', inplace' if self.inplace else ''
- return self.__class__.__name__ + '(' \
- + 'min_val=' + str(self.min_val) \
- + ', max_val=' + str(self.max_val) \
- + inplace_str + ')'
+ return 'min_val={}, max_val={}{}'.format(
+ self.min_val, self.max_val, inplace_str
+ )
class ReLU6(Hardtanh):
@@ -222,10 +216,9 @@ class ReLU6(Hardtanh):
def __init__(self, inplace=False):
super(ReLU6, self).__init__(0, 6, inplace)
- def __repr__(self):
+ def extra_repr(self):
inplace_str = 'inplace' if self.inplace else ''
- return self.__class__.__name__ + '(' \
- + inplace_str + ')'
+ return inplace_str
class Sigmoid(Module):
@@ -248,9 +241,6 @@ class Sigmoid(Module):
def forward(self, input):
return torch.sigmoid(input)
- def __repr__(self):
- return self.__class__.__name__ + '()'
-
class Tanh(Module):
r"""Applies element-wise,
@@ -273,9 +263,6 @@ class Tanh(Module):
def forward(self, input):
return torch.tanh(input)
- def __repr__(self):
- return self.__class__.__name__ + '()'
-
class ELU(Module):
r"""Applies element-wise,
@@ -307,11 +294,9 @@ class ELU(Module):
def forward(self, input):
return F.elu(input, self.alpha, self.inplace)
- def __repr__(self):
+ def extra_repr(self):
inplace_str = ', inplace' if self.inplace else ''
- return self.__class__.__name__ + '(' \
- + 'alpha=' + str(self.alpha) \
- + inplace_str + ')'
+ return 'alpha={}{}'.format(self.alpha, inplace_str)
class SELU(Module):
@@ -348,9 +333,9 @@ class SELU(Module):
def forward(self, input):
return F.selu(input, self.inplace)
- def __repr__(self):
- inplace_str = '(inplace)' if self.inplace else ''
- return self.__class__.__name__ + inplace_str
+ def extra_repr(self):
+ inplace_str = 'inplace' if self.inplace else ''
+ return inplace_str
class GLU(Module):
@@ -380,8 +365,8 @@ class GLU(Module):
def forward(self, input):
return F.glu(input, self.dim)
- def __repr__(self):
- return '{}(dim={})'.format(self.__class__.__name__, self.dim)
+ def extra_repr(self):
+ return 'dim={}'.format(self.dim)
class Hardshrink(Module):
@@ -420,9 +405,8 @@ class Hardshrink(Module):
def forward(self, input):
return F.hardshrink(input, self.lambd)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + str(self.lambd) + ')'
+ def extra_repr(self):
+ return '{}'.format(self.lambd)
class LeakyReLU(Module):
@@ -462,11 +446,9 @@ class LeakyReLU(Module):
def forward(self, input):
return F.leaky_relu(input, self.negative_slope, self.inplace)
- def __repr__(self):
+ def extra_repr(self):
inplace_str = ', inplace' if self.inplace else ''
- return self.__class__.__name__ + '(' \
- + str(self.negative_slope) \
- + inplace_str + ')'
+ return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
class LogSigmoid(Module):
@@ -489,9 +471,6 @@ class LogSigmoid(Module):
def forward(self, input):
return F.logsigmoid(input)
- def __repr__(self):
- return self.__class__.__name__ + '()'
-
class Softplus(Module):
r"""Applies element-wise :math:`\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))`
@@ -528,10 +507,8 @@ class Softplus(Module):
def forward(self, input):
return F.softplus(input, self.beta, self.threshold)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'beta=' + str(self.beta) \
- + ', threshold=' + str(self.threshold) + ')'
+ def extra_repr(self):
+ return 'beta={}, threshold={}'.format(self.beta, self.threshold)
class Softshrink(Module):
@@ -571,9 +548,8 @@ class Softshrink(Module):
def forward(self, input):
return F.softshrink(input, self.lambd)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + str(self.lambd) + ')'
+ def extra_repr(self):
+ return str(self.lambd)
class PReLU(Module):
@@ -621,9 +597,8 @@ class PReLU(Module):
def forward(self, input):
return F.prelu(input, self.weight)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'num_parameters=' + str(self.num_parameters) + ')'
+ def extra_repr(self):
+ return 'num_parameters={}'.format(self.num_parameters)
class Softsign(Module):
@@ -646,9 +621,6 @@ class Softsign(Module):
def forward(self, input):
return F.softsign(input)
- def __repr__(self):
- return self.__class__.__name__ + '()'
-
class Tanhshrink(Module):
r"""Applies element-wise, :math:`\text{Tanhshrink}(x) = x - \text{Tanh}(x)`
@@ -670,9 +642,6 @@ class Tanhshrink(Module):
def forward(self, input):
return F.tanhshrink(input)
- def __repr__(self):
- return self.__class__.__name__ + '()'
-
class Softmin(Module):
r"""Applies the Softmin function to an n-dimensional input Tensor
@@ -706,9 +675,6 @@ class Softmin(Module):
def forward(self, input):
return F.softmin(input, self.dim, _stacklevel=5)
- def __repr__(self):
- return self.__class__.__name__ + '()'
-
class Softmax(Module):
r"""Applies the Softmax function to an n-dimensional input Tensor
@@ -754,9 +720,6 @@ class Softmax(Module):
def forward(self, input):
return F.softmax(input, self.dim, _stacklevel=5)
- def __repr__(self):
- return self.__class__.__name__ + '()'
-
class Softmax2d(Module):
r"""Applies SoftMax over features to each spatial location.
@@ -784,9 +747,6 @@ class Softmax2d(Module):
assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input'
return F.softmax(input, 1, _stacklevel=5)
- def __repr__(self):
- return self.__class__.__name__ + '()'
-
class LogSoftmax(Module):
r"""Applies the `Log(Softmax(x))` function to an n-dimensional input Tensor.
@@ -824,6 +784,3 @@ class LogSoftmax(Module):
def forward(self, input):
return F.log_softmax(input, self.dim, _stacklevel=5)
-
- def __repr__(self):
- return self.__class__.__name__ + '()'
diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py
index 63ef02c11e..5c85c3dda1 100644
--- a/torch/nn/modules/batchnorm.py
+++ b/torch/nn/modules/batchnorm.py
@@ -48,10 +48,9 @@ class _BatchNorm(Module):
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats, self.momentum, self.eps)
- def __repr__(self):
- return ('{name}({num_features}, eps={eps}, momentum={momentum},'
- ' affine={affine}, track_running_stats={track_running_stats})'
- .format(name=self.__class__.__name__, **self.__dict__))
+ def extra_repr(self):
+ return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
+ 'track_running_stats={track_running_stats}'.format(**self.__dict__)
class BatchNorm1d(_BatchNorm):
diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py
index e4430db691..7ab4ec8d54 100644
--- a/torch/nn/modules/container.py
+++ b/torch/nn/modules/container.py
@@ -267,13 +267,12 @@ class ParameterList(Module):
self.register_parameter(str(offset + i), param)
return self
- def __repr__(self):
- tmpstr = self.__class__.__name__ + '(\n'
+ def extra_repr(self):
+ tmpstr = ''
for k, p in self._parameters.items():
size_str = 'x'.join(str(size) for size in p.size())
device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
parastr = 'Parameter containing: [{} of size {}{}]'.format(
torch.typename(p.data), size_str, device_str)
tmpstr = tmpstr + ' (' + k + '): ' + parastr + '\n'
- tmpstr = tmpstr + ')'
return tmpstr
diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py
index a5c3252cc5..0a3ef76516 100644
--- a/torch/nn/modules/conv.py
+++ b/torch/nn/modules/conv.py
@@ -46,8 +46,8 @@ class _ConvNd(Module):
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
- def __repr__(self):
- s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
+ def extra_repr(self):
+ s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
@@ -59,8 +59,7 @@ class _ConvNd(Module):
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
- s += ')'
- return s.format(name=self.__class__.__name__, **self.__dict__)
+ return s.format(**self.__dict__)
class Conv1d(_ConvNd):
diff --git a/torch/nn/modules/dropout.py b/torch/nn/modules/dropout.py
index afc5e10044..e090095472 100644
--- a/torch/nn/modules/dropout.py
+++ b/torch/nn/modules/dropout.py
@@ -2,7 +2,22 @@ from .module import Module
from .. import functional as F
-class Dropout(Module):
+class _DropoutNd(Module):
+
+ def __init__(self, p=0.5, inplace=False):
+ super(_DropoutNd, self).__init__()
+ if p < 0 or p > 1:
+ raise ValueError("dropout probability has to be between 0 and 1, "
+ "but got {}".format(p))
+ self.p = p
+ self.inplace = inplace
+
+ def extra_repr(self):
+ inplace_str = ', inplace' if self.inplace else ''
+ return 'p={}{}'.format(self.p, inplace_str)
+
+
+class Dropout(_DropoutNd):
r"""During training, randomly zeroes some of the elements of the input
tensor with probability :attr:`p` using samples from a Bernoulli
distribution. The elements to zero are randomized on every forward call.
@@ -34,25 +49,11 @@ class Dropout(Module):
detectors: https://arxiv.org/abs/1207.0580
"""
- def __init__(self, p=0.5, inplace=False):
- super(Dropout, self).__init__()
- if p < 0 or p > 1:
- raise ValueError("dropout probability has to be between 0 and 1, "
- "but got {}".format(p))
- self.p = p
- self.inplace = inplace
-
def forward(self, input):
return F.dropout(input, self.p, self.training, self.inplace)
- def __repr__(self):
- inplace_str = ', inplace' if self.inplace else ''
- return self.__class__.__name__ + '(' \
- + 'p=' + str(self.p) \
- + inplace_str + ')'
-
-class Dropout2d(Module):
+class Dropout2d(_DropoutNd):
r"""Randomly zeroes whole channels of the input tensor.
The channels to zero-out are randomized on every forward call.
@@ -87,25 +88,11 @@ class Dropout2d(Module):
http://arxiv.org/abs/1411.4280
"""
- def __init__(self, p=0.5, inplace=False):
- super(Dropout2d, self).__init__()
- if p < 0 or p > 1:
- raise ValueError("dropout probability has to be between 0 and 1, "
- "but got {}".format(p))
- self.p = p
- self.inplace = inplace
-
def forward(self, input):
return F.dropout2d(input, self.p, self.training, self.inplace)
- def __repr__(self):
- inplace_str = ', inplace' if self.inplace else ''
- return self.__class__.__name__ + '(' \
- + 'p=' + str(self.p) \
- + inplace_str + ')'
-
-class Dropout3d(Module):
+class Dropout3d(_DropoutNd):
r"""Randomly zeroes whole channels of the input tensor.
The channels to zero are randomized on every forward call.
@@ -140,23 +127,9 @@ class Dropout3d(Module):
http://arxiv.org/abs/1411.4280
"""
- def __init__(self, p=0.5, inplace=False):
- super(Dropout3d, self).__init__()
- if p < 0 or p > 1:
- raise ValueError("dropout probability has to be between 0 and 1, "
- "but got {}".format(p))
- self.p = p
- self.inplace = inplace
-
def forward(self, input):
return F.dropout3d(input, self.p, self.training, self.inplace)
- def __repr__(self):
- inplace_str = ', inplace' if self.inplace else ''
- return self.__class__.__name__ + '(' \
- + 'p=' + str(self.p) \
- + inplace_str + ')'
-
class AlphaDropout(Module):
r"""Applies Alpha Dropout over the input.
diff --git a/torch/nn/modules/fold.py b/torch/nn/modules/fold.py
index 95aa6ec935..1b285b8667 100644
--- a/torch/nn/modules/fold.py
+++ b/torch/nn/modules/fold.py
@@ -68,13 +68,11 @@ class Fold(Module):
return F.fold(input, self.output_size, self.kernel_size, self.dilation,
self.padding, self.stride)
- def __repr__(self):
- return self.__class__.__name__ + ' (' \
- + 'output_size=' + str(self.output_size) \
- + ', kernel_size=' + str(self.kernel_size) \
- + ', dilation=' + str(self.dilation) \
- + ', padding=' + str(self.padding) \
- + ', stride=' + str(self.stride) + ')'
+ def extra_repr(self):
+ return 'output_size={output_size}, kernel_size={kernel_size}, ' \
+ 'dilation={dilation}, padding={padding}, stride={stride}'.format(
+ **self.__dict__
+ )
class Unfold(Module):
@@ -140,9 +138,6 @@ class Unfold(Module):
return F.unfold(input, self.kernel_size, self.dilation,
self.padding, self.stride)
- def __repr__(self):
- return self.__class__.__name__ + ' (' \
- + 'kernel_size=' + str(self.kernel_size) \
- + ', dilation=' + str(self.dilation) \
- + ', padding=' + str(self.padding) \
- + ', stride=' + str(self.stride) + ')'
+ def extra_repr(self):
+ return 'kernel_size={kernel_size}, dilation={dilation}, padding={padding},' \
+ ' stride={stride}'.format(**self.__dict__)
diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py
index dcbb2f729b..2af36229af 100644
--- a/torch/nn/modules/linear.py
+++ b/torch/nn/modules/linear.py
@@ -54,11 +54,10 @@ class Linear(Module):
def forward(self, input):
return F.linear(input, self.weight, self.bias)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'in_features=' + str(self.in_features) \
- + ', out_features=' + str(self.out_features) \
- + ', bias=' + str(self.bias is not None) + ')'
+ def extra_repr(self):
+ return 'in_features={}, out_features={}, bias={}'.format(
+ self.in_features, self.out_features, self.bias is not None
+ )
class Bilinear(Module):
@@ -115,11 +114,9 @@ class Bilinear(Module):
def forward(self, input1, input2):
return F.bilinear(input1, input2, self.weight, self.bias)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'in1_features=' + str(self.in1_features) \
- + ', in2_features=' + str(self.in2_features) \
- + ', out_features=' + str(self.out_features) \
- + ', bias=' + str(self.bias is not None) + ')'
+ def extra_repr(self):
+ return 'in1_features={}, in2_features={}, out_features={}, bias={}'.format(
+ self.in1_features, self.in2_features, self.out_features, self.bias is not None
+ )
# TODO: PartialLinear - maybe in sparse?
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 129d8bc59a..3482164562 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -766,14 +766,39 @@ class Module(object):
def _get_name(self):
return self.__class__.__name__
+ def extra_repr(self):
+ r"""Set the extra representation of the module
+
+ To print customized extra information, you should reimplement
+ this method in your own modules. Both single-line and multi-line
+ strings are acceptable.
+ """
+ return ''
+
def __repr__(self):
- tmpstr = self._get_name() + '(\n'
+ # We treat the extra repr like the sub-module, one item per line
+ extra_lines = []
+ extra_repr = self.extra_repr()
+ # empty string will be split into list ['']
+ if extra_repr:
+ extra_lines = extra_repr.split('\n')
+ child_lines = []
for key, module in self._modules.items():
- modstr = module.__repr__()
- modstr = _addindent(modstr, 2)
- tmpstr = tmpstr + ' (' + key + '): ' + modstr + '\n'
- tmpstr = tmpstr + ')'
- return tmpstr
+ mod_str = repr(module)
+ mod_str = _addindent(mod_str, 2)
+ child_lines.append('(' + key + '): ' + mod_str)
+ lines = extra_lines + child_lines
+
+ main_str = self._get_name() + '('
+ if lines:
+ # simple one-liner info, which most builtin Modules will use
+ if len(extra_lines) == 1 and not child_lines:
+ main_str += extra_lines[0]
+ else:
+ main_str += '\n ' + '\n '.join(lines) + '\n'
+
+ main_str += ')'
+ return main_str
def __dir__(self):
module_attrs = dir(self.__class__)
diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py
index e29f194d9e..1136d25e1e 100644
--- a/torch/nn/modules/normalization.py
+++ b/torch/nn/modules/normalization.py
@@ -46,12 +46,8 @@ class LocalResponseNorm(Module):
return F.local_response_norm(input, self.size, self.alpha, self.beta,
self.k)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + str(self.size) \
- + ', alpha=' + str(self.alpha) \
- + ', beta=' + str(self.beta) \
- + ', k=' + str(self.k) + ')'
+ def extra_repr(self):
+ return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
class CrossMapLRN2d(Module):
@@ -67,12 +63,8 @@ class CrossMapLRN2d(Module):
return self._backend.CrossMapLRN2d(self.size, self.alpha, self.beta,
self.k)(input)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + str(self.size) \
- + ', alpha=' + str(self.alpha) \
- + ', beta=' + str(self.beta) \
- + ', k=' + str(self.k) + ')'
+ def extra_repr(self):
+ return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
class LayerNorm(Module):
@@ -153,10 +145,9 @@ class LayerNorm(Module):
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps)
- def __repr__(self):
- return ('{name}({normalized_shape}, eps={eps}, '
- ' elementwise_affine={elementwise_affine},'
- .format(name=self.__class__.__name__, **self.__dict__))
+ def extra_repr(self):
+ return '{normalized_shape}, eps={eps}, ' \
+ 'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
class GroupNorm(Module):
@@ -223,10 +214,9 @@ class GroupNorm(Module):
return F.group_norm(
input, self.num_groups, self.weight, self.bias, self.eps)
- def __repr__(self):
- return ('{name}({num_groups}, {num_channels}, eps={eps}, '
- 'affine={affine},'
- .format(name=self.__class__.__name__, **self.__dict__))
+ def extra_repr(self):
+ return '{num_groups}, {num_channels}, eps={eps}, ' \
+ 'affine={affine}'.format(**self.__dict__)
# TODO: ContrastiveNorm2d
diff --git a/torch/nn/modules/padding.py b/torch/nn/modules/padding.py
index c87062dc43..cd691ffb3d 100644
--- a/torch/nn/modules/padding.py
+++ b/torch/nn/modules/padding.py
@@ -6,7 +6,20 @@ from .. import functional as F
# TODO: grad_output size asserts in THNN
-class ConstantPad1d(Module):
+class _ConstantPadNd(Module):
+
+ def __init__(self, value):
+ super(_ConstantPadNd, self).__init__()
+ self.value = value
+
+ def forward(self, input):
+ return F.pad(input, self.padding, 'constant', self.value)
+
+ def extra_repr(self):
+ return 'padding={}, value={}'.format(self.padding, self.value)
+
+
+class ConstantPad1d(_ConstantPadNd):
r"""Pads the input tensor boundaries with a constant value.
Args:
@@ -30,19 +43,11 @@ class ConstantPad1d(Module):
"""
def __init__(self, padding, value):
- super(ConstantPad1d, self).__init__()
+ super(ConstantPad1d, self).__init__(value)
self.padding = _pair(padding)
- self.value = value
-
- def forward(self, input):
- return F.pad(input, self.padding, 'constant', self.value)
-
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + str(self.padding) + ')'
-class ConstantPad2d(Module):
+class ConstantPad2d(_ConstantPadNd):
r"""Pads the input tensor boundaries with a constant value.
For Nd-padding, use :meth:`nn.functional.pad()`.
@@ -70,19 +75,11 @@ class ConstantPad2d(Module):
"""
def __init__(self, padding, value):
- super(ConstantPad2d, self).__init__()
+ super(ConstantPad2d, self).__init__(value)
self.padding = _quadruple(padding)
- self.value = value
-
- def forward(self, input):
- return F.pad(input, self.padding, 'constant', self.value)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + str(self.padding) + ')'
-
-class ConstantPad3d(Module):
+class ConstantPad3d(_ConstantPadNd):
r"""Pads the input tensor boundaries with a constant value.
Args:
@@ -109,19 +106,20 @@ class ConstantPad3d(Module):
"""
def __init__(self, padding, value):
- super(ConstantPad3d, self).__init__()
+ super(ConstantPad3d, self).__init__(value)
self.padding = _ntuple(6)(padding)
- self.value = value
+
+
+class _ReflectionPadNd(Module):
def forward(self, input):
- return F.pad(input, self.padding, 'constant', self.value)
+ return F.pad(input, self.padding, 'reflect')
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + str(self.padding) + ')'
+ def extra_repr(self):
+ return '{}'.format(self.padding)
-class ReflectionPad1d(Module):
+class ReflectionPad1d(_ReflectionPadNd):
r"""Pads the input tensor using the reflection of the input boundary.
Args:
@@ -148,15 +146,8 @@ class ReflectionPad1d(Module):
super(ReflectionPad1d, self).__init__()
self.padding = _pair(padding)
- def forward(self, input):
- return F.pad(input, self.padding, 'reflect')
-
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + str(self.padding) + ')'
-
-class ReflectionPad2d(Module):
+class ReflectionPad2d(_ReflectionPadNd):
r"""Pads the input tensor using the reflection of the input boundary.
Args:
@@ -185,15 +176,17 @@ class ReflectionPad2d(Module):
super(ReflectionPad2d, self).__init__()
self.padding = _quadruple(padding)
+
+class _ReplicationPadNd(Module):
+
def forward(self, input):
- return F.pad(input, self.padding, 'reflect')
+ return F.pad(input, self.padding, 'replicate')
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + str(self.padding) + ')'
+ def extra_repr(self):
+ return '{}'.format(self.padding)
-class ReplicationPad1d(Module):
+class ReplicationPad1d(_ReplicationPadNd):
r"""Pads the input tensor using replication of the input boundary.
Args:
@@ -220,15 +213,8 @@ class ReplicationPad1d(Module):
super(ReplicationPad1d, self).__init__()
self.padding = _pair(padding)
- def forward(self, input):
- return F.pad(input, self.padding, 'replicate')
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + str(self.padding) + ')'
-
-
-class ReplicationPad2d(Module):
+class ReplicationPad2d(_ReplicationPadNd):
r"""Pads the input tensor using replication of the input boundary.
Args:
@@ -257,15 +243,8 @@ class ReplicationPad2d(Module):
super(ReplicationPad2d, self).__init__()
self.padding = _quadruple(padding)
- def forward(self, input):
- return F.pad(input, self.padding, 'replicate')
-
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + str(self.padding) + ')'
-
-class ReplicationPad3d(Module):
+class ReplicationPad3d(_ReplicationPadNd):
r"""Pads the input tensor using replication of the input boundary.
Args:
@@ -295,13 +274,6 @@ class ReplicationPad3d(Module):
super(ReplicationPad3d, self).__init__()
self.padding = _ntuple(6)(padding)
- def forward(self, input):
- return F.pad(input, self.padding, 'replicate')
-
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + str(self.padding) + ')'
-
class ZeroPad2d(ConstantPad2d):
r"""Pads the input tensor boundaries with zero.
diff --git a/torch/nn/modules/pixelshuffle.py b/torch/nn/modules/pixelshuffle.py
index 24c6e00d18..786ba40128 100644
--- a/torch/nn/modules/pixelshuffle.py
+++ b/torch/nn/modules/pixelshuffle.py
@@ -39,5 +39,5 @@ class PixelShuffle(Module):
def forward(self, input):
return F.pixel_shuffle(input, self.upscale_factor)
- def __repr__(self):
- return self.__class__.__name__ + '(upscale_factor=' + str(self.upscale_factor) + ')'
+ def extra_repr(self):
+ return 'upscale_factor={}'.format(self.upscale_factor)
diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py
index 6f00417545..8292569973 100644
--- a/torch/nn/modules/pooling.py
+++ b/torch/nn/modules/pooling.py
@@ -6,7 +6,24 @@ from .utils import _single, _pair, _triple
from .. import functional as F
-class MaxPool1d(Module):
+class _MaxPoolNd(Module):
+
+ def __init__(self, kernel_size, stride=None, padding=0, dilation=1,
+ return_indices=False, ceil_mode=False):
+ super(_MaxPoolNd, self).__init__()
+ self.kernel_size = kernel_size
+ self.stride = stride or kernel_size
+ self.padding = padding
+ self.dilation = dilation
+ self.return_indices = return_indices
+ self.ceil_mode = ceil_mode
+
+ def extra_repr(self):
+ return 'kernel_size={kernel_size}, stride={stride}, padding={padding}' \
+ ', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
+
+
+class MaxPool1d(_MaxPoolNd):
r"""Applies a 1D max pooling over an input signal composed of several input
planes.
@@ -52,31 +69,17 @@ class MaxPool1d(Module):
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
- def __init__(self, kernel_size, stride=None, padding=0, dilation=1,
- return_indices=False, ceil_mode=False):
- super(MaxPool1d, self).__init__()
- self.kernel_size = kernel_size
- self.stride = stride or kernel_size
- self.padding = padding
- self.dilation = dilation
- self.return_indices = return_indices
- self.ceil_mode = ceil_mode
-
def forward(self, input):
return F.max_pool1d(input, self.kernel_size, self.stride,
self.padding, self.dilation, self.ceil_mode,
self.return_indices)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'kernel_size=' + str(self.kernel_size) \
- + ', stride=' + str(self.stride) \
- + ', padding=' + str(self.padding) \
- + ', dilation=' + str(self.dilation) \
- + ', ceil_mode=' + str(self.ceil_mode) + ')'
+ def extra_repr(self):
+ return 'kernel_size={kernel_size}, stride={stride}, padding={padding}' \
+ ', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
-class MaxPool2d(Module):
+class MaxPool2d(_MaxPoolNd):
r"""Applies a 2D max pooling over an input signal composed of several input
planes.
@@ -134,38 +137,102 @@ class MaxPool2d(Module):
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
- def __init__(self, kernel_size, stride=None, padding=0, dilation=1,
- return_indices=False, ceil_mode=False):
- super(MaxPool2d, self).__init__()
- self.kernel_size = kernel_size
- self.stride = stride or kernel_size
- self.padding = padding
- self.dilation = dilation
- self.return_indices = return_indices
- self.ceil_mode = ceil_mode
-
def forward(self, input):
return F.max_pool2d(input, self.kernel_size, self.stride,
self.padding, self.dilation, self.ceil_mode,
self.return_indices)
- def __repr__(self):
- kh, kw = _pair(self.kernel_size)
- dh, dw = _pair(self.stride)
- padh, padw = _pair(self.padding)
- dilh, dilw = _pair(self.dilation)
- padding_str = ', padding=(' + str(padh) + ', ' + str(padw) + ')' \
- if padh != 0 or padw != 0 else ''
- dilation_str = (', dilation=(' + str(dilh) + ', ' + str(dilw) + ')'
- if dilh != 0 and dilw != 0 else '')
- ceil_str = ', ceil_mode=' + str(self.ceil_mode)
- return self.__class__.__name__ + '(' \
- + 'kernel_size=(' + str(kh) + ', ' + str(kw) + ')' \
- + ', stride=(' + str(dh) + ', ' + str(dw) + ')' \
- + padding_str + dilation_str + ceil_str + ')'
-
-
-class MaxUnpool1d(Module):
+# def extra_repr(self):
+# kh, kw = _pair(self.kernel_size)
+# dh, dw = _pair(self.stride)
+# padh, padw = _pair(self.padding)
+# dilh, dilw = _pair(self.dilation)
+# padding_str = ', padding=(' + str(padh) + ', ' + str(padw) + ')' \
+# if padh != 0 or padw != 0 else ''
+# dilation_str = (', dilation=(' + str(dilh) + ', ' + str(dilw) + ')'
+# if dilh != 0 and dilw != 0 else '')
+# ceil_str = ', ceil_mode=' + str(self.ceil_mode)
+# return 'kernel_size=(' + str(kh) + ', ' + str(kw) + ')' \
+# + ', stride=(' + str(dh) + ', ' + str(dw) + ')' \
+# + padding_str + dilation_str + ceil_str
+
+
+class MaxPool3d(_MaxPoolNd):
+ r"""Applies a 3D max pooling over an input signal composed of several input
+ planes.
+
+ In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`,
+ output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)`
+ can be precisely described as:
+
+ .. math::
+
+ \begin{align*}
+ \text{out}(N_i, C_j, d, h, w) &= \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1}
+ \text{input}(N_i, C_j, \text{stride}[0] * k + d,\\ &\text{stride}[1] * h + m, \text{stride}[2] * w + n)
+ \end{align*}
+
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
+ for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
+
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
+
+ - a single ``int`` -- in which case the same value is used for the depth, height and width dimension
+ - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
+ the second `int` for the height dimension and the third `int` for the width dimension
+
+ Args:
+ kernel_size: the size of the window to take a max over
+ stride: the stride of the window. Default value is :attr:`kernel_size`
+ padding: implicit zero padding to be added on all three sides
+ dilation: a parameter that controls the stride of elements in the window
+ return_indices: if ``True``, will return the max indices along with the outputs.
+ Useful when Unpooling later
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
+
+ Shape:
+ - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
+ - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` where
+
+ .. math::
+ D_{out} = \left\lfloor\frac{D_{in} + 2 * \text{padding}[0] - \text{dilation}[0] *
+ (\text{kernel_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
+
+ H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding}[1] - \text{dilation}[1] *
+ (\text{kernel_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
+
+ W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding}[2] - \text{dilation}[2] *
+ (\text{kernel_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
+
+ Examples::
+
+ >>> # pool of square window of size=3, stride=2
+ >>> m = nn.MaxPool3d(3, stride=2)
+ >>> # pool of non-square window
+ >>> m = nn.MaxPool3d((3, 2, 2), stride=(2, 1, 2))
+ >>> input = torch.randn(20, 16, 50,44, 31)
+ >>> output = m(input)
+
+ .. _link:
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
+ """
+
+ def forward(self, input):
+ return F.max_pool3d(input, self.kernel_size, self.stride,
+ self.padding, self.dilation, self.ceil_mode,
+ self.return_indices)
+
+
+class _MaxUnpoolNd(Module):
+
+ def extra_repr(self):
+ return 'kernel_size={}, stride={}, padding={}'.format(
+ self.kernel_size, self.stride, self.padding
+ )
+
+
+class MaxUnpool1d(_MaxUnpoolNd):
r"""Computes a partial inverse of :class:`MaxPool1d`.
:class:`MaxPool1d` is not fully invertible, since the non-maximal values are lost.
@@ -232,21 +299,15 @@ class MaxUnpool1d(Module):
def __init__(self, kernel_size, stride=None, padding=0):
super(MaxUnpool1d, self).__init__()
self.kernel_size = _single(kernel_size)
- self.stride = _single(stride if stride is not None else kernel_size)
+ self.stride = _single(stride or kernel_size)
self.padding = _single(padding)
def forward(self, input, indices, output_size=None):
return F.max_unpool1d(input, indices, self.kernel_size, self.stride,
self.padding, output_size)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'kernel_size=' + str(self.kernel_size) \
- + ', stride=' + str(self.stride) \
- + ', padding=' + str(self.padding) + ')'
-
-class MaxUnpool2d(Module):
+class MaxUnpool2d(_MaxUnpoolNd):
r"""Computes a partial inverse of :class:`MaxPool2d`.
:class:`MaxPool2d` is not fully invertible, since the non-maximal values are lost.
@@ -317,21 +378,15 @@ class MaxUnpool2d(Module):
def __init__(self, kernel_size, stride=None, padding=0):
super(MaxUnpool2d, self).__init__()
self.kernel_size = _pair(kernel_size)
- self.stride = _pair(stride if stride is not None else kernel_size)
+ self.stride = _pair(stride or kernel_size)
self.padding = _pair(padding)
def forward(self, input, indices, output_size=None):
return F.max_unpool2d(input, indices, self.kernel_size, self.stride,
self.padding, output_size)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'kernel_size=' + str(self.kernel_size) \
- + ', stride=' + str(self.stride) \
- + ', padding=' + str(self.padding) + ')'
-
-class MaxUnpool3d(Module):
+class MaxUnpool3d(_MaxUnpoolNd):
r"""Computes a partial inverse of :class:`MaxPool3d`.
:class:`MaxPool3d` is not fully invertible, since the non-maximal values are lost.
@@ -383,21 +438,32 @@ class MaxUnpool3d(Module):
def __init__(self, kernel_size, stride=None, padding=0):
super(MaxUnpool3d, self).__init__()
self.kernel_size = _triple(kernel_size)
- self.stride = _triple(stride if stride is not None else kernel_size)
+ self.stride = _triple(stride or kernel_size)
self.padding = _triple(padding)
def forward(self, input, indices, output_size=None):
return F.max_unpool3d(input, indices, self.kernel_size, self.stride,
self.padding, output_size)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'kernel_size=' + str(self.kernel_size) \
- + ', stride=' + str(self.stride) \
- + ', padding=' + str(self.padding) + ')'
+class _AvgPoolNd(Module):
-class AvgPool1d(Module):
+ def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
+ count_include_pad=True):
+ super(_AvgPoolNd, self).__init__()
+ self.kernel_size = _single(kernel_size)
+ self.stride = _single(stride if stride is not None else kernel_size)
+ self.padding = _single(padding)
+ self.ceil_mode = ceil_mode
+ self.count_include_pad = count_include_pad
+
+ def extra_repr(self):
+ return 'kernel_size={}, stride={}, padding={}'.format(
+ self.kernel_size, self.stride, self.padding
+ )
+
+
+class AvgPool1d(_AvgPoolNd):
r"""Applies a 1D average pooling over an input signal composed of several
input planes.
@@ -444,30 +510,13 @@ class AvgPool1d(Module):
[torch.FloatTensor of size (1,1,3)]
"""
- def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
- count_include_pad=True):
- super(AvgPool1d, self).__init__()
- self.kernel_size = _single(kernel_size)
- self.stride = _single(stride if stride is not None else kernel_size)
- self.padding = _single(padding)
- self.ceil_mode = ceil_mode
- self.count_include_pad = count_include_pad
-
def forward(self, input):
return F.avg_pool1d(
input, self.kernel_size, self.stride, self.padding, self.ceil_mode,
self.count_include_pad)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'kernel_size=' + str(self.kernel_size) \
- + ', stride=' + str(self.stride) \
- + ', padding=' + str(self.padding) \
- + ', ceil_mode=' + str(self.ceil_mode) \
- + ', count_include_pad=' + str(self.count_include_pad) + ')'
-
-class AvgPool2d(Module):
+class AvgPool2d(_AvgPoolNd):
r"""Applies a 2D average pooling over an input signal composed of several input
planes.
@@ -519,114 +568,12 @@ class AvgPool2d(Module):
>>> output = m(input)
"""
- def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
- count_include_pad=True):
- super(AvgPool2d, self).__init__()
- self.kernel_size = kernel_size
- self.stride = stride or kernel_size
- self.padding = padding
- self.ceil_mode = ceil_mode
- self.count_include_pad = count_include_pad
-
def forward(self, input):
return F.avg_pool2d(input, self.kernel_size, self.stride,
self.padding, self.ceil_mode, self.count_include_pad)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'kernel_size=' + str(self.kernel_size) \
- + ', stride=' + str(self.stride) \
- + ', padding=' + str(self.padding) \
- + ', ceil_mode=' + str(self.ceil_mode) \
- + ', count_include_pad=' + str(self.count_include_pad) + ')'
-
-
-class MaxPool3d(Module):
- r"""Applies a 3D max pooling over an input signal composed of several input
- planes.
-
- In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`,
- output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)`
- can be precisely described as:
-
- .. math::
-
- \begin{align*}
- \text{out}(N_i, C_j, d, h, w) &= \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1}
- \text{input}(N_i, C_j, \text{stride}[0] * k + d,\\ &\text{stride}[1] * h + m, \text{stride}[2] * w + n)
- \end{align*}
-
- If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
- for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
- It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
-
- The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
-
- - a single ``int`` -- in which case the same value is used for the depth, height and width dimension
- - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
- the second `int` for the height dimension and the third `int` for the width dimension
-
- Args:
- kernel_size: the size of the window to take a max over
- stride: the stride of the window. Default value is :attr:`kernel_size`
- padding: implicit zero padding to be added on all three sides
- dilation: a parameter that controls the stride of elements in the window
- return_indices: if ``True``, will return the max indices along with the outputs.
- Useful when Unpooling later
- ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
-
- Shape:
- - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
- - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` where
-
- .. math::
- D_{out} = \left\lfloor\frac{D_{in} + 2 * \text{padding}[0] - \text{dilation}[0] *
- (\text{kernel_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
-
- H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding}[1] - \text{dilation}[1] *
- (\text{kernel_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
-
- W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding}[2] - \text{dilation}[2] *
- (\text{kernel_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
-
- Examples::
-
- >>> # pool of square window of size=3, stride=2
- >>> m = nn.MaxPool3d(3, stride=2)
- >>> # pool of non-square window
- >>> m = nn.MaxPool3d((3, 2, 2), stride=(2, 1, 2))
- >>> input = torch.randn(20, 16, 50,44, 31)
- >>> output = m(input)
-
- .. _link:
- https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
- """
- def __init__(self, kernel_size, stride=None, padding=0, dilation=1,
- return_indices=False, ceil_mode=False):
- super(MaxPool3d, self).__init__()
- self.kernel_size = kernel_size
- self.stride = stride or kernel_size
- self.padding = padding
- self.dilation = dilation
- self.return_indices = return_indices
- self.ceil_mode = ceil_mode
-
- def forward(self, input):
- return F.max_pool3d(input, self.kernel_size, self.stride,
- self.padding, self.dilation, self.ceil_mode,
- self.return_indices)
-
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'kernel_size=' + str(self.kernel_size) \
- + ', stride=' + str(self.stride) \
- + ', padding=' + str(self.padding) \
- + ', dilation=' + str(self.dilation) \
- + ', ceil_mode=' + str(self.ceil_mode) + ')'
-
-
-class AvgPool3d(Module):
+class AvgPool3d(_AvgPoolNd):
r"""Applies a 3D average pooling over an input signal composed of several input
planes.
@@ -683,15 +630,6 @@ class AvgPool3d(Module):
>>> output = m(input)
"""
- def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
- count_include_pad=True):
- super(AvgPool3d, self).__init__()
- self.kernel_size = kernel_size
- self.stride = stride or kernel_size
- self.padding = padding
- self.ceil_mode = ceil_mode
- self.count_include_pad = count_include_pad
-
def forward(self, input):
return F.avg_pool3d(input, self.kernel_size, self.stride,
self.padding, self.ceil_mode, self.count_include_pad)
@@ -702,14 +640,6 @@ class AvgPool3d(Module):
self.__dict__.setdefault('ceil_mode', False)
self.__dict__.setdefault('count_include_pad', True)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'kernel_size=' + str(self.kernel_size) \
- + ', stride=' + str(self.stride) \
- + ', padding=' + str(self.padding) \
- + ', ceil_mode=' + str(self.ceil_mode) \
- + ', count_include_pad=' + str(self.count_include_pad) + ')'
-
class FractionalMaxPool2d(Module):
r"""Applies a 2D fractional max pooling over an input signal composed of several input planes.
@@ -768,7 +698,58 @@ class FractionalMaxPool2d(Module):
_random_samples=samples)
-class LPPool2d(Module):
+class _LPPoolNd(Module):
+
+ def __init__(self, norm_type, kernel_size, stride=None, ceil_mode=False):
+ super(_LPPoolNd, self).__init__()
+ self.norm_type = norm_type
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.ceil_mode = ceil_mode
+
+ def extra_repr(self):
+ return 'norm_type={norm_type}, kernel_size{kernel_size}, stride={stride}, ' \
+ 'ceil_mode={ceil_mode}'.format(**self.__dict__)
+
+
+class LPPool1d(_LPPoolNd):
+ r"""Applies a 1D power-average pooling over an input signal composed of several input
+ planes.
+
+ On each window, the function computed is:
+
+ .. math::
+ f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
+
+ - At p = infinity, one gets Max Pooling
+ - At p = 1, one gets Sum Pooling (which is proportional to Average Pooling)
+
+ Args:
+ kernel_size: a single int, the size of the window
+ stride: a single int, the stride of the window. Default value is :attr:`kernel_size`
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
+
+ Shape:
+ - Input: :math:`(N, C, L_{in})`
+ - Output: :math:`(N, C, L_{out})` where
+
+ .. math::
+ L_{out} = \left\lfloor\frac{L_{in} +
+ 2 * \text{padding} - \text{kernel_size}}{\text{stride}} + 1\right\rfloor
+
+ Examples::
+ >>> # power-2 pool of window of length 3, with stride 2.
+ >>> m = nn.LPPool1d(2, 3, stride=2)
+ >>> input = torch.randn(20, 16, 50)
+ >>> output = m(input)
+ """
+
+ def forward(self, input):
+ return F.lp_pool1d(input, self.norm_type, self.kernel_size,
+ self.stride, self.ceil_mode)
+
+
+class LPPool2d(_LPPoolNd):
r"""Applies a 2D power-average pooling over an input signal composed of several input
planes.
@@ -813,77 +794,23 @@ class LPPool2d(Module):
"""
- def __init__(self, norm_type, kernel_size, stride=None, ceil_mode=False):
- super(LPPool2d, self).__init__()
- self.norm_type = norm_type
- self.kernel_size = kernel_size
- self.stride = stride
- self.ceil_mode = ceil_mode
-
def forward(self, input):
return F.lp_pool2d(input, self.norm_type, self.kernel_size,
self.stride, self.ceil_mode)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + str(self.norm_type) + ', ' \
- + str(self.kernel_size) + ', ' \
- + 'stride=' + str(self.stride) + ', ' \
- + 'ceil_mode=' + str(self.ceil_mode) + ')'
+class _AdaptiveMaxPoolNd(Module):
-class LPPool1d(Module):
- r"""Applies a 1D power-average pooling over an input signal composed of several input
- planes.
-
- On each window, the function computed is:
-
- .. math::
- f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
-
- - At p = infinity, one gets Max Pooling
- - At p = 1, one gets Sum Pooling (which is proportional to Average Pooling)
-
- Args:
- kernel_size: a single int, the size of the window
- stride: a single int, the stride of the window. Default value is :attr:`kernel_size`
- ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
-
- Shape:
- - Input: :math:`(N, C, L_{in})`
- - Output: :math:`(N, C, L_{out})` where
-
- .. math::
- L_{out} = \left\lfloor\frac{L_{in} +
- 2 * \text{padding} - \text{kernel_size}}{\text{stride}} + 1\right\rfloor
-
- Examples::
- >>> # power-2 pool of window of length 3, with stride 2.
- >>> m = nn.LPPool1d(2, 3, stride=2)
- >>> input = torch.randn(20, 16, 50)
- >>> output = m(input)
- """
-
- def __init__(self, norm_type, kernel_size, stride=None, ceil_mode=False):
- super(LPPool1d, self).__init__()
- self.norm_type = norm_type
- self.kernel_size = kernel_size
- self.stride = stride
- self.ceil_mode = ceil_mode
-
- def forward(self, input):
- return F.lp_pool1d(input, self.norm_type, self.kernel_size,
- self.stride, self.ceil_mode)
+ def __init__(self, output_size, return_indices=False):
+ super(_AdaptiveMaxPoolNd, self).__init__()
+ self.output_size = output_size
+ self.return_indices = return_indices
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + str(self.norm_type) + ', ' \
- + str(self.kernel_size) + ', ' \
- + 'stride=' + str(self.stride) + ', ' \
- + 'ceil_mode=' + str(self.ceil_mode) + ')'
+ def extra_repr(self):
+ return 'output_size={}'.format(self.output_size)
-class AdaptiveMaxPool1d(Module):
+class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):
r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
The output size is H, for any input size.
@@ -902,20 +829,11 @@ class AdaptiveMaxPool1d(Module):
"""
- def __init__(self, output_size, return_indices=False):
- super(AdaptiveMaxPool1d, self).__init__()
- self.output_size = output_size
- self.return_indices = return_indices
-
def forward(self, input):
return F.adaptive_max_pool1d(input, self.output_size, self.return_indices)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'output_size=' + str(self.output_size) + ')'
-
-class AdaptiveMaxPool2d(Module):
+class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
The output is of size H x W, for any input size.
@@ -945,20 +863,11 @@ class AdaptiveMaxPool2d(Module):
"""
- def __init__(self, output_size, return_indices=False):
- super(AdaptiveMaxPool2d, self).__init__()
- self.output_size = output_size
- self.return_indices = return_indices
-
def forward(self, input):
return F.adaptive_max_pool2d(input, self.output_size, self.return_indices)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'output_size=' + str(self.output_size) + ')'
-
-class AdaptiveMaxPool3d(Module):
+class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
The output is of size D x H x W, for any input size.
@@ -989,20 +898,21 @@ class AdaptiveMaxPool3d(Module):
"""
- def __init__(self, output_size, return_indices=False):
- super(AdaptiveMaxPool3d, self).__init__()
- self.output_size = output_size
- self.return_indices = return_indices
-
def forward(self, input):
return F.adaptive_max_pool3d(input, self.output_size, self.return_indices)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'output_size=' + str(self.output_size) + ')'
+class _AdaptiveAvgPoolNd(Module):
-class AdaptiveAvgPool1d(Module):
+ def __init__(self, output_size):
+ super(_AdaptiveAvgPoolNd, self).__init__()
+ self.output_size = output_size
+
+ def extra_repr(self):
+ return 'output_size={}'.format(self.output_size)
+
+
+class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
The output size is H, for any input size.
@@ -1019,19 +929,11 @@ class AdaptiveAvgPool1d(Module):
"""
- def __init__(self, output_size):
- super(AdaptiveAvgPool1d, self).__init__()
- self.output_size = output_size
-
def forward(self, input):
return F.adaptive_avg_pool1d(input, self.output_size)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'output_size=' + str(self.output_size) + ')'
-
-class AdaptiveAvgPool2d(Module):
+class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes.
The output is of size H x W, for any input size.
@@ -1059,19 +961,11 @@ class AdaptiveAvgPool2d(Module):
"""
- def __init__(self, output_size):
- super(AdaptiveAvgPool2d, self).__init__()
- self.output_size = output_size
-
def forward(self, input):
return F.adaptive_avg_pool2d(input, self.output_size)
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'output_size=' + str(self.output_size) + ')'
-
-class AdaptiveAvgPool3d(Module):
+class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes.
The output is of size D x H x W, for any input size.
@@ -1099,13 +993,5 @@ class AdaptiveAvgPool3d(Module):
"""
- def __init__(self, output_size):
- super(AdaptiveAvgPool3d, self).__init__()
- self.output_size = output_size
-
def forward(self, input):
return F.adaptive_avg_pool3d(input, self.output_size)
-
- def __repr__(self):
- return self.__class__.__name__ + '(' \
- + 'output_size=' + str(self.output_size) + ')'
diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py
index 3675be08fd..02731cffb7 100644
--- a/torch/nn/modules/rnn.py
+++ b/torch/nn/modules/rnn.py
@@ -195,8 +195,8 @@ class RNNBase(Module):
output = PackedSequence(output, batch_sizes)
return output, hidden
- def __repr__(self):
- s = '{name}({input_size}, {hidden_size}'
+ def extra_repr(self):
+ s = '{input_size}, {hidden_size}'
if self.num_layers != 1:
s += ', num_layers={num_layers}'
if self.bias is not True:
@@ -207,8 +207,7 @@ class RNNBase(Module):
s += ', dropout={dropout}'
if self.bidirectional is not False:
s += ', bidirectional={bidirectional}'
- s += ')'
- return s.format(name=self.__class__.__name__, **self.__dict__)
+ return s.format(**self.__dict__)
def __setstate__(self, d):
super(RNNBase, self).__setstate__(d)
@@ -487,14 +486,13 @@ class GRU(RNNBase):
class RNNCellBase(Module):
- def __repr__(self):
- s = '{name}({input_size}, {hidden_size}'
+ def extra_repr(self):
+ s = '{input_size}, {hidden_size}'
if 'bias' in self.__dict__ and self.bias is not True:
s += ', bias={bias}'
if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
s += ', nonlinearity={nonlinearity}'
- s += ')'
- return s.format(name=self.__class__.__name__, **self.__dict__)
+ return s.format(**self.__dict__)
def check_forward_input(self, input):
if input.size(1) != self.input_size:
diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py
index b52d94a6f2..adc19cd723 100644
--- a/torch/nn/modules/sparse.py
+++ b/torch/nn/modules/sparse.py
@@ -106,8 +106,8 @@ class Embedding(Module):
input, self.weight, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
- def __repr__(self):
- s = '{name}({num_embeddings}, {embedding_dim}'
+ def extra_repr(self):
+ s = '{num_embeddings}, {embedding_dim}'
if self.padding_idx is not None:
s += ', padding_idx={padding_idx}'
if self.max_norm is not None:
@@ -118,8 +118,7 @@ class Embedding(Module):
s += ', scale_grad_by_freq={scale_grad_by_freq}'
if self.sparse is not False:
s += ', sparse=True'
- s += ')'
- return s.format(name=self.__class__.__name__, **self.__dict__)
+ return s.format(**self.__dict__)
@classmethod
def from_pretrained(cls, embeddings, freeze=True):
@@ -238,8 +237,8 @@ class EmbeddingBag(Module):
self.max_norm, self.norm_type,
self.scale_grad_by_freq, self.mode, self.sparse)
- def __repr__(self):
- s = '{name}({num_embeddings}, {embedding_dim}'
+ def extra_repr(self):
+ s = '{num_embeddings}, {embedding_dim}'
if self.max_norm is not None:
s += ', max_norm={max_norm}'
if self.norm_type != 2:
@@ -247,7 +246,6 @@ class EmbeddingBag(Module):
if self.scale_grad_by_freq is not False:
s += ', scale_grad_by_freq={scale_grad_by_freq}'
s += ', mode={mode}'
- s += ')'
- return s.format(name=self.__class__.__name__, **self.__dict__)
+ return s.format(**self.__dict__)
# TODO: SparseLinear
diff --git a/torch/nn/modules/upsampling.py b/torch/nn/modules/upsampling.py
index 681e9ff59f..ec00fbaeb8 100644
--- a/torch/nn/modules/upsampling.py
+++ b/torch/nn/modules/upsampling.py
@@ -138,13 +138,13 @@ class Upsample(Module):
def forward(self, input):
return F.upsample(input, self.size, self.scale_factor, self.mode, self.align_corners)
- def __repr__(self):
+ def extra_repr(self):
if self.scale_factor is not None:
info = 'scale_factor=' + str(self.scale_factor)
else:
info = 'size=' + str(self.size)
info += ', mode=' + self.mode
- return self.__class__.__name__ + '(' + info + ')'
+ return info
class UpsamplingNearest2d(Upsample):