summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Riazati <davidriazati@fb.com>2018-12-18 11:43:45 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-12-18 12:03:08 -0800
commit3118124cd6132588c4bf6614cbff1a3abb977bd9 (patch)
treef0ad51b99c24b4729d3f59b0f5ba0a48926e766a
parentf4c504593cc73c4c3939f8d3e8e012c4d47bfd8d (diff)
downloadpytorch-3118124cd6132588c4bf6614cbff1a3abb977bd9.tar.gz
pytorch-3118124cd6132588c4bf6614cbff1a3abb977bd9.tar.bz2
pytorch-3118124cd6132588c4bf6614cbff1a3abb977bd9.zip
Add (Un)Fold modules to standard library (#14759)
Summary: Depends on #14597 for the corresponding aten ops. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14759 Differential Revision: D13325356 Pulled By: driazati fbshipit-source-id: 99e39449c1ccfa293de05672c31a11e580bdd11f
-rw-r--r--test/test_jit.py2
-rw-r--r--torch/csrc/jit/register_special_ops.cpp9
-rw-r--r--torch/jit/__init__.py3
-rw-r--r--torch/nn/functional.py16
-rw-r--r--torch/nn/modules/fold.py8
5 files changed, 32 insertions, 6 deletions
diff --git a/test/test_jit.py b/test/test_jit.py
index 89ba4cc21d..8d88e8c3de 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -10635,6 +10635,8 @@ EXCLUDE_MODULE_EXPORT_IMPORT = {
'MaxPool3d',
'AdaptiveAvgPool2d',
'AdaptiveAvgPool3d',
+ 'Fold',
+ 'Unfold',
}
# NB: JIT script tests for all nn functional interfaces, script mode does
diff --git a/torch/csrc/jit/register_special_ops.cpp b/torch/csrc/jit/register_special_ops.cpp
index 9ae08a5d1c..e467823a71 100644
--- a/torch/csrc/jit/register_special_ops.cpp
+++ b/torch/csrc/jit/register_special_ops.cpp
@@ -115,6 +115,15 @@ RegisterOperators reg({
return 0;
};
}),
+ Operator(
+ "aten::_assert_int_or_pair(int[] vals, str name, str message) -> Tensor",
+ [](const Node* node) {
+ return [](Stack& stack) {
+ // Everything is a list at the point this is used, so don't do anything
+ drop(stack, 3);
+ return 0;
+ };
+ }),
});
}
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index dc1499ffc8..ab06a34fac 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -1441,8 +1441,7 @@ def _get_builtin_table():
_builtin_table[id(torch.nn.functional.upsample_nearest)] = "aten::__upsample_nearest"
_builtin_table[id(torch.nn.functional.upsample)] = "aten::__upsample"
_builtin_table[id(torch.nn.functional.upsample_bilinear)] = "aten::__upsample_bilinear"
- _builtin_table[id(torch.nn.functional.fold)] = "aten::fold"
- _builtin_table[id(torch.nn.functional.unfold)] = "aten::unfold"
+ _builtin_table[id(torch.nn.functional.assert_int_or_pair)] = "aten::_assert_int_or_pair"
return _builtin_table
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index b789dde02f..72c5c27dff 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -2826,7 +2826,9 @@ def assert_int_or_pair(arg, arg_name, message):
assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name)
+@torch._jit_internal.weak_script
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
+ # type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa
r"""Extracts sliding local blocks from an batched input tensor.
.. warning::
@@ -2843,13 +2845,17 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
assert_int_or_pair(padding, 'padding', msg)
assert_int_or_pair(stride, 'stride', msg)
- return torch._C._nn.thnn_im2col(input, _pair(kernel_size),
- _pair(dilation), _pair(padding), _pair(stride))
+ ret = torch._C._nn.thnn_im2col(input, _pair(kernel_size),
+ _pair(dilation), _pair(padding), _pair(stride))
else:
raise NotImplementedError("Input Error: Only 4D input Tensors are supported (got {}D)".format(input.dim()))
+ ret = input # TODO: remove when jit supports exception control flow
+ return ret
+@torch._jit_internal.weak_script
def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
+ # type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa
r"""Combines an array of sliding local blocks into a large containing
tensor.
@@ -2867,7 +2873,9 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
assert_int_or_pair(padding, 'padding', msg)
assert_int_or_pair(stride, 'stride', msg)
- return torch._C._nn.thnn_col2im(input, _pair(output_size), _pair(kernel_size),
- _pair(dilation), _pair(padding), _pair(stride))
+ ret = torch._C._nn.thnn_col2im(input, _pair(output_size), _pair(kernel_size),
+ _pair(dilation), _pair(padding), _pair(stride))
else:
raise NotImplementedError("Input Error: Only 3D input Tensors are supported (got {}D)".format(input.dim()))
+ ret = input # TODO: remove when jit supports exception control flow
+ return ret
diff --git a/torch/nn/modules/fold.py b/torch/nn/modules/fold.py
index d003582092..03adaef5e7 100644
--- a/torch/nn/modules/fold.py
+++ b/torch/nn/modules/fold.py
@@ -1,8 +1,10 @@
# coding=utf-8
from .module import Module
from .. import functional as F
+from ..._jit_internal import weak_module, weak_script_method
+@weak_module
class Fold(Module):
r"""Combines an array of sliding local blocks into a large containing
tensor.
@@ -87,6 +89,8 @@ class Fold(Module):
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
+ __constants__ = ['output_size', 'kernel_size', 'dilation', 'padding',
+ 'stride']
def __init__(self, output_size, kernel_size, dilation=1, padding=0, stride=1):
super(Fold, self).__init__()
@@ -96,6 +100,7 @@ class Fold(Module):
self.padding = padding
self.stride = stride
+ @weak_script_method
def forward(self, input):
return F.fold(input, self.output_size, self.kernel_size, self.dilation,
self.padding, self.stride)
@@ -107,6 +112,7 @@ class Fold(Module):
)
+@weak_module
class Unfold(Module):
r"""Extracts sliding local blocks from a batched input tensor.
@@ -201,6 +207,7 @@ class Unfold(Module):
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
+ __constants__ = ['kernel_size', 'dilation', 'padding', 'stride']
def __init__(self, kernel_size, dilation=1, padding=0, stride=1):
super(Unfold, self).__init__()
@@ -209,6 +216,7 @@ class Unfold(Module):
self.padding = padding
self.stride = stride
+ @weak_script_method
def forward(self, input):
return F.unfold(input, self.kernel_size, self.dilation,
self.padding, self.stride)