summaryrefslogtreecommitdiff
path: root/torch/onnx
diff options
context:
space:
mode:
authorJames Reed <jamesreed@fb.com>2018-09-13 12:32:41 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-09-13 12:39:52 -0700
commit0f1ca569ceae07b800f037805aa60691b8a2e801 (patch)
tree2086524fd4263b2dfb65f5c1083cec8b881a1157 /torch/onnx
parentacb6f18bab4bf7c801e445cf9b438cec827829ae (diff)
downloadpytorch-0f1ca569ceae07b800f037805aa60691b8a2e801.tar.gz
pytorch-0f1ca569ceae07b800f037805aa60691b8a2e801.tar.bz2
pytorch-0f1ca569ceae07b800f037805aa60691b8a2e801.zip
End-to-end dynamic slicing with ONNX DynamicSlice experimental operator (#11255)
Summary: Requires https://github.com/onnx/onnx/pull/1377 This PR makes it so that slices with dynamic boundary values can be exported from pytorch and run in caffe2 via ONNX. Pull Request resolved: https://github.com/pytorch/pytorch/pull/11255 Differential Revision: D9790216 Pulled By: jamesr66a fbshipit-source-id: 6adfcddc5788df4d34d7ca98341077140402a3e2
Diffstat (limited to 'torch/onnx')
-rw-r--r--torch/onnx/symbolic.py16
1 files changed, 13 insertions, 3 deletions
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index 30e8672be6..d5b586c384 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -151,7 +151,7 @@ def _unimplemented(op, msg):
# increasing this number. This includes symbolic definitions NOT in this
# file, so grep for "OpName" (with quotes)
-_onnx_opset_version = 7
+_onnx_opset_version = 9
# ---------------------------------------------------------------------
@@ -981,11 +981,21 @@ def full_like(g, input, fill_value):
return add(g, zeros_like(g, input), fill_value, g.op("Constant", value_t=torch.tensor(1)))
-@parse_args('v', 'i', 'i', 'i', 'i')
+@parse_args('v', 'v', 'v', 'v', 'i')
def slice(g, self, dim, start, end, step):
if step != 1:
_unimplemented("slice", "step!=1 is currently not supported")
- return g.op("Slice", self, axes_i=[dim], starts_i=[start], ends_i=[end])
+ if start.node().kind() != 'onnx::Constant' or \
+ end.node().kind() != 'onnx::Constant' or dim.node().kind() != 'onnx::Constant':
+ start_unsqueezed = g.op("Unsqueeze", start, axes_i=[0])
+ end_unsqueezed = g.op("Unsqueeze", end, axes_i=[0])
+ dim_unsqueezed = g.op("Unsqueeze", dim, axes_i=[0])
+ return g.op("DynamicSlice", self, start_unsqueezed, end_unsqueezed, dim_unsqueezed)
+ else:
+ start = _parse_arg(start, 'i')
+ end = _parse_arg(end, 'i')
+ dim = _parse_arg(dim, 'i')
+ return g.op("Slice", self, axes_i=[dim], starts_i=[start], ends_i=[end])
@parse_args('v', 'f', 'f')