diff options
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactoryBase.hpp | 6 | ||||
-rw-r--r-- | src/backends/cl/ClLayerSupport.cpp | 9 | ||||
-rw-r--r-- | src/backends/cl/ClLayerSupport.hpp | 5 | ||||
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.cpp | 6 | ||||
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.hpp | 3 | ||||
-rw-r--r-- | src/backends/cl/workloads/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClSumWorkload.cpp | 52 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClSumWorkload.hpp | 30 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClWorkloads.hpp | 1 |
9 files changed, 113 insertions, 1 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactoryBase.hpp b/src/backends/backendsCommon/WorkloadFactoryBase.hpp index 960dbd341..0436a5349 100644 --- a/src/backends/backendsCommon/WorkloadFactoryBase.hpp +++ b/src/backends/backendsCommon/WorkloadFactoryBase.hpp @@ -270,9 +270,13 @@ public: const WorkloadInfo& /*info*/) const override { return nullptr; } + std::unique_ptr<IWorkload> CreateReduceSum(const ReduceSumQueueDescriptor& /*descriptor*/, + const WorkloadInfo& /*info*/) const override + { return nullptr; } + std::unique_ptr<IWorkload> CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const override { return nullptr; } }; -} //namespace armnn
\ No newline at end of file +} //namespace armnn diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index 7418dbd9e..07fd6d94e 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -61,6 +61,7 @@ #include "workloads/ClStackWorkload.hpp" #include "workloads/ClStridedSliceWorkload.hpp" #include "workloads/ClSubtractionWorkload.hpp" +#include "workloads/ClSumWorkload.hpp" #include "workloads/ClTransposeConvolution2dWorkload.hpp" #include "workloads/ClTransposeWorkload.hpp" #endif @@ -865,4 +866,12 @@ bool ClLayerSupport::IsTransposeSupported(const TensorInfo& input, FORWARD_WORKLOAD_VALIDATE_FUNC(ClTransposeWorkloadValidate, reasonIfUnsupported, input, output, descriptor); } +bool ClLayerSupport::IsReduceSumSupported(const TensorInfo& input, + const TensorInfo& output, + const ReduceSumDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported) const +{ + FORWARD_WORKLOAD_VALIDATE_FUNC(ClReduceSumWorkloadValidate, reasonIfUnsupported, input, output, descriptor); +} + } // namespace armnn diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp index d785f5438..872cd5775 100644 --- a/src/backends/cl/ClLayerSupport.hpp +++ b/src/backends/cl/ClLayerSupport.hpp @@ -302,6 +302,11 @@ public: const TransposeDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsReduceSumSupported(const TensorInfo& input, + const TensorInfo& output, + const ReduceSumDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + }; } // namespace armnn diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index 50a867ca2..3a78092b6 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -541,6 +541,12 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateTranspose(const TransposeQue return MakeWorkload<ClTransposeWorkload>(descriptor, info); } +std::unique_ptr<IWorkload> ClWorkloadFactory::CreateReduceSum(const ReduceSumQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return MakeWorkload<ClReduceSumWorkload>(descriptor, info); +} + std::unique_ptr<IWorkload> ClWorkloadFactory::CreateTransposeConvolution2d( const TransposeConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp index 3f92e8647..c27fb6a41 100644 --- a/src/backends/cl/ClWorkloadFactory.hpp +++ b/src/backends/cl/ClWorkloadFactory.hpp @@ -218,6 +218,9 @@ public: std::unique_ptr<IWorkload> CreateTranspose(const TransposeQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr<IWorkload> CreateReduceSum(const ReduceSumQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + std::unique_ptr<IWorkload> CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const override; diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt index 6d0aa792a..632070dc1 100644 --- a/src/backends/cl/workloads/CMakeLists.txt +++ b/src/backends/cl/workloads/CMakeLists.txt @@ -72,6 +72,8 @@ list(APPEND armnnClBackendWorkloads_sources ClQuantizedLstmWorkload.hpp ClQuantizeWorkload.cpp ClQuantizeWorkload.hpp + ClSumWorkload.cpp + ClSumWorkload.hpp ClReshapeWorkload.cpp ClReshapeWorkload.hpp ClResizeWorkload.cpp diff --git a/src/backends/cl/workloads/ClSumWorkload.cpp b/src/backends/cl/workloads/ClSumWorkload.cpp new file mode 100644 index 000000000..4fb415a1f --- /dev/null +++ b/src/backends/cl/workloads/ClSumWorkload.cpp @@ -0,0 +1,52 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ClSumWorkload.hpp" + +#include <cl/ClTensorHandle.hpp> +#include <aclCommon/ArmComputeTensorUtils.hpp> + +#include "ClWorkloadUtils.hpp" + +namespace armnn +{ +using namespace armcomputetensorutils; + +arm_compute::Status ClReduceSumWorkloadValidate(const TensorInfo& input, + const TensorInfo& output, + const ReduceSumDescriptor& desc) +{ + const arm_compute::TensorInfo aclInputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(input); + const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output); + + arm_compute::Coordinates coords = BuildArmComputeReductionCoordinates(aclInputInfo.num_dimensions(), + input.GetNumDimensions(), + desc.m_vAxis); + + return arm_compute::CLReduceSum::validate(&aclInputInfo, coords, desc.m_KeepDims, &aclOutputInfo); +} + +ClReduceSumWorkload::ClReduceSumWorkload(const ReduceSumQueueDescriptor& descriptor, const WorkloadInfo& info) + : BaseWorkload<ReduceSumQueueDescriptor>(descriptor, info) +{ + m_Data.ValidateInputsOutputs("ClSumWorkload", 1, 1); + + arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor(); + arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor(); + + arm_compute::Coordinates coords = BuildArmComputeReductionCoordinates(input.info()->num_dimensions(), + info.m_InputTensorInfos[0].GetNumDimensions(), + m_Data.m_Parameters.m_vAxis); + + m_Layer.configure(&input, coords, m_Data.m_Parameters.m_KeepDims, &output); +} + +void ClReduceSumWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT_CL("ClSumWorkload_Execute"); + m_Layer.run(); +} + +} //namespace armnn diff --git a/src/backends/cl/workloads/ClSumWorkload.hpp b/src/backends/cl/workloads/ClSumWorkload.hpp new file mode 100644 index 000000000..2e8abdb4e --- /dev/null +++ b/src/backends/cl/workloads/ClSumWorkload.hpp @@ -0,0 +1,30 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <backendsCommon/Workload.hpp> + +#include <arm_compute/runtime/CL/functions/CLReduceSum.h> + +namespace armnn +{ + +arm_compute::Status ClReduceSumWorkloadValidate(const TensorInfo& input, + const TensorInfo& output, + const ReduceSumDescriptor& desc); + +class ClReduceSumWorkload : public BaseWorkload<ReduceSumQueueDescriptor> +{ +public: + ClReduceSumWorkload(const ReduceSumQueueDescriptor& descriptor, const WorkloadInfo& info); + + void Execute() const override; + +private: + mutable arm_compute::CLReduceSum m_Layer; +}; + +} //namespace armnn diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp index 7b3ce439b..9bc17c557 100644 --- a/src/backends/cl/workloads/ClWorkloads.hpp +++ b/src/backends/cl/workloads/ClWorkloads.hpp @@ -36,6 +36,7 @@ #include "ClQLstmWorkload.hpp" #include "ClQuantizeWorkload.hpp" #include "ClQuantizedLstmWorkload.hpp" +#include "ClSumWorkload.hpp" #include "ClReshapeWorkload.hpp" #include "ClResizeWorkload.hpp" #include "ClRsqrtWorkload.hpp" |