summaryrefslogtreecommitdiff
path: root/res/PyTorchExamples/examples/strided_slice/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'res/PyTorchExamples/examples/strided_slice/__init__.py')
-rw-r--r--res/PyTorchExamples/examples/strided_slice/__init__.py25
1 files changed, 25 insertions, 0 deletions
diff --git a/res/PyTorchExamples/examples/strided_slice/__init__.py b/res/PyTorchExamples/examples/strided_slice/__init__.py
new file mode 100644
index 000000000..7277da873
--- /dev/null
+++ b/res/PyTorchExamples/examples/strided_slice/__init__.py
@@ -0,0 +1,25 @@
+import torch
+import torch.nn as nn
+
+
+# model
+#
+# Notes:
+# - This model requires opset version 10+. Previous version does not support strides.
+class net_strided_slice(nn.Module):
+ def __init__(self, begin, end, stride):
+ super().__init__()
+ self.key = [slice(begin[i], end[i], stride[i]) for i in range(len(begin))]
+
+ def forward(self, input):
+ # this is general way to do input[:, :, 1:5:2, 0:5:2]
+ return input[self.key]
+
+ def onnx_opset_version(self):
+ return 10
+
+
+_model_ = net_strided_slice([0, 0, 1, 0], [1, 3, 5, 5], [1, 1, 2, 2])
+
+# dummy input for onnx generation
+_dummy_ = torch.randn(1, 3, 5, 5)