diff options
Diffstat (limited to 'res/PyTorchExamples/examples/strided_slice/__init__.py')
-rw-r--r-- | res/PyTorchExamples/examples/strided_slice/__init__.py | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/res/PyTorchExamples/examples/strided_slice/__init__.py b/res/PyTorchExamples/examples/strided_slice/__init__.py new file mode 100644 index 000000000..7277da873 --- /dev/null +++ b/res/PyTorchExamples/examples/strided_slice/__init__.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn + + +# model +# +# Notes: +# - This model requires opset version 10+. Previous version does not support strides. +class net_strided_slice(nn.Module): + def __init__(self, begin, end, stride): + super().__init__() + self.key = [slice(begin[i], end[i], stride[i]) for i in range(len(begin))] + + def forward(self, input): + # this is general way to do input[:, :, 1:5:2, 0:5:2] + return input[self.key] + + def onnx_opset_version(self): + return 10 + + +_model_ = net_strided_slice([0, 0, 1, 0], [1, 3, 5, 5], [1, 1, 2, 2]) + +# dummy input for onnx generation +_dummy_ = torch.randn(1, 3, 5, 5) |