diff options
Diffstat (limited to 'compute/cker/include/cker/Shape.h')
-rw-r--r-- | compute/cker/include/cker/Shape.h | 45 |
1 files changed, 28 insertions, 17 deletions
diff --git a/compute/cker/include/cker/Shape.h b/compute/cker/include/cker/Shape.h index 39449c68f..43b511d05 100644 --- a/compute/cker/include/cker/Shape.h +++ b/compute/cker/include/cker/Shape.h @@ -226,6 +226,11 @@ inline int Offset(const Shape &shape, int i0, int i1, int i2, int i3) return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3; } +inline int Offset(const Shape &shape, int *index) +{ + return Offset(shape, index[0], index[1], index[2], index[3]); +} + inline int FlatSizeSkipDim(const Shape &shape, int skip_dim) { const int dims_count = shape.DimensionsCount(); @@ -241,29 +246,35 @@ inline int FlatSizeSkipDim(const Shape &shape, int skip_dim) // Flat size calculation, checking that dimensions match with one or more other // arrays. -inline int MatchingFlatSize(const Shape &shape, const Shape &check_shape_0) +template <typename... Ts> inline bool checkMatching(const Shape &shape, Ts... check_shapes) { - UNUSED_RELEASE(check_shape_0); - assert(shape.DimensionsCount() == check_shape_0.DimensionsCount()); - const int dims_count = shape.DimensionsCount(); - for (int i = 0; i < dims_count; ++i) + const Shape check_shapes_array[sizeof...(Ts)] = {std::forward<Ts>(check_shapes)...}; + for (const auto &check_shape : check_shapes_array) { - assert(shape.Dims(i) == check_shape_0.Dims(i)); + if (shape.DimensionsCount() != check_shape.DimensionsCount()) + { + return false; + } + for (int i = 0; i < shape.DimensionsCount(); ++i) + { + if (shape.Dims(i) != check_shape.Dims(i)) + { + return false; + } + } } - return shape.FlatSize(); + return true; } -inline int MatchingFlatSize(const Shape &shape, const Shape &check_shape_0, - const Shape &check_shape_1) +struct UNUSED_ALL { - UNUSED_RELEASE(check_shape_0); - assert(shape.DimensionsCount() == check_shape_0.DimensionsCount()); - const int dims_count = shape.DimensionsCount(); - for (int i = 0; i < dims_count; ++i) - { - assert(shape.Dims(i) == check_shape_0.Dims(i)); - } - return MatchingFlatSize(shape, check_shape_1); + template <typename... Args> UNUSED_ALL(Args const &...) {} +}; +template <typename... Ts> inline int MatchingFlatSize(const Shape &shape, Ts... check_shapes) +{ + UNUSED_ALL{check_shapes...}; + assert(checkMatching(shape, std::forward<Ts>(check_shapes)...)); + return shape.FlatSize(); } inline int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim, const Shape &check_shape_0) |