summaryrefslogtreecommitdiff
path: root/libs/ARMComputeEx/src/core/CL/kernels/CLGatherKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'libs/ARMComputeEx/src/core/CL/kernels/CLGatherKernel.cpp')
-rw-r--r--libs/ARMComputeEx/src/core/CL/kernels/CLGatherKernel.cpp21
1 files changed, 4 insertions, 17 deletions
diff --git a/libs/ARMComputeEx/src/core/CL/kernels/CLGatherKernel.cpp b/libs/ARMComputeEx/src/core/CL/kernels/CLGatherKernel.cpp
index 23efafa6a..ae2801e2b 100644
--- a/libs/ARMComputeEx/src/core/CL/kernels/CLGatherKernel.cpp
+++ b/libs/ARMComputeEx/src/core/CL/kernels/CLGatherKernel.cpp
@@ -17,26 +17,14 @@
#include "arm_compute/core/CL/kernels/CLGatherKernel.h"
#include "arm_compute/core/CL/CLHelpers.h"
-#include "arm_compute/core/CL/CLKernelLibrary.h"
#include "arm_compute/core/CL/CLKernelLibraryEx.h"
#include "arm_compute/core/CL/ICLTensor.h"
-#include "arm_compute/core/CL/OpenCL.h"
-#include "arm_compute/core/Error.h"
-#include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/TensorInfo.h"
-#include "arm_compute/core/Validate.h"
-#include "arm_compute/core/Window.h"
-
-#include <cmath>
-#include <cstdlib>
-#include <set>
-#include <string>
using namespace arm_compute;
namespace
{
-constexpr unsigned int num_elems_processed_per_iteration = 16;
+constexpr unsigned int num_elems_processed_per_iteration = 1;
Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2,
const ITensorInfo *output)
@@ -46,6 +34,7 @@ Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2,
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::S32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S32,
DataType::F32);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input1, output);
return Status{};
}
@@ -57,8 +46,7 @@ CLGatherKernel::CLGatherKernel() : _input1(nullptr), _input2(nullptr), _output(n
void CLGatherKernel::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::S32);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input1, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(), output->info()));
_input1 = input1;
_input2 = input2;
@@ -89,11 +77,10 @@ void CLGatherKernel::configure(const ICLTensor *input1, const ICLTensor *input2,
static_cast<cl::Kernel>(CLKernelLibraryEx::get().create_kernel(kernel_name, build_opts));
// Configure kernel window
- const unsigned int num_elems_processed_per_iteration = 1;
Window win = calculate_max_window(*input2->info(), Steps(num_elems_processed_per_iteration));
output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
Status CLGatherKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2,