summaryrefslogtreecommitdiff
path: root/runtimes/libs/ARMComputeEx/src/core/CL/cl_kernels/topkv2.cl
blob: 50472e4f9f3ec0e8e7df17f52073d6a43a2dfe69 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
/*
 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
 * Copyright (c) 2017 ARM Limited.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "helpers.h"

__kernel void topkv2_init(VECTOR_DECLARATION(input), __global float *in_key_buf,
                          __global int *in_ind_buf, const int n)
{
  int gid = get_global_id(0);
  int lws = get_local_size(0);
  int groups = get_num_groups(0);
  int gws = lws * groups;
  int iter = n / gws;

  Vector input = CONVERT_TO_VECTOR_STRUCT_NO_STEP(input);

  for (int i = 0; i < iter; ++i)
  {
    int idx = i * gws + gid;
    in_key_buf[idx] = *(__global float *)(input.ptr + idx * input.stride_x);
    in_ind_buf[idx] = idx;
  }
}

__kernel void topkv2_find_first_negative(__global float *out_key_buf,
                                         __global int *first_negative_idx, int n)
{
  int gid = get_global_id(0);

  if (gid == n - 1)
  {
    // if the last item is positive, the first negative index is n.
    if (out_key_buf[gid] > 0.f)
      *first_negative_idx = n;
  }
  else if (gid == 0)
  {
    // if the first item is negative, set it 0.
    if (out_key_buf[gid] < 0.f)
      *first_negative_idx = 0;
  }
  else
  {
    // if its left is positive and it is negative, then it is the first negative item.
    if (out_key_buf[gid - 1] > 0.f && out_key_buf[gid] < 0.f)
      *first_negative_idx = gid;
  }
}

__kernel void topkv2_reorder_negatives(__global float *in_key_buf, __global float *out_key_buf,
                                       __global float *in_ind_buf, __global float *out_ind_buf,
                                       __global int *first_negative_idx, int n)
{
  int gid = get_global_id(0);

  int num_negs = n - *first_negative_idx;
  int in_idx;

  if (gid < num_negs)
  {
    in_idx = n - 1 - gid;
  }
  else
  {
    in_idx = gid - num_negs;
  }

  out_key_buf[gid] = in_key_buf[in_idx];
  out_ind_buf[gid] = in_ind_buf[in_idx];
}

__kernel void topkv2_store(VECTOR_DECLARATION(values), VECTOR_DECLARATION(indices),
                           __global float *out_key_buf, __global int *out_ind_buf, int n)
{
  int gid = get_global_id(0);

  Vector values = CONVERT_TO_VECTOR_STRUCT_NO_STEP(values);
  Vector indices = CONVERT_TO_VECTOR_STRUCT_NO_STEP(indices);

  int idx = n - 1 - gid;

  *(__global float *)(values.ptr + gid * values.stride_x) = out_key_buf[idx];
  *(__global int *)(indices.ptr + gid * indices.stride_x) = out_ind_buf[idx];
}