summaryrefslogtreecommitdiff
path: root/boost/compute/algorithm/detail/radix_sort.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'boost/compute/algorithm/detail/radix_sort.hpp')
-rw-r--r--boost/compute/algorithm/detail/radix_sort.hpp50
1 files changed, 48 insertions, 2 deletions
diff --git a/boost/compute/algorithm/detail/radix_sort.hpp b/boost/compute/algorithm/detail/radix_sort.hpp
index c2ba4ed17c..8e6d5f9c0a 100644
--- a/boost/compute/algorithm/detail/radix_sort.hpp
+++ b/boost/compute/algorithm/detail/radix_sort.hpp
@@ -92,6 +92,8 @@ const char radix_sort_source[] =
"#define RADIX_MASK ((((T)(1)) << K_BITS) - 1)\n"
"#define SIGN_BIT ((sizeof(T) * CHAR_BIT) - 1)\n"
+"#if defined(ASC)\n" // asc order
+
"inline uint radix(const T x, const uint low_bit)\n"
"{\n"
"#if defined(IS_FLOATING_POINT)\n"
@@ -104,6 +106,25 @@ const char radix_sort_source[] =
"#endif\n"
"}\n"
+"#else\n" // desc order
+
+// For signed types we just negate the x and for unsigned types we
+// subtract the x from max value of its type ((T)(-1) is a max value
+// of type T when T is an unsigned type).
+"inline uint radix(const T x, const uint low_bit)\n"
+"{\n"
+"#if defined(IS_FLOATING_POINT)\n"
+" const T mask = -(x >> SIGN_BIT) | (((T)(1)) << SIGN_BIT);\n"
+" return (((-x) ^ mask) >> low_bit) & RADIX_MASK;\n"
+"#elif defined(IS_SIGNED)\n"
+" return (((-x) ^ (((T)(1)) << SIGN_BIT)) >> low_bit) & RADIX_MASK;\n"
+"#else\n"
+" return (((T)(-1) - x) >> low_bit) & RADIX_MASK;\n"
+"#endif\n"
+"}\n"
+
+"#endif\n" // #if defined(ASC)
+
"__kernel void count(__global const T *input,\n"
" const uint input_offset,\n"
" const uint input_size,\n"
@@ -227,6 +248,7 @@ template<class T, class T2>
inline void radix_sort_impl(const buffer_iterator<T> first,
const buffer_iterator<T> last,
const buffer_iterator<T2> values_first,
+ const bool ascending,
command_queue &queue)
{
@@ -279,6 +301,10 @@ inline void radix_sort_impl(const buffer_iterator<T> first,
options << enable_double<T2>();
}
+ if(ascending){
+ options << " -DASC";
+ }
+
// load radix sort program
program radix_sort_program = cache->get_or_build(
cache_key, options.str(), radix_sort_source, context
@@ -396,18 +422,38 @@ inline void radix_sort(Iterator first,
Iterator last,
command_queue &queue)
{
- radix_sort_impl(first, last, buffer_iterator<int>(), queue);
+ radix_sort_impl(first, last, buffer_iterator<int>(), true, queue);
+}
+
+template<class KeyIterator, class ValueIterator>
+inline void radix_sort_by_key(KeyIterator keys_first,
+ KeyIterator keys_last,
+ ValueIterator values_first,
+ command_queue &queue)
+{
+ radix_sort_impl(keys_first, keys_last, values_first, true, queue);
+}
+
+template<class Iterator>
+inline void radix_sort(Iterator first,
+ Iterator last,
+ const bool ascending,
+ command_queue &queue)
+{
+ radix_sort_impl(first, last, buffer_iterator<int>(), ascending, queue);
}
template<class KeyIterator, class ValueIterator>
inline void radix_sort_by_key(KeyIterator keys_first,
KeyIterator keys_last,
ValueIterator values_first,
+ const bool ascending,
command_queue &queue)
{
- radix_sort_impl(keys_first, keys_last, values_first, queue);
+ radix_sort_impl(keys_first, keys_last, values_first, ascending, queue);
}
+
} // end detail namespace
} // end compute namespace
} // end boost namespace