summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorNarine Kokhlikyan <narine@fb.com>2019-03-18 12:21:52 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-18 12:33:20 -0700
commit670f509984efce8aab0b0b5b166ff01472af6f1c (patch)
tree2e08299b465ce4de367a24c16b65570786e00f82 /test
parent2b7a5d1876c2a34da6d16e843ddb429493bd666e (diff)
downloadpytorch-670f509984efce8aab0b0b5b166ff01472af6f1c.tar.gz
pytorch-670f509984efce8aab0b0b5b166ff01472af6f1c.tar.bz2
pytorch-670f509984efce8aab0b0b5b166ff01472af6f1c.zip
Circular Convolution Function via circular padding (#17240)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17240 Added circular padding in addition to zero padding to Conv1D, Conv2D and Conv3D based on the solution suggested in: https://github.com/pytorch/pytorch/issues/3858 Reviewed By: ezyang Differential Revision: D14126416 fbshipit-source-id: a2f1587503ee0cfff98d5cb0d5b0a600ef8aaeb4
Diffstat (limited to 'test')
-rw-r--r--test/common_nn.py173
1 files changed, 173 insertions, 0 deletions
diff --git a/test/common_nn.py b/test/common_nn.py
index ca82a5fec6..afac197479 100644
--- a/test/common_nn.py
+++ b/test/common_nn.py
@@ -2260,6 +2260,107 @@ new_module_tests = [
input_size=(),
desc='scalar',
),
+ dict(
+ fullname='Padding12_1dcircular',
+ constructor=wrap_functional(F.pad, pad=(1, 2), mode='circular'),
+ input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
+ reference_fn=lambda i, _: padding1d_circular(i, (1, 2)),
+ skip_double=TEST_WITH_ROCM,
+ pickle=False,
+ ),
+ dict(
+ fullname='Padding31_1dcircular',
+ constructor=wrap_functional(F.pad, pad=(3, 1), mode='circular'),
+ input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
+ reference_fn=lambda i, _: padding1d_circular(i, (3, 1)),
+ skip_double=TEST_WITH_ROCM,
+ pickle=False,
+ ),
+ dict(
+ fullname='Padding33_1dcircular',
+ constructor=wrap_functional(F.pad, pad=(3, 3), mode='circular'),
+ input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
+ reference_fn=lambda i, _: padding1d_circular(i, (3, 3)),
+ skip_double=TEST_WITH_ROCM,
+ pickle=False,
+ ),
+ dict(
+ fullname='Padding1221_2dcircular',
+ constructor=wrap_functional(F.pad, pad=(1, 2, 2, 1), mode='circular'),
+ input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 1, 2, 3]),
+ reference_fn=lambda i, _: padding2d_circular(i, (1, 2, 2, 1)),
+ skip_double=TEST_WITH_ROCM,
+ pickle=False,
+ ),
+ dict(
+ fullname='Padding2322_2dcircular',
+ constructor=wrap_functional(F.pad, pad=(2, 3, 2, 2), mode='circular'),
+ input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 1, 2, 3]),
+ reference_fn=lambda i, _: padding2d_circular(i, (2, 3, 2, 2)),
+ skip_double=TEST_WITH_ROCM,
+ pickle=False,
+ ),
+ dict(
+ fullname='Padding3331_2dcircular',
+ constructor=wrap_functional(F.pad, pad=(3, 3, 3, 1), mode='circular'),
+ input_fn=lambda: torch.arange(9, out=torch.DoubleTensor()).reshape([1, 1, 3, 3]),
+ reference_fn=lambda i, _: padding2d_circular(i, (3, 3, 3, 1)),
+ skip_double=TEST_WITH_ROCM,
+ pickle=False,
+ ),
+ dict(
+ fullname='Padding122112_3dcircular',
+ constructor=wrap_functional(F.pad, pad=(1, 2, 2, 1, 1, 2), mode='circular'),
+ input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]),
+ reference_fn=lambda i, _: padding3d_circular(i, (1, 2, 2, 1, 1, 2)),
+ skip_double=TEST_WITH_ROCM,
+ pickle=False,
+ ),
+ dict(
+ fullname='Padding322112_3dcircular',
+ constructor=wrap_functional(F.pad, pad=(3, 2, 2, 1, 1, 2), mode='circular'),
+ input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]),
+ reference_fn=lambda i, _: padding3d_circular(i, (3, 2, 2, 1, 1, 2)),
+ skip_double=TEST_WITH_ROCM,
+ pickle=False,
+ ),
+ dict(
+ fullname='Padding332122_3dcircular',
+ constructor=wrap_functional(F.pad, pad=(3, 3, 2, 1, 2, 2), mode='circular'),
+ input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]),
+ reference_fn=lambda i, _: padding3d_circular(i, (3, 3, 2, 1, 2, 2)),
+ skip_double=TEST_WITH_ROCM,
+ pickle=False,
+ ),
+
+ dict(
+ module_name='Conv1d',
+ constructor_args=(3, 4, 2, 2, (1,), 1, 1, True, 'circular'),
+ input_size=(2, 3, 5,),
+ cudnn=True,
+ desc='stride1_pad1circular',
+ ),
+ dict(
+ module_name='Conv1d',
+ constructor_args=(3, 4, 2, 2, (2,), 1, 1, True, 'circular'),
+ input_size=(2, 3, 5,),
+ cudnn=True,
+ desc='stride1_pad2circular',
+ ),
+ dict(
+ module_name='Conv2d',
+ constructor_args=(3, 4, (3, 3), (2, 2), (1, 2), 1, 1, True, 'circular'),
+ input_size=(2, 3, 3, 3),
+ cudnn=True,
+ desc='pad2circular'
+ ),
+ dict(
+ module_name='Conv3d',
+ constructor_args=(3, 4, 2, 2, (1, 2, 3), 1, 1, True, 'circular'),
+ input_size=(2, 3, 3, 3, 3),
+ cudnn=True,
+ desc='stride_pad1circular',
+ ),
]
@@ -2501,6 +2602,78 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
output = output.to(dt)
return output
+
+def padding1d_circular(input, pad):
+ r""" input:
+ [[[0., 1., 2.],
+ [3., 4., 5.]]]
+ pad: (1, 2)
+ output:
+ [[[2., 0., 1., 2., 0., 1.],
+ [5., 3., 4., 5., 3., 4.]]]
+ """
+ return torch.cat([input[:, :, -pad[0]:], input,
+ input[:, :, 0:pad[1]]], dim=2)
+
+
+def padding2d_circular(input, pad):
+ r"""input:
+ [[[[0., 1., 2],
+ [3., 4., 5.]]]]
+ pad: (1, 2, 2, 1)
+ output:
+ [[[[2., 0., 1., 2., 0., 1.],
+ [5., 3., 4., 5., 3., 4.],
+ [2., 0., 1., 2., 0., 1.],
+ [5., 3., 4., 5., 3., 4.],
+ [2., 0., 1., 2., 0., 1.]]]]
+ """
+ input = torch.cat([input[:, :, -pad[2]:], input, input[:, :, 0:pad[3]]], dim=2)
+ return torch.cat([input[:, :, :, -pad[0]:], input, input[:, :, :, 0:pad[1]]], dim=3)
+
+
+def padding3d_circular(input, pad):
+ r"""input:
+ [[[[[ 0., 1., 2.],
+ [ 3., 4., 5.]],
+ [[ 6., 7., 8.],
+ [ 9., 10., 11.]]]]]
+ pad: (1, 2, 2, 1, 1, 2)
+ output: [[[[[ 8., 6., 7., 8., 6., 7.],
+ [11., 9., 10., 11., 9., 10.],
+ [ 8., 6., 7., 8., 6., 7.],
+ [11., 9., 10., 11., 9., 10.],
+ [ 8., 6., 7., 8., 6., 7.]],
+
+ [[ 2., 0., 1., 2., 0., 1.],
+ [ 5., 3., 4., 5., 3., 4.],
+ [ 2., 0., 1., 2., 0., 1.],
+ [ 5., 3., 4., 5., 3., 4.],
+ [ 2., 0., 1., 2., 0., 1.]],
+
+ [[ 8., 6., 7., 8., 6., 7.],
+ [11., 9., 10., 11., 9., 10.],
+ [ 8., 6., 7., 8., 6., 7.],
+ [11., 9., 10., 11., 9., 10.],
+ [ 8., 6., 7., 8., 6., 7.]],
+
+ [[ 2., 0., 1., 2., 0., 1.],
+ [ 5., 3., 4., 5., 3., 4.],
+ [ 2., 0., 1., 2., 0., 1.],
+ [ 5., 3., 4., 5., 3., 4.],
+ [ 2., 0., 1., 2., 0., 1.]],
+
+ [[ 8., 6., 7., 8., 6., 7.],
+ [11., 9., 10., 11., 9., 10.],
+ [ 8., 6., 7., 8., 6., 7.],
+ [11., 9., 10., 11., 9., 10.],
+ [ 8., 6., 7., 8., 6., 7.]]]]]
+ """
+ input = torch.cat([input[:, :, -pad[4]:], input, input[:, :, 0:pad[5]]], dim=2)
+ input = torch.cat([input[:, :, :, -pad[2]:], input, input[:, :, :, 0:pad[3]]], dim=3)
+ return torch.cat([input[:, :, :, :, -pad[0]:], input, input[:, :, :, :, 0:pad[1]]], dim=4)
+
+
loss_reference_fns = {
'KLDivLoss': kldivloss_reference,
'NLLLoss': nllloss_reference,