blob: 29628cd2654ca081f0cb1f50499cb2281f22b17f (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
#include "tflite/TensorShapeUtils.h"
namespace nnfw
{
namespace tflite
{
nnfw::misc::tensor::Shape broadcast(const nnfw::misc::tensor::Shape &lhs_shape,
const nnfw::misc::tensor::Shape &rhs_shape)
{
const uint32_t lhs_rank = lhs_shape.rank();
const uint32_t rhs_rank = rhs_shape.rank();
const uint32_t out_rank = std::max(lhs_rank, rhs_rank);
const uint32_t lhs_rank_diff = out_rank - lhs_rank;
const uint32_t rhs_rank_diff = out_rank - rhs_rank;
nnfw::misc::tensor::Shape out_shape(out_rank);
for (uint32_t axis = 0; axis < out_rank; ++axis)
{
out_shape.dim(axis) = std::max(axis < lhs_rank_diff ? 1 : lhs_shape.dim(axis - lhs_rank_diff),
axis < rhs_rank_diff ? 1 : rhs_shape.dim(axis - rhs_rank_diff));
}
return out_shape;
}
} // namespace tflite
} // namespace nnfw
|