diff options
author | James Reed <jamesreed@fb.com> | 2018-09-13 12:32:41 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-09-13 12:39:52 -0700 |
commit | 0f1ca569ceae07b800f037805aa60691b8a2e801 (patch) | |
tree | 2086524fd4263b2dfb65f5c1083cec8b881a1157 /torch/onnx | |
parent | acb6f18bab4bf7c801e445cf9b438cec827829ae (diff) | |
download | pytorch-0f1ca569ceae07b800f037805aa60691b8a2e801.tar.gz pytorch-0f1ca569ceae07b800f037805aa60691b8a2e801.tar.bz2 pytorch-0f1ca569ceae07b800f037805aa60691b8a2e801.zip |
End-to-end dynamic slicing with ONNX DynamicSlice experimental operator (#11255)
Summary:
Requires https://github.com/onnx/onnx/pull/1377
This PR makes it so that slices with dynamic boundary values can be exported from pytorch and run in caffe2 via ONNX.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11255
Differential Revision: D9790216
Pulled By: jamesr66a
fbshipit-source-id: 6adfcddc5788df4d34d7ca98341077140402a3e2
Diffstat (limited to 'torch/onnx')
-rw-r--r-- | torch/onnx/symbolic.py | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 30e8672be6..d5b586c384 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -151,7 +151,7 @@ def _unimplemented(op, msg): # increasing this number. This includes symbolic definitions NOT in this # file, so grep for "OpName" (with quotes) -_onnx_opset_version = 7 +_onnx_opset_version = 9 # --------------------------------------------------------------------- @@ -981,11 +981,21 @@ def full_like(g, input, fill_value): return add(g, zeros_like(g, input), fill_value, g.op("Constant", value_t=torch.tensor(1))) -@parse_args('v', 'i', 'i', 'i', 'i') +@parse_args('v', 'v', 'v', 'v', 'i') def slice(g, self, dim, start, end, step): if step != 1: _unimplemented("slice", "step!=1 is currently not supported") - return g.op("Slice", self, axes_i=[dim], starts_i=[start], ends_i=[end]) + if start.node().kind() != 'onnx::Constant' or \ + end.node().kind() != 'onnx::Constant' or dim.node().kind() != 'onnx::Constant': + start_unsqueezed = g.op("Unsqueeze", start, axes_i=[0]) + end_unsqueezed = g.op("Unsqueeze", end, axes_i=[0]) + dim_unsqueezed = g.op("Unsqueeze", dim, axes_i=[0]) + return g.op("DynamicSlice", self, start_unsqueezed, end_unsqueezed, dim_unsqueezed) + else: + start = _parse_arg(start, 'i') + end = _parse_arg(end, 'i') + dim = _parse_arg(dim, 'i') + return g.op("Slice", self, axes_i=[dim], starts_i=[start], ends_i=[end]) @parse_args('v', 'f', 'f') |