diff options
author | David Riazati <davidriazati@fb.com> | 2018-12-18 11:43:45 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-12-18 12:03:08 -0800 |
commit | 3118124cd6132588c4bf6614cbff1a3abb977bd9 (patch) | |
tree | f0ad51b99c24b4729d3f59b0f5ba0a48926e766a | |
parent | f4c504593cc73c4c3939f8d3e8e012c4d47bfd8d (diff) | |
download | pytorch-3118124cd6132588c4bf6614cbff1a3abb977bd9.tar.gz pytorch-3118124cd6132588c4bf6614cbff1a3abb977bd9.tar.bz2 pytorch-3118124cd6132588c4bf6614cbff1a3abb977bd9.zip |
Add (Un)Fold modules to standard library (#14759)
Summary:
Depends on #14597 for the corresponding aten ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14759
Differential Revision: D13325356
Pulled By: driazati
fbshipit-source-id: 99e39449c1ccfa293de05672c31a11e580bdd11f
-rw-r--r-- | test/test_jit.py | 2 | ||||
-rw-r--r-- | torch/csrc/jit/register_special_ops.cpp | 9 | ||||
-rw-r--r-- | torch/jit/__init__.py | 3 | ||||
-rw-r--r-- | torch/nn/functional.py | 16 | ||||
-rw-r--r-- | torch/nn/modules/fold.py | 8 |
5 files changed, 32 insertions, 6 deletions
diff --git a/test/test_jit.py b/test/test_jit.py index 89ba4cc21d..8d88e8c3de 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -10635,6 +10635,8 @@ EXCLUDE_MODULE_EXPORT_IMPORT = { 'MaxPool3d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', + 'Fold', + 'Unfold', } # NB: JIT script tests for all nn functional interfaces, script mode does diff --git a/torch/csrc/jit/register_special_ops.cpp b/torch/csrc/jit/register_special_ops.cpp index 9ae08a5d1c..e467823a71 100644 --- a/torch/csrc/jit/register_special_ops.cpp +++ b/torch/csrc/jit/register_special_ops.cpp @@ -115,6 +115,15 @@ RegisterOperators reg({ return 0; }; }), + Operator( + "aten::_assert_int_or_pair(int[] vals, str name, str message) -> Tensor", + [](const Node* node) { + return [](Stack& stack) { + // Everything is a list at the point this is used, so don't do anything + drop(stack, 3); + return 0; + }; + }), }); } diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index dc1499ffc8..ab06a34fac 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -1441,8 +1441,7 @@ def _get_builtin_table(): _builtin_table[id(torch.nn.functional.upsample_nearest)] = "aten::__upsample_nearest" _builtin_table[id(torch.nn.functional.upsample)] = "aten::__upsample" _builtin_table[id(torch.nn.functional.upsample_bilinear)] = "aten::__upsample_bilinear" - _builtin_table[id(torch.nn.functional.fold)] = "aten::fold" - _builtin_table[id(torch.nn.functional.unfold)] = "aten::unfold" + _builtin_table[id(torch.nn.functional.assert_int_or_pair)] = "aten::_assert_int_or_pair" return _builtin_table diff --git a/torch/nn/functional.py b/torch/nn/functional.py index b789dde02f..72c5c27dff 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2826,7 +2826,9 @@ def assert_int_or_pair(arg, arg_name, message): assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name) +@torch._jit_internal.weak_script def unfold(input, kernel_size, dilation=1, padding=0, stride=1): + # type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa r"""Extracts sliding local blocks from an batched input tensor. .. warning:: @@ -2843,13 +2845,17 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1): assert_int_or_pair(padding, 'padding', msg) assert_int_or_pair(stride, 'stride', msg) - return torch._C._nn.thnn_im2col(input, _pair(kernel_size), - _pair(dilation), _pair(padding), _pair(stride)) + ret = torch._C._nn.thnn_im2col(input, _pair(kernel_size), + _pair(dilation), _pair(padding), _pair(stride)) else: raise NotImplementedError("Input Error: Only 4D input Tensors are supported (got {}D)".format(input.dim())) + ret = input # TODO: remove when jit supports exception control flow + return ret +@torch._jit_internal.weak_script def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1): + # type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa r"""Combines an array of sliding local blocks into a large containing tensor. @@ -2867,7 +2873,9 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1): assert_int_or_pair(padding, 'padding', msg) assert_int_or_pair(stride, 'stride', msg) - return torch._C._nn.thnn_col2im(input, _pair(output_size), _pair(kernel_size), - _pair(dilation), _pair(padding), _pair(stride)) + ret = torch._C._nn.thnn_col2im(input, _pair(output_size), _pair(kernel_size), + _pair(dilation), _pair(padding), _pair(stride)) else: raise NotImplementedError("Input Error: Only 3D input Tensors are supported (got {}D)".format(input.dim())) + ret = input # TODO: remove when jit supports exception control flow + return ret diff --git a/torch/nn/modules/fold.py b/torch/nn/modules/fold.py index d003582092..03adaef5e7 100644 --- a/torch/nn/modules/fold.py +++ b/torch/nn/modules/fold.py @@ -1,8 +1,10 @@ # coding=utf-8 from .module import Module from .. import functional as F +from ..._jit_internal import weak_module, weak_script_method +@weak_module class Fold(Module): r"""Combines an array of sliding local blocks into a large containing tensor. @@ -87,6 +89,8 @@ class Fold(Module): https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ + __constants__ = ['output_size', 'kernel_size', 'dilation', 'padding', + 'stride'] def __init__(self, output_size, kernel_size, dilation=1, padding=0, stride=1): super(Fold, self).__init__() @@ -96,6 +100,7 @@ class Fold(Module): self.padding = padding self.stride = stride + @weak_script_method def forward(self, input): return F.fold(input, self.output_size, self.kernel_size, self.dilation, self.padding, self.stride) @@ -107,6 +112,7 @@ class Fold(Module): ) +@weak_module class Unfold(Module): r"""Extracts sliding local blocks from a batched input tensor. @@ -201,6 +207,7 @@ class Unfold(Module): https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ + __constants__ = ['kernel_size', 'dilation', 'padding', 'stride'] def __init__(self, kernel_size, dilation=1, padding=0, stride=1): super(Unfold, self).__init__() @@ -209,6 +216,7 @@ class Unfold(Module): self.padding = padding self.stride = stride + @weak_script_method def forward(self, input): return F.unfold(input, self.kernel_size, self.dilation, self.padding, self.stride) |