diff options
Diffstat (limited to 'libs/ARMComputeEx/arm_compute/core/CL/kernels')
7 files changed, 773 insertions, 0 deletions
diff --git a/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLCastKernel.h b/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLCastKernel.h new file mode 100644 index 000000000..6bd33bf8f --- /dev/null +++ b/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLCastKernel.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2016-2018 ARM Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef __ARM_COMPUTE_CLCASTKERNEL_H__ +#define __ARM_COMPUTE_CLCASTKERNEL_H__ + +#include "arm_compute/core/CL/ICLKernel.h" + +namespace arm_compute +{ +class ICLTensor; + +/** OpenCL kernel to perform a cast operation */ +class CLCastKernel : public ICLKernel +{ +public: + /** Default constructor */ + CLCastKernel(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLCastKernel(const CLCastKernel &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLCastKernel &operator=(const CLCastKernel &) = delete; + /** Allow instances of this class to be moved */ + CLCastKernel(CLCastKernel &&) = default; + /** Allow instances of this class to be moved */ + CLCastKernel &operator=(CLCastKernel &&) = default; + /** Default destructor */ + ~CLCastKernel() = default; + /** Initialise the kernel's input and output. + * + * @param[in] input Input tensor. Data types supported: U8/QASYMM8/S16/S32/F16/F32. + * @param[in] output Output tensor. Data types supported: U8/QASYMM8/S16/S32/F16/F32. + */ + void configure(const ICLTensor *input, ICLTensor *output); + + // Inherited methods overridden: + void run(const Window &window, cl::CommandQueue &queue) override; + +private: + const ICLTensor *_input; /**< Source tensor */ + ICLTensor *_output; /**< Destination tensor */ +}; +} // namespace arm_compute +#endif /* __ARM_COMPUTE_CLCASTKERNEL_H__ */ diff --git a/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLGatherKernel.h b/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLGatherKernel.h new file mode 100644 index 000000000..a51441aca --- /dev/null +++ b/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLGatherKernel.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2016-2018 ARM Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef __ARM_COMPUTE_CLGATHERKERNEL_H__ +#define __ARM_COMPUTE_CLGATHERKERNEL_H__ + +#include "arm_compute/core/CL/ICLKernel.h" +#include "arm_compute/core/Types.h" + +namespace arm_compute +{ +class ICLTensor; + +/** Interface for the gather kernel. + * + */ +class CLGatherKernel : public ICLKernel +{ +public: + /** Default constructor.*/ + CLGatherKernel(); + /** Prevent instances of this class from being copied (As this class contains pointers). */ + CLGatherKernel(const CLGatherKernel &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers). */ + CLGatherKernel &operator=(const CLGatherKernel &) = delete; + /** Allow instances of this class to be moved */ + CLGatherKernel(CLGatherKernel &&) = default; + /** Allow instances of this class to be moved */ + CLGatherKernel &operator=(CLGatherKernel &&) = default; + /** Initialise the kernel's input, output and border mode. + * + * @param[in] input1 An input tensor. Data types supported: U8/S32/F32. + * @param[in] input2 An input tensor. Data types supported: S32. + * @param[out] output The output tensor, Data types supported: same as @p input1. + */ + void configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output); + /** Static function to check if given info will lead to a valid configuration of @ref + * CLGatherKernel + * + * @param[in] input1 An input tensor. Data types supported: U8/S32/F32. + * @param[in] input2 An input tensor. Data types supported: S32. + * @param[out] output The output tensor, Data types supported: same as @p input1. + * + * @return a status + */ + static Status validate(const ITensorInfo *input1, const ITensorInfo *input2, + const ITensorInfo *output); + + // Inherited methods overridden: + void run(const Window &window, cl::CommandQueue &queue) override; + +private: + const ICLTensor *_input1; + const ICLTensor *_input2; + ICLTensor *_output; +}; +} // namespace arm_compute +#endif /*__ARM_COMPUTE_CLGATHERKERNEL_H__ */ diff --git a/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLPixelWiseDivisionKernel.h b/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLPixelWiseDivisionKernel.h new file mode 100644 index 000000000..cd2b255bc --- /dev/null +++ b/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLPixelWiseDivisionKernel.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2016-2018 ARM Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef __ARM_COMPUTE_CLPIXELWISEDIVISIONKERNEL_H__ +#define __ARM_COMPUTE_CLPIXELWISEDIVISIONKERNEL_H__ + +#include "arm_compute/core/CL/ICLKernel.h" +#include "arm_compute/core/Types.h" + +namespace arm_compute +{ +class ICLTensor; + +/** Interface for the pixelwise division kernel. + * + */ +class CLPixelWiseDivisionKernel : public ICLKernel +{ +public: + /** Default constructor.*/ + CLPixelWiseDivisionKernel(); + /** Prevent instances of this class from being copied (As this class contains pointers). */ + CLPixelWiseDivisionKernel(const CLPixelWiseDivisionKernel &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers). */ + CLPixelWiseDivisionKernel &operator=(const CLPixelWiseDivisionKernel &) = delete; + /** Allow instances of this class to be moved */ + CLPixelWiseDivisionKernel(CLPixelWiseDivisionKernel &&) = default; + /** Allow instances of this class to be moved */ + CLPixelWiseDivisionKernel &operator=(CLPixelWiseDivisionKernel &&) = default; + /** Initialise the kernel's input, output and border mode. + * + * @param[in] input1 An input tensor. Data types supported: U8/QS8/QS16/S16/F16/F32. + * @param[in] input2 An input tensor. Data types supported: same as @p input1. + * @param[out] output The output tensor, Data types supported: same as @p input1. Note: + * U8 (QS8, QS16) requires both inputs to be U8 (QS8, QS16). + * @param[in] scale Scale to apply after division. + * Scale must be positive and its value must be either 1/255 or 1/2^n + * where n is between 0 and 15. For QS8 and QS16 scale must be 1. + * @param[in] overflow_policy Overflow policy. Supported overflow policies: Wrap, Saturate + * @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest + * even. + */ + void configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float scale, + ConvertPolicy overflow_policy, RoundingPolicy rounding_policy); + /** Static function to check if given info will lead to a valid configuration of @ref + * CLPixelWiseDivisionKernel + * + * @param[in] input1 An input tensor info. Data types supported: U8/QS8/QS16/S16/F16/F32. + * @param[in] input2 An input tensor info. Data types supported: same as @p input1. + * @param[in] output The output tensor info, Data types supported: same as @p input1. + * Note: U8 (QS8, QS16) requires both inputs to be U8 (QS8, QS16). + * @param[in] scale Scale to apply after division. + * Scale must be positive and its value must be either 1/255 or 1/2^n + * where n is between 0 and 15. For QS8 and QS16 scale must be 1. + * @param[in] overflow_policy Overflow policy. Supported overflow policies: Wrap, Saturate + * @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even. + * + * @return a status + */ + static Status validate(const ITensorInfo *input1, const ITensorInfo *input2, + const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, + RoundingPolicy rounding_policy); + + // Inherited methods overridden: + void run(const Window &window, cl::CommandQueue &queue) override; + BorderSize border_size() const override; + +private: + const ICLTensor *_input1; + const ICLTensor *_input2; + ICLTensor *_output; +}; +} // namespace arm_compute +#endif /*__ARM_COMPUTE_CLPIXELWISEDIVISIONKERNEL_H__ */ diff --git a/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLReduceMaxKernel.h b/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLReduceMaxKernel.h new file mode 100644 index 000000000..a7d96cc5c --- /dev/null +++ b/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLReduceMaxKernel.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2016-2018 ARM Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef __ARM_COMPUTE_CLREDUCEMAXKERNEL_H__ +#define __ARM_COMPUTE_CLREDUCEMAXKERNEL_H__ + +#include "arm_compute/core/CL/ICLKernel.h" +#include "arm_compute/core/Types.h" + +namespace arm_compute +{ +class ICLTensor; + +/** Interface for the pixelwise division kernel. + * + */ +class CLReduceMaxKernel : public ICLKernel +{ +public: + /** Default constructor.*/ + CLReduceMaxKernel(); + /** Prevent instances of this class from being copied (As this class contains pointers). */ + CLReduceMaxKernel(const CLReduceMaxKernel &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers). */ + CLReduceMaxKernel &operator=(const CLReduceMaxKernel &) = delete; + /** Allow instances of this class to be moved */ + CLReduceMaxKernel(CLReduceMaxKernel &&) = default; + /** Allow instances of this class to be moved */ + CLReduceMaxKernel &operator=(CLReduceMaxKernel &&) = default; + /** Initialise the kernel's input, output and border mode. + * + * @param[in] input An input tensor. Data types supported: U8/QS8/QS16/S16/F16/F32. + * @param[in] axis Axis to reduce + * @param[out] output The output tensor, Data types supported: same as @p input1. Note: + * U8 (QS8, QS16) requires both inputs to be U8 (QS8, QS16). + */ + void configure(const ICLTensor *input, int32_t axis, ICLTensor *output); + /** Static function to check if given info will lead to a valid configuration of @ref + * CLReduceMaxKernel + * + * @param[in] input An input tensor info. Data types supported: U8/QS8/QS16/S16/F16/F32. + * @param[in] axis Axis to reduce + * @param[in] output The output tensor info, Data types supported: same as @p input1. + * Note: U8 (QS8, QS16) requires both inputs to be U8 (QS8, QS16). + * + * @return a status + */ + static Status validate(const ITensorInfo *input, int32_t axis, const ITensorInfo *output); + + // Inherited methods overridden: + void run(const Window &window, cl::CommandQueue &queue) override; + void run_on_cpu(cl::CommandQueue &queue); + +private: + const ICLTensor *_input; + ICLTensor *_output; + int32_t _axis; +}; +} // namespace arm_compute +#endif /*__ARM_COMPUTE_CLREDUCEMAXKERNEL_H__ */ diff --git a/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLReductionMeanKernel.h b/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLReductionMeanKernel.h new file mode 100644 index 000000000..de9df3381 --- /dev/null +++ b/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLReductionMeanKernel.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2016-2018 ARM Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef __ARM_COMPUTE_CLREDUCTIONMEANKERNEL_H__ +#define __ARM_COMPUTE_CLREDUCTIONMEANKERNEL_H__ + +#include "arm_compute/core/CL/ICLKernel.h" +#include "arm_compute/core/Types.h" + +namespace arm_compute +{ +class ICLTensor; + +/** Interface for the reduction operation kernel */ +class CLReductionMeanKernel : public ICLKernel +{ +public: + /** Default constructor */ + CLReductionMeanKernel(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLReductionMeanKernel(const CLReductionMeanKernel &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLReductionMeanKernel &operator=(const CLReductionMeanKernel &) = delete; + /** Allow instances of this class to be moved */ + CLReductionMeanKernel(CLReductionMeanKernel &&) = default; + /** Allow instances of this class to be moved */ + CLReductionMeanKernel &operator=(CLReductionMeanKernel &&) = default; + /** Default destructor */ + ~CLReductionMeanKernel() = default; + + /** Set the input and output tensors. + * + * @param[in] input Source tensor. Data types supported: F32. Data layouts supported: NCHW. + * @param[out] output Destination tensor. Data types and data layouts supported: Same as @p input. + * Output will have the same number of dimensions as input. + * @param[in] axis Axis along which to reduce. Supported reduction axis : 0, 1 + */ + void configure(const ICLTensor *input, ICLTensor *output, std::vector<uint32_t> axis); + + /** Static function to check if given info will lead to a valid configuration of @ref + * CLReductionMeanKernel. + * + * @param[in] input Source tensor info. Data types supported: F32. Data layouts supported: NCHW. + * @param[in] output Destination tensor info. Data types and data layouts supported: Same as @p + * input. + * Output will have the same number of dimensions as input. + * @param[in] axis Axis along which to reduce. Supported reduction axis : 0, 1 + * + * @return a status + */ + static Status validate(const ITensorInfo *input, const ITensorInfo *output, + std::vector<uint32_t> axis); + + // Inherited methods overridden: + void run(const Window &window, cl::CommandQueue &queue) override; + BorderSize border_size() const override; + +private: + const ICLTensor *_input; + ICLTensor *_output; + std::vector<uint32_t> _reduction_axis; + BorderSize _border_size; +}; +} // namespace arm_compute +#endif /*__ARM_COMPUTE_CLREDUCTIONMEANKERNEL_H__ */ diff --git a/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLStridedSliceKernel.h b/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLStridedSliceKernel.h new file mode 100644 index 000000000..248ae6635 --- /dev/null +++ b/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLStridedSliceKernel.h @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2016-2018 ARM Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef __ARM_COMPUTE_CLSTRIDEDSLICEKERNEL_H__ +#define __ARM_COMPUTE_CLSTRIDEDSLICEKERNEL_H__ + +#include "arm_compute/core/CL/ICLKernel.h" +#include "arm_compute/core/Types.h" + +namespace arm_compute +{ +class ICLTensor; + +/** Interface for the kernel to extract a strided slice of a tensor */ +class CLStridedSliceKernel : public ICLKernel +{ +public: + /** Default constructor */ + CLStridedSliceKernel(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLStridedSliceKernel(const CLStridedSliceKernel &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLStridedSliceKernel &operator=(const CLStridedSliceKernel &) = delete; + /** Allow instances of this class to be moved */ + CLStridedSliceKernel(CLStridedSliceKernel &&) = default; + /** Allow instances of this class to be moved */ + CLStridedSliceKernel &operator=(CLStridedSliceKernel &&) = default; + /** Default destructor */ + ~CLStridedSliceKernel() = default; + /** Set the input and output of the kernel + * + * @param[in] input Source tensor. Data type supported: + * U8/S8/QS8/QASYMM8/U16/S16/QS16/U32/S32/F16/F32 + * @param[out] output Destination tensor. Data type supported: Same as @p input + * @param[in] beginData The begin tensor. Data types supported: S32. + * The number of dimensions must be 1. + * The length must be the same as the number of dimensions of input. + * @param[in] endData The end tensor. Data types supported: S32. + * The number of dimensions must be 1. + * The length must be the same as the number of dimensions of input. + * @param[in] strideData The stride tensor. Data types supported: S32. + * The number of dimensions must be 1. + * The length must be the same as the number of dimensions of input. + * @param[in] beginMask Mask for begin + * @param[in] endMask Mask for end + * @param[in] shrinkAxisMask Mask for shrink axis. + * + */ + void configure(const ICLTensor *input, ICLTensor *output, ICLTensor *beginData, + ICLTensor *endData, ICLTensor *stridesData, int32_t beginMask, int32_t endMask, + int32_t shrinkAxisMask); + + /** Static function to check if given info will lead to a valid configuration of @ref + * CLStridedSliceKernel + * + * @param[in] input The input tensor info. Data types supported: + * U8/S8/QS8/QASYMM8/U16/S16/QS16/U32/S32/F16/F32 + * @param[in] output The output tensor info, Data types supported: same as @p input1. + * @param[in] begin The begin tensor info. Data types supported: S32. + * The number of dimensions must be 1. + * The length must be the same as the number of dimensions of input. + * @param[in] end The end tensor info. Data types supported: S32. + * The number of dimensions must be 1. + * The length must be the same as the number of dimensions of input. + * @param[in] stride The stride tensor info. Data types supported: S32. + * The number of dimensions must be 1. + * The length must be the same as the number of dimensions of input. + * @param[in] beginMask Mask for begin + * @param[in] endMask Mask for end + * @param[in] shrinkAxisMask Mask for shrink axis. + * + * @return a status + */ + static Status validate(const ITensorInfo *input, const ITensorInfo *output, + const ITensorInfo *begin, const ITensorInfo *end, + const ITensorInfo *stride, int32_t beginMask, int32_t endMask, + int32_t shrinkAxisMask); + + // Inherited methods overridden: + void run(const Window &window, cl::CommandQueue &queue) override; + +private: + const ICLTensor *_input; /** Source tensor */ + ICLTensor *_output; /** Destination tensor */ + ICLTensor *_beginData; /** Start indices of input tensor */ + ICLTensor *_endData; /** Stop indices of input tensor */ + ICLTensor *_stridesData; /** Strides tensor */ + int32_t _beginMask; /** Begin mask */ + int32_t _endMask; /** End mask */ + int32_t _shrinkAxisMask; /** Shrink axis mask */ +}; +} // namespace arm_compute +#endif /*__ARM_COMPUTE_CLSTRIDEDSLICEKERNEL_H__ */ diff --git a/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLTopKV2Kernel.h b/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLTopKV2Kernel.h new file mode 100644 index 000000000..5c567f38e --- /dev/null +++ b/libs/ARMComputeEx/arm_compute/core/CL/kernels/CLTopKV2Kernel.h @@ -0,0 +1,301 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2016-2018 ARM Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef __ARM_COMPUTE_CLTOPKV2KERNEL_H__ +#define __ARM_COMPUTE_CLTOPKV2KERNEL_H__ + +#include "arm_compute/core/CL/ICLArray.h" +#include "arm_compute/core/CL/ICLKernel.h" + +#include <array> + +// these parameters can be changed +#define _ITEMS 16 // number of items in a group +#define _GROUPS 4 // the number of virtual processors is _ITEMS * _GROUPS +#define _HISTOSPLIT (_ITEMS * _GROUPS / 2) // number of splits of the histogram +#define PERMUT // store the final permutation +//////////////////////////////////////////////////////// + +namespace arm_compute +{ +class ICLTensor; + +class CLTopKV2Single : public ICLKernel +{ +public: + /** Constructor */ + CLTopKV2Single(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLTopKV2Single(const CLTopKV2Single &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLTopKV2Single &operator=(const CLTopKV2Single &) = delete; + /** Allow instances of this class to be moved */ + CLTopKV2Single(CLTopKV2Single &&) = default; + /** Allow instances of this class to be moved */ + CLTopKV2Single &operator=(CLTopKV2Single &&) = default; + + void configure(ICLTensor *input, ICLTensor *topk_values, ICLTensor *topk_indices, + cl::Buffer *indices, cl::Buffer *temp_stack, int k, int n); + + // Inherited methods overridden: + void run(const Window &window, cl::CommandQueue &queue) override; + +private: + ICLTensor *_input; + ICLTensor *_topk_values; + ICLTensor *_topk_indices; +}; + +class CLTopKV2Init : public ICLKernel +{ +public: + /** Constructor */ + CLTopKV2Init(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLTopKV2Init(const CLTopKV2Init &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLTopKV2Init &operator=(const CLTopKV2Init &) = delete; + /** Allow instances of this class to be moved */ + CLTopKV2Init(CLTopKV2Init &&) = default; + /** Allow instances of this class to be moved */ + CLTopKV2Init &operator=(CLTopKV2Init &&) = default; + + void configure(ICLTensor *input, cl::Buffer *in_key_buf, cl::Buffer *in_ind_buf, int n); + + // Inherited methods overridden: + void run(const Window &window, cl::CommandQueue &queue) override; + +private: + ICLTensor *_input; +}; + +class CLRadixSortHistogram : public ICLKernel +{ +public: + /** Constructor */ + CLRadixSortHistogram(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLRadixSortHistogram(const CLRadixSortHistogram &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLRadixSortHistogram &operator=(const CLRadixSortHistogram &) = delete; + /** Allow instances of this class to be moved */ + CLRadixSortHistogram(CLRadixSortHistogram &&) = default; + /** Allow instances of this class to be moved */ + CLRadixSortHistogram &operator=(CLRadixSortHistogram &&) = default; + + void configure(cl::Buffer *hist_buf, int bits, int n); + + void setPass(int pass, cl::Buffer *in_key_buf) + { + _pass = pass; + _in_key_buf = in_key_buf; + } + + // Inherited methods overridden: + void run(const Window &window, cl::CommandQueue &queue) override; + +private: + int _pass; + cl::Buffer *_in_key_buf; +}; + +class CLRadixSortScanHistogram : public ICLKernel +{ +public: + /** Constructor */ + CLRadixSortScanHistogram(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLRadixSortScanHistogram(const CLRadixSortScanHistogram &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLRadixSortScanHistogram &operator=(const CLRadixSortScanHistogram &) = delete; + /** Allow instances of this class to be moved */ + CLRadixSortScanHistogram(CLRadixSortScanHistogram &&) = default; + /** Allow instances of this class to be moved */ + CLRadixSortScanHistogram &operator=(CLRadixSortScanHistogram &&) = default; + + void configure(cl::Buffer *hist_buf, cl::Buffer *glob_sum_buf, int bits); + + // Inherited methods overridden: + void run(const Window &window, cl::CommandQueue &queue) override; +}; + +class CLRadixSortGlobalScanHistogram : public ICLKernel +{ +public: + /** Constructor */ + CLRadixSortGlobalScanHistogram(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLRadixSortGlobalScanHistogram(const CLRadixSortGlobalScanHistogram &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLRadixSortGlobalScanHistogram &operator=(const CLRadixSortGlobalScanHistogram &) = delete; + /** Allow instances of this class to be moved */ + CLRadixSortGlobalScanHistogram(CLRadixSortGlobalScanHistogram &&) = default; + /** Allow instances of this class to be moved */ + CLRadixSortGlobalScanHistogram &operator=(CLRadixSortGlobalScanHistogram &&) = default; + + void configure(cl::Buffer *glob_sum_buf, cl::Buffer *temp_buf, int bits); + + // Inherited methods overridden: + void run(const Window &window, cl::CommandQueue &queue) override; +}; + +class CLRadixSortPasteHistogram : public ICLKernel +{ +public: + /** Constructor */ + CLRadixSortPasteHistogram(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLRadixSortPasteHistogram(const CLRadixSortPasteHistogram &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLRadixSortPasteHistogram &operator=(const CLRadixSortPasteHistogram &) = delete; + /** Allow instances of this class to be moved */ + CLRadixSortPasteHistogram(CLRadixSortPasteHistogram &&) = default; + /** Allow instances of this class to be moved */ + CLRadixSortPasteHistogram &operator=(CLRadixSortPasteHistogram &&) = default; + + void configure(cl::Buffer *hist_buf, cl::Buffer *glob_sum_buf, int bits); + + // Inherited methods overridden: + void run(const Window &window, cl::CommandQueue &queue) override; +}; + +class CLRadixSortReorder : public ICLKernel +{ +public: + /** Constructor */ + CLRadixSortReorder(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLRadixSortReorder(const CLRadixSortReorder &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLRadixSortReorder &operator=(const CLRadixSortReorder &) = delete; + /** Allow instances of this class to be moved */ + CLRadixSortReorder(CLRadixSortReorder &&) = default; + /** Allow instances of this class to be moved */ + CLRadixSortReorder &operator=(CLRadixSortReorder &&) = default; + + void configure(cl::Buffer *hist_buf, int bits, int n); + + void setPass(int pass, cl::Buffer *in_key_buf, cl::Buffer *out_key_buf, cl::Buffer *in_ind_buf, + cl::Buffer *out_ind_buf) + { + _pass = pass; + _in_key_buf = in_key_buf; + _out_key_buf = out_key_buf; + _in_ind_buf = in_ind_buf; + _out_ind_buf = out_ind_buf; + } + // Inherited methods overridden: + void run(const Window &window, cl::CommandQueue &queue) override; + +private: + int _pass; + cl::Buffer *_in_key_buf; + cl::Buffer *_out_key_buf; + cl::Buffer *_in_ind_buf; + cl::Buffer *_out_ind_buf; +}; + +class CLTopKV2FindFirstNegative : public ICLKernel +{ +public: + /** Constructor */ + CLTopKV2FindFirstNegative(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLTopKV2FindFirstNegative(const CLTopKV2FindFirstNegative &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLTopKV2FindFirstNegative &operator=(const CLTopKV2FindFirstNegative &) = delete; + /** Allow instances of this class to be moved */ + CLTopKV2FindFirstNegative(CLTopKV2FindFirstNegative &&) = default; + /** Allow instances of this class to be moved */ + CLTopKV2FindFirstNegative &operator=(CLTopKV2FindFirstNegative &&) = default; + + void configure(cl::Buffer *first_negative_idx_buf, int n); + + void setOutputBuffer(cl::Buffer *out_key_buf) { _out_key_buf = out_key_buf; } + + // Inherited methods overridden: + void run(const Window &window, cl::CommandQueue &queue) override; + +private: + cl::Buffer *_out_key_buf; +}; + +class CLTopKV2ReorderNegatives : public ICLKernel +{ +public: + /** Constructor */ + CLTopKV2ReorderNegatives(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLTopKV2ReorderNegatives(const CLTopKV2ReorderNegatives &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLTopKV2ReorderNegatives &operator=(const CLTopKV2ReorderNegatives &) = delete; + /** Allow instances of this class to be moved */ + CLTopKV2ReorderNegatives(CLTopKV2ReorderNegatives &&) = default; + /** Allow instances of this class to be moved */ + CLTopKV2ReorderNegatives &operator=(CLTopKV2ReorderNegatives &&) = default; + + void configure(cl::Buffer *first_negative_idx_buf, int n); + + void setBuffers(cl::Buffer *in_key_buf, cl::Buffer *out_key_buf, cl::Buffer *in_ind_buf, + cl::Buffer *out_ind_buf) + { + _in_key_buf = in_key_buf; + _out_key_buf = out_key_buf; + _in_ind_buf = in_ind_buf; + _out_ind_buf = out_ind_buf; + } + + // Inherited methods overridden: + void run(const Window &window, cl::CommandQueue &queue) override; + +private: + cl::Buffer *_in_key_buf; + cl::Buffer *_out_key_buf; + cl::Buffer *_in_ind_buf; + cl::Buffer *_out_ind_buf; +}; + +class CLTopKV2Store : public ICLKernel +{ +public: + /** Constructor */ + CLTopKV2Store(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLTopKV2Store(const CLTopKV2Store &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLTopKV2Store &operator=(const CLTopKV2Store &) = delete; + /** Allow instances of this class to be moved */ + CLTopKV2Store(CLTopKV2Store &&) = default; + /** Allow instances of this class to be moved */ + CLTopKV2Store &operator=(CLTopKV2Store &&) = default; + + void configure(ICLTensor *values, ICLTensor *indices, int k, int n); + + void setOutputBuffers(cl::Buffer *out_key_buf, cl::Buffer *out_ind_buf); + + // Inherited methods overridden: + void run(const Window &window, cl::CommandQueue &queue) override; + +private: + ICLTensor *_values; + ICLTensor *_indices; + cl::Buffer *_out_key_buf; + cl::Buffer *_out_ind_buf; +}; + +} // namespace arm_compute + +#endif // __ARM_COMPUTE_CLTOPKV2KERNEL_H__ |