diff options
Diffstat (limited to 'runtimes/libs/tflite/src/TensorShapeUtils.cpp')
-rw-r--r-- | runtimes/libs/tflite/src/TensorShapeUtils.cpp | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/runtimes/libs/tflite/src/TensorShapeUtils.cpp b/runtimes/libs/tflite/src/TensorShapeUtils.cpp new file mode 100644 index 000000000..29628cd26 --- /dev/null +++ b/runtimes/libs/tflite/src/TensorShapeUtils.cpp @@ -0,0 +1,29 @@ +#include "tflite/TensorShapeUtils.h" + +namespace nnfw +{ +namespace tflite +{ + +nnfw::misc::tensor::Shape broadcast(const nnfw::misc::tensor::Shape &lhs_shape, + const nnfw::misc::tensor::Shape &rhs_shape) +{ + const uint32_t lhs_rank = lhs_shape.rank(); + const uint32_t rhs_rank = rhs_shape.rank(); + const uint32_t out_rank = std::max(lhs_rank, rhs_rank); + const uint32_t lhs_rank_diff = out_rank - lhs_rank; + const uint32_t rhs_rank_diff = out_rank - rhs_rank; + + nnfw::misc::tensor::Shape out_shape(out_rank); + + for (uint32_t axis = 0; axis < out_rank; ++axis) + { + out_shape.dim(axis) = std::max(axis < lhs_rank_diff ? 1 : lhs_shape.dim(axis - lhs_rank_diff), + axis < rhs_rank_diff ? 1 : rhs_shape.dim(axis - rhs_rank_diff)); + } + + return out_shape; +} + +} // namespace tflite +} // namespace nnfw |