diff options
author | Kaiyu Shi <skyisno.1@gmail.com> | 2018-04-03 01:52:33 +0800 |
---|---|---|
committer | Edward Z. Yang <ezyang@mit.edu> | 2018-04-02 13:52:33 -0400 |
commit | 605307f8f3c249d9279030502d2aac98d4170b83 (patch) | |
tree | b1327aea596e5a62965087be7b311281885f540f | |
parent | 7355f5cd8dc52a048d8c367cabfed9e888acd586 (diff) | |
download | pytorch-605307f8f3c249d9279030502d2aac98d4170b83.tar.gz pytorch-605307f8f3c249d9279030502d2aac98d4170b83.tar.bz2 pytorch-605307f8f3c249d9279030502d2aac98d4170b83.zip |
Add support for printing extra information in Module and refactor redundant codes (#5936)
This PR enables users to print extra information of their subclassed nn.Module.
Now I simply insert the user-defined string at the ending of module name, which should be discussed in this PR.
Before this PR, users should redefine the __repr__ and copy&paste the source code from Module.
* Add support for extra information on Module
* Rewrite the repr method of Module
* Fix flake8
* Change the __repr__ to get_extra_repr in Linear
* Fix extra new-line for empty line
* Add test for __repr__ method
* Fix bug of block string indent
* Add indent for multi-line repr test.
* Address review comments
* Update tutorial for creating nn.Module
* Fix flake8, add extra_repr of bilinear
* Refactor DropoutNd
* Change to extra_repr in some Modules
* Fix flake8
* Refactor padding modules
* Refactor pooling module
* Fix typo
* Change to extra_repr
* Fix bug for GroupNorm
* Fix bug for LayerNorm
-rw-r--r-- | docs/source/notes/extending.rst | 7 | ||||
-rw-r--r-- | test/test_nn.py | 18 | ||||
-rw-r--r-- | torch/nn/modules/activation.py | 105 | ||||
-rw-r--r-- | torch/nn/modules/batchnorm.py | 7 | ||||
-rw-r--r-- | torch/nn/modules/container.py | 5 | ||||
-rw-r--r-- | torch/nn/modules/conv.py | 7 | ||||
-rw-r--r-- | torch/nn/modules/dropout.py | 63 | ||||
-rw-r--r-- | torch/nn/modules/fold.py | 21 | ||||
-rw-r--r-- | torch/nn/modules/linear.py | 19 | ||||
-rw-r--r-- | torch/nn/modules/module.py | 37 | ||||
-rw-r--r-- | torch/nn/modules/normalization.py | 30 | ||||
-rw-r--r-- | torch/nn/modules/padding.py | 100 | ||||
-rw-r--r-- | torch/nn/modules/pixelshuffle.py | 4 | ||||
-rw-r--r-- | torch/nn/modules/pooling.py | 536 | ||||
-rw-r--r-- | torch/nn/modules/rnn.py | 14 | ||||
-rw-r--r-- | torch/nn/modules/sparse.py | 14 | ||||
-rw-r--r-- | torch/nn/modules/upsampling.py | 4 |
17 files changed, 402 insertions, 589 deletions
diff --git a/docs/source/notes/extending.rst b/docs/source/notes/extending.rst index e9b7ef26b2..216da7795c 100644 --- a/docs/source/notes/extending.rst +++ b/docs/source/notes/extending.rst @@ -173,6 +173,13 @@ This is how a ``Linear`` module can be implemented:: # See the autograd section for explanation of what happens here. return LinearFunction.apply(input, self.weight, self.bias) + def extra_repr(self): + # (Optional)Set the extra information about this module. You can test + # it by printing an object of this class. + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, self.out_features, self.bias is not None + ) + Writing custom C extensions --------------------------- diff --git a/test/test_nn.py b/test/test_nn.py index d8ada0ddbd..31c575f52a 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -768,6 +768,24 @@ class TestNN(NNTestCase): for key in keys: self.assertTrue(hasattr(linear, key)) + def test_repr(self): + # no extra information or sub-modules + empty_sequential = nn.Sequential() + expected_repr_empty = 'Sequential()' + self.assertEqual(repr(empty_sequential), expected_repr_empty) + + # one liner extra information + linear = nn.Linear(1, 1) + expected_repr_linear = 'Linear(in_features=1, out_features=1, bias=True)' + self.assertEqual(repr(linear), expected_repr_linear) + + # sub-modules repr + sequential = nn.Sequential(linear) + expected_repr_sequential = 'Sequential(\n' \ + ' (0): Linear(in_features=1, out_features=1, bias=True)\n' \ + ')' + self.assertEqual(repr(sequential), expected_repr_sequential) + def test_dir_digit(self): model = nn.Sequential(nn.Linear(2, 2)) keys = dir(model) diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index d4a2a46db4..18535cf67e 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -45,12 +45,11 @@ class Threshold(Module): def forward(self, input): return F.threshold(input, self.threshold, self.value, self.inplace) - def __repr__(self): + def extra_repr(self): inplace_str = ', inplace' if self.inplace else '' - return self.__class__.__name__ + ' (' \ - + str(self.threshold) \ - + ', ' + str(self.value) \ - + inplace_str + ')' + return 'threshold={}, value={}{}'.format( + self.threshold, self.value, inplace_str + ) class ReLU(Threshold): @@ -77,10 +76,9 @@ class ReLU(Threshold): def __init__(self, inplace=False): super(ReLU, self).__init__(0, 0, inplace) - def __repr__(self): + def extra_repr(self): inplace_str = 'inplace' if self.inplace else '' - return self.__class__.__name__ + '(' \ - + inplace_str + ')' + return inplace_str class RReLU(Module): @@ -129,12 +127,9 @@ class RReLU(Module): def forward(self, input): return F.rrelu(input, self.lower, self.upper, self.training, self.inplace) - def __repr__(self): + def extra_repr(self): inplace_str = ', inplace' if self.inplace else '' - return self.__class__.__name__ + '(' \ - + str(self.lower) \ - + ', ' + str(self.upper) \ - + inplace_str + ')' + return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str) class Hardtanh(Module): @@ -191,12 +186,11 @@ class Hardtanh(Module): def forward(self, input): return F.hardtanh(input, self.min_val, self.max_val, self.inplace) - def __repr__(self): + def extra_repr(self): inplace_str = ', inplace' if self.inplace else '' - return self.__class__.__name__ + '(' \ - + 'min_val=' + str(self.min_val) \ - + ', max_val=' + str(self.max_val) \ - + inplace_str + ')' + return 'min_val={}, max_val={}{}'.format( + self.min_val, self.max_val, inplace_str + ) class ReLU6(Hardtanh): @@ -222,10 +216,9 @@ class ReLU6(Hardtanh): def __init__(self, inplace=False): super(ReLU6, self).__init__(0, 6, inplace) - def __repr__(self): + def extra_repr(self): inplace_str = 'inplace' if self.inplace else '' - return self.__class__.__name__ + '(' \ - + inplace_str + ')' + return inplace_str class Sigmoid(Module): @@ -248,9 +241,6 @@ class Sigmoid(Module): def forward(self, input): return torch.sigmoid(input) - def __repr__(self): - return self.__class__.__name__ + '()' - class Tanh(Module): r"""Applies element-wise, @@ -273,9 +263,6 @@ class Tanh(Module): def forward(self, input): return torch.tanh(input) - def __repr__(self): - return self.__class__.__name__ + '()' - class ELU(Module): r"""Applies element-wise, @@ -307,11 +294,9 @@ class ELU(Module): def forward(self, input): return F.elu(input, self.alpha, self.inplace) - def __repr__(self): + def extra_repr(self): inplace_str = ', inplace' if self.inplace else '' - return self.__class__.__name__ + '(' \ - + 'alpha=' + str(self.alpha) \ - + inplace_str + ')' + return 'alpha={}{}'.format(self.alpha, inplace_str) class SELU(Module): @@ -348,9 +333,9 @@ class SELU(Module): def forward(self, input): return F.selu(input, self.inplace) - def __repr__(self): - inplace_str = '(inplace)' if self.inplace else '' - return self.__class__.__name__ + inplace_str + def extra_repr(self): + inplace_str = 'inplace' if self.inplace else '' + return inplace_str class GLU(Module): @@ -380,8 +365,8 @@ class GLU(Module): def forward(self, input): return F.glu(input, self.dim) - def __repr__(self): - return '{}(dim={})'.format(self.__class__.__name__, self.dim) + def extra_repr(self): + return 'dim={}'.format(self.dim) class Hardshrink(Module): @@ -420,9 +405,8 @@ class Hardshrink(Module): def forward(self, input): return F.hardshrink(input, self.lambd) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + str(self.lambd) + ')' + def extra_repr(self): + return '{}'.format(self.lambd) class LeakyReLU(Module): @@ -462,11 +446,9 @@ class LeakyReLU(Module): def forward(self, input): return F.leaky_relu(input, self.negative_slope, self.inplace) - def __repr__(self): + def extra_repr(self): inplace_str = ', inplace' if self.inplace else '' - return self.__class__.__name__ + '(' \ - + str(self.negative_slope) \ - + inplace_str + ')' + return 'negative_slope={}{}'.format(self.negative_slope, inplace_str) class LogSigmoid(Module): @@ -489,9 +471,6 @@ class LogSigmoid(Module): def forward(self, input): return F.logsigmoid(input) - def __repr__(self): - return self.__class__.__name__ + '()' - class Softplus(Module): r"""Applies element-wise :math:`\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))` @@ -528,10 +507,8 @@ class Softplus(Module): def forward(self, input): return F.softplus(input, self.beta, self.threshold) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'beta=' + str(self.beta) \ - + ', threshold=' + str(self.threshold) + ')' + def extra_repr(self): + return 'beta={}, threshold={}'.format(self.beta, self.threshold) class Softshrink(Module): @@ -571,9 +548,8 @@ class Softshrink(Module): def forward(self, input): return F.softshrink(input, self.lambd) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + str(self.lambd) + ')' + def extra_repr(self): + return str(self.lambd) class PReLU(Module): @@ -621,9 +597,8 @@ class PReLU(Module): def forward(self, input): return F.prelu(input, self.weight) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'num_parameters=' + str(self.num_parameters) + ')' + def extra_repr(self): + return 'num_parameters={}'.format(self.num_parameters) class Softsign(Module): @@ -646,9 +621,6 @@ class Softsign(Module): def forward(self, input): return F.softsign(input) - def __repr__(self): - return self.__class__.__name__ + '()' - class Tanhshrink(Module): r"""Applies element-wise, :math:`\text{Tanhshrink}(x) = x - \text{Tanh}(x)` @@ -670,9 +642,6 @@ class Tanhshrink(Module): def forward(self, input): return F.tanhshrink(input) - def __repr__(self): - return self.__class__.__name__ + '()' - class Softmin(Module): r"""Applies the Softmin function to an n-dimensional input Tensor @@ -706,9 +675,6 @@ class Softmin(Module): def forward(self, input): return F.softmin(input, self.dim, _stacklevel=5) - def __repr__(self): - return self.__class__.__name__ + '()' - class Softmax(Module): r"""Applies the Softmax function to an n-dimensional input Tensor @@ -754,9 +720,6 @@ class Softmax(Module): def forward(self, input): return F.softmax(input, self.dim, _stacklevel=5) - def __repr__(self): - return self.__class__.__name__ + '()' - class Softmax2d(Module): r"""Applies SoftMax over features to each spatial location. @@ -784,9 +747,6 @@ class Softmax2d(Module): assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input' return F.softmax(input, 1, _stacklevel=5) - def __repr__(self): - return self.__class__.__name__ + '()' - class LogSoftmax(Module): r"""Applies the `Log(Softmax(x))` function to an n-dimensional input Tensor. @@ -824,6 +784,3 @@ class LogSoftmax(Module): def forward(self, input): return F.log_softmax(input, self.dim, _stacklevel=5) - - def __repr__(self): - return self.__class__.__name__ + '()' diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 63ef02c11e..5c85c3dda1 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -48,10 +48,9 @@ class _BatchNorm(Module): input, self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, self.momentum, self.eps) - def __repr__(self): - return ('{name}({num_features}, eps={eps}, momentum={momentum},' - ' affine={affine}, track_running_stats={track_running_stats})' - .format(name=self.__class__.__name__, **self.__dict__)) + def extra_repr(self): + return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ + 'track_running_stats={track_running_stats}'.format(**self.__dict__) class BatchNorm1d(_BatchNorm): diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index e4430db691..7ab4ec8d54 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -267,13 +267,12 @@ class ParameterList(Module): self.register_parameter(str(offset + i), param) return self - def __repr__(self): - tmpstr = self.__class__.__name__ + '(\n' + def extra_repr(self): + tmpstr = '' for k, p in self._parameters.items(): size_str = 'x'.join(str(size) for size in p.size()) device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device()) parastr = 'Parameter containing: [{} of size {}{}]'.format( torch.typename(p.data), size_str, device_str) tmpstr = tmpstr + ' (' + k + '): ' + parastr + '\n' - tmpstr = tmpstr + ')' return tmpstr diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index a5c3252cc5..0a3ef76516 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -46,8 +46,8 @@ class _ConvNd(Module): if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) - def __repr__(self): - s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' + def extra_repr(self): + s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') if self.padding != (0,) * len(self.padding): s += ', padding={padding}' @@ -59,8 +59,7 @@ class _ConvNd(Module): s += ', groups={groups}' if self.bias is None: s += ', bias=False' - s += ')' - return s.format(name=self.__class__.__name__, **self.__dict__) + return s.format(**self.__dict__) class Conv1d(_ConvNd): diff --git a/torch/nn/modules/dropout.py b/torch/nn/modules/dropout.py index afc5e10044..e090095472 100644 --- a/torch/nn/modules/dropout.py +++ b/torch/nn/modules/dropout.py @@ -2,7 +2,22 @@ from .module import Module from .. import functional as F -class Dropout(Module): +class _DropoutNd(Module): + + def __init__(self, p=0.5, inplace=False): + super(_DropoutNd, self).__init__() + if p < 0 or p > 1: + raise ValueError("dropout probability has to be between 0 and 1, " + "but got {}".format(p)) + self.p = p + self.inplace = inplace + + def extra_repr(self): + inplace_str = ', inplace' if self.inplace else '' + return 'p={}{}'.format(self.p, inplace_str) + + +class Dropout(_DropoutNd): r"""During training, randomly zeroes some of the elements of the input tensor with probability :attr:`p` using samples from a Bernoulli distribution. The elements to zero are randomized on every forward call. @@ -34,25 +49,11 @@ class Dropout(Module): detectors: https://arxiv.org/abs/1207.0580 """ - def __init__(self, p=0.5, inplace=False): - super(Dropout, self).__init__() - if p < 0 or p > 1: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - self.p = p - self.inplace = inplace - def forward(self, input): return F.dropout(input, self.p, self.training, self.inplace) - def __repr__(self): - inplace_str = ', inplace' if self.inplace else '' - return self.__class__.__name__ + '(' \ - + 'p=' + str(self.p) \ - + inplace_str + ')' - -class Dropout2d(Module): +class Dropout2d(_DropoutNd): r"""Randomly zeroes whole channels of the input tensor. The channels to zero-out are randomized on every forward call. @@ -87,25 +88,11 @@ class Dropout2d(Module): http://arxiv.org/abs/1411.4280 """ - def __init__(self, p=0.5, inplace=False): - super(Dropout2d, self).__init__() - if p < 0 or p > 1: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - self.p = p - self.inplace = inplace - def forward(self, input): return F.dropout2d(input, self.p, self.training, self.inplace) - def __repr__(self): - inplace_str = ', inplace' if self.inplace else '' - return self.__class__.__name__ + '(' \ - + 'p=' + str(self.p) \ - + inplace_str + ')' - -class Dropout3d(Module): +class Dropout3d(_DropoutNd): r"""Randomly zeroes whole channels of the input tensor. The channels to zero are randomized on every forward call. @@ -140,23 +127,9 @@ class Dropout3d(Module): http://arxiv.org/abs/1411.4280 """ - def __init__(self, p=0.5, inplace=False): - super(Dropout3d, self).__init__() - if p < 0 or p > 1: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - self.p = p - self.inplace = inplace - def forward(self, input): return F.dropout3d(input, self.p, self.training, self.inplace) - def __repr__(self): - inplace_str = ', inplace' if self.inplace else '' - return self.__class__.__name__ + '(' \ - + 'p=' + str(self.p) \ - + inplace_str + ')' - class AlphaDropout(Module): r"""Applies Alpha Dropout over the input. diff --git a/torch/nn/modules/fold.py b/torch/nn/modules/fold.py index 95aa6ec935..1b285b8667 100644 --- a/torch/nn/modules/fold.py +++ b/torch/nn/modules/fold.py @@ -68,13 +68,11 @@ class Fold(Module): return F.fold(input, self.output_size, self.kernel_size, self.dilation, self.padding, self.stride) - def __repr__(self): - return self.__class__.__name__ + ' (' \ - + 'output_size=' + str(self.output_size) \ - + ', kernel_size=' + str(self.kernel_size) \ - + ', dilation=' + str(self.dilation) \ - + ', padding=' + str(self.padding) \ - + ', stride=' + str(self.stride) + ')' + def extra_repr(self): + return 'output_size={output_size}, kernel_size={kernel_size}, ' \ + 'dilation={dilation}, padding={padding}, stride={stride}'.format( + **self.__dict__ + ) class Unfold(Module): @@ -140,9 +138,6 @@ class Unfold(Module): return F.unfold(input, self.kernel_size, self.dilation, self.padding, self.stride) - def __repr__(self): - return self.__class__.__name__ + ' (' \ - + 'kernel_size=' + str(self.kernel_size) \ - + ', dilation=' + str(self.dilation) \ - + ', padding=' + str(self.padding) \ - + ', stride=' + str(self.stride) + ')' + def extra_repr(self): + return 'kernel_size={kernel_size}, dilation={dilation}, padding={padding},' \ + ' stride={stride}'.format(**self.__dict__) diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index dcbb2f729b..2af36229af 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -54,11 +54,10 @@ class Linear(Module): def forward(self, input): return F.linear(input, self.weight, self.bias) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'in_features=' + str(self.in_features) \ - + ', out_features=' + str(self.out_features) \ - + ', bias=' + str(self.bias is not None) + ')' + def extra_repr(self): + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, self.out_features, self.bias is not None + ) class Bilinear(Module): @@ -115,11 +114,9 @@ class Bilinear(Module): def forward(self, input1, input2): return F.bilinear(input1, input2, self.weight, self.bias) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'in1_features=' + str(self.in1_features) \ - + ', in2_features=' + str(self.in2_features) \ - + ', out_features=' + str(self.out_features) \ - + ', bias=' + str(self.bias is not None) + ')' + def extra_repr(self): + return 'in1_features={}, in2_features={}, out_features={}, bias={}'.format( + self.in1_features, self.in2_features, self.out_features, self.bias is not None + ) # TODO: PartialLinear - maybe in sparse? diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 129d8bc59a..3482164562 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -766,14 +766,39 @@ class Module(object): def _get_name(self): return self.__class__.__name__ + def extra_repr(self): + r"""Set the extra representation of the module + + To print customized extra information, you should reimplement + this method in your own modules. Both single-line and multi-line + strings are acceptable. + """ + return '' + def __repr__(self): - tmpstr = self._get_name() + '(\n' + # We treat the extra repr like the sub-module, one item per line + extra_lines = [] + extra_repr = self.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split('\n') + child_lines = [] for key, module in self._modules.items(): - modstr = module.__repr__() - modstr = _addindent(modstr, 2) - tmpstr = tmpstr + ' (' + key + '): ' + modstr + '\n' - tmpstr = tmpstr + ')' - return tmpstr + mod_str = repr(module) + mod_str = _addindent(mod_str, 2) + child_lines.append('(' + key + '): ' + mod_str) + lines = extra_lines + child_lines + + main_str = self._get_name() + '(' + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += '\n ' + '\n '.join(lines) + '\n' + + main_str += ')' + return main_str def __dir__(self): module_attrs = dir(self.__class__) diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index e29f194d9e..1136d25e1e 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -46,12 +46,8 @@ class LocalResponseNorm(Module): return F.local_response_norm(input, self.size, self.alpha, self.beta, self.k) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + str(self.size) \ - + ', alpha=' + str(self.alpha) \ - + ', beta=' + str(self.beta) \ - + ', k=' + str(self.k) + ')' + def extra_repr(self): + return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__) class CrossMapLRN2d(Module): @@ -67,12 +63,8 @@ class CrossMapLRN2d(Module): return self._backend.CrossMapLRN2d(self.size, self.alpha, self.beta, self.k)(input) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + str(self.size) \ - + ', alpha=' + str(self.alpha) \ - + ', beta=' + str(self.beta) \ - + ', k=' + str(self.k) + ')' + def extra_repr(self): + return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__) class LayerNorm(Module): @@ -153,10 +145,9 @@ class LayerNorm(Module): return F.layer_norm( input, self.normalized_shape, self.weight, self.bias, self.eps) - def __repr__(self): - return ('{name}({normalized_shape}, eps={eps}, ' - ' elementwise_affine={elementwise_affine},' - .format(name=self.__class__.__name__, **self.__dict__)) + def extra_repr(self): + return '{normalized_shape}, eps={eps}, ' \ + 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) class GroupNorm(Module): @@ -223,10 +214,9 @@ class GroupNorm(Module): return F.group_norm( input, self.num_groups, self.weight, self.bias, self.eps) - def __repr__(self): - return ('{name}({num_groups}, {num_channels}, eps={eps}, ' - 'affine={affine},' - .format(name=self.__class__.__name__, **self.__dict__)) + def extra_repr(self): + return '{num_groups}, {num_channels}, eps={eps}, ' \ + 'affine={affine}'.format(**self.__dict__) # TODO: ContrastiveNorm2d diff --git a/torch/nn/modules/padding.py b/torch/nn/modules/padding.py index c87062dc43..cd691ffb3d 100644 --- a/torch/nn/modules/padding.py +++ b/torch/nn/modules/padding.py @@ -6,7 +6,20 @@ from .. import functional as F # TODO: grad_output size asserts in THNN -class ConstantPad1d(Module): +class _ConstantPadNd(Module): + + def __init__(self, value): + super(_ConstantPadNd, self).__init__() + self.value = value + + def forward(self, input): + return F.pad(input, self.padding, 'constant', self.value) + + def extra_repr(self): + return 'padding={}, value={}'.format(self.padding, self.value) + + +class ConstantPad1d(_ConstantPadNd): r"""Pads the input tensor boundaries with a constant value. Args: @@ -30,19 +43,11 @@ class ConstantPad1d(Module): """ def __init__(self, padding, value): - super(ConstantPad1d, self).__init__() + super(ConstantPad1d, self).__init__(value) self.padding = _pair(padding) - self.value = value - - def forward(self, input): - return F.pad(input, self.padding, 'constant', self.value) - - def __repr__(self): - return self.__class__.__name__ + '(' \ - + str(self.padding) + ')' -class ConstantPad2d(Module): +class ConstantPad2d(_ConstantPadNd): r"""Pads the input tensor boundaries with a constant value. For Nd-padding, use :meth:`nn.functional.pad()`. @@ -70,19 +75,11 @@ class ConstantPad2d(Module): """ def __init__(self, padding, value): - super(ConstantPad2d, self).__init__() + super(ConstantPad2d, self).__init__(value) self.padding = _quadruple(padding) - self.value = value - - def forward(self, input): - return F.pad(input, self.padding, 'constant', self.value) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + str(self.padding) + ')' - -class ConstantPad3d(Module): +class ConstantPad3d(_ConstantPadNd): r"""Pads the input tensor boundaries with a constant value. Args: @@ -109,19 +106,20 @@ class ConstantPad3d(Module): """ def __init__(self, padding, value): - super(ConstantPad3d, self).__init__() + super(ConstantPad3d, self).__init__(value) self.padding = _ntuple(6)(padding) - self.value = value + + +class _ReflectionPadNd(Module): def forward(self, input): - return F.pad(input, self.padding, 'constant', self.value) + return F.pad(input, self.padding, 'reflect') - def __repr__(self): - return self.__class__.__name__ + '(' \ - + str(self.padding) + ')' + def extra_repr(self): + return '{}'.format(self.padding) -class ReflectionPad1d(Module): +class ReflectionPad1d(_ReflectionPadNd): r"""Pads the input tensor using the reflection of the input boundary. Args: @@ -148,15 +146,8 @@ class ReflectionPad1d(Module): super(ReflectionPad1d, self).__init__() self.padding = _pair(padding) - def forward(self, input): - return F.pad(input, self.padding, 'reflect') - - def __repr__(self): - return self.__class__.__name__ + '(' \ - + str(self.padding) + ')' - -class ReflectionPad2d(Module): +class ReflectionPad2d(_ReflectionPadNd): r"""Pads the input tensor using the reflection of the input boundary. Args: @@ -185,15 +176,17 @@ class ReflectionPad2d(Module): super(ReflectionPad2d, self).__init__() self.padding = _quadruple(padding) + +class _ReplicationPadNd(Module): + def forward(self, input): - return F.pad(input, self.padding, 'reflect') + return F.pad(input, self.padding, 'replicate') - def __repr__(self): - return self.__class__.__name__ + '(' \ - + str(self.padding) + ')' + def extra_repr(self): + return '{}'.format(self.padding) -class ReplicationPad1d(Module): +class ReplicationPad1d(_ReplicationPadNd): r"""Pads the input tensor using replication of the input boundary. Args: @@ -220,15 +213,8 @@ class ReplicationPad1d(Module): super(ReplicationPad1d, self).__init__() self.padding = _pair(padding) - def forward(self, input): - return F.pad(input, self.padding, 'replicate') - def __repr__(self): - return self.__class__.__name__ + '(' \ - + str(self.padding) + ')' - - -class ReplicationPad2d(Module): +class ReplicationPad2d(_ReplicationPadNd): r"""Pads the input tensor using replication of the input boundary. Args: @@ -257,15 +243,8 @@ class ReplicationPad2d(Module): super(ReplicationPad2d, self).__init__() self.padding = _quadruple(padding) - def forward(self, input): - return F.pad(input, self.padding, 'replicate') - - def __repr__(self): - return self.__class__.__name__ + '(' \ - + str(self.padding) + ')' - -class ReplicationPad3d(Module): +class ReplicationPad3d(_ReplicationPadNd): r"""Pads the input tensor using replication of the input boundary. Args: @@ -295,13 +274,6 @@ class ReplicationPad3d(Module): super(ReplicationPad3d, self).__init__() self.padding = _ntuple(6)(padding) - def forward(self, input): - return F.pad(input, self.padding, 'replicate') - - def __repr__(self): - return self.__class__.__name__ + '(' \ - + str(self.padding) + ')' - class ZeroPad2d(ConstantPad2d): r"""Pads the input tensor boundaries with zero. diff --git a/torch/nn/modules/pixelshuffle.py b/torch/nn/modules/pixelshuffle.py index 24c6e00d18..786ba40128 100644 --- a/torch/nn/modules/pixelshuffle.py +++ b/torch/nn/modules/pixelshuffle.py @@ -39,5 +39,5 @@ class PixelShuffle(Module): def forward(self, input): return F.pixel_shuffle(input, self.upscale_factor) - def __repr__(self): - return self.__class__.__name__ + '(upscale_factor=' + str(self.upscale_factor) + ')' + def extra_repr(self): + return 'upscale_factor={}'.format(self.upscale_factor) diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index 6f00417545..8292569973 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -6,7 +6,24 @@ from .utils import _single, _pair, _triple from .. import functional as F -class MaxPool1d(Module): +class _MaxPoolNd(Module): + + def __init__(self, kernel_size, stride=None, padding=0, dilation=1, + return_indices=False, ceil_mode=False): + super(_MaxPoolNd, self).__init__() + self.kernel_size = kernel_size + self.stride = stride or kernel_size + self.padding = padding + self.dilation = dilation + self.return_indices = return_indices + self.ceil_mode = ceil_mode + + def extra_repr(self): + return 'kernel_size={kernel_size}, stride={stride}, padding={padding}' \ + ', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__) + + +class MaxPool1d(_MaxPoolNd): r"""Applies a 1D max pooling over an input signal composed of several input planes. @@ -52,31 +69,17 @@ class MaxPool1d(Module): https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ - def __init__(self, kernel_size, stride=None, padding=0, dilation=1, - return_indices=False, ceil_mode=False): - super(MaxPool1d, self).__init__() - self.kernel_size = kernel_size - self.stride = stride or kernel_size - self.padding = padding - self.dilation = dilation - self.return_indices = return_indices - self.ceil_mode = ceil_mode - def forward(self, input): return F.max_pool1d(input, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode, self.return_indices) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'kernel_size=' + str(self.kernel_size) \ - + ', stride=' + str(self.stride) \ - + ', padding=' + str(self.padding) \ - + ', dilation=' + str(self.dilation) \ - + ', ceil_mode=' + str(self.ceil_mode) + ')' + def extra_repr(self): + return 'kernel_size={kernel_size}, stride={stride}, padding={padding}' \ + ', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__) -class MaxPool2d(Module): +class MaxPool2d(_MaxPoolNd): r"""Applies a 2D max pooling over an input signal composed of several input planes. @@ -134,38 +137,102 @@ class MaxPool2d(Module): https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ - def __init__(self, kernel_size, stride=None, padding=0, dilation=1, - return_indices=False, ceil_mode=False): - super(MaxPool2d, self).__init__() - self.kernel_size = kernel_size - self.stride = stride or kernel_size - self.padding = padding - self.dilation = dilation - self.return_indices = return_indices - self.ceil_mode = ceil_mode - def forward(self, input): return F.max_pool2d(input, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode, self.return_indices) - def __repr__(self): - kh, kw = _pair(self.kernel_size) - dh, dw = _pair(self.stride) - padh, padw = _pair(self.padding) - dilh, dilw = _pair(self.dilation) - padding_str = ', padding=(' + str(padh) + ', ' + str(padw) + ')' \ - if padh != 0 or padw != 0 else '' - dilation_str = (', dilation=(' + str(dilh) + ', ' + str(dilw) + ')' - if dilh != 0 and dilw != 0 else '') - ceil_str = ', ceil_mode=' + str(self.ceil_mode) - return self.__class__.__name__ + '(' \ - + 'kernel_size=(' + str(kh) + ', ' + str(kw) + ')' \ - + ', stride=(' + str(dh) + ', ' + str(dw) + ')' \ - + padding_str + dilation_str + ceil_str + ')' - - -class MaxUnpool1d(Module): +# def extra_repr(self): +# kh, kw = _pair(self.kernel_size) +# dh, dw = _pair(self.stride) +# padh, padw = _pair(self.padding) +# dilh, dilw = _pair(self.dilation) +# padding_str = ', padding=(' + str(padh) + ', ' + str(padw) + ')' \ +# if padh != 0 or padw != 0 else '' +# dilation_str = (', dilation=(' + str(dilh) + ', ' + str(dilw) + ')' +# if dilh != 0 and dilw != 0 else '') +# ceil_str = ', ceil_mode=' + str(self.ceil_mode) +# return 'kernel_size=(' + str(kh) + ', ' + str(kw) + ')' \ +# + ', stride=(' + str(dh) + ', ' + str(dw) + ')' \ +# + padding_str + dilation_str + ceil_str + + +class MaxPool3d(_MaxPoolNd): + r"""Applies a 3D max pooling over an input signal composed of several input + planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`, + output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)` + can be precisely described as: + + .. math:: + + \begin{align*} + \text{out}(N_i, C_j, d, h, w) &= \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} + \text{input}(N_i, C_j, \text{stride}[0] * k + d,\\ &\text{stride}[1] * h + m, \text{stride}[2] * w + n) + \end{align*} + + If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides + for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: + + - a single ``int`` -- in which case the same value is used for the depth, height and width dimension + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + the second `int` for the height dimension and the third `int` for the width dimension + + Args: + kernel_size: the size of the window to take a max over + stride: the stride of the window. Default value is :attr:`kernel_size` + padding: implicit zero padding to be added on all three sides + dilation: a parameter that controls the stride of elements in the window + return_indices: if ``True``, will return the max indices along with the outputs. + Useful when Unpooling later + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` where + + .. math:: + D_{out} = \left\lfloor\frac{D_{in} + 2 * \text{padding}[0] - \text{dilation}[0] * + (\text{kernel_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor + + H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding}[1] - \text{dilation}[1] * + (\text{kernel_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor + + W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding}[2] - \text{dilation}[2] * + (\text{kernel_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor + + Examples:: + + >>> # pool of square window of size=3, stride=2 + >>> m = nn.MaxPool3d(3, stride=2) + >>> # pool of non-square window + >>> m = nn.MaxPool3d((3, 2, 2), stride=(2, 1, 2)) + >>> input = torch.randn(20, 16, 50,44, 31) + >>> output = m(input) + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + + def forward(self, input): + return F.max_pool3d(input, self.kernel_size, self.stride, + self.padding, self.dilation, self.ceil_mode, + self.return_indices) + + +class _MaxUnpoolNd(Module): + + def extra_repr(self): + return 'kernel_size={}, stride={}, padding={}'.format( + self.kernel_size, self.stride, self.padding + ) + + +class MaxUnpool1d(_MaxUnpoolNd): r"""Computes a partial inverse of :class:`MaxPool1d`. :class:`MaxPool1d` is not fully invertible, since the non-maximal values are lost. @@ -232,21 +299,15 @@ class MaxUnpool1d(Module): def __init__(self, kernel_size, stride=None, padding=0): super(MaxUnpool1d, self).__init__() self.kernel_size = _single(kernel_size) - self.stride = _single(stride if stride is not None else kernel_size) + self.stride = _single(stride or kernel_size) self.padding = _single(padding) def forward(self, input, indices, output_size=None): return F.max_unpool1d(input, indices, self.kernel_size, self.stride, self.padding, output_size) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'kernel_size=' + str(self.kernel_size) \ - + ', stride=' + str(self.stride) \ - + ', padding=' + str(self.padding) + ')' - -class MaxUnpool2d(Module): +class MaxUnpool2d(_MaxUnpoolNd): r"""Computes a partial inverse of :class:`MaxPool2d`. :class:`MaxPool2d` is not fully invertible, since the non-maximal values are lost. @@ -317,21 +378,15 @@ class MaxUnpool2d(Module): def __init__(self, kernel_size, stride=None, padding=0): super(MaxUnpool2d, self).__init__() self.kernel_size = _pair(kernel_size) - self.stride = _pair(stride if stride is not None else kernel_size) + self.stride = _pair(stride or kernel_size) self.padding = _pair(padding) def forward(self, input, indices, output_size=None): return F.max_unpool2d(input, indices, self.kernel_size, self.stride, self.padding, output_size) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'kernel_size=' + str(self.kernel_size) \ - + ', stride=' + str(self.stride) \ - + ', padding=' + str(self.padding) + ')' - -class MaxUnpool3d(Module): +class MaxUnpool3d(_MaxUnpoolNd): r"""Computes a partial inverse of :class:`MaxPool3d`. :class:`MaxPool3d` is not fully invertible, since the non-maximal values are lost. @@ -383,21 +438,32 @@ class MaxUnpool3d(Module): def __init__(self, kernel_size, stride=None, padding=0): super(MaxUnpool3d, self).__init__() self.kernel_size = _triple(kernel_size) - self.stride = _triple(stride if stride is not None else kernel_size) + self.stride = _triple(stride or kernel_size) self.padding = _triple(padding) def forward(self, input, indices, output_size=None): return F.max_unpool3d(input, indices, self.kernel_size, self.stride, self.padding, output_size) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'kernel_size=' + str(self.kernel_size) \ - + ', stride=' + str(self.stride) \ - + ', padding=' + str(self.padding) + ')' +class _AvgPoolNd(Module): -class AvgPool1d(Module): + def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, + count_include_pad=True): + super(_AvgPoolNd, self).__init__() + self.kernel_size = _single(kernel_size) + self.stride = _single(stride if stride is not None else kernel_size) + self.padding = _single(padding) + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + + def extra_repr(self): + return 'kernel_size={}, stride={}, padding={}'.format( + self.kernel_size, self.stride, self.padding + ) + + +class AvgPool1d(_AvgPoolNd): r"""Applies a 1D average pooling over an input signal composed of several input planes. @@ -444,30 +510,13 @@ class AvgPool1d(Module): [torch.FloatTensor of size (1,1,3)] """ - def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, - count_include_pad=True): - super(AvgPool1d, self).__init__() - self.kernel_size = _single(kernel_size) - self.stride = _single(stride if stride is not None else kernel_size) - self.padding = _single(padding) - self.ceil_mode = ceil_mode - self.count_include_pad = count_include_pad - def forward(self, input): return F.avg_pool1d( input, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'kernel_size=' + str(self.kernel_size) \ - + ', stride=' + str(self.stride) \ - + ', padding=' + str(self.padding) \ - + ', ceil_mode=' + str(self.ceil_mode) \ - + ', count_include_pad=' + str(self.count_include_pad) + ')' - -class AvgPool2d(Module): +class AvgPool2d(_AvgPoolNd): r"""Applies a 2D average pooling over an input signal composed of several input planes. @@ -519,114 +568,12 @@ class AvgPool2d(Module): >>> output = m(input) """ - def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, - count_include_pad=True): - super(AvgPool2d, self).__init__() - self.kernel_size = kernel_size - self.stride = stride or kernel_size - self.padding = padding - self.ceil_mode = ceil_mode - self.count_include_pad = count_include_pad - def forward(self, input): return F.avg_pool2d(input, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'kernel_size=' + str(self.kernel_size) \ - + ', stride=' + str(self.stride) \ - + ', padding=' + str(self.padding) \ - + ', ceil_mode=' + str(self.ceil_mode) \ - + ', count_include_pad=' + str(self.count_include_pad) + ')' - - -class MaxPool3d(Module): - r"""Applies a 3D max pooling over an input signal composed of several input - planes. - - In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`, - output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)` - can be precisely described as: - - .. math:: - - \begin{align*} - \text{out}(N_i, C_j, d, h, w) &= \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} - \text{input}(N_i, C_j, \text{stride}[0] * k + d,\\ &\text{stride}[1] * h + m, \text{stride}[2] * w + n) - \end{align*} - - If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides - for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. - It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - - The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: - - - a single ``int`` -- in which case the same value is used for the depth, height and width dimension - - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, - the second `int` for the height dimension and the third `int` for the width dimension - - Args: - kernel_size: the size of the window to take a max over - stride: the stride of the window. Default value is :attr:`kernel_size` - padding: implicit zero padding to be added on all three sides - dilation: a parameter that controls the stride of elements in the window - return_indices: if ``True``, will return the max indices along with the outputs. - Useful when Unpooling later - ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape - - Shape: - - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` - - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` where - - .. math:: - D_{out} = \left\lfloor\frac{D_{in} + 2 * \text{padding}[0] - \text{dilation}[0] * - (\text{kernel_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor - - H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding}[1] - \text{dilation}[1] * - (\text{kernel_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor - - W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding}[2] - \text{dilation}[2] * - (\text{kernel_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor - - Examples:: - - >>> # pool of square window of size=3, stride=2 - >>> m = nn.MaxPool3d(3, stride=2) - >>> # pool of non-square window - >>> m = nn.MaxPool3d((3, 2, 2), stride=(2, 1, 2)) - >>> input = torch.randn(20, 16, 50,44, 31) - >>> output = m(input) - - .. _link: - https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md - """ - def __init__(self, kernel_size, stride=None, padding=0, dilation=1, - return_indices=False, ceil_mode=False): - super(MaxPool3d, self).__init__() - self.kernel_size = kernel_size - self.stride = stride or kernel_size - self.padding = padding - self.dilation = dilation - self.return_indices = return_indices - self.ceil_mode = ceil_mode - - def forward(self, input): - return F.max_pool3d(input, self.kernel_size, self.stride, - self.padding, self.dilation, self.ceil_mode, - self.return_indices) - - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'kernel_size=' + str(self.kernel_size) \ - + ', stride=' + str(self.stride) \ - + ', padding=' + str(self.padding) \ - + ', dilation=' + str(self.dilation) \ - + ', ceil_mode=' + str(self.ceil_mode) + ')' - - -class AvgPool3d(Module): +class AvgPool3d(_AvgPoolNd): r"""Applies a 3D average pooling over an input signal composed of several input planes. @@ -683,15 +630,6 @@ class AvgPool3d(Module): >>> output = m(input) """ - def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, - count_include_pad=True): - super(AvgPool3d, self).__init__() - self.kernel_size = kernel_size - self.stride = stride or kernel_size - self.padding = padding - self.ceil_mode = ceil_mode - self.count_include_pad = count_include_pad - def forward(self, input): return F.avg_pool3d(input, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) @@ -702,14 +640,6 @@ class AvgPool3d(Module): self.__dict__.setdefault('ceil_mode', False) self.__dict__.setdefault('count_include_pad', True) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'kernel_size=' + str(self.kernel_size) \ - + ', stride=' + str(self.stride) \ - + ', padding=' + str(self.padding) \ - + ', ceil_mode=' + str(self.ceil_mode) \ - + ', count_include_pad=' + str(self.count_include_pad) + ')' - class FractionalMaxPool2d(Module): r"""Applies a 2D fractional max pooling over an input signal composed of several input planes. @@ -768,7 +698,58 @@ class FractionalMaxPool2d(Module): _random_samples=samples) -class LPPool2d(Module): +class _LPPoolNd(Module): + + def __init__(self, norm_type, kernel_size, stride=None, ceil_mode=False): + super(_LPPoolNd, self).__init__() + self.norm_type = norm_type + self.kernel_size = kernel_size + self.stride = stride + self.ceil_mode = ceil_mode + + def extra_repr(self): + return 'norm_type={norm_type}, kernel_size{kernel_size}, stride={stride}, ' \ + 'ceil_mode={ceil_mode}'.format(**self.__dict__) + + +class LPPool1d(_LPPoolNd): + r"""Applies a 1D power-average pooling over an input signal composed of several input + planes. + + On each window, the function computed is: + + .. math:: + f(X) = \sqrt[p]{\sum_{x \in X} x^{p}} + + - At p = infinity, one gets Max Pooling + - At p = 1, one gets Sum Pooling (which is proportional to Average Pooling) + + Args: + kernel_size: a single int, the size of the window + stride: a single int, the stride of the window. Default value is :attr:`kernel_size` + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + + Shape: + - Input: :math:`(N, C, L_{in})` + - Output: :math:`(N, C, L_{out})` where + + .. math:: + L_{out} = \left\lfloor\frac{L_{in} + + 2 * \text{padding} - \text{kernel_size}}{\text{stride}} + 1\right\rfloor + + Examples:: + >>> # power-2 pool of window of length 3, with stride 2. + >>> m = nn.LPPool1d(2, 3, stride=2) + >>> input = torch.randn(20, 16, 50) + >>> output = m(input) + """ + + def forward(self, input): + return F.lp_pool1d(input, self.norm_type, self.kernel_size, + self.stride, self.ceil_mode) + + +class LPPool2d(_LPPoolNd): r"""Applies a 2D power-average pooling over an input signal composed of several input planes. @@ -813,77 +794,23 @@ class LPPool2d(Module): """ - def __init__(self, norm_type, kernel_size, stride=None, ceil_mode=False): - super(LPPool2d, self).__init__() - self.norm_type = norm_type - self.kernel_size = kernel_size - self.stride = stride - self.ceil_mode = ceil_mode - def forward(self, input): return F.lp_pool2d(input, self.norm_type, self.kernel_size, self.stride, self.ceil_mode) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + str(self.norm_type) + ', ' \ - + str(self.kernel_size) + ', ' \ - + 'stride=' + str(self.stride) + ', ' \ - + 'ceil_mode=' + str(self.ceil_mode) + ')' +class _AdaptiveMaxPoolNd(Module): -class LPPool1d(Module): - r"""Applies a 1D power-average pooling over an input signal composed of several input - planes. - - On each window, the function computed is: - - .. math:: - f(X) = \sqrt[p]{\sum_{x \in X} x^{p}} - - - At p = infinity, one gets Max Pooling - - At p = 1, one gets Sum Pooling (which is proportional to Average Pooling) - - Args: - kernel_size: a single int, the size of the window - stride: a single int, the stride of the window. Default value is :attr:`kernel_size` - ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape - - Shape: - - Input: :math:`(N, C, L_{in})` - - Output: :math:`(N, C, L_{out})` where - - .. math:: - L_{out} = \left\lfloor\frac{L_{in} + - 2 * \text{padding} - \text{kernel_size}}{\text{stride}} + 1\right\rfloor - - Examples:: - >>> # power-2 pool of window of length 3, with stride 2. - >>> m = nn.LPPool1d(2, 3, stride=2) - >>> input = torch.randn(20, 16, 50) - >>> output = m(input) - """ - - def __init__(self, norm_type, kernel_size, stride=None, ceil_mode=False): - super(LPPool1d, self).__init__() - self.norm_type = norm_type - self.kernel_size = kernel_size - self.stride = stride - self.ceil_mode = ceil_mode - - def forward(self, input): - return F.lp_pool1d(input, self.norm_type, self.kernel_size, - self.stride, self.ceil_mode) + def __init__(self, output_size, return_indices=False): + super(_AdaptiveMaxPoolNd, self).__init__() + self.output_size = output_size + self.return_indices = return_indices - def __repr__(self): - return self.__class__.__name__ + '(' \ - + str(self.norm_type) + ', ' \ - + str(self.kernel_size) + ', ' \ - + 'stride=' + str(self.stride) + ', ' \ - + 'ceil_mode=' + str(self.ceil_mode) + ')' + def extra_repr(self): + return 'output_size={}'.format(self.output_size) -class AdaptiveMaxPool1d(Module): +class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd): r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes. The output size is H, for any input size. @@ -902,20 +829,11 @@ class AdaptiveMaxPool1d(Module): """ - def __init__(self, output_size, return_indices=False): - super(AdaptiveMaxPool1d, self).__init__() - self.output_size = output_size - self.return_indices = return_indices - def forward(self, input): return F.adaptive_max_pool1d(input, self.output_size, self.return_indices) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'output_size=' + str(self.output_size) + ')' - -class AdaptiveMaxPool2d(Module): +class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd): r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes. The output is of size H x W, for any input size. @@ -945,20 +863,11 @@ class AdaptiveMaxPool2d(Module): """ - def __init__(self, output_size, return_indices=False): - super(AdaptiveMaxPool2d, self).__init__() - self.output_size = output_size - self.return_indices = return_indices - def forward(self, input): return F.adaptive_max_pool2d(input, self.output_size, self.return_indices) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'output_size=' + str(self.output_size) + ')' - -class AdaptiveMaxPool3d(Module): +class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd): r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes. The output is of size D x H x W, for any input size. @@ -989,20 +898,21 @@ class AdaptiveMaxPool3d(Module): """ - def __init__(self, output_size, return_indices=False): - super(AdaptiveMaxPool3d, self).__init__() - self.output_size = output_size - self.return_indices = return_indices - def forward(self, input): return F.adaptive_max_pool3d(input, self.output_size, self.return_indices) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'output_size=' + str(self.output_size) + ')' +class _AdaptiveAvgPoolNd(Module): -class AdaptiveAvgPool1d(Module): + def __init__(self, output_size): + super(_AdaptiveAvgPoolNd, self).__init__() + self.output_size = output_size + + def extra_repr(self): + return 'output_size={}'.format(self.output_size) + + +class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd): r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes. The output size is H, for any input size. @@ -1019,19 +929,11 @@ class AdaptiveAvgPool1d(Module): """ - def __init__(self, output_size): - super(AdaptiveAvgPool1d, self).__init__() - self.output_size = output_size - def forward(self, input): return F.adaptive_avg_pool1d(input, self.output_size) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'output_size=' + str(self.output_size) + ')' - -class AdaptiveAvgPool2d(Module): +class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd): r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes. The output is of size H x W, for any input size. @@ -1059,19 +961,11 @@ class AdaptiveAvgPool2d(Module): """ - def __init__(self, output_size): - super(AdaptiveAvgPool2d, self).__init__() - self.output_size = output_size - def forward(self, input): return F.adaptive_avg_pool2d(input, self.output_size) - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'output_size=' + str(self.output_size) + ')' - -class AdaptiveAvgPool3d(Module): +class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd): r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes. The output is of size D x H x W, for any input size. @@ -1099,13 +993,5 @@ class AdaptiveAvgPool3d(Module): """ - def __init__(self, output_size): - super(AdaptiveAvgPool3d, self).__init__() - self.output_size = output_size - def forward(self, input): return F.adaptive_avg_pool3d(input, self.output_size) - - def __repr__(self): - return self.__class__.__name__ + '(' \ - + 'output_size=' + str(self.output_size) + ')' diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index 3675be08fd..02731cffb7 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -195,8 +195,8 @@ class RNNBase(Module): output = PackedSequence(output, batch_sizes) return output, hidden - def __repr__(self): - s = '{name}({input_size}, {hidden_size}' + def extra_repr(self): + s = '{input_size}, {hidden_size}' if self.num_layers != 1: s += ', num_layers={num_layers}' if self.bias is not True: @@ -207,8 +207,7 @@ class RNNBase(Module): s += ', dropout={dropout}' if self.bidirectional is not False: s += ', bidirectional={bidirectional}' - s += ')' - return s.format(name=self.__class__.__name__, **self.__dict__) + return s.format(**self.__dict__) def __setstate__(self, d): super(RNNBase, self).__setstate__(d) @@ -487,14 +486,13 @@ class GRU(RNNBase): class RNNCellBase(Module): - def __repr__(self): - s = '{name}({input_size}, {hidden_size}' + def extra_repr(self): + s = '{input_size}, {hidden_size}' if 'bias' in self.__dict__ and self.bias is not True: s += ', bias={bias}' if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh": s += ', nonlinearity={nonlinearity}' - s += ')' - return s.format(name=self.__class__.__name__, **self.__dict__) + return s.format(**self.__dict__) def check_forward_input(self, input): if input.size(1) != self.input_size: diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py index b52d94a6f2..adc19cd723 100644 --- a/torch/nn/modules/sparse.py +++ b/torch/nn/modules/sparse.py @@ -106,8 +106,8 @@ class Embedding(Module): input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) - def __repr__(self): - s = '{name}({num_embeddings}, {embedding_dim}' + def extra_repr(self): + s = '{num_embeddings}, {embedding_dim}' if self.padding_idx is not None: s += ', padding_idx={padding_idx}' if self.max_norm is not None: @@ -118,8 +118,7 @@ class Embedding(Module): s += ', scale_grad_by_freq={scale_grad_by_freq}' if self.sparse is not False: s += ', sparse=True' - s += ')' - return s.format(name=self.__class__.__name__, **self.__dict__) + return s.format(**self.__dict__) @classmethod def from_pretrained(cls, embeddings, freeze=True): @@ -238,8 +237,8 @@ class EmbeddingBag(Module): self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse) - def __repr__(self): - s = '{name}({num_embeddings}, {embedding_dim}' + def extra_repr(self): + s = '{num_embeddings}, {embedding_dim}' if self.max_norm is not None: s += ', max_norm={max_norm}' if self.norm_type != 2: @@ -247,7 +246,6 @@ class EmbeddingBag(Module): if self.scale_grad_by_freq is not False: s += ', scale_grad_by_freq={scale_grad_by_freq}' s += ', mode={mode}' - s += ')' - return s.format(name=self.__class__.__name__, **self.__dict__) + return s.format(**self.__dict__) # TODO: SparseLinear diff --git a/torch/nn/modules/upsampling.py b/torch/nn/modules/upsampling.py index 681e9ff59f..ec00fbaeb8 100644 --- a/torch/nn/modules/upsampling.py +++ b/torch/nn/modules/upsampling.py @@ -138,13 +138,13 @@ class Upsample(Module): def forward(self, input): return F.upsample(input, self.size, self.scale_factor, self.mode, self.align_corners) - def __repr__(self): + def extra_repr(self): if self.scale_factor is not None: info = 'scale_factor=' + str(self.scale_factor) else: info = 'size=' + str(self.size) info += ', mode=' + self.mode - return self.__class__.__name__ + '(' + info + ')' + return info class UpsamplingNearest2d(Upsample): |