summaryrefslogtreecommitdiff
path: root/libs/tflite/src/TensorShapeUtils.cpp
blob: b5d90671970cdb25f3ababd636935fcb181c8761 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include "tflite/TensorShapeUtils.h"

namespace nnfw
{
namespace tflite
{

nnfw::misc::tensor::Shape broadcast(const nnfw::misc::tensor::Shape &lhs_shape,
                                    const nnfw::misc::tensor::Shape &rhs_shape)
{
  const uint32_t lhs_rank = lhs_shape.rank();
  const uint32_t rhs_rank = rhs_shape.rank();
  const uint32_t out_rank = std::max(lhs_rank, rhs_rank);

  // TODO Simplify implementation
  std::vector<int32_t> lhs_normalized_dims;
  std::vector<int32_t> rhs_normalized_dims;

  for (uint32_t n = 0; n < out_rank - lhs_rank; ++n)
  {
    lhs_normalized_dims.emplace_back(1);
  }
  for (uint32_t axis = 0; axis < lhs_rank; ++axis)
  {
    lhs_normalized_dims.emplace_back(lhs_shape.dim(axis));
  }

  for (uint32_t n = 0; n < out_rank - rhs_rank; ++n)
  {
    rhs_normalized_dims.emplace_back(1);
  }
  for (uint32_t axis = 0; axis < rhs_rank; ++axis)
  {
    rhs_normalized_dims.emplace_back(rhs_shape.dim(axis));
  }

  nnfw::misc::tensor::Shape out_shape(out_rank);

  for (uint32_t axis = 0; axis < out_rank; ++axis)
  {
    out_shape.dim(axis) = std::max(lhs_normalized_dims.at(axis), rhs_normalized_dims.at(axis));
  }

  return out_shape;
}

} // namespace tflite
} // namespace nnfw