diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/common_nn.py | 173 |
1 files changed, 173 insertions, 0 deletions
diff --git a/test/common_nn.py b/test/common_nn.py index ca82a5fec6..afac197479 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -2260,6 +2260,107 @@ new_module_tests = [ input_size=(), desc='scalar', ), + dict( + fullname='Padding12_1dcircular', + constructor=wrap_functional(F.pad, pad=(1, 2), mode='circular'), + input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]), + reference_fn=lambda i, _: padding1d_circular(i, (1, 2)), + skip_double=TEST_WITH_ROCM, + pickle=False, + ), + dict( + fullname='Padding31_1dcircular', + constructor=wrap_functional(F.pad, pad=(3, 1), mode='circular'), + input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]), + reference_fn=lambda i, _: padding1d_circular(i, (3, 1)), + skip_double=TEST_WITH_ROCM, + pickle=False, + ), + dict( + fullname='Padding33_1dcircular', + constructor=wrap_functional(F.pad, pad=(3, 3), mode='circular'), + input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]), + reference_fn=lambda i, _: padding1d_circular(i, (3, 3)), + skip_double=TEST_WITH_ROCM, + pickle=False, + ), + dict( + fullname='Padding1221_2dcircular', + constructor=wrap_functional(F.pad, pad=(1, 2, 2, 1), mode='circular'), + input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 1, 2, 3]), + reference_fn=lambda i, _: padding2d_circular(i, (1, 2, 2, 1)), + skip_double=TEST_WITH_ROCM, + pickle=False, + ), + dict( + fullname='Padding2322_2dcircular', + constructor=wrap_functional(F.pad, pad=(2, 3, 2, 2), mode='circular'), + input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 1, 2, 3]), + reference_fn=lambda i, _: padding2d_circular(i, (2, 3, 2, 2)), + skip_double=TEST_WITH_ROCM, + pickle=False, + ), + dict( + fullname='Padding3331_2dcircular', + constructor=wrap_functional(F.pad, pad=(3, 3, 3, 1), mode='circular'), + input_fn=lambda: torch.arange(9, out=torch.DoubleTensor()).reshape([1, 1, 3, 3]), + reference_fn=lambda i, _: padding2d_circular(i, (3, 3, 3, 1)), + skip_double=TEST_WITH_ROCM, + pickle=False, + ), + dict( + fullname='Padding122112_3dcircular', + constructor=wrap_functional(F.pad, pad=(1, 2, 2, 1, 1, 2), mode='circular'), + input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]), + reference_fn=lambda i, _: padding3d_circular(i, (1, 2, 2, 1, 1, 2)), + skip_double=TEST_WITH_ROCM, + pickle=False, + ), + dict( + fullname='Padding322112_3dcircular', + constructor=wrap_functional(F.pad, pad=(3, 2, 2, 1, 1, 2), mode='circular'), + input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]), + reference_fn=lambda i, _: padding3d_circular(i, (3, 2, 2, 1, 1, 2)), + skip_double=TEST_WITH_ROCM, + pickle=False, + ), + dict( + fullname='Padding332122_3dcircular', + constructor=wrap_functional(F.pad, pad=(3, 3, 2, 1, 2, 2), mode='circular'), + input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]), + reference_fn=lambda i, _: padding3d_circular(i, (3, 3, 2, 1, 2, 2)), + skip_double=TEST_WITH_ROCM, + pickle=False, + ), + + dict( + module_name='Conv1d', + constructor_args=(3, 4, 2, 2, (1,), 1, 1, True, 'circular'), + input_size=(2, 3, 5,), + cudnn=True, + desc='stride1_pad1circular', + ), + dict( + module_name='Conv1d', + constructor_args=(3, 4, 2, 2, (2,), 1, 1, True, 'circular'), + input_size=(2, 3, 5,), + cudnn=True, + desc='stride1_pad2circular', + ), + dict( + module_name='Conv2d', + constructor_args=(3, 4, (3, 3), (2, 2), (1, 2), 1, 1, True, 'circular'), + input_size=(2, 3, 3, 3), + cudnn=True, + desc='pad2circular' + ), + dict( + module_name='Conv3d', + constructor_args=(3, 4, 2, 2, (1, 2, 3), 1, 1, True, 'circular'), + input_size=(2, 3, 3, 3, 3), + cudnn=True, + desc='stride_pad1circular', + ), ] @@ -2501,6 +2602,78 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0 output = output.to(dt) return output + +def padding1d_circular(input, pad): + r""" input: + [[[0., 1., 2.], + [3., 4., 5.]]] + pad: (1, 2) + output: + [[[2., 0., 1., 2., 0., 1.], + [5., 3., 4., 5., 3., 4.]]] + """ + return torch.cat([input[:, :, -pad[0]:], input, + input[:, :, 0:pad[1]]], dim=2) + + +def padding2d_circular(input, pad): + r"""input: + [[[[0., 1., 2], + [3., 4., 5.]]]] + pad: (1, 2, 2, 1) + output: + [[[[2., 0., 1., 2., 0., 1.], + [5., 3., 4., 5., 3., 4.], + [2., 0., 1., 2., 0., 1.], + [5., 3., 4., 5., 3., 4.], + [2., 0., 1., 2., 0., 1.]]]] + """ + input = torch.cat([input[:, :, -pad[2]:], input, input[:, :, 0:pad[3]]], dim=2) + return torch.cat([input[:, :, :, -pad[0]:], input, input[:, :, :, 0:pad[1]]], dim=3) + + +def padding3d_circular(input, pad): + r"""input: + [[[[[ 0., 1., 2.], + [ 3., 4., 5.]], + [[ 6., 7., 8.], + [ 9., 10., 11.]]]]] + pad: (1, 2, 2, 1, 1, 2) + output: [[[[[ 8., 6., 7., 8., 6., 7.], + [11., 9., 10., 11., 9., 10.], + [ 8., 6., 7., 8., 6., 7.], + [11., 9., 10., 11., 9., 10.], + [ 8., 6., 7., 8., 6., 7.]], + + [[ 2., 0., 1., 2., 0., 1.], + [ 5., 3., 4., 5., 3., 4.], + [ 2., 0., 1., 2., 0., 1.], + [ 5., 3., 4., 5., 3., 4.], + [ 2., 0., 1., 2., 0., 1.]], + + [[ 8., 6., 7., 8., 6., 7.], + [11., 9., 10., 11., 9., 10.], + [ 8., 6., 7., 8., 6., 7.], + [11., 9., 10., 11., 9., 10.], + [ 8., 6., 7., 8., 6., 7.]], + + [[ 2., 0., 1., 2., 0., 1.], + [ 5., 3., 4., 5., 3., 4.], + [ 2., 0., 1., 2., 0., 1.], + [ 5., 3., 4., 5., 3., 4.], + [ 2., 0., 1., 2., 0., 1.]], + + [[ 8., 6., 7., 8., 6., 7.], + [11., 9., 10., 11., 9., 10.], + [ 8., 6., 7., 8., 6., 7.], + [11., 9., 10., 11., 9., 10.], + [ 8., 6., 7., 8., 6., 7.]]]]] + """ + input = torch.cat([input[:, :, -pad[4]:], input, input[:, :, 0:pad[5]]], dim=2) + input = torch.cat([input[:, :, :, -pad[2]:], input, input[:, :, :, 0:pad[3]]], dim=3) + return torch.cat([input[:, :, :, :, -pad[0]:], input, input[:, :, :, :, 0:pad[1]]], dim=4) + + loss_reference_fns = { 'KLDivLoss': kldivloss_reference, 'NLLLoss': nllloss_reference, |