summaryrefslogtreecommitdiff
path: root/runtimes/pure_arm_compute
diff options
context:
space:
mode:
authorPrasanna R/SNAP /SRI-Bangalore/Engineer/삼성전자 <prasanna.r@samsung.com>2019-01-17 15:07:52 +0530
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>2019-01-17 18:37:52 +0900
commit186796269d4068eb3737dd431a6fca006da2e4b7 (patch)
tree836dd843e90b5227211005b1a2f0edc32898a014 /runtimes/pure_arm_compute
parent017d638804008f5b5e9fc70603e60e84e3766a06 (diff)
downloadnnfw-186796269d4068eb3737dd431a6fca006da2e4b7.tar.gz
nnfw-186796269d4068eb3737dd431a6fca006da2e4b7.tar.bz2
nnfw-186796269d4068eb3737dd431a6fca006da2e4b7.zip
Add Broadcast support for PReLU in PACL (#4072)
This patch adds broadcast support for PReLU in PACL. Signed-off-by: prasannar <prasanna.r@samsung.com>
Diffstat (limited to 'runtimes/pure_arm_compute')
-rw-r--r--runtimes/pure_arm_compute/src/compilation.cc12
1 files changed, 11 insertions, 1 deletions
diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc
index 51f4fe4b1..9c12294a3 100644
--- a/runtimes/pure_arm_compute/src/compilation.cc
+++ b/runtimes/pure_arm_compute/src/compilation.cc
@@ -3016,10 +3016,20 @@ void Planner::visit(const ::internal::tflite::op::PReLU::Node &node)
const ::internal::tflite::operand::Index ifm_index{node.param().ifm_index};
const ::internal::tflite::operand::Index alpha_index{node.param().alpha_index};
- // Set shape constraints
+ // Set Shape Constraints and TensorInfo
_builder.addShapeConstr(
ofm_index, asTensorInfo(asTensorShape(_ctx.at(ofm_index).shape()), _ctx.at(ofm_index).type(),
_ctx.at(ofm_index).scale(), _ctx.at(ofm_index).zeroPoint()));
+
+ if (!(_ctx.at(ifm_index).shape() == _ctx.at(alpha_index).shape()))
+ {
+ const auto broadcast_rank =
+ std::max(_ctx.at(ifm_index).shape().rank(), _ctx.at(alpha_index).shape().rank());
+ const_cast<::internal::tflite::operand::Shape &>(_ctx.at(ifm_index).shape())
+ .extendRank(broadcast_rank);
+ const_cast<::internal::tflite::operand::Shape &>(_ctx.at(alpha_index).shape())
+ .extendRank(broadcast_rank);
+ }
_builder.addShapeConstr(
ifm_index, asTensorInfo(asTensorShape(_ctx.at(ifm_index).shape()), _ctx.at(ifm_index).type(),
_ctx.at(ifm_index).scale(), _ctx.at(ifm_index).zeroPoint()));