summaryrefslogtreecommitdiff
path: root/caffe2/operators/slice_op.cc
blob: 535dac88a752a22ceed779620343b262ef4d70c7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include "caffe2/operators/slice_op.h"
#include "caffe2/utils/math.h"

namespace caffe2 {

REGISTER_CPU_OPERATOR(Slice, SliceOp<CPUContext>);
REGISTER_CPU_GRADIENT_OPERATOR(SliceGradient, SliceGradientOp<CPUContext>);

OPERATOR_SCHEMA(Slice)
    .NumInputs(1, 3)
    .NumOutputs(1)
    .DisallowInputFillers() // the filler cannot be enabled without output dims
    .SetDoc(R"DOC(
Produces a slice of the input tensor.

- Currently, only slicing in a single dimension is supported.

- Start and end indices are either passed as two 1D input tensors or using the `starts` and `ends` arguments.

- If a negative value is passed for any of the start or end indices, it represents the number of elements before the end of that dimension. End indices are non-inclusive unless negative (end index -1 means up to and including the last element).

Github Links:
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/slice_op.cc

<details>

<summary> <b>Example</b> </summary>

**Code**

```

workspace.ResetWorkspace()

op = core.CreateOperator(
    "Slice",
    ["X"],
    ["Y"],
    starts=(0,1),
    ends=(-1,3)
)

workspace.FeedBlob("X", np.array([[1,2,3,4],[5,6,7,8]]))
print("X:", workspace.FetchBlob("X"))
workspace.RunOperatorOnce(op)
print("Y:", workspace.FetchBlob("Y"))

```

**Result**

```

X:
[[1 2 3 4]
 [5 6 7 8]]
Y:
[[2 3]
 [6 7]]

```

</details>

)DOC")
    .Input(0, "X", "(*Tensor*): tensor to extract slices from")
    .Input(
        1,
        "starts",
        "(*Tensor`<int>`*): 1D tensor of start-indices for each dimension of data")
    .Input(
        2,
        "ends",
        "(*Tensor`<int>`*): 1D tensor of end-indices for each dimension of data")
    .Arg("starts", "(*Tuple(int)*): list of starting indices")
    .Arg("ends", "(*Tuple(int)*): list of ending indices")
    .TensorInferenceFunction([](const OperatorDef& def,
                                const vector<TensorShape>& in) {
      if (in.size() > 1) {
        // Cannot compute shape inference when the splits are defined
        // in data.
        return vector<TensorShape>();
      }
      auto const& data = in[0];

      ArgumentHelper helper(def);
      auto starts = helper.GetRepeatedArgument<int>("starts", vector<int>());
      auto ends = helper.GetRepeatedArgument<int>("ends", vector<int>());
      vector<int> dst_sizes(data.dims_size());

      for (int i = 0; i < data.dims_size(); ++i) {
        if (i >= starts.size()) {
          continue;
        }
        if (data.dims_size() > 0) {
          auto start = starts[i];
          auto end = ends[i];
          if (start < 0) {
            start = data.dims(i) + 1 + start;
          }
          if (end < 0) {
            end = data.dims(i) + 1 + end;
          }
          dst_sizes[i] = end - start;
        } else {
          dst_sizes[i] = 0;
        }
      }
      return vector<TensorShape>{
          CreateTensorShape(dst_sizes, data.data_type())};
    })
    .Output(0, "Y", "(*Tensor*): sliced output tensor")
    .InheritOnnxSchema();

GRADIENT_OPERATOR_SCHEMA(SliceGradient);

namespace {
struct GetSliceGradient : public GradientMakerBase {
  using GradientMakerBase::GradientMakerBase;
  vector<OperatorDef> GetGradientDefs() override {
    if (def_.input_size() > 1) {
      return vector<OperatorDef>{CreateOperatorDef(
          "SliceGradient",
          "",
          std::vector<string>{I(0), I(1), I(2), GO(0)},
          std::vector<string>{GI(0)})};
    } else {
      return vector<OperatorDef>{CreateOperatorDef(
          "SliceGradient",
          "",
          std::vector<string>{I(0), GO(0)},
          std::vector<string>{GI(0)})};
    }
  }
};
}
REGISTER_GRADIENT(Slice, GetSliceGradient);
} // namespace caffe2