diff options
author | 오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com> | 2019-04-16 08:44:31 +0900 |
---|---|---|
committer | 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com> | 2019-04-16 08:44:31 +0900 |
commit | 60a3e226cade5676616649b07231ab48bb9960e4 (patch) | |
tree | 19f07a716840ad9c9f8d3f933e647e26ef9bdcf8 /runtimes | |
parent | 83d66c0551e0ecf01316b88fc672f8155fdef939 (diff) | |
download | nnfw-60a3e226cade5676616649b07231ab48bb9960e4.tar.gz nnfw-60a3e226cade5676616649b07231ab48bb9960e4.tar.bz2 nnfw-60a3e226cade5676616649b07231ab48bb9960e4.zip |
Remove axis conversion in libs/ARMComputeEx (#4992)
Remove axis conversion in libs/ARMComputeEx because conversion should done before configuration
Fix axis conversion in pack/unpack acl-cl extend kernel
Signed-off-by: Hyeongseok Oh <hseok82.oh@samsung.com>
Diffstat (limited to 'runtimes')
-rw-r--r-- | runtimes/pure_arm_compute/src/compilation.cc | 24 |
1 files changed, 20 insertions, 4 deletions
diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index ae7384037..fd200203f 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -5154,6 +5154,13 @@ void Planner::visit(const ::internal::tflite::op::Unpack::Node &node) int32_t axis = _ctx.at(::internal::tflite::operand::Index{node.param().axis_index}).asScalar<int32_t>(); + // Negatige axis is supported, -1 implies R-1 axis where R is input rank + if (axis < 0) + { + axis += input_rank; + assert(axis >= 0); + } + uint32_t axis_uint = ToARMComputeAxis(input_rank, axis).value(); // int32_t num_split = // _ctx.at(::internal::tflite::operand::Index{node.param().num_split_index}).asScalar<int32_t>(); @@ -5168,14 +5175,14 @@ void Planner::visit(const ::internal::tflite::op::Unpack::Node &node) { std::vector<int32_t> ofm_indexes; int ifm_index; - int axis; + uint32_t axis; }; if (input_rank == 4) { Param param; param.ifm_index = ifm_index.asInt(); - param.axis = axis; + param.axis = axis_uint; for (const auto &index : node.param().ofm_indexes) { param.ofm_indexes.push_back(index); @@ -5241,19 +5248,28 @@ void Planner::visit(const ::internal::tflite::op::Pack::Node &node) int32_t axis = _ctx.at(::internal::tflite::operand::Index{node.param().axis_index}).asScalar<int32_t>(); + // A negative axis implies axis from the end. + // For example, axis = -1 implies the first axis from the end, i.e. axis = Rank - 1. + // Similarly, axis = -2 imples second axis from the end, i.e. axis = Rank - 2. + if (axis < 0) + { + axis += output_rank; + assert(axis >= 0); + } + uint32_t axis_uint = ToARMComputeAxis(output_rank, axis).value(); struct Param { std::vector<int32_t> ifm_indexes; int ofm_index; - int axis; + uint32_t axis; }; if (input_rank == 3) { Param param; param.ofm_index = ofm_index.asInt(); - param.axis = axis; + param.axis = axis_uint; // TODO: Fix this once all permutations are present. if (param.axis != 0) |