diff options
Diffstat (limited to 'boost/compute/algorithm/detail/radix_sort.hpp')
-rw-r--r-- | boost/compute/algorithm/detail/radix_sort.hpp | 50 |
1 files changed, 48 insertions, 2 deletions
diff --git a/boost/compute/algorithm/detail/radix_sort.hpp b/boost/compute/algorithm/detail/radix_sort.hpp index c2ba4ed17c..8e6d5f9c0a 100644 --- a/boost/compute/algorithm/detail/radix_sort.hpp +++ b/boost/compute/algorithm/detail/radix_sort.hpp @@ -92,6 +92,8 @@ const char radix_sort_source[] = "#define RADIX_MASK ((((T)(1)) << K_BITS) - 1)\n" "#define SIGN_BIT ((sizeof(T) * CHAR_BIT) - 1)\n" +"#if defined(ASC)\n" // asc order + "inline uint radix(const T x, const uint low_bit)\n" "{\n" "#if defined(IS_FLOATING_POINT)\n" @@ -104,6 +106,25 @@ const char radix_sort_source[] = "#endif\n" "}\n" +"#else\n" // desc order + +// For signed types we just negate the x and for unsigned types we +// subtract the x from max value of its type ((T)(-1) is a max value +// of type T when T is an unsigned type). +"inline uint radix(const T x, const uint low_bit)\n" +"{\n" +"#if defined(IS_FLOATING_POINT)\n" +" const T mask = -(x >> SIGN_BIT) | (((T)(1)) << SIGN_BIT);\n" +" return (((-x) ^ mask) >> low_bit) & RADIX_MASK;\n" +"#elif defined(IS_SIGNED)\n" +" return (((-x) ^ (((T)(1)) << SIGN_BIT)) >> low_bit) & RADIX_MASK;\n" +"#else\n" +" return (((T)(-1) - x) >> low_bit) & RADIX_MASK;\n" +"#endif\n" +"}\n" + +"#endif\n" // #if defined(ASC) + "__kernel void count(__global const T *input,\n" " const uint input_offset,\n" " const uint input_size,\n" @@ -227,6 +248,7 @@ template<class T, class T2> inline void radix_sort_impl(const buffer_iterator<T> first, const buffer_iterator<T> last, const buffer_iterator<T2> values_first, + const bool ascending, command_queue &queue) { @@ -279,6 +301,10 @@ inline void radix_sort_impl(const buffer_iterator<T> first, options << enable_double<T2>(); } + if(ascending){ + options << " -DASC"; + } + // load radix sort program program radix_sort_program = cache->get_or_build( cache_key, options.str(), radix_sort_source, context @@ -396,18 +422,38 @@ inline void radix_sort(Iterator first, Iterator last, command_queue &queue) { - radix_sort_impl(first, last, buffer_iterator<int>(), queue); + radix_sort_impl(first, last, buffer_iterator<int>(), true, queue); +} + +template<class KeyIterator, class ValueIterator> +inline void radix_sort_by_key(KeyIterator keys_first, + KeyIterator keys_last, + ValueIterator values_first, + command_queue &queue) +{ + radix_sort_impl(keys_first, keys_last, values_first, true, queue); +} + +template<class Iterator> +inline void radix_sort(Iterator first, + Iterator last, + const bool ascending, + command_queue &queue) +{ + radix_sort_impl(first, last, buffer_iterator<int>(), ascending, queue); } template<class KeyIterator, class ValueIterator> inline void radix_sort_by_key(KeyIterator keys_first, KeyIterator keys_last, ValueIterator values_first, + const bool ascending, command_queue &queue) { - radix_sort_impl(keys_first, keys_last, values_first, queue); + radix_sort_impl(keys_first, keys_last, values_first, ascending, queue); } + } // end detail namespace } // end compute namespace } // end boost namespace |