diff options
Diffstat (limited to 'compute/cker/include/cker/operation/StridedSlice.h')
-rw-r--r-- | compute/cker/include/cker/operation/StridedSlice.h | 35 |
1 files changed, 32 insertions, 3 deletions
diff --git a/compute/cker/include/cker/operation/StridedSlice.h b/compute/cker/include/cker/operation/StridedSlice.h index c57b4daa0..2f1089575 100644 --- a/compute/cker/include/cker/operation/StridedSlice.h +++ b/compute/cker/include/cker/operation/StridedSlice.h @@ -260,12 +260,41 @@ template <typename T> inline void StridedSlice(const StridedSliceParams &op_params, const Shape &unextended_input_shape, const T *input_data, const Shape &unextended_output_shape, T *output_data) { - // Note that the output_shape is not used herein. - StridedSliceParams params_copy = op_params; - assert(unextended_input_shape.DimensionsCount() <= 4); assert(unextended_output_shape.DimensionsCount() <= 4); + bool optimize = true; + int st_count = op_params.strides_count; + for (int idx = 0; idx < st_count - 1; idx++) + { + const int axis_size = unextended_input_shape.Dims(idx); + const int start = StartForAxis(op_params, unextended_input_shape, idx); + const int stop = StopForAxis(op_params, unextended_input_shape, idx, start); + if ((axis_size != 1) && (start != 0 || stop != 0)) + { + optimize = false; + break; + } + } + + if (optimize) + { + if (op_params.strides[st_count - 1] == 1) + { + const int start = StartForAxis(op_params, unextended_input_shape, st_count - 1); + const int end = StopForAxis(op_params, unextended_input_shape, st_count - 1, start); + + for (int idx = 0; idx < end - start; idx++) + { + output_data[idx] = input_data[idx + start]; + } + return; + } + } + + // Note that the output_shape is not used herein. + StridedSliceParams params_copy = op_params; + const Shape input_shape = Shape::ExtendedShape(4, unextended_input_shape); const Shape output_shape = Shape::ExtendedShape(4, unextended_output_shape); |