summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/backends/backendsCommon/WorkloadFactoryBase.hpp6
-rw-r--r--src/backends/cl/ClLayerSupport.cpp9
-rw-r--r--src/backends/cl/ClLayerSupport.hpp5
-rw-r--r--src/backends/cl/ClWorkloadFactory.cpp6
-rw-r--r--src/backends/cl/ClWorkloadFactory.hpp3
-rw-r--r--src/backends/cl/workloads/CMakeLists.txt2
-rw-r--r--src/backends/cl/workloads/ClSumWorkload.cpp52
-rw-r--r--src/backends/cl/workloads/ClSumWorkload.hpp30
-rw-r--r--src/backends/cl/workloads/ClWorkloads.hpp1
9 files changed, 113 insertions, 1 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactoryBase.hpp b/src/backends/backendsCommon/WorkloadFactoryBase.hpp
index 960dbd341..0436a5349 100644
--- a/src/backends/backendsCommon/WorkloadFactoryBase.hpp
+++ b/src/backends/backendsCommon/WorkloadFactoryBase.hpp
@@ -270,9 +270,13 @@ public:
const WorkloadInfo& /*info*/) const override
{ return nullptr; }
+ std::unique_ptr<IWorkload> CreateReduceSum(const ReduceSumQueueDescriptor& /*descriptor*/,
+ const WorkloadInfo& /*info*/) const override
+ { return nullptr; }
+
std::unique_ptr<IWorkload> CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor& /*descriptor*/,
const WorkloadInfo& /*info*/) const override
{ return nullptr; }
};
-} //namespace armnn \ No newline at end of file
+} //namespace armnn
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index 7418dbd9e..07fd6d94e 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -61,6 +61,7 @@
#include "workloads/ClStackWorkload.hpp"
#include "workloads/ClStridedSliceWorkload.hpp"
#include "workloads/ClSubtractionWorkload.hpp"
+#include "workloads/ClSumWorkload.hpp"
#include "workloads/ClTransposeConvolution2dWorkload.hpp"
#include "workloads/ClTransposeWorkload.hpp"
#endif
@@ -865,4 +866,12 @@ bool ClLayerSupport::IsTransposeSupported(const TensorInfo& input,
FORWARD_WORKLOAD_VALIDATE_FUNC(ClTransposeWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
}
+bool ClLayerSupport::IsReduceSumSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const ReduceSumDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ FORWARD_WORKLOAD_VALIDATE_FUNC(ClReduceSumWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
+}
+
} // namespace armnn
diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp
index d785f5438..872cd5775 100644
--- a/src/backends/cl/ClLayerSupport.hpp
+++ b/src/backends/cl/ClLayerSupport.hpp
@@ -302,6 +302,11 @@ public:
const TransposeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsReduceSumSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const ReduceSumDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
};
} // namespace armnn
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp
index 50a867ca2..3a78092b6 100644
--- a/src/backends/cl/ClWorkloadFactory.cpp
+++ b/src/backends/cl/ClWorkloadFactory.cpp
@@ -541,6 +541,12 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateTranspose(const TransposeQue
return MakeWorkload<ClTransposeWorkload>(descriptor, info);
}
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateReduceSum(const ReduceSumQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return MakeWorkload<ClReduceSumWorkload>(descriptor, info);
+}
+
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateTransposeConvolution2d(
const TransposeConvolution2dQueueDescriptor& descriptor,
const WorkloadInfo& info) const
diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp
index 3f92e8647..c27fb6a41 100644
--- a/src/backends/cl/ClWorkloadFactory.hpp
+++ b/src/backends/cl/ClWorkloadFactory.hpp
@@ -218,6 +218,9 @@ public:
std::unique_ptr<IWorkload> CreateTranspose(const TransposeQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateReduceSum(const ReduceSumQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
+
std::unique_ptr<IWorkload> CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt
index 6d0aa792a..632070dc1 100644
--- a/src/backends/cl/workloads/CMakeLists.txt
+++ b/src/backends/cl/workloads/CMakeLists.txt
@@ -72,6 +72,8 @@ list(APPEND armnnClBackendWorkloads_sources
ClQuantizedLstmWorkload.hpp
ClQuantizeWorkload.cpp
ClQuantizeWorkload.hpp
+ ClSumWorkload.cpp
+ ClSumWorkload.hpp
ClReshapeWorkload.cpp
ClReshapeWorkload.hpp
ClResizeWorkload.cpp
diff --git a/src/backends/cl/workloads/ClSumWorkload.cpp b/src/backends/cl/workloads/ClSumWorkload.cpp
new file mode 100644
index 000000000..4fb415a1f
--- /dev/null
+++ b/src/backends/cl/workloads/ClSumWorkload.cpp
@@ -0,0 +1,52 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ClSumWorkload.hpp"
+
+#include <cl/ClTensorHandle.hpp>
+#include <aclCommon/ArmComputeTensorUtils.hpp>
+
+#include "ClWorkloadUtils.hpp"
+
+namespace armnn
+{
+using namespace armcomputetensorutils;
+
+arm_compute::Status ClReduceSumWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& output,
+ const ReduceSumDescriptor& desc)
+{
+ const arm_compute::TensorInfo aclInputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(input);
+ const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
+
+ arm_compute::Coordinates coords = BuildArmComputeReductionCoordinates(aclInputInfo.num_dimensions(),
+ input.GetNumDimensions(),
+ desc.m_vAxis);
+
+ return arm_compute::CLReduceSum::validate(&aclInputInfo, coords, desc.m_KeepDims, &aclOutputInfo);
+}
+
+ClReduceSumWorkload::ClReduceSumWorkload(const ReduceSumQueueDescriptor& descriptor, const WorkloadInfo& info)
+ : BaseWorkload<ReduceSumQueueDescriptor>(descriptor, info)
+{
+ m_Data.ValidateInputsOutputs("ClSumWorkload", 1, 1);
+
+ arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
+ arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
+
+ arm_compute::Coordinates coords = BuildArmComputeReductionCoordinates(input.info()->num_dimensions(),
+ info.m_InputTensorInfos[0].GetNumDimensions(),
+ m_Data.m_Parameters.m_vAxis);
+
+ m_Layer.configure(&input, coords, m_Data.m_Parameters.m_KeepDims, &output);
+}
+
+void ClReduceSumWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT_CL("ClSumWorkload_Execute");
+ m_Layer.run();
+}
+
+} //namespace armnn
diff --git a/src/backends/cl/workloads/ClSumWorkload.hpp b/src/backends/cl/workloads/ClSumWorkload.hpp
new file mode 100644
index 000000000..2e8abdb4e
--- /dev/null
+++ b/src/backends/cl/workloads/ClSumWorkload.hpp
@@ -0,0 +1,30 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backendsCommon/Workload.hpp>
+
+#include <arm_compute/runtime/CL/functions/CLReduceSum.h>
+
+namespace armnn
+{
+
+arm_compute::Status ClReduceSumWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& output,
+ const ReduceSumDescriptor& desc);
+
+class ClReduceSumWorkload : public BaseWorkload<ReduceSumQueueDescriptor>
+{
+public:
+ ClReduceSumWorkload(const ReduceSumQueueDescriptor& descriptor, const WorkloadInfo& info);
+
+ void Execute() const override;
+
+private:
+ mutable arm_compute::CLReduceSum m_Layer;
+};
+
+} //namespace armnn
diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp
index 7b3ce439b..9bc17c557 100644
--- a/src/backends/cl/workloads/ClWorkloads.hpp
+++ b/src/backends/cl/workloads/ClWorkloads.hpp
@@ -36,6 +36,7 @@
#include "ClQLstmWorkload.hpp"
#include "ClQuantizeWorkload.hpp"
#include "ClQuantizedLstmWorkload.hpp"
+#include "ClSumWorkload.hpp"
#include "ClReshapeWorkload.hpp"
#include "ClResizeWorkload.hpp"
#include "ClRsqrtWorkload.hpp"