summaryrefslogtreecommitdiff
path: root/runtimes/libs/ARMComputeEx/src/core/CL/kernels/CLGatherExKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'runtimes/libs/ARMComputeEx/src/core/CL/kernels/CLGatherExKernel.cpp')
-rw-r--r--runtimes/libs/ARMComputeEx/src/core/CL/kernels/CLGatherExKernel.cpp181
1 files changed, 181 insertions, 0 deletions
diff --git a/runtimes/libs/ARMComputeEx/src/core/CL/kernels/CLGatherExKernel.cpp b/runtimes/libs/ARMComputeEx/src/core/CL/kernels/CLGatherExKernel.cpp
new file mode 100644
index 000000000..c83ece0e9
--- /dev/null
+++ b/runtimes/libs/ARMComputeEx/src/core/CL/kernels/CLGatherExKernel.cpp
@@ -0,0 +1,181 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2016-2018 ARM Limited.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "arm_compute/core/CL/kernels/CLGatherExKernel.h"
+
+#include "arm_compute/core/CL/CLHelpers.h"
+#include "arm_compute/core/CL/CLKernelLibraryEx.h"
+#include "arm_compute/core/CL/ICLTensor.h"
+#include "arm_compute/core/UtilsEx.h"
+
+using namespace arm_compute;
+
+namespace
+{
+
+inline TensorShape compute_gather_shape(const TensorShape &input_shape,
+ const TensorShape &indices_shape, uint32_t actual_axis)
+{
+ ARM_COMPUTE_ERROR_ON(indices_shape.num_dimensions() > 3);
+ ARM_COMPUTE_ERROR_ON(input_shape.num_dimensions() > 4);
+ ARM_COMPUTE_ERROR_ON(input_shape.num_dimensions() + indices_shape.num_dimensions() - 1 > 4);
+ ARM_COMPUTE_ERROR_ON(actual_axis >= input_shape.num_dimensions());
+
+ TensorShape output_shape = input_shape;
+ if (indices_shape.num_dimensions() == 1)
+ {
+ output_shape[actual_axis] = indices_shape[0];
+ }
+ else if (indices_shape.num_dimensions() > 1)
+ {
+ output_shape.shift_right(indices_shape.num_dimensions() - 1);
+
+ for (uint32_t i = 0, o = 0; o < output_shape.num_dimensions(); ++o, ++i)
+ {
+ if (o == actual_axis)
+ {
+ ++i;
+ for (uint32_t in = 0; in < indices_shape.num_dimensions(); ++in, ++o)
+ {
+ output_shape[o] = indices_shape[in];
+ }
+ }
+ else
+ {
+ output_shape[o] = input_shape[i];
+ }
+ }
+ }
+ return output_shape;
+}
+
+/** Wrap-around a number within the range 0 <= x < m
+ *
+ * @param[in] x Input value
+ * @param[in] m Range
+ *
+ * @return the wrapped-around number
+ */
+template <typename T> inline T wrap_around(T x, T m) { return x >= 0 ? x % m : (x % m + m) % m; }
+
+inline Status validate_arguments(const ITensorInfo *input, const ITensorInfo *indices,
+ const ITensorInfo *output, int axis)
+{
+ const uint32_t actual_axis = wrap_around(axis, static_cast<int>(input->num_dimensions()));
+ ARM_COMPUTE_RETURN_ERROR_ON(indices->num_dimensions() > 3);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4);
+ ARM_COMPUTE_ERROR_ON(input->num_dimensions() + indices->num_dimensions() - 1 > 4);
+ ARM_COMPUTE_RETURN_ERROR_ON(actual_axis >= input->num_dimensions());
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(
+ input, 1, DataType::U8, DataType::S8, DataType::QASYMM8, DataType::U16, DataType::S16,
+ DataType::U32, DataType::S32, DataType::F16, DataType::F32);
+
+ if (output->total_size() != 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
+ TensorShape output_shape =
+ compute_gather_shape(input->tensor_shape(), indices->tensor_shape(), actual_axis);
+ ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() != output->tensor_shape().total_size());
+ }
+
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::U32, DataType::S32);
+
+ return Status{};
+}
+
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *indices,
+ ITensorInfo *output, int axis)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output, indices);
+ const uint32_t actual_axis = wrap_around(axis, static_cast<int>(input->num_dimensions()));
+ std::unique_ptr<ITensorInfo> output_info = input->clone();
+ output_info->set_tensor_shape(
+ compute_gather_shape(input->tensor_shape(), indices->tensor_shape(), actual_axis));
+ // Output auto initialization if not yet initialized
+ auto_init_if_empty((*output), output_info->tensor_shape(), 1, input->data_type());
+
+ // Create window
+ Window win = calculate_max_window(*output, Steps());
+ output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape()));
+
+ return std::make_pair(Status{}, win);
+}
+
+} // namespace
+
+CLGatherExKernel::CLGatherExKernel()
+ : _input(nullptr), _indices(nullptr), _output(nullptr), _axis(0)
+{
+}
+
+void CLGatherExKernel::configure(const ICLTensor *input, const ICLTensor *indices,
+ ICLTensor *output, int axis)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output, indices);
+ ARM_COMPUTE_ERROR_THROW_ON(
+ validate_arguments(input->info(), indices->info(), output->info(), axis));
+
+ // Configure kernel window
+ auto win_config =
+ validate_and_configure_window(input->info(), indices->info(), output->info(), axis);
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+
+ _input = input;
+ _output = output;
+ _indices = indices;
+ _axis = wrap_around(axis, static_cast<int>(input->info()->num_dimensions()));
+
+ // Set build options
+ CLBuildOptions build_opts;
+ build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
+ build_opts.add_option("-DOUTPUT_DIM_Z=" +
+ support::cpp11::to_string(output->info()->dimension(2)));
+ build_opts.add_option("-DINPUT_DIM_Z=" + support::cpp11::to_string(input->info()->dimension(2)));
+ build_opts.add_option("-DAXIS=" + support::cpp11::to_string(_axis));
+ build_opts.add_option("-DINDICES_DIM=" +
+ support::cpp11::to_string(indices->info()->num_dimensions()));
+
+ // Create kernel
+ _kernel = static_cast<cl::Kernel>(
+ CLKernelLibraryEx::get().create_kernel("gather_ex", build_opts.options()));
+ ICLKernel::configure_internal(win_config.second);
+}
+
+Status CLGatherExKernel::validate(const ITensorInfo *input, const ITensorInfo *indices,
+ const ITensorInfo *output, int axis)
+{
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, indices, output, axis));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
+ indices->clone().get(),
+ output->clone().get(), axis)
+ .first);
+ return Status{};
+}
+
+void CLGatherExKernel::run(const Window &window, cl::CommandQueue &queue)
+{
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
+
+ Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ, 4);
+ unsigned int idx = 0;
+ add_4D_tensor_argument(idx, _input, window_collapsed);
+ add_3D_tensor_argument(idx, _indices, window_collapsed);
+ add_4D_tensor_argument(idx, _output, window_collapsed);
+ enqueue(queue, *this, window_collapsed, lws_hint());
+}