diff options
author | Xiaomeng Yang <yangxm@fb.com> | 2019-02-20 14:38:35 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-20 14:49:01 -0800 |
commit | 2e67b34ea78f19868bd661c5e8715b4e2516a5e9 (patch) | |
tree | 4c5921c76321805b083f839bf3eff2a9eccee118 /caffe2/utils | |
parent | 474adf5458e1ad917548e2e8ef28014dfc029ed1 (diff) | |
download | pytorch-2e67b34ea78f19868bd661c5e8715b4e2516a5e9.tar.gz pytorch-2e67b34ea78f19868bd661c5e8715b4e2516a5e9.tar.bz2 pytorch-2e67b34ea78f19868bd661c5e8715b4e2516a5e9.zip |
Separate gpu reduce functions (#17146)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17146
Separate gpu reduce functions
i-am-not-moving-c2-to-c10
Reviewed By: houseroad
Differential Revision: D14097564
fbshipit-source-id: a27de340997111a794b1d083c1673d4263afb9fb
Diffstat (limited to 'caffe2/utils')
-rw-r--r-- | caffe2/utils/math/elementwise.cu | 493 | ||||
-rw-r--r-- | caffe2/utils/math/reduce.cu | 330 | ||||
-rw-r--r-- | caffe2/utils/math/reduce.cuh | 20 | ||||
-rw-r--r-- | caffe2/utils/math_gpu.cu | 633 |
4 files changed, 693 insertions, 783 deletions
diff --git a/caffe2/utils/math/elementwise.cu b/caffe2/utils/math/elementwise.cu index b7605deece..006fbd0b27 100644 --- a/caffe2/utils/math/elementwise.cu +++ b/caffe2/utils/math/elementwise.cu @@ -1,6 +1,11 @@ #include "caffe2/utils/math/elementwise.h" +#include <type_traits> + +#include <thrust/execution_policy.h> +#include <thrust/fill.h> #include <thrust/functional.h> +#include <thrust/transform.h> #include "caffe2/core/context_gpu.h" #include "caffe2/utils/conversions.h" @@ -12,153 +17,126 @@ namespace math { namespace { -#define DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(T, Func, DeviceFunc) \ - __global__ void Func##CUDAKernel(const int N, const T* X, T* Y) { \ - const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; \ - if (i < N) { \ - Y[i] = DeviceFunc(X[i]); \ - } \ +template <typename T> +__global__ void SinCosCUDAKernel(const int N, const T* X, T* S, T* C) { + const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (i < N) { +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + c10::cuda::compat::sincos(__ldg(X + i), S + i, C + i); +#else + c10::cuda::compat::sincos(X[i], S + i, C + i); +#endif + } +} + +} // namespace + +#define CAFFE2_SPECIALIZED_CUDA_SET(T) \ + template <> \ + CAFFE2_CUDA_EXPORT void Set<T, CUDAContext>( \ + const int N, const T alpha, T* Y, CUDAContext* context) { \ + if (N == 0) { \ + return; \ + } \ + if (alpha == T(0)) { \ + cudaMemsetAsync(Y, 0, sizeof(T) * N, context->cuda_stream()); \ + } else { \ + thrust::fill( \ + thrust::cuda::par.on(context->cuda_stream()), Y, Y + N, alpha); \ + } \ } -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Exp, expf) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Log, logf) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Cos, cosf) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Acos, acosf) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Sin, sinf) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Asin, asinf) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Tan, tanf) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Atan, atanf) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Sinh, sinhf) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Cosh, coshf) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Tanh, tanhf) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Abs, fabsf) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Sqr, utils::Square<float>) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Sqrt, sqrtf) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Rsqrt, rsqrtf) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Cbrt, cbrtf) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Erf, erff) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(double, Erf, erf) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION( +CAFFE2_SPECIALIZED_CUDA_SET(bool) +CAFFE2_SPECIALIZED_CUDA_SET(char) +CAFFE2_SPECIALIZED_CUDA_SET(std::int8_t) +CAFFE2_SPECIALIZED_CUDA_SET(std::int16_t) +CAFFE2_SPECIALIZED_CUDA_SET(std::int32_t) +CAFFE2_SPECIALIZED_CUDA_SET(std::int64_t) +CAFFE2_SPECIALIZED_CUDA_SET(std::uint8_t) +CAFFE2_SPECIALIZED_CUDA_SET(std::uint16_t) +CAFFE2_SPECIALIZED_CUDA_SET(float) +CAFFE2_SPECIALIZED_CUDA_SET(double) +CAFFE2_SPECIALIZED_CUDA_SET(at::Half) +#undef CAFFE2_SPECIALIZED_CUDA_SET + +#define DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(T, Func, DeviceFunc) \ + template <> \ + CAFFE2_CUDA_EXPORT void Func<T, CUDAContext>( \ + const int N, const T* X, T* Y, CUDAContext* context) { \ + if (N > 0) { \ + thrust::transform( \ + thrust::cuda::par.on(context->cuda_stream()), \ + X, \ + X + N, \ + Y, \ + [] __device__(const T x) { return DeviceFunc(x); }); \ + } \ + } +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Exp, expf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Log, logf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sin, sinf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Asin, asinf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cos, cosf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Acos, acosf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tan, tanf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Atan, atanf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sinh, sinhf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cosh, coshf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tanh, tanhf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Abs, fabsf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Inv, utils::Inv<float>) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Inv, utils::Inv<double>) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqr, utils::Square<float>) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqrt, sqrtf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Rsqrt, rsqrtf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( std::int32_t, Cube, utils::Cube<std::int32_t>) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION( +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( std::int64_t, Cube, utils::Cube<std::int64_t>) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Cube, utils::Cube<float>) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(double, Cube, utils::Cube<double>) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(bool, Not, utils::Not<bool>) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION( +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cube, utils::Cube<float>) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Cube, utils::Cube<double>) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cbrt, cbrtf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Erf, erff) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Erf, erf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(bool, Not, utils::Not<bool>) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( std::int32_t, Neg, utils::Negate<std::int32_t>) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION( +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( std::int64_t, Neg, utils::Negate<std::int64_t>) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Neg, utils::Negate<float>) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(double, Neg, utils::Negate<double>) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION( +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Neg, utils::Negate<float>) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Neg, utils::Negate<double>) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( std::int32_t, Sign, utils::Sign<std::int32_t>) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION( +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( std::int64_t, Sign, utils::Sign<std::int64_t>) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Sign, utils::Sign<float>) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(double, Sign, utils::Sign<double>) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Inv, utils::Inv<float>) -DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(double, Inv, utils::Inv<double>) -#undef DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sign, utils::Sign<float>) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Sign, utils::Sign<double>) +#undef DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION -template <typename T> -__global__ void SinCosCUDAKernel(const int N, const T* X, T* S, T* C) { - const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - if (i < N) { -#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) - c10::cuda::compat::sincos(__ldg(X + i), S + i, C + i); -#else - c10::cuda::compat::sincos(X[i], S + i, C + i); -#endif +#define DELEGATE_CUDA_POWX(T, DeviceFunc) \ + template <> \ + CAFFE2_CUDA_EXPORT void Powx<T, CUDAContext>( \ + const int N, const T* A, const T b, T* Y, CUDAContext* context) { \ + thrust::transform( \ + thrust::cuda::par.on(context->cuda_stream()), \ + A, \ + A + N, \ + Y, \ + [b] __device__(const T x) { return DeviceFunc(x, b); }); \ } -} - -template <typename T, class Func> -__global__ void SimpleBinaryCUDAKernel( - const int N, - const Func func, - const T* A, - const T* B, - T* C) { - const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - if (i < N) { - C[i] = func(A[i], B[i]); - } -} - -template <typename T, class Comp> -__global__ void SimpleCompareCUDAKernel( - const int N, - const Comp comp, - const T* A, - const T* B, - bool* C) { - const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - if (i < N) { - C[i] = comp(A[i], B[i]); - } -} - -} // namespace - -#define DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(T, Func) \ - template <> \ - CAFFE2_CUDA_EXPORT void Func<T, CUDAContext>( \ - const int N, const T* X, T* Y, CUDAContext* context) { \ - if (N > 0) { \ - const int M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - Func##CUDAKernel<<< \ - M, \ - CAFFE_CUDA_NUM_THREADS, \ - 0, \ - context->cuda_stream()>>>(N, X, Y); \ - } \ - } -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Exp) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Log) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cos) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Acos) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sin) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Asin) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tan) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Atan) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sinh) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cosh) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tanh) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Abs) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqr) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqrt) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Rsqrt) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cbrt) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Erf) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(double, Erf) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cube) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(double, Cube) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(std::int32_t, Cube) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(std::int64_t, Cube) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(bool, Not) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Neg) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(double, Neg) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(std::int32_t, Neg) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(std::int64_t, Neg) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sign) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(double, Sign) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(std::int32_t, Sign) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(std::int64_t, Sign) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Inv) -DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(double, Inv) -#undef DEFINE_SIMPLE_CUDA_UNARY_FUNCTION +DELEGATE_CUDA_POWX(float, powf) +#undef DELEGATE_CUDA_POWX #define CAFFE2_SPECIALIZED_CUDA_SINCOS(T) \ template <> \ @@ -175,17 +153,255 @@ CAFFE2_SPECIALIZED_CUDA_SINCOS(float) CAFFE2_SPECIALIZED_CUDA_SINCOS(double) #undef CAFFE2_SPECIALIZED_CUDA_SINCOS +#define DELEGATE_CUDA_SCALE_BY_CUBLAS_FUNCTION(TAlpha, TData, CuBLASFunc) \ + template <> \ + CAFFE2_CUDA_EXPORT void Scale<TAlpha, TData, CUDAContext>( \ + const int N, \ + const TAlpha alpha, \ + const TData* X, \ + TData* Y, \ + CUDAContext* context) { \ + if (N == 0) { \ + return; \ + } \ + const TData alpha_host = static_cast<TData>(alpha); \ + if (Y == X) { \ + CUBLAS_ENFORCE(cublasSetPointerMode( \ + context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ + CUBLAS_ENFORCE( \ + CuBLASFunc(context->cublas_handle(), N, &alpha_host, Y, 1)); \ + } else { \ + thrust::transform( \ + thrust::cuda::par.on(context->cuda_stream()), \ + X, \ + X + N, \ + Y, \ + [alpha_host] __device__(const TData x) { return x * alpha_host; }); \ + } \ + } \ + template <> \ + CAFFE2_CUDA_EXPORT void Scale<TAlpha, TData, CUDAContext>( \ + const int N, \ + const TAlpha* alpha, \ + const TData* X, \ + TData* Y, \ + CUDAContext* context) { \ + if (N == 0) { \ + return; \ + } \ + if (std::is_same<TAlpha, TData>::value && Y == X) { \ + CUBLAS_ENFORCE(cublasSetPointerMode( \ + context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); \ + CUBLAS_ENFORCE(CuBLASFunc( \ + context->cublas_handle(), \ + N, \ + reinterpret_cast<const TData*>(alpha), \ + Y, \ + 1)); \ + } else { \ + thrust::transform( \ + thrust::cuda::par.on(context->cuda_stream()), \ + X, \ + X + N, \ + Y, \ + [alpha] __device__(const TData x) { \ + return x * static_cast<TData>(*alpha); \ + }); \ + } \ + } +DELEGATE_CUDA_SCALE_BY_CUBLAS_FUNCTION(float, float, cublasSscal) +DELEGATE_CUDA_SCALE_BY_CUBLAS_FUNCTION(double, double, cublasDscal) +DELEGATE_CUDA_SCALE_BY_CUBLAS_FUNCTION(float, double, cublasDscal) +#undef DELEGATE_CUDA_SCALE_BY_CUBLAS_FUNCTION + +#define CAFFE2_SPECIALIZED_CUDA_SCALE(TAlpha, TData) \ + template <> \ + CAFFE2_CUDA_EXPORT void Scale<TAlpha, TData, CUDAContext>( \ + const int N, \ + const TAlpha alpha, \ + const TData* X, \ + TData* Y, \ + CUDAContext* context) { \ + if (N > 0) { \ + thrust::transform( \ + thrust::cuda::par.on(context->cuda_stream()), \ + X, \ + X + N, \ + Y, \ + [alpha] __device__(const TData x) { \ + return x * static_cast<TData>(alpha); \ + }); \ + } \ + } \ + template <> \ + CAFFE2_CUDA_EXPORT void Scale<TAlpha, TData, CUDAContext>( \ + const int N, \ + const TAlpha* alpha, \ + const TData* X, \ + TData* Y, \ + CUDAContext* context) { \ + if (N > 0) { \ + thrust::transform( \ + thrust::cuda::par.on(context->cuda_stream()), \ + X, \ + X + N, \ + Y, \ + [alpha] __device__(const TData x) { \ + return x * static_cast<TData>(*alpha); \ + }); \ + } \ + } +CAFFE2_SPECIALIZED_CUDA_SCALE(std::int32_t, std::int32_t) +CAFFE2_SPECIALIZED_CUDA_SCALE(std::int64_t, std::int64_t) +#undef CAFFE2_SPECIALIZED_CUDA_SCALE + +#ifdef __HIP_PLATFORM_HCC__ + +#define CAFFE2_SPECIALIZED_CUDA_HALF_SCALE(TAlpha) \ + template <> \ + CAFFE2_CUDA_EXPORT void Scale<TAlpha, at::Half, CUDAContext>( \ + const int N, \ + const TAlpha alpha, \ + const at::Half* X, \ + at::Half* Y, \ + CUDAContext* context) { \ + if (N > 0) { \ + const float alpha_host = convert::To<TAlpha, float>(alpha); \ + thrust::transform( \ + thrust::cuda::par.on(context->cuda_stream()), \ + X, \ + X + N, \ + Y, \ + [alpha_host] __device__(const at::Half x) { \ + return convert::To<float, at::Half>( \ + convert::To<at::Half, float>(x) * alpha_host); \ + }); \ + } \ + } \ + template <> \ + CAFFE2_CUDA_EXPORT void Scale<TAlpha, at::Half, CUDAContext>( \ + const int N, \ + const TAlpha* alpha, \ + const at::Half* X, \ + at::Half* Y, \ + CUDAContext* context) { \ + if (N > 0) { \ + thrust::transform( \ + thrust::cuda::par.on(context->cuda_stream()), \ + X, \ + X + N, \ + Y, \ + [alpha] __device__(const at::Half x) { \ + return convert::To<float, at::Half>( \ + convert::To<at::Half, float>(x) * \ + convert::To<TAlpha, float>(*alpha)); \ + }); \ + } \ + } +CAFFE2_SPECIALIZED_CUDA_HALF_SCALE(at::Half) +CAFFE2_SPECIALIZED_CUDA_HALF_SCALE(float) +#undef CAFFE2_SPECIALIZED_CUDA_HALF_SCALE + +#else // __HIP_PLATFORM_HCC__ + +#define DELEGATE_CUDA_HALF_SCALE_BY_CUBLAS_FUNCTION( \ + TAlpha, CuBLASFunc, kAlphaType, kExecutionType) \ + template <> \ + CAFFE2_CUDA_EXPORT void Scale<TAlpha, at::Half, CUDAContext>( \ + const int N, \ + const TAlpha alpha, \ + const at::Half* X, \ + at::Half* Y, \ + CUDAContext* context) { \ + if (N == 0) { \ + return; \ + } \ + if (Y == X) { \ + CUBLAS_ENFORCE(cublasSetPointerMode( \ + context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ + CUBLAS_ENFORCE(cublasScalEx( \ + context->cublas_handle(), \ + N, \ + &alpha, \ + kAlphaType, \ + Y, \ + CUDA_R_16F, \ + 1, \ + kExecutionType)); \ + } else { \ + const float alpha_host = convert::To<TAlpha, float>(alpha); \ + thrust::transform( \ + thrust::cuda::par.on(context->cuda_stream()), \ + X, \ + X + N, \ + Y, \ + [alpha_host] __device__(const at::Half x) { \ + return convert::To<float, at::Half>( \ + convert::To<at::Half, float>(x) * alpha_host); \ + }); \ + } \ + } \ + template <> \ + CAFFE2_CUDA_EXPORT void Scale<TAlpha, at::Half, CUDAContext>( \ + const int N, \ + const TAlpha* alpha, \ + const at::Half* X, \ + at::Half* Y, \ + CUDAContext* context) { \ + if (N == 0) { \ + return; \ + } \ + if (Y == X) { \ + CUBLAS_ENFORCE(cublasSetPointerMode( \ + context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ + CUBLAS_ENFORCE(cublasScalEx( \ + context->cublas_handle(), \ + N, \ + alpha, \ + kAlphaType, \ + Y, \ + CUDA_R_16F, \ + 1, \ + kExecutionType)); \ + } else { \ + thrust::transform( \ + thrust::cuda::par.on(context->cuda_stream()), \ + X, \ + X + N, \ + Y, \ + [alpha] __device__(const at::Half x) { \ + return convert::To<float, at::Half>( \ + convert::To<at::Half, float>(x) * \ + convert::To<TAlpha, float>(*alpha)); \ + }); \ + } \ + } +DELEGATE_CUDA_HALF_SCALE_BY_CUBLAS_FUNCTION( + at::Half, + cublasScalEx, + CUDA_R_16F, + CUDA_R_32F) +DELEGATE_CUDA_HALF_SCALE_BY_CUBLAS_FUNCTION( + float, + cublasScalEx, + CUDA_R_32F, + CUDA_R_32F) +#undef DELEGATE_CUDA_HALF_SCALE_BY_CUBLAS_FUNCTION + +#endif // __HIP_PLATFORM_HCC__ + #define DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(T, Func, DeviceFunc) \ template <> \ CAFFE2_CUDA_EXPORT void Func<T, CUDAContext>( \ const int N, const T* A, const T* B, T* C, CUDAContext* context) { \ if (N > 0) { \ - const int M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - SimpleBinaryCUDAKernel<<< \ - M, \ - CAFFE_CUDA_NUM_THREADS, \ - 0, \ - context->cuda_stream()>>>(N, DeviceFunc, A, B, C); \ + thrust::transform( \ + thrust::cuda::par.on(context->cuda_stream()), \ + A, \ + A + N, \ + B, \ + C, \ + DeviceFunc); \ } \ } DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( @@ -273,12 +489,13 @@ DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( CAFFE2_CUDA_EXPORT void Func<T, CUDAContext>( \ const int N, const T* A, const T* B, bool* C, CUDAContext* context) { \ if (N > 0) { \ - const int M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - SimpleCompareCUDAKernel<<< \ - M, \ - CAFFE_CUDA_NUM_THREADS, \ - 0, \ - context->cuda_stream()>>>(N, DeviceComp, A, B, C); \ + thrust::transform( \ + thrust::cuda::par.on(context->cuda_stream()), \ + A, \ + A + N, \ + B, \ + C, \ + DeviceComp); \ } \ } DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, EQ, thrust::equal_to<bool>()) diff --git a/caffe2/utils/math/reduce.cu b/caffe2/utils/math/reduce.cu index 18291d35f9..7e9f6530ce 100644 --- a/caffe2/utils/math/reduce.cu +++ b/caffe2/utils/math/reduce.cu @@ -1,15 +1,20 @@ -#include "caffe2/utils/math.h" +#include "caffe2/utils/math/reduce.h" #include <algorithm> #include <functional> +#include <limits> #include <numeric> #include <vector> #include <cub/block/block_reduce.cuh> #include <cub/cub.cuh> +#include <thrust/execution_policy.h> +#include <thrust/reduce.h> +#include <thrust/transform.h> + #include "caffe2/core/context_gpu.h" -#include "caffe2/utils/fixed_divisor.h" +#include "caffe2/utils/math/elementwise.h" #include "caffe2/utils/math/reduce.cuh" #include "caffe2/utils/math/utils.h" @@ -18,6 +23,221 @@ namespace math { namespace { +template <typename T, class Reducer> +__global__ void RowwiseReduceCUDAKernel( + const int cols, + const Reducer reducer, + const T init, + const T alpha, + const T* X, + T* Y) { + __shared__ typename BlockReduce<T>::TempStorage temp_storage; + const int r = blockIdx.x; + T val = init; + for (int c = threadIdx.x; c < cols; c += blockDim.x) { +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + val = reducer(val, __ldg(X + r * cols + c)); +#else + val = reducer(val, X[r * cols + c]); +#endif + } + val = BlockReduce<T>(temp_storage).Reduce(val, reducer); + if (threadIdx.x == 0) { + Y[r] = val * alpha; + } +} + +template <typename T, class Reducer> +__global__ void ColwiseReduceCUDAKernel( + const int rows, + const int cols, + const Reducer reducer, + const T init, + const T alpha, + const T* X, + T* Y) { + __shared__ typename BlockReduce<T>::TempStorage temp_storage; + const int c = blockIdx.x; + T val = init; + for (int r = threadIdx.x; r < rows; r += blockDim.x) { +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + val = reducer(val, __ldg(X + r * cols + c)); +#else + val = reducer(val, X[r * cols + c]); +#endif + } + val = BlockReduce<T>(temp_storage).Reduce(val, reducer); + if (threadIdx.x == 0) { + Y[c] = val * alpha; + } +} + +template <typename T, class Reducer, int kBlockDimX, int kBlockDimY> +__global__ void BothEndsReduceCUDAKernel( + const int M, + const int N, + const int K, + const Reducer reducer, + const T init, + const T alpha, + const T* X, + T* Y) { + __shared__ typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage + temp_storage; + const int n = blockIdx.x; + T val = init; + for (int m = threadIdx.x; m < M; m += blockDim.x) { + for (int k = threadIdx.y; k < K; k += blockDim.y) { +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + val = reducer(val, __ldg(X + (m * N + n) * K + k)); +#else + val = reducer(val, X[(m * N + n) * K + k]); +#endif + } + } + val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(temp_storage) + .Reduce(val, reducer); + if (threadIdx.x == 0 && threadIdx.y == 0) { + Y[n] = val * alpha; + } +} + +template <typename T, class Reducer, int D> +__global__ void ReduceTensorCUDAKernel( + const int inner_size, + const SimpleArray<int, D> X_strides, + const SimpleArray<int, D> Y_dims, + const Reducer reducer, + const T init, + const T alpha, + const T* X, + T* Y) { + __shared__ typename BlockReduce<T>::TempStorage temp_storage; + const int x = blockIdx.x; + T val = init; + for (int y = threadIdx.x; y < inner_size; y += blockDim.x) { + int X_index = 0; + int Y_index = x * inner_size + y; +#pragma unroll + for (int d = D - 1; d >= 0; --d) { + X_index += Y_index % Y_dims.data[d] * X_strides.data[d]; + Y_index /= Y_dims.data[d]; + } +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + val = reducer(val, __ldg(X + X_index)); +#else + val = reducer(val, X[X_index]); +#endif + } + val = BlockReduce<T>(temp_storage).Reduce(val, reducer); + if (threadIdx.x == 0) { + Y[x] = val * alpha; + } +} + +template <typename T, class Reducer, int D> +void ReduceTensorCUDAImpl( + const int outer_size, + const int inner_size, + const int* dims, + const int* axes, + const Reducer& reducer, + const T init, + const T alpha, + const T* X, + T* Y, + CUDAContext* context) { + SimpleArray<int, D> X_strides; + SimpleArray<int, D> Y_dims; + utils::ComputeTransposedStrides(D, dims, axes, X_strides.data); + for (int i = 0; i < D; ++i) { + Y_dims.data[i] = dims[axes[i]]; + } + ReduceTensorCUDAKernel<T, Reducer, D> + <<<outer_size, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>( + inner_size, X_strides, Y_dims, reducer, init, alpha, X, Y); +} + +template <typename T, class Reducer> +void ReduceTensorCUDA( + const int ndim, + const int* X_dims, + const int* Y_dims, + const Reducer& reducer, + const T init, + const T alpha, + const T* X, + T* Y, + CUDAContext* context) { + CAFFE_ENFORCE(utils::CheckReduceDims(ndim, X_dims, Y_dims)); + const int X_size = + std::accumulate(X_dims, X_dims + ndim, 1, std::multiplies<int>()); + const int Y_size = + std::accumulate(Y_dims, Y_dims + ndim, 1, std::multiplies<int>()); + if (X_size == 0) { + Set<T, CUDAContext>(Y_size, init * alpha, Y, context); + return; + } + if (std::equal(X_dims, X_dims + ndim, Y_dims)) { + Scale<T, T, CUDAContext>(X_size, alpha, X, Y, context); + return; + } + int rows; + int cols; + if (utils::IsRowwiseReduce(ndim, X_dims, Y_dims, &rows, &cols)) { + RowwiseReduceCUDAKernel<T, Reducer> + <<<rows, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>( + cols, reducer, init, alpha, X, Y); + return; + } + if (utils::IsColwiseReduce(ndim, X_dims, Y_dims, &rows, &cols)) { + ColwiseReduceCUDAKernel<T, Reducer> + <<<cols, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>( + rows, cols, reducer, init, alpha, X, Y); + return; + } + int M; + int N; + int K; + if (utils::IsBothEndsReduce(ndim, X_dims, Y_dims, &M, &N, &K)) { + DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_2( + K, + BothEndsReduceCUDAKernel, + T, + Reducer, + N, + context->cuda_stream(), + M, + N, + K, + reducer, + init, + alpha, + X, + Y); + return; + } + std::vector<int> axes(ndim); + utils::ComputeTransposeAxesForReduceOp(ndim, Y_dims, axes.data()); + const int outer_size = Y_size; + const int inner_size = X_size / Y_size; + DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_2( + ndim, + ReduceTensorCUDAImpl, + T, + Reducer, + outer_size, + inner_size, + X_dims, + axes.data(), + reducer, + init, + alpha, + X, + Y, + context); +} + template <typename T> __global__ void RowwiseMomentsCUDAKernel(const int cols, const T* X, T* mean, T* var) { @@ -119,7 +339,7 @@ template <typename T, int D> __global__ void MomentsCUDAKernel( const int inner_size, const SimpleArray<int, D> X_strides, - const SimpleArray<FixedDivisor<int>, D> Y_dims, + const SimpleArray<int, D> Y_dims, const T* X, T* mean, T* var) { @@ -134,9 +354,8 @@ __global__ void MomentsCUDAKernel( int Y_index = x * inner_size + y; #pragma unroll for (int d = D - 1; d >= 0; --d) { - int r; - Y_dims.data[d].DivMod(Y_index, &Y_index, &r); - X_index += r * X_strides.data[d]; + X_index += Y_index % Y_dims.data[d] * X_strides.data[d]; + Y_index /= Y_dims.data[d]; } #if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) m_val += __ldg(X + X_index); @@ -156,7 +375,7 @@ __global__ void MomentsCUDAKernel( } template <typename T, int D> -CAFFE2_CUDA_EXPORT void MomentsCUDAImpl( +void MomentsCUDAImpl( const int outer_size, const int inner_size, const int* dims, @@ -166,10 +385,10 @@ CAFFE2_CUDA_EXPORT void MomentsCUDAImpl( T* var, CUDAContext* context) { SimpleArray<int, D> X_strides; - SimpleArray<FixedDivisor<int>, D> Y_dims; + SimpleArray<int, D> Y_dims; utils::ComputeTransposedStrides(D, dims, axes, X_strides.data); for (int i = 0; i < D; ++i) { - Y_dims.data[i] = FixedDivisor<int>(dims[axes[i]]); + Y_dims.data[i] = dims[axes[i]]; } MomentsCUDAKernel<T, D> <<<outer_size, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>( @@ -177,7 +396,7 @@ CAFFE2_CUDA_EXPORT void MomentsCUDAImpl( } template <typename T> -CAFFE2_CUDA_EXPORT void MomentsCUDA( +void MomentsCUDA( const int ndim, const int* X_dims, const int* Y_dims, @@ -223,7 +442,7 @@ CAFFE2_CUDA_EXPORT void MomentsCUDA( int N; int K; if (utils::IsBothEndsReduce(ndim, X_dims, Y_dims, &M, &N, &K)) { - DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK( + DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1( K, BothEndsMomentsCUDAKernel, T, @@ -257,6 +476,95 @@ CAFFE2_CUDA_EXPORT void MomentsCUDA( } // namespace +#define DELEGATE_CUDA_REDUCE_FUNCTION(T, Func, Reducer, kInit) \ + template <> \ + CAFFE2_CUDA_EXPORT void Func<T, CUDAContext>( \ + const int ndim, \ + const int* X_dims, \ + const int* Y_dims, \ + const T alpha, \ + const T* X, \ + T* Y, \ + CUDAContext* context) { \ + ReduceTensorCUDA<T, Reducer>( \ + ndim, X_dims, Y_dims, Reducer(), kInit, alpha, X, Y, context); \ + } +DELEGATE_CUDA_REDUCE_FUNCTION( + std::int32_t, + ReduceMin, + cub::Min, + std::numeric_limits<std::int32_t>::max()) +DELEGATE_CUDA_REDUCE_FUNCTION( + std::int64_t, + ReduceMin, + cub::Min, + std::numeric_limits<std::int64_t>::max()) +DELEGATE_CUDA_REDUCE_FUNCTION( + float, + ReduceMin, + cub::Min, + std::numeric_limits<float>::max()) +DELEGATE_CUDA_REDUCE_FUNCTION( + double, + ReduceMin, + cub::Min, + std::numeric_limits<double>::max()) +DELEGATE_CUDA_REDUCE_FUNCTION( + std::int32_t, + ReduceMax, + cub::Max, + std::numeric_limits<std::int32_t>::lowest()) +DELEGATE_CUDA_REDUCE_FUNCTION( + std::int64_t, + ReduceMax, + cub::Max, + std::numeric_limits<std::int64_t>::lowest()) +DELEGATE_CUDA_REDUCE_FUNCTION( + float, + ReduceMax, + cub::Max, + std::numeric_limits<float>::lowest()) +DELEGATE_CUDA_REDUCE_FUNCTION( + double, + ReduceMax, + cub::Max, + std::numeric_limits<double>::lowest()) +DELEGATE_CUDA_REDUCE_FUNCTION(std::int32_t, ReduceSum, cub::Sum, 0) +DELEGATE_CUDA_REDUCE_FUNCTION(std::int64_t, ReduceSum, cub::Sum, 0LL) +DELEGATE_CUDA_REDUCE_FUNCTION(float, ReduceSum, cub::Sum, 0.0f) +DELEGATE_CUDA_REDUCE_FUNCTION(double, ReduceSum, cub::Sum, 0.0) +#undef DELEGATE_CUDA_REDUCE_FUNCTION + +#define CAFFE2_SPECIALIZED_CUDA_REDUCE_MEAN(T) \ + template <> \ + CAFFE2_CUDA_EXPORT void ReduceMean<T, CUDAContext>( \ + const int ndim, \ + const int* X_dims, \ + const int* Y_dims, \ + const T alpha, \ + const T* X, \ + T* Y, \ + CUDAContext* context) { \ + int scale = 1; \ + for (int i = 0; i < ndim; ++i) { \ + if (Y_dims[i] == 1) { \ + scale *= X_dims[i]; \ + } \ + } \ + ReduceTensorCUDA<T, cub::Sum>( \ + ndim, \ + X_dims, \ + Y_dims, \ + cub::Sum(), \ + T(0), \ + alpha / static_cast<T>(scale), \ + X, \ + Y, \ + context); \ + } +CAFFE2_SPECIALIZED_CUDA_REDUCE_MEAN(float) +#undef CAFFE2_SPECIALIZED_CUDA_REDUCE_MEAN + #define CAFFE2_SPECIALIZED_CUDA_MOMENTS(T) \ template <> \ CAFFE2_CUDA_EXPORT void Moments<T, CUDAContext>( \ diff --git a/caffe2/utils/math/reduce.cuh b/caffe2/utils/math/reduce.cuh index d191cbce8b..937cd50752 100644 --- a/caffe2/utils/math/reduce.cuh +++ b/caffe2/utils/math/reduce.cuh @@ -15,7 +15,7 @@ template <typename T, int kBlockDimX, int kBlockDimY> using BlockReduce2D = cub:: BlockReduce<T, kBlockDimX, cub::BLOCK_REDUCE_WARP_REDUCTIONS, kBlockDimY>; -#define DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK( \ +#define DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1( \ size, Func, T, grid_dim, cuda_stream, ...) \ do { \ if (size >= 128) { \ @@ -30,6 +30,24 @@ using BlockReduce2D = cub:: } \ } while (false) +#define DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_2( \ + size, Func, T1, T2, grid_dim, cuda_stream, ...) \ + do { \ + if (size >= 128) { \ + Func<T1, T2, 1, 128> \ + <<<grid_dim, dim3(1, 128), 0, cuda_stream>>>(__VA_ARGS__); \ + } else if (size >= 64) { \ + Func<T1, T2, 2, 64> \ + <<<grid_dim, dim3(2, 64), 0, cuda_stream>>>(__VA_ARGS__); \ + } else if (size >= 32) { \ + Func<T1, T2, 4, 32> \ + <<<grid_dim, dim3(4, 32), 0, cuda_stream>>>(__VA_ARGS__); \ + } else { \ + Func<T1, T2, 8, 16> \ + <<<grid_dim, dim3(8, 16), 0, cuda_stream>>>(__VA_ARGS__); \ + } \ + } while (false) + } // namespace caffe2 #endif // CAFFE2_UTILS_MATH_REDUCE_CUH_ diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu index 7314199b4a..819edc6d33 100644 --- a/caffe2/utils/math_gpu.cu +++ b/caffe2/utils/math_gpu.cu @@ -1469,61 +1469,6 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>( } namespace { - -template <typename T> -__global__ void SetKernel(const int N, const T alpha, T* Y) { - CUDA_1D_KERNEL_LOOP(i, N) { - Y[i] = alpha; - } -} - -} // namespace - -#define CAFFE2_SPECIALIZED_CUDA_SET(T) \ - template <> \ - CAFFE2_CUDA_API void Set<T, CUDAContext>( \ - const int N, const T alpha, T* Y, CUDAContext* context) { \ - if (N == 0) { \ - return; \ - } \ - if (alpha == T(0)) { \ - cudaMemsetAsync(Y, 0, sizeof(T) * N, context->cuda_stream()); \ - } else { \ - SetKernel<T> \ - <<<CAFFE_GET_BLOCKS(N), \ - CAFFE_CUDA_NUM_THREADS, \ - 0, \ - context->cuda_stream()>>>(N, alpha, Y); \ - } \ - } -CAFFE2_SPECIALIZED_CUDA_SET(float); -CAFFE2_SPECIALIZED_CUDA_SET(double); -CAFFE2_SPECIALIZED_CUDA_SET(bool); -CAFFE2_SPECIALIZED_CUDA_SET(int8_t); -CAFFE2_SPECIALIZED_CUDA_SET(int16_t); -CAFFE2_SPECIALIZED_CUDA_SET(int); -CAFFE2_SPECIALIZED_CUDA_SET(int64_t); -CAFFE2_SPECIALIZED_CUDA_SET(char); -CAFFE2_SPECIALIZED_CUDA_SET(uint8_t); -CAFFE2_SPECIALIZED_CUDA_SET(uint16_t); -#undef CAFFE2_SPECIALIZED_CUDA_SET - -template <> -CAFFE2_CUDA_EXPORT void Set<at::Half, CUDAContext>( - const int N, - const at::Half alpha, - at::Half* Y, - CUDAContext* context) { - if (N > 0) { - SetKernel<at::Half> - <<<CAFFE_GET_BLOCKS(N), - CAFFE_CUDA_NUM_THREADS, - 0, - context->cuda_stream()>>>(N, alpha, Y); - } -} - -namespace { template <typename T> __global__ void UniformShift(const size_t N, const float min, const float max, T* x) { @@ -1920,340 +1865,6 @@ CAFFE2_CUDA_EXPORT void Select<at::Half, CUDAContext>( context->cuda_stream()>>>(N, D, x, idx, y); } -namespace { - -template <typename TAlpha, typename TData> -__global__ void -ScaleCUDAKernel(const int n, const TAlpha alpha, const TData* x, TData* y) { - CUDA_1D_KERNEL_LOOP(i, n) { -#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) - y[i] = __ldg(x + i) * static_cast<TData>(alpha); -#else - y[i] = x[i] * static_cast<TData>(alpha); -#endif - } -} - -template <typename TAlpha, typename TData> -__global__ void -ScaleCUDAKernel(const int n, const TAlpha* alpha, const TData* x, TData* y) { - CUDA_1D_KERNEL_LOOP(i, n) { -#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) - y[i] = __ldg(x + i) * static_cast<TData>(__ldg(alpha)); -#else - y[i] = x[i] * static_cast<TData>(*alpha); -#endif - } -} - -template <typename T> -__global__ void PowKernel(const int n, const T* x, const T exponent, T* y) { - CUDA_1D_KERNEL_LOOP(i, n) { - y[i] = powf(x[i], exponent); - } -} - -} // namespace - -template <> -CAFFE2_CUDA_EXPORT void Powx<float, CUDAContext>( - const int N, - const float* a, - const float b, - float* y, - CUDAContext* context) { - PowKernel<<< - CAFFE_GET_BLOCKS(N), - CAFFE_CUDA_NUM_THREADS, - 0, - context->cuda_stream()>>>(N, a, b, y); -} - -#define DELEGATE_CUBLAS_SCALE_FUNCTION(TAlpha, TData, CuBLASFunc) \ - template <> \ - CAFFE2_CUDA_EXPORT void Scale<TAlpha, TData, CUDAContext>( \ - const int N, \ - const TAlpha alpha, \ - const TData* x, \ - TData* y, \ - CUDAContext* context) { \ - if (N == 0) { \ - return; \ - } \ - if (x != y) { \ - cudaMemcpyAsync( \ - y, \ - x, \ - sizeof(TData) * N, \ - cudaMemcpyDeviceToDevice, \ - context->cuda_stream()); \ - } \ - if (alpha != TAlpha(1)) { \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ - CUBLAS_ENFORCE(CuBLASFunc(context->cublas_handle(), N, &alpha, y, 1)); \ - } \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Scale<TAlpha, TData, CUDAContext>( \ - const int N, \ - const TAlpha* alpha, \ - const TData* x, \ - TData* y, \ - CUDAContext* context) { \ - if (N == 0) { \ - return; \ - } \ - if (x != y) { \ - cudaMemcpyAsync( \ - y, \ - x, \ - sizeof(TData) * N, \ - cudaMemcpyDeviceToDevice, \ - context->cuda_stream()); \ - } \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); \ - CUBLAS_ENFORCE(CuBLASFunc(context->cublas_handle(), N, alpha, y, 1)); \ - } -DELEGATE_CUBLAS_SCALE_FUNCTION(float, float, cublasSscal) -DELEGATE_CUBLAS_SCALE_FUNCTION(double, double, cublasDscal) -#undef DELEGATE_CUBLAS_SCALE_FUNCTION - -#define CAFFE2_SPECIALIZED_CUDA_SCALE(TAlpha, TData) \ - template <> \ - CAFFE2_CUDA_EXPORT void Scale<TAlpha, TData, CUDAContext>( \ - const int N, \ - const TAlpha alpha, \ - const TData* x, \ - TData* y, \ - CUDAContext* context) { \ - if (N == 0) { \ - return; \ - } \ - if (alpha == TAlpha(1)) { \ - if (x != y) { \ - cudaMemcpyAsync( \ - y, \ - x, \ - sizeof(TData) * N, \ - cudaMemcpyDeviceToDevice, \ - context->cuda_stream()); \ - } \ - return; \ - } \ - ScaleCUDAKernel<TAlpha, TData> \ - <<<CAFFE_GET_BLOCKS(N), \ - CAFFE_CUDA_NUM_THREADS, \ - 0, \ - context->cuda_stream()>>>(N, alpha, x, y); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Scale<TAlpha, TData, CUDAContext>( \ - const int N, \ - const TAlpha* alpha, \ - const TData* x, \ - TData* y, \ - CUDAContext* context) { \ - if (N == 0) { \ - return; \ - } \ - ScaleCUDAKernel<TAlpha, TData> \ - <<<CAFFE_GET_BLOCKS(N), \ - CAFFE_CUDA_NUM_THREADS, \ - 0, \ - context->cuda_stream()>>>(N, alpha, x, y); \ - } -CAFFE2_SPECIALIZED_CUDA_SCALE(std::int32_t, std::int32_t) -CAFFE2_SPECIALIZED_CUDA_SCALE(std::int64_t, std::int64_t) - -#ifndef __HIP_PLATFORM_HCC__ -template <> -CAFFE2_CUDA_EXPORT void Scale<at::Half, at::Half, CUDAContext>( - const int N, - const at::Half alpha, - const at::Half* x, - at::Half* y, - CUDAContext* context) { - if (N == 0) { - return; - } - if (x != y) { - cudaMemcpyAsync( - y, - x, - sizeof(at::Half) * N, - cudaMemcpyDeviceToDevice, - context->cuda_stream()); - } - CUBLAS_ENFORCE( - cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasScalEx( - context->cublas_handle(), - N, - &alpha, - CUDA_R_16F, - y, - CUDA_R_16F, - 1, - CUDA_R_32F)); -} - -template <> -CAFFE2_CUDA_EXPORT void Scale<at::Half, at::Half, CUDAContext>( - const int N, - const at::Half* alpha, - const at::Half* x, - at::Half* y, - CUDAContext* context) { - if (N == 0) { - return; - } - if (x != y) { - cudaMemcpyAsync( - y, - x, - sizeof(at::Half) * N, - cudaMemcpyDeviceToDevice, - context->cuda_stream()); - } - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); - CUBLAS_ENFORCE(cublasScalEx( - context->cublas_handle(), - N, - alpha, - CUDA_R_16F, - y, - CUDA_R_16F, - 1, - CUDA_R_32F)); -} - -template <> -CAFFE2_CUDA_EXPORT void Scale<float, at::Half, CUDAContext>( - const int N, - const float alpha, - const at::Half* x, - at::Half* y, - CUDAContext* context) { - if (N == 0) { - return; - } - if (x != y) { - cudaMemcpyAsync( - y, - x, - sizeof(at::Half) * N, - cudaMemcpyDeviceToDevice, - context->cuda_stream()); - } - if (alpha != 1.0f) { - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasScalEx( - context->cublas_handle(), - N, - &alpha, - CUDA_R_32F, - y, - CUDA_R_16F, - 1, - CUDA_R_32F)); - } -} - -template <> -CAFFE2_CUDA_EXPORT void Scale<float, at::Half, CUDAContext>( - const int N, - const float* alpha, - const at::Half* x, - at::Half* y, - CUDAContext* context) { - if (N == 0) { - return; - } - if (x != y) { - cudaMemcpyAsync( - y, - x, - sizeof(at::Half) * N, - cudaMemcpyDeviceToDevice, - context->cuda_stream()); - } - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); - CUBLAS_ENFORCE(cublasScalEx( - context->cublas_handle(), - N, - alpha, - CUDA_R_32F, - y, - CUDA_R_16F, - 1, - CUDA_R_32F)); -} - -#else // __HIP_PLATFORM_HCC__ - -namespace { -template <> -__global__ void ScaleCUDAKernel<at::Half, at::Half>( - const int n, - const at::Half alpha, - const at::Half* x, - at::Half* y) { - CUDA_1D_KERNEL_LOOP(i, n) { - y[i] = convert::To<float, at::Half>( - convert::To<at::Half, float>(x[i]) * - convert::To<at::Half, float>(alpha)); - } -} - -template <> -__global__ void ScaleCUDAKernel<at::Half, at::Half>( - const int n, - const at::Half* alpha, - const at::Half* x, - at::Half* y) { - CUDA_1D_KERNEL_LOOP(i, n) { - y[i] = convert::To<float, at::Half>( - convert::To<at::Half, float>(x[i]) * - convert::To<at::Half, float>(*alpha)); - } -} - -template <> -__global__ void ScaleCUDAKernel<float, at::Half>( - const int n, - const float alpha, - const at::Half* x, - at::Half* y) { - CUDA_1D_KERNEL_LOOP(i, n) { - y[i] = convert::To<float, at::Half>( - convert::To<at::Half, float>(x[i]) * alpha); - } -} - -template <> -__global__ void ScaleCUDAKernel<float, at::Half>( - const int n, - const float* alpha, - const at::Half* x, - at::Half* y) { - CUDA_1D_KERNEL_LOOP(i, n) { - y[i] = convert::To<float, at::Half>( - convert::To<at::Half, float>(x[i]) * (*alpha)); - } -} -} // namespace - -CAFFE2_SPECIALIZED_HIP_SCALE(at::Half, at::Half) -CAFFE2_SPECIALIZED_HIP_SCALE(float, at::Half) -#endif // __HIP_PLATFORM_HCC__ - -#undef CAFFE2_SPECIALIZED_CUDA_SCALE - template <> CAFFE2_CUDA_EXPORT void Axpy<float, CUDAContext>( const int N, @@ -3283,250 +2894,6 @@ CAFFE2_CUDA_EXPORT void Maximum( namespace { -template <typename T, class Reducer, int D> -__global__ void ReduceTensorCUDAKernel( - const int outer_size, - const int inner_size, - SimpleArray<int, D> X_strides, - SimpleArray<FIXED_DIVISOR, D> Y_dims, - const Reducer reducer, - const T init, - const T alpha, - const T* X, - T* Y) { - __shared__ typename BlockReduce<T>::TempStorage temp_storage; - for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { - T val = init; - for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { - int X_index = 0; - int Y_index = i * inner_size + j; -#pragma unroll - for (int d = D - 1; d >= 0; --d) { - int r; - FIXED_DIVISOR_DIV_MOD(Y_dims.data[d], Y_index, &Y_index, &r); - X_index += r * X_strides.data[d]; - } -#if __CUDA_ARCH__ >= 350 - val = reducer(val, __ldg(X + X_index)); -#else - val = reducer(val, X[X_index]); -#endif - } - val = BlockReduce<T>(temp_storage).Reduce(val, reducer); - if (threadIdx.x == 0) { - Y[i] = val * alpha; - } - __syncthreads(); - } -} - -template <typename T, class Reducer, int D> -CAFFE2_CUDA_EXPORT void ReduceTensorCUDAImpl( - const int outer_size, - const int inner_size, - const int* dims, - const int* axes, - const Reducer& reducer, - const T init, - const T alpha, - const T* X, - T* Y, - CUDAContext* context) { - SimpleArray<int, D> X_strides; - SimpleArray<FIXED_DIVISOR, D> Y_dims; - utils::ComputeTransposedStrides(D, dims, axes, X_strides.data); - for (int i = 0; i < D; ++i) { - Y_dims.data[i] = FIXED_DIVISOR(dims[axes[i]]); - } - ReduceTensorCUDAKernel<T, Reducer, D> - <<<std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS), - CAFFE_CUDA_NUM_THREADS, - 0, - context->cuda_stream()>>>( - outer_size, - inner_size, - X_strides, - Y_dims, - reducer, - init, - alpha, - X, - Y); -} - -template <typename T, class Reducer> -CAFFE2_CUDA_EXPORT void ReduceTensorCUDA( - const int ndim, - const int* X_dims, - const int* Y_dims, - const Reducer& reducer, - const T init, - const T alpha, - const T* X, - T* Y, - CUDAContext* context) { - const int X_size = - std::accumulate(X_dims, X_dims + ndim, 1, std::multiplies<int>()); - const int Y_size = - std::accumulate(Y_dims, Y_dims + ndim, 1, std::multiplies<int>()); - if (X_size == 0) { - Set<T, CUDAContext>(Y_size, alpha * init, Y, context); - return; - } - if (alpha == T(0)) { - Set<T, CUDAContext>(Y_size, T(0), Y, context); - return; - } - if (std::equal(X_dims, X_dims + ndim, Y_dims)) { - Scale<T, T, CUDAContext>(X_size, alpha, X, Y, context); - return; - } - int rows; - int cols; - if (utils::IsRowwiseReduce(ndim, X_dims, Y_dims, &rows, &cols)) { - RowwiseReduceKernel<T> - <<<std::min(rows, CAFFE_MAXIMUM_NUM_BLOCKS), - CAFFE_CUDA_NUM_THREADS, - 0, - context->cuda_stream()>>>(rows, cols, reducer, init, alpha, X, Y); - return; - } - if (utils::IsColwiseReduce(ndim, X_dims, Y_dims, &rows, &cols)) { - ColwiseReduceKernel<T> - <<<std::min(cols, CAFFE_MAXIMUM_NUM_BLOCKS), - CAFFE_CUDA_NUM_THREADS, - 0, - context->cuda_stream()>>>(rows, cols, reducer, init, alpha, X, Y); - return; - } - std::vector<int> axes(ndim); - utils::ComputeTransposeAxesForReduceOp(ndim, Y_dims, axes.data()); - const int outer_size = Y_size; - const int inner_size = X_size / Y_size; - DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_2( - ndim, - ReduceTensorCUDAImpl, - T, - Reducer, - outer_size, - inner_size, - X_dims, - axes.data(), - reducer, - init, - alpha, - X, - Y, - context); -} - -} // namespace - -#define CAFFE2_SPECIALIZED_CUDA_REDUCE_MIN(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void ReduceMin<T, CUDAContext>( \ - const int ndim, \ - const int* X_dims, \ - const int* Y_dims, \ - const T alpha, \ - const T* X, \ - T* Y, \ - CUDAContext* context) { \ - ReduceTensorCUDA( \ - ndim, \ - X_dims, \ - Y_dims, \ - cub::Min(), \ - std::numeric_limits<T>::max(), \ - alpha, \ - X, \ - Y, \ - context); \ - } -CAFFE2_SPECIALIZED_CUDA_REDUCE_MIN(std::int32_t) -CAFFE2_SPECIALIZED_CUDA_REDUCE_MIN(std::int64_t) -CAFFE2_SPECIALIZED_CUDA_REDUCE_MIN(float) -CAFFE2_SPECIALIZED_CUDA_REDUCE_MIN(double) -#undef CAFFE2_SPECIALIZED_CUDA_REDUCE_MIN - -#define CAFFE2_SPECIALIZED_CUDA_REDUCE_MAX(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void ReduceMax<T, CUDAContext>( \ - const int ndim, \ - const int* X_dims, \ - const int* Y_dims, \ - const T alpha, \ - const T* X, \ - T* Y, \ - CUDAContext* context) { \ - ReduceTensorCUDA( \ - ndim, \ - X_dims, \ - Y_dims, \ - cub::Max(), \ - std::numeric_limits<T>::lowest(), \ - alpha, \ - X, \ - Y, \ - context); \ - } -CAFFE2_SPECIALIZED_CUDA_REDUCE_MAX(std::int32_t) -CAFFE2_SPECIALIZED_CUDA_REDUCE_MAX(std::int64_t) -CAFFE2_SPECIALIZED_CUDA_REDUCE_MAX(float) -CAFFE2_SPECIALIZED_CUDA_REDUCE_MAX(double) -#undef CAFFE2_SPECIALIZED_CUDA_REDUCE_MAX - -#define CAFFE2_SPECIALIZED_CUDA_REDUCE_SUM(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void ReduceSum<T, CUDAContext>( \ - const int ndim, \ - const int* X_dims, \ - const int* Y_dims, \ - const T alpha, \ - const T* X, \ - T* Y, \ - CUDAContext* context) { \ - ReduceTensorCUDA( \ - ndim, X_dims, Y_dims, cub::Sum(), T(0), alpha, X, Y, context); \ - } -CAFFE2_SPECIALIZED_CUDA_REDUCE_SUM(std::int32_t) -CAFFE2_SPECIALIZED_CUDA_REDUCE_SUM(std::int64_t) -CAFFE2_SPECIALIZED_CUDA_REDUCE_SUM(float) -CAFFE2_SPECIALIZED_CUDA_REDUCE_SUM(double) -#undef CAFFE2_SPECIALIZED_CUDA_REDUCE_SUM - -#define CAFFE2_SPECIALIZED_CUDA_REDUCE_MEAN(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void ReduceMean<T, CUDAContext>( \ - const int ndim, \ - const int* X_dims, \ - const int* Y_dims, \ - const T alpha, \ - const T* X, \ - T* Y, \ - CUDAContext* context) { \ - int scale = 1; \ - for (int i = 0; i < ndim; ++i) { \ - if (Y_dims[i] == 1) { \ - scale *= X_dims[i]; \ - } \ - } \ - ReduceTensorCUDA( \ - ndim, \ - X_dims, \ - Y_dims, \ - cub::Sum(), \ - T(0), \ - alpha / static_cast<T>(scale), \ - X, \ - Y, \ - context); \ - } -CAFFE2_SPECIALIZED_CUDA_REDUCE_MEAN(float) -#undef CAFFE2_SPECIALIZED_CUDA_REDUCE_MEAN - -namespace { - template <typename T, int D> __global__ void BroadcastCUDAKernel( const int Y_size, |