summaryrefslogtreecommitdiff
path: root/compute/cker/include/cker/operation/StridedSlice.h
diff options
context:
space:
mode:
Diffstat (limited to 'compute/cker/include/cker/operation/StridedSlice.h')
-rw-r--r--compute/cker/include/cker/operation/StridedSlice.h35
1 files changed, 32 insertions, 3 deletions
diff --git a/compute/cker/include/cker/operation/StridedSlice.h b/compute/cker/include/cker/operation/StridedSlice.h
index c57b4daa0..2f1089575 100644
--- a/compute/cker/include/cker/operation/StridedSlice.h
+++ b/compute/cker/include/cker/operation/StridedSlice.h
@@ -260,12 +260,41 @@ template <typename T>
inline void StridedSlice(const StridedSliceParams &op_params, const Shape &unextended_input_shape,
const T *input_data, const Shape &unextended_output_shape, T *output_data)
{
- // Note that the output_shape is not used herein.
- StridedSliceParams params_copy = op_params;
-
assert(unextended_input_shape.DimensionsCount() <= 4);
assert(unextended_output_shape.DimensionsCount() <= 4);
+ bool optimize = true;
+ int st_count = op_params.strides_count;
+ for (int idx = 0; idx < st_count - 1; idx++)
+ {
+ const int axis_size = unextended_input_shape.Dims(idx);
+ const int start = StartForAxis(op_params, unextended_input_shape, idx);
+ const int stop = StopForAxis(op_params, unextended_input_shape, idx, start);
+ if ((axis_size != 1) && (start != 0 || stop != 0))
+ {
+ optimize = false;
+ break;
+ }
+ }
+
+ if (optimize)
+ {
+ if (op_params.strides[st_count - 1] == 1)
+ {
+ const int start = StartForAxis(op_params, unextended_input_shape, st_count - 1);
+ const int end = StopForAxis(op_params, unextended_input_shape, st_count - 1, start);
+
+ for (int idx = 0; idx < end - start; idx++)
+ {
+ output_data[idx] = input_data[idx + start];
+ }
+ return;
+ }
+ }
+
+ // Note that the output_shape is not used herein.
+ StridedSliceParams params_copy = op_params;
+
const Shape input_shape = Shape::ExtendedShape(4, unextended_input_shape);
const Shape output_shape = Shape::ExtendedShape(4, unextended_output_shape);