summaryrefslogtreecommitdiff
path: root/runtimes/libs/tflite/src/TensorShapeUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'runtimes/libs/tflite/src/TensorShapeUtils.cpp')
-rw-r--r--runtimes/libs/tflite/src/TensorShapeUtils.cpp29
1 files changed, 29 insertions, 0 deletions
diff --git a/runtimes/libs/tflite/src/TensorShapeUtils.cpp b/runtimes/libs/tflite/src/TensorShapeUtils.cpp
new file mode 100644
index 000000000..29628cd26
--- /dev/null
+++ b/runtimes/libs/tflite/src/TensorShapeUtils.cpp
@@ -0,0 +1,29 @@
+#include "tflite/TensorShapeUtils.h"
+
+namespace nnfw
+{
+namespace tflite
+{
+
+nnfw::misc::tensor::Shape broadcast(const nnfw::misc::tensor::Shape &lhs_shape,
+ const nnfw::misc::tensor::Shape &rhs_shape)
+{
+ const uint32_t lhs_rank = lhs_shape.rank();
+ const uint32_t rhs_rank = rhs_shape.rank();
+ const uint32_t out_rank = std::max(lhs_rank, rhs_rank);
+ const uint32_t lhs_rank_diff = out_rank - lhs_rank;
+ const uint32_t rhs_rank_diff = out_rank - rhs_rank;
+
+ nnfw::misc::tensor::Shape out_shape(out_rank);
+
+ for (uint32_t axis = 0; axis < out_rank; ++axis)
+ {
+ out_shape.dim(axis) = std::max(axis < lhs_rank_diff ? 1 : lhs_shape.dim(axis - lhs_rank_diff),
+ axis < rhs_rank_diff ? 1 : rhs_shape.dim(axis - rhs_rank_diff));
+ }
+
+ return out_shape;
+}
+
+} // namespace tflite
+} // namespace nnfw