summaryrefslogtreecommitdiff
path: root/libs/ARMComputeEx/arm_compute/runtime/CL/functions/CLTopKV2.h
blob: 5327e016fbbf176ca682b6de804f7e69294c2adc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
/*
 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
 * Copyright (c) 2016-2018 ARM Limited.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/**
 * @file CLTopKV2.h
 * @ingroup COM_AI_RUNTIME
 * @brief This file contains arm_compute::CLTopKV2 class
 */
#ifndef __ARM_COMPUTE_CLTOPK_V2_H__
#define __ARM_COMPUTE_CLTOPK_V2_H__

#include "arm_compute/core/CL/kernels/CLTopKV2Kernel.h"

#include "arm_compute/runtime/IFunction.h"

namespace arm_compute
{
class ICLTensor;

/**
 * @brief Class to execute TopKV2 operation.
 */
class CLTopKV2 : public IFunction
{
public:
  /**
   * @brief Construct a new CLTopKV2 object
   */
  CLTopKV2();

  /**
   * @brief Prevent instances of this class from being copied (As this class contains pointers)
   */
  CLTopKV2(const CLTopKV2 &) = delete;

  /**
   * @brief Prevent instances of this class from being copied (As this class contains pointers)
   */
  CLTopKV2 &operator=(const CLTopKV2 &) = delete;

  /**
   * @brief Construct a new CLTopKV2 object by using copy constructor
   * @param[in] CLTopKV2 object to move
   */
  CLTopKV2(CLTopKV2 &&) = default;

  /**
   * @brief Assign a CLTopKV2 object.
   * @param[in] CLTopKV2 object to assign. This object will be moved.
   */
  CLTopKV2 &operator=(CLTopKV2 &&) = default;

  /**
   * @brief Initialise the kernel's inputs and outputs.
   * @param[in]  input     Input image. Data types supported: U8/S16/F32.
   * @param[in]  k         The value of `k`.
   * @param[out] values    Top k values. Data types supported: S32 if input type is U8/S16, F32 if
   * input type is F32.
   * @param[out] indices   Indices related to top k values. Data types supported: S32 if input type
   * is U8/S16, F32 if input type is F32.
   * @return N/A
   */
  void configure(ICLTensor *input, int k, ICLTensor *values, ICLTensor *indices,
                 int total_bits = 32, int bits = 4);

  /**
   * @brief Run the kernels contained in the function
   * Depending on the value of the following environment variables it works differently:
   *   - If the value of environment variable "ACL_TOPKV2" == "GPU_SINGLE",
   *     quick sort on GPU is used.
   *   - If the value of environment variable "ACL_TOPKV2" == ""GPU"",
   *     radix sort on GPU is used.
   *   - For other value, TopKV2 runs on CPU
   * @return N/A
   */
  void run() override;

private:
  void run_on_cpu();
  void run_on_gpu();
  void run_on_gpu_single_quicksort();

  uint32_t _k;
  uint32_t _total_bits;
  uint32_t _bits;
  uint32_t _radix;
  uint32_t _hist_buf_size;
  uint32_t _glob_sum_buf_size;
  uint32_t _n;

  ICLTensor *_input;
  ICLTensor *_values;
  ICLTensor *_indices;

  cl::Buffer _qs_idx_buf;
  cl::Buffer _qs_temp_buf;
  cl::Buffer _hist_buf;
  cl::Buffer _glob_sum_buf;
  cl::Buffer _temp_buf;
  cl::Buffer _first_negative_idx_buf;
  cl::Buffer _in_key_buf;
  cl::Buffer _out_key_buf;
  cl::Buffer _in_ind_buf;
  cl::Buffer _out_ind_buf;

  cl::Buffer *_p_in_key_buf;
  cl::Buffer *_p_out_key_buf;
  cl::Buffer *_p_in_ind_buf;
  cl::Buffer *_p_out_ind_buf;

  CLTopKV2Single _qs_kernel;
  CLTopKV2Init _init_kernel;
  CLRadixSortHistogram _hist_kernel;
  CLRadixSortScanHistogram _scan_hist_kernel;
  CLRadixSortGlobalScanHistogram _glob_scan_hist_kernel;
  CLRadixSortPasteHistogram _paste_hist_kernel;
  CLRadixSortReorder _reorder_kernel;
  CLTopKV2FindFirstNegative _find_first_negative_kernel;
  CLTopKV2ReorderNegatives _reorder_negatives_kernel;
  CLTopKV2Store _store_kernel;
};
}
#endif // __ARM_COMPUTE_CLTOPK_V2_H__