diff options
author | Narine Kokhlikyan <narine@fb.com> | 2019-03-18 12:21:52 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-03-18 12:33:20 -0700 |
commit | 670f509984efce8aab0b0b5b166ff01472af6f1c (patch) | |
tree | 2e08299b465ce4de367a24c16b65570786e00f82 /test | |
parent | 2b7a5d1876c2a34da6d16e843ddb429493bd666e (diff) | |
download | pytorch-670f509984efce8aab0b0b5b166ff01472af6f1c.tar.gz pytorch-670f509984efce8aab0b0b5b166ff01472af6f1c.tar.bz2 pytorch-670f509984efce8aab0b0b5b166ff01472af6f1c.zip |
Circular Convolution Function via circular padding (#17240)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17240
Added circular padding in addition to zero padding to Conv1D, Conv2D and Conv3D based on the solution suggested in: https://github.com/pytorch/pytorch/issues/3858
Reviewed By: ezyang
Differential Revision: D14126416
fbshipit-source-id: a2f1587503ee0cfff98d5cb0d5b0a600ef8aaeb4
Diffstat (limited to 'test')
-rw-r--r-- | test/common_nn.py | 173 |
1 files changed, 173 insertions, 0 deletions
diff --git a/test/common_nn.py b/test/common_nn.py index ca82a5fec6..afac197479 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -2260,6 +2260,107 @@ new_module_tests = [ input_size=(), desc='scalar', ), + dict( + fullname='Padding12_1dcircular', + constructor=wrap_functional(F.pad, pad=(1, 2), mode='circular'), + input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]), + reference_fn=lambda i, _: padding1d_circular(i, (1, 2)), + skip_double=TEST_WITH_ROCM, + pickle=False, + ), + dict( + fullname='Padding31_1dcircular', + constructor=wrap_functional(F.pad, pad=(3, 1), mode='circular'), + input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]), + reference_fn=lambda i, _: padding1d_circular(i, (3, 1)), + skip_double=TEST_WITH_ROCM, + pickle=False, + ), + dict( + fullname='Padding33_1dcircular', + constructor=wrap_functional(F.pad, pad=(3, 3), mode='circular'), + input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]), + reference_fn=lambda i, _: padding1d_circular(i, (3, 3)), + skip_double=TEST_WITH_ROCM, + pickle=False, + ), + dict( + fullname='Padding1221_2dcircular', + constructor=wrap_functional(F.pad, pad=(1, 2, 2, 1), mode='circular'), + input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 1, 2, 3]), + reference_fn=lambda i, _: padding2d_circular(i, (1, 2, 2, 1)), + skip_double=TEST_WITH_ROCM, + pickle=False, + ), + dict( + fullname='Padding2322_2dcircular', + constructor=wrap_functional(F.pad, pad=(2, 3, 2, 2), mode='circular'), + input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 1, 2, 3]), + reference_fn=lambda i, _: padding2d_circular(i, (2, 3, 2, 2)), + skip_double=TEST_WITH_ROCM, + pickle=False, + ), + dict( + fullname='Padding3331_2dcircular', + constructor=wrap_functional(F.pad, pad=(3, 3, 3, 1), mode='circular'), + input_fn=lambda: torch.arange(9, out=torch.DoubleTensor()).reshape([1, 1, 3, 3]), + reference_fn=lambda i, _: padding2d_circular(i, (3, 3, 3, 1)), + skip_double=TEST_WITH_ROCM, + pickle=False, + ), + dict( + fullname='Padding122112_3dcircular', + constructor=wrap_functional(F.pad, pad=(1, 2, 2, 1, 1, 2), mode='circular'), + input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]), + reference_fn=lambda i, _: padding3d_circular(i, (1, 2, 2, 1, 1, 2)), + skip_double=TEST_WITH_ROCM, + pickle=False, + ), + dict( + fullname='Padding322112_3dcircular', + constructor=wrap_functional(F.pad, pad=(3, 2, 2, 1, 1, 2), mode='circular'), + input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]), + reference_fn=lambda i, _: padding3d_circular(i, (3, 2, 2, 1, 1, 2)), + skip_double=TEST_WITH_ROCM, + pickle=False, + ), + dict( + fullname='Padding332122_3dcircular', + constructor=wrap_functional(F.pad, pad=(3, 3, 2, 1, 2, 2), mode='circular'), + input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]), + reference_fn=lambda i, _: padding3d_circular(i, (3, 3, 2, 1, 2, 2)), + skip_double=TEST_WITH_ROCM, + pickle=False, + ), + + dict( + module_name='Conv1d', + constructor_args=(3, 4, 2, 2, (1,), 1, 1, True, 'circular'), + input_size=(2, 3, 5,), + cudnn=True, + desc='stride1_pad1circular', + ), + dict( + module_name='Conv1d', + constructor_args=(3, 4, 2, 2, (2,), 1, 1, True, 'circular'), + input_size=(2, 3, 5,), + cudnn=True, + desc='stride1_pad2circular', + ), + dict( + module_name='Conv2d', + constructor_args=(3, 4, (3, 3), (2, 2), (1, 2), 1, 1, True, 'circular'), + input_size=(2, 3, 3, 3), + cudnn=True, + desc='pad2circular' + ), + dict( + module_name='Conv3d', + constructor_args=(3, 4, 2, 2, (1, 2, 3), 1, 1, True, 'circular'), + input_size=(2, 3, 3, 3, 3), + cudnn=True, + desc='stride_pad1circular', + ), ] @@ -2501,6 +2602,78 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0 output = output.to(dt) return output + +def padding1d_circular(input, pad): + r""" input: + [[[0., 1., 2.], + [3., 4., 5.]]] + pad: (1, 2) + output: + [[[2., 0., 1., 2., 0., 1.], + [5., 3., 4., 5., 3., 4.]]] + """ + return torch.cat([input[:, :, -pad[0]:], input, + input[:, :, 0:pad[1]]], dim=2) + + +def padding2d_circular(input, pad): + r"""input: + [[[[0., 1., 2], + [3., 4., 5.]]]] + pad: (1, 2, 2, 1) + output: + [[[[2., 0., 1., 2., 0., 1.], + [5., 3., 4., 5., 3., 4.], + [2., 0., 1., 2., 0., 1.], + [5., 3., 4., 5., 3., 4.], + [2., 0., 1., 2., 0., 1.]]]] + """ + input = torch.cat([input[:, :, -pad[2]:], input, input[:, :, 0:pad[3]]], dim=2) + return torch.cat([input[:, :, :, -pad[0]:], input, input[:, :, :, 0:pad[1]]], dim=3) + + +def padding3d_circular(input, pad): + r"""input: + [[[[[ 0., 1., 2.], + [ 3., 4., 5.]], + [[ 6., 7., 8.], + [ 9., 10., 11.]]]]] + pad: (1, 2, 2, 1, 1, 2) + output: [[[[[ 8., 6., 7., 8., 6., 7.], + [11., 9., 10., 11., 9., 10.], + [ 8., 6., 7., 8., 6., 7.], + [11., 9., 10., 11., 9., 10.], + [ 8., 6., 7., 8., 6., 7.]], + + [[ 2., 0., 1., 2., 0., 1.], + [ 5., 3., 4., 5., 3., 4.], + [ 2., 0., 1., 2., 0., 1.], + [ 5., 3., 4., 5., 3., 4.], + [ 2., 0., 1., 2., 0., 1.]], + + [[ 8., 6., 7., 8., 6., 7.], + [11., 9., 10., 11., 9., 10.], + [ 8., 6., 7., 8., 6., 7.], + [11., 9., 10., 11., 9., 10.], + [ 8., 6., 7., 8., 6., 7.]], + + [[ 2., 0., 1., 2., 0., 1.], + [ 5., 3., 4., 5., 3., 4.], + [ 2., 0., 1., 2., 0., 1.], + [ 5., 3., 4., 5., 3., 4.], + [ 2., 0., 1., 2., 0., 1.]], + + [[ 8., 6., 7., 8., 6., 7.], + [11., 9., 10., 11., 9., 10.], + [ 8., 6., 7., 8., 6., 7.], + [11., 9., 10., 11., 9., 10.], + [ 8., 6., 7., 8., 6., 7.]]]]] + """ + input = torch.cat([input[:, :, -pad[4]:], input, input[:, :, 0:pad[5]]], dim=2) + input = torch.cat([input[:, :, :, -pad[2]:], input, input[:, :, :, 0:pad[3]]], dim=3) + return torch.cat([input[:, :, :, :, -pad[0]:], input, input[:, :, :, :, 0:pad[1]]], dim=4) + + loss_reference_fns = { 'KLDivLoss': kldivloss_reference, 'NLLLoss': nllloss_reference, |