diff options
Diffstat (limited to 'runtimes/neurun/backend/acl_neon/KernelGenerator.cc')
-rw-r--r-- | runtimes/neurun/backend/acl_neon/KernelGenerator.cc | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/runtimes/neurun/backend/acl_neon/KernelGenerator.cc b/runtimes/neurun/backend/acl_neon/KernelGenerator.cc index 0293b8368..a05fa8b30 100644 --- a/runtimes/neurun/backend/acl_neon/KernelGenerator.cc +++ b/runtimes/neurun/backend/acl_neon/KernelGenerator.cc @@ -1030,6 +1030,49 @@ void KernelGenerator::visit(const model::operation::PReLUNode &node) _execution_builder->append(std::move(acl_fn)); } +void KernelGenerator::visit(const model::operation::ReduceSumNode &node) +{ + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(model::operation::ReduceSumNode::Input::INPUT)}; + const auto axis_index{node.param().axis_index}; + + const auto axis_base = _ctx.at(axis_index).data().base(); + const auto axis_size = _ctx.at(axis_index).shape().num_elements(); + const auto input_rank = _ctx.at(input_index).shape().rank(); + + auto output_alloc = _tensor_builder->at(output_index).get(); + auto input_alloc = _tensor_builder->at(input_index).get(); + const auto frontend_layout = _current_subg_layout; + const auto backend_layout = input_alloc->layout(); + // The axis's data must exist as constant values + assert(axis_base != nullptr); + std::set<int32_t> axes; + for (size_t n = 0; n < axis_size; ++n) + { + int32_t axis_value = *(reinterpret_cast<const int32_t *>(axis_base) + n); + if (axis_value < 0) + { + axis_value += input_rank; + } + axes.insert(::neurun::backend::acl_common::ToARMComputeAxis(input_rank, axis_value, + frontend_layout, backend_layout) + .value()); + } + arm_compute::Coordinates fixed_axes; + for (const auto &a : axes) + { + fixed_axes.set(fixed_axes.num_dimensions(), a); + } + + auto fn = nnfw::cpp14::make_unique<::arm_compute::NEReduceSum>(); + + fn->configure(input_alloc->handle(), fixed_axes, false, output_alloc->handle()); + + auto acl_fn = asAclFunction(std::move(fn)); + + _execution_builder->append(std::move(acl_fn)); +} + void KernelGenerator::visit(const model::operation::ReLUNode &node) { const auto output_index{node.getOutputs().at(0)}; |