summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiaomeng Yang <yangxm@fb.com>2019-01-18 22:37:12 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-18 22:40:16 -0800
commitb436f94b53e261177422fe92680d42f19195d3d0 (patch)
tree37d276fb86c4f7ee5ba474308c6037d1802af786
parente8b872abe225dfc9f622eaa847e82ad7296ea17b (diff)
downloadpytorch-b436f94b53e261177422fe92680d42f19195d3d0.tar.gz
pytorch-b436f94b53e261177422fe92680d42f19195d3d0.tar.bz2
pytorch-b436f94b53e261177422fe92680d42f19195d3d0.zip
Separate affine_channel from math and optimize it (#16135)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16135 Separate affine_channel from math and optimize it i-am-not-moving-c2-to-c10 Reviewed By: houseroad Differential Revision: D13727606 fbshipit-source-id: 8980af4afadaf964a18a9da581106fe30896a7e9
-rw-r--r--caffe2/operators/length_split_op.h2
-rw-r--r--caffe2/operators/stylizer_ops.cc2
-rw-r--r--caffe2/operators/top_k.cu2
-rw-r--r--caffe2/operators/top_k_heap_selection.cuh2
-rw-r--r--caffe2/operators/top_k_radix_selection.cuh4
-rw-r--r--caffe2/utils/CMakeLists.txt21
-rw-r--r--caffe2/utils/GpuScanUtils.cuh2
-rw-r--r--caffe2/utils/math.h38
-rw-r--r--caffe2/utils/math/elementwise.cc55
-rw-r--r--caffe2/utils/math/elementwise.cu105
-rw-r--r--caffe2/utils/math/elementwise.h24
-rw-r--r--caffe2/utils/math_cpu.cc55
-rw-r--r--caffe2/utils/math_gpu.cu61
-rw-r--r--caffe2/utils/math_utils.h29
14 files changed, 246 insertions, 156 deletions
diff --git a/caffe2/operators/length_split_op.h b/caffe2/operators/length_split_op.h
index 7270c0274e..87eba854ea 100644
--- a/caffe2/operators/length_split_op.h
+++ b/caffe2/operators/length_split_op.h
@@ -57,7 +57,7 @@ class LengthsSplitOp final : public Operator<Context> {
for (int i = 0; i < M; i++) {
int32_t mod = Ldata[i] % n_split_;
int32_t res =
- mod != 0 ? math::divUp(Ldata[i], n_split_) : Ldata[i] / n_split_ + 1;
+ mod != 0 ? math::DivUp(Ldata[i], n_split_) : Ldata[i] / n_split_ + 1;
for (int j = 0; j < n_split_; j++) {
Ydata[(i * n_split_) + j] = mod-- > 0 ? res : res - 1;
}
diff --git a/caffe2/operators/stylizer_ops.cc b/caffe2/operators/stylizer_ops.cc
index d731ace03c..76bdbce45b 100644
--- a/caffe2/operators/stylizer_ops.cc
+++ b/caffe2/operators/stylizer_ops.cc
@@ -138,7 +138,7 @@ class PackedInt8BGRANHWCToNCHWCStylizerPreprocessOp
// For ARM NEON, we read in multiples of kNeonNoiseReadSize since
// the inner loop is vectorized. Round up to the next highest
// multiple of kNeonNoiseReadSize
- size = math::roundUp(size, kNeonNoiseReadSize) + size;
+ size = math::RoundUp(size, kNeonNoiseReadSize) + size;
noise->Resize(size);
math::RandGaussian<float, CPUContext>(
diff --git a/caffe2/operators/top_k.cu b/caffe2/operators/top_k.cu
index 52a59d01ce..85810547e8 100644
--- a/caffe2/operators/top_k.cu
+++ b/caffe2/operators/top_k.cu
@@ -56,7 +56,7 @@ void RunRadixSelectionImpl(
int64_t* indices,
CUDAContext* context) {
const int block = std::min(
- math::roundUp(static_cast<int>(inner_size), kWarpSize),
+ math::RoundUp(static_cast<int>(inner_size), kWarpSize),
CAFFE_CUDA_NUM_THREADS);
gatherTopK<T, kSelectMax, int64_t>
<<<outer_size, block, 0, context->cuda_stream()>>>(
diff --git a/caffe2/operators/top_k_heap_selection.cuh b/caffe2/operators/top_k_heap_selection.cuh
index 674d317be4..e9c5c0f6d2 100644
--- a/caffe2/operators/top_k_heap_selection.cuh
+++ b/caffe2/operators/top_k_heap_selection.cuh
@@ -72,7 +72,7 @@ __device__ inline void warpHeapInsert(K k, V v, K* keyHeap, V* valueHeap) {
// log2(8 / 2) = 2 levels of interior nodes for heap size 8 (0 and 12)
int i = 0;
#pragma unroll
- for (int levels = 0; levels < math::integerLog2(HeapSize / 2); ++levels) {
+ for (int levels = 0; levels < math::IntegerLog2(HeapSize / 2); ++levels) {
int leftChild = i * 2 + 1;
int rightChild = leftChild + 1;
K leftKey = keyHeap[leftChild];
diff --git a/caffe2/operators/top_k_radix_selection.cuh b/caffe2/operators/top_k_radix_selection.cuh
index 0bf38e4005..adc9ff141c 100644
--- a/caffe2/operators/top_k_radix_selection.cuh
+++ b/caffe2/operators/top_k_radix_selection.cuh
@@ -218,7 +218,7 @@ __device__ DataType findPattern(DataType* smem,
__syncthreads();
// All threads participate in the loop, in order to sync on the flag
- int numIterations = math::roundUp(sliceSize, (int) blockDim.x);
+ int numIterations = math::RoundUp(sliceSize, (int) blockDim.x);
for (int i = threadIdx.x; i < numIterations; i += blockDim.x) {
bool inRange = (i < sliceSize);
DataType v = inRange ? data[i] : (DataType)0;
@@ -388,7 +388,7 @@ __global__ void gatherTopK(const T* inputPtr,
// All threads need to participate in the loop and the prefix sum,
// but not necessarily in the load; hence loop bounds being rounded
// up to a multiple of the block dim.
- int numIterations = math::roundUp(inputSliceSize, (int) blockDim.x);
+ int numIterations = math::RoundUp(inputSliceSize, (int) blockDim.x);
int writeIndexStart = 0;
for (int i = threadIdx.x; i < numIterations; i += blockDim.x) {
diff --git a/caffe2/utils/CMakeLists.txt b/caffe2/utils/CMakeLists.txt
index a99fd7fe76..fafe6f1923 100644
--- a/caffe2/utils/CMakeLists.txt
+++ b/caffe2/utils/CMakeLists.txt
@@ -1,17 +1,18 @@
list(APPEND Caffe2_CPU_SRCS
+ utils/bench_utils.cc
+ utils/cpuid.cc
+ utils/math/elementwise.cc
+ utils/math_cpu.cc
+ utils/math_utils.cc
+ utils/murmur_hash3.cc
utils/proto_convert.cc
- utils/proto_wrap.cc
utils/proto_utils.cc
- utils/murmur_hash3.cc
- utils/smart_tensor_printer.cc
+ utils/proto_wrap.cc
utils/signal_handler.cc
+ utils/smart_tensor_printer.cc
utils/string_utils.cc
- utils/threadpool/ThreadPool.cc
- utils/cpuid.cc
- utils/bench_utils.cc
- utils/math_cpu.cc
- utils/math_utils.cc
- utils/thread_name.cc)
+ utils/thread_name.cc
+ utils/threadpool/ThreadPool.cc)
# ---[ threadpool/pthreadpool* is a local modification of the NNPACK
# pthreadpool with a very similar interface. Neither NNPACK, nor this
@@ -24,10 +25,12 @@ if (NOT MSVC)
endif()
set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS}
+ utils/math/elementwise.cu
utils/math_gpu.cu
)
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS}
+ utils/math/hip/elementwise.hip
utils/hip/math_gpu.hip
)
diff --git a/caffe2/utils/GpuScanUtils.cuh b/caffe2/utils/GpuScanUtils.cuh
index c6003825ff..24ae38c076 100644
--- a/caffe2/utils/GpuScanUtils.cuh
+++ b/caffe2/utils/GpuScanUtils.cuh
@@ -123,7 +123,7 @@ __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, Bi
// The outgoing carry for all threads is the last warp's sum
#if defined(__HIP_PLATFORM_HCC__)
- *carry = smem[math::divUp<int>(blockDim.x, kWarpSize) - 1];
+ *carry = smem[math::DivUp<int>(blockDim.x, kWarpSize) - 1];
#else
*carry = smem[(blockDim.x / kWarpSize) - 1];
#endif // __HIP_PLATFORM_HCC__
diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h
index cb949eec5c..af0e66823e 100644
--- a/caffe2/utils/math.h
+++ b/caffe2/utils/math.h
@@ -15,6 +15,7 @@ extern "C" {
#include "caffe2/core/common.h"
#include "caffe2/core/types.h"
+#include "caffe2/utils/math/elementwise.h"
#include "caffe2/utils/math_utils.h"
namespace caffe2 {
@@ -662,17 +663,6 @@ CAFFE2_API void CopyMatrix(
template <typename T, class Context>
CAFFE2_API void CopyVector(const int N, const T* A, T* B, Context* context);
-template <typename T, class Context, StorageOrder kOrder>
-CAFFE2_API void AffineChannel(
- const int N,
- const int C,
- const int HxW,
- const T* X,
- const T* scale,
- const T* bias,
- T* Y,
- Context* context);
-
template <typename T, class Context>
CAFFE2_API void NCHW2NHWC(
const int N,
@@ -691,32 +681,6 @@ CAFFE2_API void NHWC2NCHW(
T* Y,
Context* context);
-// Calculates ceil(a / b). User must be careful to ensure that there
-// is no overflow or underflow in the calculation.
-template <typename T>
-constexpr T divUp(T a, T b) {
- return (a + b - (T)1) / b;
-}
-
-// Rounds a up to the next highest multiple of b. User must be careful
-// to ensure that there is no overflow or underflow in the calculation
-// of divUp.
-template <typename T>
-constexpr T roundUp(T a, T b) {
- return divUp<T>(a, b) * b;
-}
-
-// Returns log2(n) for a positive integer type
-template <typename T>
-constexpr int integerLog2(T n, int p = 0) {
- return (n <= 1) ? p : integerLog2(n / 2, p + 1);
-}
-
-// Returns the next highest power-of-2 for an integer type
-template <typename T>
-constexpr T integerNextHighestPowerOf2(T v) {
- return (integerIsPowerOf2(v) ? (T)2 * v : ((T)1 << (integerLog2(v) + 1)));
-}
} // namespace math
} // namespace caffe2
diff --git a/caffe2/utils/math/elementwise.cc b/caffe2/utils/math/elementwise.cc
new file mode 100644
index 0000000000..5634fb876b
--- /dev/null
+++ b/caffe2/utils/math/elementwise.cc
@@ -0,0 +1,55 @@
+#include "caffe2/utils/math/elementwise.h"
+
+#include "caffe2/core/context.h"
+#include "caffe2/utils/eigen_utils.h"
+
+namespace caffe2 {
+namespace math {
+
+#define CAFFE2_SPECIALIZED_AFFINE_CHANNEL(T) \
+ template <> \
+ void AffineChannel<T, CPUContext, StorageOrder::NCHW>( \
+ const int N, \
+ const int C, \
+ const int HxW, \
+ const T* X, \
+ const T* scale, \
+ const T* bias, \
+ T* Y, \
+ CPUContext* /* context */) { \
+ ConstEigenVectorArrayMap<T> scale_arr(scale, C); \
+ ConstEigenVectorArrayMap<T> bias_arr(bias, C); \
+ const int stride = C * HxW; \
+ const T* X_ptr = X; \
+ T* Y_ptr = Y; \
+ for (int i = 0; i < N; ++i) { \
+ EigenArrayMap<T>(Y_ptr, HxW, C) = \
+ (ConstEigenArrayMap<T>(X_ptr, HxW, C).rowwise() * \
+ scale_arr.transpose()) \
+ .rowwise() + \
+ bias_arr.transpose(); \
+ X_ptr += stride; \
+ Y_ptr += stride; \
+ } \
+ } \
+ template <> \
+ void AffineChannel<T, CPUContext, StorageOrder::NHWC>( \
+ const int N, \
+ const int C, \
+ const int HxW, \
+ const T* X, \
+ const T* scale, \
+ const T* bias, \
+ T* Y, \
+ CPUContext* /* context */) { \
+ EigenArrayMap<T>(Y, C, N * HxW) = \
+ (ConstEigenArrayMap<T>(X, C, N * HxW).colwise() * \
+ ConstEigenVectorArrayMap<T>(scale, C)) \
+ .colwise() + \
+ ConstEigenVectorArrayMap<T>(bias, C); \
+ }
+CAFFE2_SPECIALIZED_AFFINE_CHANNEL(float)
+#undef CAFFE2_SPECIALIZED_AFFINE_CHANNEL
+
+} // namespace math
+} // namespace caffe2
diff --git a/caffe2/utils/math/elementwise.cu b/caffe2/utils/math/elementwise.cu
new file mode 100644
index 0000000000..de28176e9f
--- /dev/null
+++ b/caffe2/utils/math/elementwise.cu
@@ -0,0 +1,105 @@
+#include "caffe2/utils/math/elementwise.h"
+
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/utils/math_utils.h"
+
+namespace caffe2 {
+namespace math {
+
+namespace {
+
+template <typename T>
+__global__ void AffineChannelNCHWCUDAKernel(
+ const int C,
+ const int HxW,
+ const int K,
+ const T* X,
+ const T* scale,
+ const T* bias,
+ T* Y);
+
+template <>
+__global__ void AffineChannelNCHWCUDAKernel<float>(
+ const int C,
+ const int HxW,
+ const int K,
+ const float* X,
+ const float* scale,
+ const float* bias,
+ float* Y) {
+ const int nc = blockIdx.x / K;
+ const int block = blockIdx.x % K;
+ const int c = nc % C;
+ const int w = block * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+ if (w < HxW) {
+ const int index = nc * HxW + w;
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
+ Y[index] = fmaf(__ldg(X + index), __ldg(scale + c), __ldg(bias + c));
+#else
+ Y[index] = fmaf(X[index], scale[c], bias[c]);
+#endif
+ }
+}
+
+template <typename T>
+__global__ void AffineChannelNHWCCUDAKernel(
+ const int C,
+ const T* X,
+ const T* scale,
+ const T* bias,
+ T* Y);
+
+template <>
+__global__ void AffineChannelNHWCCUDAKernel<float>(
+ const int C,
+ const float* X,
+ const float* scale,
+ const float* bias,
+ float* Y) {
+ for (int c = threadIdx.x; c < C; c += blockDim.x) {
+ const int index = blockIdx.x * C + c;
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
+ Y[index] = fmaf(__ldg(X + index), __ldg(scale + c), __ldg(bias + c));
+#else
+ Y[index] = fmaf(X[index], scale[c], bias[c]);
+#endif
+ }
+}
+
+} // namespace
+
+#define CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(T) \
+ template <> \
+ CAFFE2_CUDA_EXPORT void AffineChannel<T, CUDAContext, StorageOrder::NCHW>( \
+ const int N, \
+ const int C, \
+ const int HxW, \
+ const T* X, \
+ const T* scale, \
+ const T* bias, \
+ T* Y, \
+ CUDAContext* context) { \
+ const int K = DivUp(HxW, CAFFE_CUDA_NUM_THREADS); \
+ AffineChannelNCHWCUDAKernel<T> \
+ <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>( \
+ C, HxW, K, X, scale, bias, Y); \
+ } \
+ template <> \
+ CAFFE2_CUDA_EXPORT void AffineChannel<T, CUDAContext, StorageOrder::NHWC>( \
+ const int N, \
+ const int C, \
+ const int HxW, \
+ const T* X, \
+ const T* scale, \
+ const T* bias, \
+ T* Y, \
+ CUDAContext* context) { \
+ AffineChannelNHWCCUDAKernel<T> \
+ <<<N * HxW, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>( \
+ C, X, scale, bias, Y); \
+ }
+CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(float)
+#undef CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL
+
+} // namespace math
+} // namespace caffe2
diff --git a/caffe2/utils/math/elementwise.h b/caffe2/utils/math/elementwise.h
new file mode 100644
index 0000000000..accfac76d8
--- /dev/null
+++ b/caffe2/utils/math/elementwise.h
@@ -0,0 +1,24 @@
+#ifndef CAFFE2_UTILS_MATH_ELEMENTWISE_H_
+#define CAFFE2_UTILS_MATH_ELEMENTWISE_H_
+
+#include "caffe2/core/common.h"
+#include "caffe2/core/types.h"
+
+namespace caffe2 {
+namespace math {
+
+template <typename T, class Context, StorageOrder kOrder>
+CAFFE2_API void AffineChannel(
+ const int N,
+ const int C,
+ const int HxW,
+ const T* X,
+ const T* scale,
+ const T* bias,
+ T* Y,
+ Context* context);
+
+} // namespace math
+} // namespace caffe2
+
+#endif // CAFFE2_UTILS_MATH_ELEMENTWISE_H_
diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc
index dd782391b7..d59e0045e3 100644
--- a/caffe2/utils/math_cpu.cc
+++ b/caffe2/utils/math_cpu.cc
@@ -785,11 +785,11 @@ DELEGATE_CBRT_FUNCTION(float)
DELEGATE_CBRT_FUNCTION(double)
#undef DELEGATE_CBRT_FUNCTION
-#define DELEGATE_ERF_FUNCTION(T) \
- template <> \
- C10_EXPORT void Erf<T, CPUContext>( \
- const int N, const T* X, T* Y, CPUContext*) { \
- std::transform(X, X + N, Y, [](const T x) { return erf(x); }); \
+#define DELEGATE_ERF_FUNCTION(T) \
+ template <> \
+ C10_EXPORT void Erf<T, CPUContext>( \
+ const int N, const T* X, T* Y, CPUContext*) { \
+ std::transform(X, X + N, Y, [](const T x) { return erf(x); }); \
}
DELEGATE_ERF_FUNCTION(float)
DELEGATE_ERF_FUNCTION(double)
@@ -4141,51 +4141,6 @@ CAFFE2_SPECIALIZED_TRANSPOSE(std::uint8_t)
CAFFE2_SPECIALIZED_TRANSPOSE(std::uint16_t)
#undef CAFFE2_SPECIALIZED_TRANSPOSE
-#define CAFFE2_SPECIALIZED_AFFINE_CHANNEL(T) \
- template <> \
- void AffineChannel<T, CPUContext, StorageOrder::NCHW>( \
- const int N, \
- const int C, \
- const int HxW, \
- const T* X, \
- const T* scale, \
- const T* bias, \
- T* Y, \
- CPUContext* /* context */) { \
- ConstEigenVectorArrayMap<T> scale_arr(scale, C); \
- ConstEigenVectorArrayMap<T> bias_arr(bias, C); \
- const int stride = C * HxW; \
- const T* X_ptr = X; \
- T* Y_ptr = Y; \
- for (int i = 0; i < N; ++i) { \
- EigenArrayMap<T>(Y_ptr, HxW, C) = \
- (ConstEigenArrayMap<T>(X_ptr, HxW, C).rowwise() * \
- scale_arr.transpose()) \
- .rowwise() + \
- bias_arr.transpose(); \
- X_ptr += stride; \
- Y_ptr += stride; \
- } \
- } \
- template <> \
- void AffineChannel<T, CPUContext, StorageOrder::NHWC>( \
- const int N, \
- const int C, \
- const int HxW, \
- const T* X, \
- const T* scale, \
- const T* bias, \
- T* Y, \
- CPUContext* /* context */) { \
- EigenArrayMap<T>(Y, C, N * HxW) = \
- (ConstEigenArrayMap<T>(X, C, N * HxW).colwise() * \
- ConstEigenVectorArrayMap<T>(scale, C)) \
- .colwise() + \
- ConstEigenVectorArrayMap<T>(bias, C); \
- }
-CAFFE2_SPECIALIZED_AFFINE_CHANNEL(float)
-#undef CAFFE2_SPECIALIZED_AFFINE_CHANNEL
-
#define CAFFE2_SPECIALIZED_NCHW2NHWC(T) \
template <> \
C10_EXPORT void NCHW2NHWC<T, CPUContext>( \
diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu
index 820dfe490b..0229969e74 100644
--- a/caffe2/utils/math_gpu.cu
+++ b/caffe2/utils/math_gpu.cu
@@ -764,9 +764,9 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
#ifdef __HIP_PLATFORM_HCC__
// rocblas doesn't support cublasSgemmEx type API yet.
- // It has more general rocblas_gemm_ex API which is more close to cublasGemmEx
- // rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C, whereas
- // cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C
+ // It has more general rocblas_gemm_ex API which is more close to
+ // cublasGemmEx rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C,
+ // whereas cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C
ROCBLAS_ENFORCE(rocblas_gemm_ex(
context->rocblashandle(),
cu_trans_B,
@@ -1171,8 +1171,8 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
CUBLAS_ENFORCE(cublasSetPointerMode(
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
#ifdef __HIP_PLATFORM_HCC__
- // D[i*stride_d] = alpha*op(A[i*stride_a])*op(B[i*stride_b]) + beta*C[i*stride_c],
- // for i in [0,batch_count-1]
+ // D[i*stride_d] = alpha*op(A[i*stride_a])*op(B[i*stride_b]) +
+ // beta*C[i*stride_c], for i in [0,batch_count-1]
ROCBLAS_ENFORCE(rocblas_gemm_strided_batched_ex(
context->rocblashandle(),
cu_trans_B,
@@ -1560,9 +1560,9 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
#ifdef __HIP_PLATFORM_HCC__
// rocblas doesn't support cublasSgemmEx type API yet.
- // It has more general rocblas_gemm_ex API which is more close to cublasGemmEx
- // rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C, whereas
- // cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C
+ // It has more general rocblas_gemm_ex API which is more close to
+ // cublasGemmEx rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C,
+ // whereas cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C
ROCBLAS_ENFORCE(rocblas_gemm_ex(
context->rocblashandle(),
cu_trans_A,
@@ -4219,51 +4219,6 @@ CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(int)
CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(int64_t)
#undef CAFFE2_SPECIALIZED_CUDA_TRANSPOSE
-namespace {
-
-template <typename T, StorageOrder kOrder>
-__global__ void AffineChannelCUDAKernel(
- const int size,
- const int C,
- const int HxW,
- const T* X,
- const T* scale,
- const T* bias,
- T* Y) {
- CUDA_1D_KERNEL_LOOP(i, size) {
- const int c = kOrder == StorageOrder::NCHW ? i / HxW % C : i % C;
-#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
- Y[i] = __ldg(scale + c) * __ldg(X + i) + __ldg(bias + c);
-#else
- Y[i] = scale[c] * X[i] + bias[c];
-#endif
- }
-}
-
-} // namespace
-
-#define CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(T, kOrder) \
- template <> \
- CAFFE2_CUDA_EXPORT void AffineChannel<T, CUDAContext, kOrder>( \
- const int N, \
- const int C, \
- const int HxW, \
- const T* X, \
- const T* scale, \
- const T* bias, \
- T* Y, \
- CUDAContext* context) { \
- const int size = N * C * HxW; \
- AffineChannelCUDAKernel<T, kOrder> \
- <<<CAFFE_GET_BLOCKS(size), \
- CAFFE_CUDA_NUM_THREADS, \
- 0, \
- context->cuda_stream()>>>(size, C, HxW, X, scale, bias, Y); \
- }
-CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(float, StorageOrder::NCHW)
-CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(float, StorageOrder::NHWC)
-#undef CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL
-
#define CAFFE2_SPECIALIZED_CUDA_NCHW2NHWC(T) \
template <> \
CAFFE2_CUDA_EXPORT void NCHW2NHWC<T, CUDAContext>( \
diff --git a/caffe2/utils/math_utils.h b/caffe2/utils/math_utils.h
index 02e327e6ac..bd53eb1ec9 100644
--- a/caffe2/utils/math_utils.h
+++ b/caffe2/utils/math_utils.h
@@ -11,6 +11,7 @@
namespace caffe2 {
namespace math {
+
namespace utils {
MATH_UTILS_DECL bool Not(const bool x) {
@@ -135,6 +136,34 @@ CAFFE2_API void ComputeTransposedStrides(
int* strides);
} // namespace utils
+
+// Calculates ceil(a / b). User must be careful to ensure that there
+// is no overflow or underflow in the calculation.
+template <typename T>
+constexpr T DivUp(const T a, const T b) {
+ return (a + b - T(1)) / b;
+}
+
+// Rounds a up to the next highest multiple of b. User must be careful
+// to ensure that there is no overflow or underflow in the calculation
+// of divUp.
+template <typename T>
+constexpr T RoundUp(const T a, const T b) {
+ return DivUp<T>(a, b) * b;
+}
+
+// Returns log2(n) for a positive integer type
+template <typename T>
+constexpr int IntegerLog2(T n, int p = 0) {
+ return (n <= 1) ? p : IntegerLog2(n / 2, p + 1);
+}
+
+// Returns the next highest power-of-2 for an integer type
+template <typename T>
+constexpr T IntegerNextHighestPowerOf2(T v) {
+ return (IntegerIsPowerOf2(v) ? T(2) * v : (T(1) << (IntegerLog2(v) + 1)));
+}
+
} // namespace math
} // namespace caffe2