summaryrefslogtreecommitdiff
path: root/compute/cker/include/cker/Shape.h
diff options
context:
space:
mode:
Diffstat (limited to 'compute/cker/include/cker/Shape.h')
-rw-r--r--compute/cker/include/cker/Shape.h45
1 files changed, 28 insertions, 17 deletions
diff --git a/compute/cker/include/cker/Shape.h b/compute/cker/include/cker/Shape.h
index 39449c68f..43b511d05 100644
--- a/compute/cker/include/cker/Shape.h
+++ b/compute/cker/include/cker/Shape.h
@@ -226,6 +226,11 @@ inline int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
}
+inline int Offset(const Shape &shape, int *index)
+{
+ return Offset(shape, index[0], index[1], index[2], index[3]);
+}
+
inline int FlatSizeSkipDim(const Shape &shape, int skip_dim)
{
const int dims_count = shape.DimensionsCount();
@@ -241,29 +246,35 @@ inline int FlatSizeSkipDim(const Shape &shape, int skip_dim)
// Flat size calculation, checking that dimensions match with one or more other
// arrays.
-inline int MatchingFlatSize(const Shape &shape, const Shape &check_shape_0)
+template <typename... Ts> inline bool checkMatching(const Shape &shape, Ts... check_shapes)
{
- UNUSED_RELEASE(check_shape_0);
- assert(shape.DimensionsCount() == check_shape_0.DimensionsCount());
- const int dims_count = shape.DimensionsCount();
- for (int i = 0; i < dims_count; ++i)
+ const Shape check_shapes_array[sizeof...(Ts)] = {std::forward<Ts>(check_shapes)...};
+ for (const auto &check_shape : check_shapes_array)
{
- assert(shape.Dims(i) == check_shape_0.Dims(i));
+ if (shape.DimensionsCount() != check_shape.DimensionsCount())
+ {
+ return false;
+ }
+ for (int i = 0; i < shape.DimensionsCount(); ++i)
+ {
+ if (shape.Dims(i) != check_shape.Dims(i))
+ {
+ return false;
+ }
+ }
}
- return shape.FlatSize();
+ return true;
}
-inline int MatchingFlatSize(const Shape &shape, const Shape &check_shape_0,
- const Shape &check_shape_1)
+struct UNUSED_ALL
{
- UNUSED_RELEASE(check_shape_0);
- assert(shape.DimensionsCount() == check_shape_0.DimensionsCount());
- const int dims_count = shape.DimensionsCount();
- for (int i = 0; i < dims_count; ++i)
- {
- assert(shape.Dims(i) == check_shape_0.Dims(i));
- }
- return MatchingFlatSize(shape, check_shape_1);
+ template <typename... Args> UNUSED_ALL(Args const &...) {}
+};
+template <typename... Ts> inline int MatchingFlatSize(const Shape &shape, Ts... check_shapes)
+{
+ UNUSED_ALL{check_shapes...};
+ assert(checkMatching(shape, std::forward<Ts>(check_shapes)...));
+ return shape.FlatSize();
}
inline int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim, const Shape &check_shape_0)