summaryrefslogtreecommitdiff
path: root/runtimes/libs/ARMComputeEx/src/core/CL/cl_kernels/topkv2_quicksort.cl
blob: 9594daf193c0314b423d893fac7483e817744ba5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
/*
 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
 * Copyright (c) 2017 ARM Limited.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "helpers.h"

__global inline float *get_vec_elem(Vector *vec, int idx)
{
  return (__global float *)(vec->ptr + idx * vec->stride_x);
}

__global inline int *get_vec_elem_int(Vector *vec, int idx)
{
  return (__global int *)(vec->ptr + idx * vec->stride_x);
}

// A utility function to swap two elements
void swap(__global float *a, __global float *b)
{
  float t = *a;
  *a = *b;
  *b = t;
}

void swap_idx(__global int *a, __global int *b)
{
  int t = *a;
  *a = *b;
  *b = t;
}

/* This function is same in both iterative and recursive*/
int partition(Vector *arr, __global int *indices, int l, int h)
{
  float x = *get_vec_elem(arr, h);
  int i = (l - 1);

  for (int j = l; j <= h - 1; j++)
  {
    if (*get_vec_elem(arr, j) >= x)
    {
      i++;
      swap(get_vec_elem(arr, i), get_vec_elem(arr, j));
      swap_idx(&indices[i], &indices[j]);
    }
  }
  swap(get_vec_elem(arr, i + 1), get_vec_elem(arr, h));
  swap_idx(&indices[i + 1], &indices[h]);
  return (i + 1);
}

/* A[] --> Array to be sorted,
   l  --> Starting index,
   h  --> Ending index */
void quickSortIterative(Vector *arr, __global int *indices, __global int *stack, int l, int h)
{
  // Create an auxiliary stack

  // initialize top of stack
  int top = -1;

  // push initial values of l and h to stack
  stack[++top] = l;
  stack[++top] = h;

  // Keep popping from stack while is not empty
  while (top >= 0)
  {
    // Pop h and l
    h = stack[top--];
    l = stack[top--];

    // Set pivot element at its correct position
    // in sorted array
    int p = partition(arr, indices, l, h);

    // If there are elements on left side of pivot,
    // then push left side to stack
    if (p - 1 > l)
    {
      stack[++top] = l;
      stack[++top] = p - 1;
    }

    // If there are elements on right side of pivot,
    // then push right side to stack
    if (p + 1 < h)
    {
      stack[++top] = p + 1;
      stack[++top] = h;
    }
  }
}

__kernel void topkv2_quicksort(VECTOR_DECLARATION(input), VECTOR_DECLARATION(topk_values),
                               VECTOR_DECLARATION(topk_indices), __global int *indices,
                               __global int *temp_stack, int k, int n)
{
  Vector input = CONVERT_TO_VECTOR_STRUCT_NO_STEP(input);
  Vector topk_values = CONVERT_TO_VECTOR_STRUCT_NO_STEP(topk_values);
  Vector topk_indices = CONVERT_TO_VECTOR_STRUCT_NO_STEP(topk_indices);

  for (int i = 0; i < n; ++i)
  {
    indices[i] = i;
  }

  quickSortIterative(&input, indices, temp_stack, 0, n - 1);

  // extract k items.
  for (int i = 0; i < k; ++i)
  {
    *get_vec_elem(&topk_values, i) = *get_vec_elem(&input, i);
    *get_vec_elem_int(&topk_indices, i) = indices[i];
  }
}