summaryrefslogtreecommitdiff
path: root/caffe2/utils
diff options
context:
space:
mode:
authorXiaomeng Yang <yangxm@fb.com>2019-02-01 23:45:38 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-01 23:49:07 -0800
commit7d4a81cbb217938529d558b6fab9d101db525d25 (patch)
treec1ab858ce5232c032b8bdda91742e72982db2c10 /caffe2/utils
parentf36f3cce9adfc24b3a662b16545416b9b4df719a (diff)
downloadpytorch-7d4a81cbb217938529d558b6fab9d101db525d25.tar.gz
pytorch-7d4a81cbb217938529d558b6fab9d101db525d25.tar.bz2
pytorch-7d4a81cbb217938529d558b6fab9d101db525d25.zip
Use macro for reduce on 2d blocks (#16344)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16344 Use macro for reduce on 2d blocks i-am-not-moving-c2-to-c10 Reviewed By: houseroad Differential Revision: D13808988 fbshipit-source-id: b68c0fb6079c1b6e203a072083aba7a95c202bc2
Diffstat (limited to 'caffe2/utils')
-rw-r--r--caffe2/utils/math/elementwise.cu18
-rw-r--r--caffe2/utils/math/reduce.cu37
-rw-r--r--caffe2/utils/math/reduce.cuh35
3 files changed, 57 insertions, 33 deletions
diff --git a/caffe2/utils/math/elementwise.cu b/caffe2/utils/math/elementwise.cu
index 0798b6f640..cc7613cda9 100644
--- a/caffe2/utils/math/elementwise.cu
+++ b/caffe2/utils/math/elementwise.cu
@@ -23,8 +23,8 @@ __global__ void SinCosCUDAKernel(const int N, const T* X, T* S, T* C) {
template <typename T>
__global__ void AffineChannelNCHWCUDAKernel(
const int C,
+ const int M,
const int HxW,
- const int K,
const T* X,
const T* scale,
const T* bias,
@@ -33,15 +33,15 @@ __global__ void AffineChannelNCHWCUDAKernel(
template <>
__global__ void AffineChannelNCHWCUDAKernel<float>(
const int C,
+ const int M,
const int HxW,
- const int K,
const float* X,
const float* scale,
const float* bias,
float* Y) {
- const int nc = blockIdx.x / K;
+ const int nc = blockIdx.x / M;
const int c = nc % C;
- const int w = blockIdx.x % K * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+ const int w = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
if (w < HxW) {
const int index = nc * HxW + w;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
@@ -180,10 +180,10 @@ CAFFE2_SPECIALIZED_CUDA_SINCOS(double)
const T* bias, \
T* Y, \
CUDAContext* context) { \
- const int K = DivUp(HxW, CAFFE_CUDA_NUM_THREADS); \
+ const int M = DivUp(HxW, CAFFE_CUDA_NUM_THREADS); \
AffineChannelNCHWCUDAKernel<T> \
- <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>( \
- C, HxW, K, X, scale, bias, Y); \
+ <<<N * C * M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>( \
+ C, M, HxW, X, scale, bias, Y); \
} \
template <> \
CAFFE2_CUDA_EXPORT void AffineChannel<T, CUDAContext, StorageOrder::NHWC>( \
@@ -195,9 +195,9 @@ CAFFE2_SPECIALIZED_CUDA_SINCOS(double)
const T* bias, \
T* Y, \
CUDAContext* context) { \
- const int K = DivUp(C, CAFFE_CUDA_NUM_THREADS); \
+ const int M = DivUp(C, CAFFE_CUDA_NUM_THREADS); \
AffineChannelNHWCCUDAKernel<T> \
- <<<dim3(N* HxW, K), \
+ <<<dim3(N* HxW, M), \
CAFFE_CUDA_NUM_THREADS, \
0, \
context->cuda_stream()>>>(C, X, scale, bias, Y); \
diff --git a/caffe2/utils/math/reduce.cu b/caffe2/utils/math/reduce.cu
index f597ec789d..31a653930e 100644
--- a/caffe2/utils/math/reduce.cu
+++ b/caffe2/utils/math/reduce.cu
@@ -10,6 +10,7 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/utils/fixed_divisor.h"
+#include "caffe2/utils/math/reduce.cuh"
#include "caffe2/utils/math_utils.h"
namespace caffe2 {
@@ -18,13 +19,6 @@ namespace math {
namespace {
template <typename T>
-using BlockReduce = cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS>;
-
-template <typename T, int kBlockDimX, int kBlockDimY>
-using BlockReduce2D = cub::
- BlockReduce<T, kBlockDimX, cub::BLOCK_REDUCE_WARP_REDUCTIONS, kBlockDimY>;
-
-template <typename T>
__global__ void
RowwiseMomentsCUDAKernel(const int cols, const T* X, T* mean, T* var) {
__shared__ typename BlockReduce<T>::TempStorage m_storage;
@@ -229,23 +223,18 @@ CAFFE2_CUDA_EXPORT void MomentsCUDA(
int N;
int K;
if (utils::IsBothEndsReduce(ndim, X_dims, Y_dims, &M, &N, &K)) {
- if (K >= 128) {
- BothEndsMomentsCUDAKernel<T, 1, 128>
- <<<N, dim3(1, 128), 0, context->cuda_stream()>>>(
- M, N, K, X, mean, var);
- } else if (K >= 64) {
- BothEndsMomentsCUDAKernel<T, 2, 64>
- <<<N, dim3(2, 64), 0, context->cuda_stream()>>>(
- M, N, K, X, mean, var);
- } else if (K >= 32) {
- BothEndsMomentsCUDAKernel<T, 4, 32>
- <<<N, dim3(4, 32), 0, context->cuda_stream()>>>(
- M, N, K, X, mean, var);
- } else {
- BothEndsMomentsCUDAKernel<T, 8, 16>
- <<<N, dim3(8, 16), 0, context->cuda_stream()>>>(
- M, N, K, X, mean, var);
- }
+ DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK(
+ K,
+ BothEndsMomentsCUDAKernel,
+ T,
+ N,
+ context->cuda_stream(),
+ M,
+ N,
+ K,
+ X,
+ mean,
+ var);
return;
}
std::vector<int> axes(ndim);
diff --git a/caffe2/utils/math/reduce.cuh b/caffe2/utils/math/reduce.cuh
new file mode 100644
index 0000000000..d191cbce8b
--- /dev/null
+++ b/caffe2/utils/math/reduce.cuh
@@ -0,0 +1,35 @@
+#ifndef CAFFE2_UTILS_MATH_REDUCE_CUH_
+#define CAFFE2_UTILS_MATH_REDUCE_CUH_
+
+#include <cub/block/block_reduce.cuh>
+#include <cub/cub.cuh>
+
+#include "caffe2/core/common_gpu.h"
+
+namespace caffe2 {
+
+template <typename T>
+using BlockReduce = cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS>;
+
+template <typename T, int kBlockDimX, int kBlockDimY>
+using BlockReduce2D = cub::
+ BlockReduce<T, kBlockDimX, cub::BLOCK_REDUCE_WARP_REDUCTIONS, kBlockDimY>;
+
+#define DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK( \
+ size, Func, T, grid_dim, cuda_stream, ...) \
+ do { \
+ if (size >= 128) { \
+ Func<T, 1, 128> \
+ <<<grid_dim, dim3(1, 128), 0, cuda_stream>>>(__VA_ARGS__); \
+ } else if (size >= 64) { \
+ Func<T, 2, 64><<<grid_dim, dim3(2, 64), 0, cuda_stream>>>(__VA_ARGS__); \
+ } else if (size >= 32) { \
+ Func<T, 4, 32><<<grid_dim, dim3(4, 32), 0, cuda_stream>>>(__VA_ARGS__); \
+ } else { \
+ Func<T, 8, 16><<<grid_dim, dim3(8, 16), 0, cuda_stream>>>(__VA_ARGS__); \
+ } \
+ } while (false)
+
+} // namespace caffe2
+
+#endif // CAFFE2_UTILS_MATH_REDUCE_CUH_