diff options
Diffstat (limited to 'libs/ARMComputeEx/src/core/CL/kernels/CLGatherKernel.cpp')
-rw-r--r-- | libs/ARMComputeEx/src/core/CL/kernels/CLGatherKernel.cpp | 21 |
1 files changed, 4 insertions, 17 deletions
diff --git a/libs/ARMComputeEx/src/core/CL/kernels/CLGatherKernel.cpp b/libs/ARMComputeEx/src/core/CL/kernels/CLGatherKernel.cpp index 23efafa6a..ae2801e2b 100644 --- a/libs/ARMComputeEx/src/core/CL/kernels/CLGatherKernel.cpp +++ b/libs/ARMComputeEx/src/core/CL/kernels/CLGatherKernel.cpp @@ -17,26 +17,14 @@ #include "arm_compute/core/CL/kernels/CLGatherKernel.h" #include "arm_compute/core/CL/CLHelpers.h" -#include "arm_compute/core/CL/CLKernelLibrary.h" #include "arm_compute/core/CL/CLKernelLibraryEx.h" #include "arm_compute/core/CL/ICLTensor.h" -#include "arm_compute/core/CL/OpenCL.h" -#include "arm_compute/core/Error.h" -#include "arm_compute/core/Helpers.h" -#include "arm_compute/core/TensorInfo.h" -#include "arm_compute/core/Validate.h" -#include "arm_compute/core/Window.h" - -#include <cmath> -#include <cstdlib> -#include <set> -#include <string> using namespace arm_compute; namespace { -constexpr unsigned int num_elems_processed_per_iteration = 16; +constexpr unsigned int num_elems_processed_per_iteration = 1; Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output) @@ -46,6 +34,7 @@ Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::S32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S32, DataType::F32); + ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input1, output); return Status{}; } @@ -57,8 +46,7 @@ CLGatherKernel::CLGatherKernel() : _input1(nullptr), _input2(nullptr), _output(n void CLGatherKernel::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output) { ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::S32); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input1, output); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(), output->info())); _input1 = input1; _input2 = input2; @@ -89,11 +77,10 @@ void CLGatherKernel::configure(const ICLTensor *input1, const ICLTensor *input2, static_cast<cl::Kernel>(CLKernelLibraryEx::get().create_kernel(kernel_name, build_opts)); // Configure kernel window - const unsigned int num_elems_processed_per_iteration = 1; Window win = calculate_max_window(*input2->info(), Steps(num_elems_processed_per_iteration)); output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape())); - ICLKernel::configure(win); + ICLKernel::configure_internal(win); } Status CLGatherKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, |