diff options
author | Xiaomeng Yang <yangxm@fb.com> | 2019-01-18 22:37:12 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-18 22:40:16 -0800 |
commit | b436f94b53e261177422fe92680d42f19195d3d0 (patch) | |
tree | 37d276fb86c4f7ee5ba474308c6037d1802af786 | |
parent | e8b872abe225dfc9f622eaa847e82ad7296ea17b (diff) | |
download | pytorch-b436f94b53e261177422fe92680d42f19195d3d0.tar.gz pytorch-b436f94b53e261177422fe92680d42f19195d3d0.tar.bz2 pytorch-b436f94b53e261177422fe92680d42f19195d3d0.zip |
Separate affine_channel from math and optimize it (#16135)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16135
Separate affine_channel from math and optimize it
i-am-not-moving-c2-to-c10
Reviewed By: houseroad
Differential Revision: D13727606
fbshipit-source-id: 8980af4afadaf964a18a9da581106fe30896a7e9
-rw-r--r-- | caffe2/operators/length_split_op.h | 2 | ||||
-rw-r--r-- | caffe2/operators/stylizer_ops.cc | 2 | ||||
-rw-r--r-- | caffe2/operators/top_k.cu | 2 | ||||
-rw-r--r-- | caffe2/operators/top_k_heap_selection.cuh | 2 | ||||
-rw-r--r-- | caffe2/operators/top_k_radix_selection.cuh | 4 | ||||
-rw-r--r-- | caffe2/utils/CMakeLists.txt | 21 | ||||
-rw-r--r-- | caffe2/utils/GpuScanUtils.cuh | 2 | ||||
-rw-r--r-- | caffe2/utils/math.h | 38 | ||||
-rw-r--r-- | caffe2/utils/math/elementwise.cc | 55 | ||||
-rw-r--r-- | caffe2/utils/math/elementwise.cu | 105 | ||||
-rw-r--r-- | caffe2/utils/math/elementwise.h | 24 | ||||
-rw-r--r-- | caffe2/utils/math_cpu.cc | 55 | ||||
-rw-r--r-- | caffe2/utils/math_gpu.cu | 61 | ||||
-rw-r--r-- | caffe2/utils/math_utils.h | 29 |
14 files changed, 246 insertions, 156 deletions
diff --git a/caffe2/operators/length_split_op.h b/caffe2/operators/length_split_op.h index 7270c0274e..87eba854ea 100644 --- a/caffe2/operators/length_split_op.h +++ b/caffe2/operators/length_split_op.h @@ -57,7 +57,7 @@ class LengthsSplitOp final : public Operator<Context> { for (int i = 0; i < M; i++) { int32_t mod = Ldata[i] % n_split_; int32_t res = - mod != 0 ? math::divUp(Ldata[i], n_split_) : Ldata[i] / n_split_ + 1; + mod != 0 ? math::DivUp(Ldata[i], n_split_) : Ldata[i] / n_split_ + 1; for (int j = 0; j < n_split_; j++) { Ydata[(i * n_split_) + j] = mod-- > 0 ? res : res - 1; } diff --git a/caffe2/operators/stylizer_ops.cc b/caffe2/operators/stylizer_ops.cc index d731ace03c..76bdbce45b 100644 --- a/caffe2/operators/stylizer_ops.cc +++ b/caffe2/operators/stylizer_ops.cc @@ -138,7 +138,7 @@ class PackedInt8BGRANHWCToNCHWCStylizerPreprocessOp // For ARM NEON, we read in multiples of kNeonNoiseReadSize since // the inner loop is vectorized. Round up to the next highest // multiple of kNeonNoiseReadSize - size = math::roundUp(size, kNeonNoiseReadSize) + size; + size = math::RoundUp(size, kNeonNoiseReadSize) + size; noise->Resize(size); math::RandGaussian<float, CPUContext>( diff --git a/caffe2/operators/top_k.cu b/caffe2/operators/top_k.cu index 52a59d01ce..85810547e8 100644 --- a/caffe2/operators/top_k.cu +++ b/caffe2/operators/top_k.cu @@ -56,7 +56,7 @@ void RunRadixSelectionImpl( int64_t* indices, CUDAContext* context) { const int block = std::min( - math::roundUp(static_cast<int>(inner_size), kWarpSize), + math::RoundUp(static_cast<int>(inner_size), kWarpSize), CAFFE_CUDA_NUM_THREADS); gatherTopK<T, kSelectMax, int64_t> <<<outer_size, block, 0, context->cuda_stream()>>>( diff --git a/caffe2/operators/top_k_heap_selection.cuh b/caffe2/operators/top_k_heap_selection.cuh index 674d317be4..e9c5c0f6d2 100644 --- a/caffe2/operators/top_k_heap_selection.cuh +++ b/caffe2/operators/top_k_heap_selection.cuh @@ -72,7 +72,7 @@ __device__ inline void warpHeapInsert(K k, V v, K* keyHeap, V* valueHeap) { // log2(8 / 2) = 2 levels of interior nodes for heap size 8 (0 and 12) int i = 0; #pragma unroll - for (int levels = 0; levels < math::integerLog2(HeapSize / 2); ++levels) { + for (int levels = 0; levels < math::IntegerLog2(HeapSize / 2); ++levels) { int leftChild = i * 2 + 1; int rightChild = leftChild + 1; K leftKey = keyHeap[leftChild]; diff --git a/caffe2/operators/top_k_radix_selection.cuh b/caffe2/operators/top_k_radix_selection.cuh index 0bf38e4005..adc9ff141c 100644 --- a/caffe2/operators/top_k_radix_selection.cuh +++ b/caffe2/operators/top_k_radix_selection.cuh @@ -218,7 +218,7 @@ __device__ DataType findPattern(DataType* smem, __syncthreads(); // All threads participate in the loop, in order to sync on the flag - int numIterations = math::roundUp(sliceSize, (int) blockDim.x); + int numIterations = math::RoundUp(sliceSize, (int) blockDim.x); for (int i = threadIdx.x; i < numIterations; i += blockDim.x) { bool inRange = (i < sliceSize); DataType v = inRange ? data[i] : (DataType)0; @@ -388,7 +388,7 @@ __global__ void gatherTopK(const T* inputPtr, // All threads need to participate in the loop and the prefix sum, // but not necessarily in the load; hence loop bounds being rounded // up to a multiple of the block dim. - int numIterations = math::roundUp(inputSliceSize, (int) blockDim.x); + int numIterations = math::RoundUp(inputSliceSize, (int) blockDim.x); int writeIndexStart = 0; for (int i = threadIdx.x; i < numIterations; i += blockDim.x) { diff --git a/caffe2/utils/CMakeLists.txt b/caffe2/utils/CMakeLists.txt index a99fd7fe76..fafe6f1923 100644 --- a/caffe2/utils/CMakeLists.txt +++ b/caffe2/utils/CMakeLists.txt @@ -1,17 +1,18 @@ list(APPEND Caffe2_CPU_SRCS + utils/bench_utils.cc + utils/cpuid.cc + utils/math/elementwise.cc + utils/math_cpu.cc + utils/math_utils.cc + utils/murmur_hash3.cc utils/proto_convert.cc - utils/proto_wrap.cc utils/proto_utils.cc - utils/murmur_hash3.cc - utils/smart_tensor_printer.cc + utils/proto_wrap.cc utils/signal_handler.cc + utils/smart_tensor_printer.cc utils/string_utils.cc - utils/threadpool/ThreadPool.cc - utils/cpuid.cc - utils/bench_utils.cc - utils/math_cpu.cc - utils/math_utils.cc - utils/thread_name.cc) + utils/thread_name.cc + utils/threadpool/ThreadPool.cc) # ---[ threadpool/pthreadpool* is a local modification of the NNPACK # pthreadpool with a very similar interface. Neither NNPACK, nor this @@ -24,10 +25,12 @@ if (NOT MSVC) endif() set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} + utils/math/elementwise.cu utils/math_gpu.cu ) set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} + utils/math/hip/elementwise.hip utils/hip/math_gpu.hip ) diff --git a/caffe2/utils/GpuScanUtils.cuh b/caffe2/utils/GpuScanUtils.cuh index c6003825ff..24ae38c076 100644 --- a/caffe2/utils/GpuScanUtils.cuh +++ b/caffe2/utils/GpuScanUtils.cuh @@ -123,7 +123,7 @@ __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, Bi // The outgoing carry for all threads is the last warp's sum #if defined(__HIP_PLATFORM_HCC__) - *carry = smem[math::divUp<int>(blockDim.x, kWarpSize) - 1]; + *carry = smem[math::DivUp<int>(blockDim.x, kWarpSize) - 1]; #else *carry = smem[(blockDim.x / kWarpSize) - 1]; #endif // __HIP_PLATFORM_HCC__ diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h index cb949eec5c..af0e66823e 100644 --- a/caffe2/utils/math.h +++ b/caffe2/utils/math.h @@ -15,6 +15,7 @@ extern "C" { #include "caffe2/core/common.h" #include "caffe2/core/types.h" +#include "caffe2/utils/math/elementwise.h" #include "caffe2/utils/math_utils.h" namespace caffe2 { @@ -662,17 +663,6 @@ CAFFE2_API void CopyMatrix( template <typename T, class Context> CAFFE2_API void CopyVector(const int N, const T* A, T* B, Context* context); -template <typename T, class Context, StorageOrder kOrder> -CAFFE2_API void AffineChannel( - const int N, - const int C, - const int HxW, - const T* X, - const T* scale, - const T* bias, - T* Y, - Context* context); - template <typename T, class Context> CAFFE2_API void NCHW2NHWC( const int N, @@ -691,32 +681,6 @@ CAFFE2_API void NHWC2NCHW( T* Y, Context* context); -// Calculates ceil(a / b). User must be careful to ensure that there -// is no overflow or underflow in the calculation. -template <typename T> -constexpr T divUp(T a, T b) { - return (a + b - (T)1) / b; -} - -// Rounds a up to the next highest multiple of b. User must be careful -// to ensure that there is no overflow or underflow in the calculation -// of divUp. -template <typename T> -constexpr T roundUp(T a, T b) { - return divUp<T>(a, b) * b; -} - -// Returns log2(n) for a positive integer type -template <typename T> -constexpr int integerLog2(T n, int p = 0) { - return (n <= 1) ? p : integerLog2(n / 2, p + 1); -} - -// Returns the next highest power-of-2 for an integer type -template <typename T> -constexpr T integerNextHighestPowerOf2(T v) { - return (integerIsPowerOf2(v) ? (T)2 * v : ((T)1 << (integerLog2(v) + 1))); -} } // namespace math } // namespace caffe2 diff --git a/caffe2/utils/math/elementwise.cc b/caffe2/utils/math/elementwise.cc new file mode 100644 index 0000000000..5634fb876b --- /dev/null +++ b/caffe2/utils/math/elementwise.cc @@ -0,0 +1,55 @@ +#include "caffe2/utils/math/elementwise.h" + +#include "caffe2/core/context.h" +#include "caffe2/utils/eigen_utils.h" + +namespace caffe2 { +namespace math { + +#define CAFFE2_SPECIALIZED_AFFINE_CHANNEL(T) \ + template <> \ + void AffineChannel<T, CPUContext, StorageOrder::NCHW>( \ + const int N, \ + const int C, \ + const int HxW, \ + const T* X, \ + const T* scale, \ + const T* bias, \ + T* Y, \ + CPUContext* /* context */) { \ + ConstEigenVectorArrayMap<T> scale_arr(scale, C); \ + ConstEigenVectorArrayMap<T> bias_arr(bias, C); \ + const int stride = C * HxW; \ + const T* X_ptr = X; \ + T* Y_ptr = Y; \ + for (int i = 0; i < N; ++i) { \ + EigenArrayMap<T>(Y_ptr, HxW, C) = \ + (ConstEigenArrayMap<T>(X_ptr, HxW, C).rowwise() * \ + scale_arr.transpose()) \ + .rowwise() + \ + bias_arr.transpose(); \ + X_ptr += stride; \ + Y_ptr += stride; \ + } \ + } \ + template <> \ + void AffineChannel<T, CPUContext, StorageOrder::NHWC>( \ + const int N, \ + const int C, \ + const int HxW, \ + const T* X, \ + const T* scale, \ + const T* bias, \ + T* Y, \ + CPUContext* /* context */) { \ + EigenArrayMap<T>(Y, C, N * HxW) = \ + (ConstEigenArrayMap<T>(X, C, N * HxW).colwise() * \ + ConstEigenVectorArrayMap<T>(scale, C)) \ + .colwise() + \ + ConstEigenVectorArrayMap<T>(bias, C); \ + } +CAFFE2_SPECIALIZED_AFFINE_CHANNEL(float) +#undef CAFFE2_SPECIALIZED_AFFINE_CHANNEL + +} // namespace math +} // namespace caffe2 diff --git a/caffe2/utils/math/elementwise.cu b/caffe2/utils/math/elementwise.cu new file mode 100644 index 0000000000..de28176e9f --- /dev/null +++ b/caffe2/utils/math/elementwise.cu @@ -0,0 +1,105 @@ +#include "caffe2/utils/math/elementwise.h" + +#include "caffe2/core/context_gpu.h" +#include "caffe2/utils/math_utils.h" + +namespace caffe2 { +namespace math { + +namespace { + +template <typename T> +__global__ void AffineChannelNCHWCUDAKernel( + const int C, + const int HxW, + const int K, + const T* X, + const T* scale, + const T* bias, + T* Y); + +template <> +__global__ void AffineChannelNCHWCUDAKernel<float>( + const int C, + const int HxW, + const int K, + const float* X, + const float* scale, + const float* bias, + float* Y) { + const int nc = blockIdx.x / K; + const int block = blockIdx.x % K; + const int c = nc % C; + const int w = block * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (w < HxW) { + const int index = nc * HxW + w; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + Y[index] = fmaf(__ldg(X + index), __ldg(scale + c), __ldg(bias + c)); +#else + Y[index] = fmaf(X[index], scale[c], bias[c]); +#endif + } +} + +template <typename T> +__global__ void AffineChannelNHWCCUDAKernel( + const int C, + const T* X, + const T* scale, + const T* bias, + T* Y); + +template <> +__global__ void AffineChannelNHWCCUDAKernel<float>( + const int C, + const float* X, + const float* scale, + const float* bias, + float* Y) { + for (int c = threadIdx.x; c < C; c += blockDim.x) { + const int index = blockIdx.x * C + c; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + Y[index] = fmaf(__ldg(X + index), __ldg(scale + c), __ldg(bias + c)); +#else + Y[index] = fmaf(X[index], scale[c], bias[c]); +#endif + } +} + +} // namespace + +#define CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(T) \ + template <> \ + CAFFE2_CUDA_EXPORT void AffineChannel<T, CUDAContext, StorageOrder::NCHW>( \ + const int N, \ + const int C, \ + const int HxW, \ + const T* X, \ + const T* scale, \ + const T* bias, \ + T* Y, \ + CUDAContext* context) { \ + const int K = DivUp(HxW, CAFFE_CUDA_NUM_THREADS); \ + AffineChannelNCHWCUDAKernel<T> \ + <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>( \ + C, HxW, K, X, scale, bias, Y); \ + } \ + template <> \ + CAFFE2_CUDA_EXPORT void AffineChannel<T, CUDAContext, StorageOrder::NHWC>( \ + const int N, \ + const int C, \ + const int HxW, \ + const T* X, \ + const T* scale, \ + const T* bias, \ + T* Y, \ + CUDAContext* context) { \ + AffineChannelNHWCCUDAKernel<T> \ + <<<N * HxW, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>( \ + C, X, scale, bias, Y); \ + } +CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(float) +#undef CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL + +} // namespace math +} // namespace caffe2 diff --git a/caffe2/utils/math/elementwise.h b/caffe2/utils/math/elementwise.h new file mode 100644 index 0000000000..accfac76d8 --- /dev/null +++ b/caffe2/utils/math/elementwise.h @@ -0,0 +1,24 @@ +#ifndef CAFFE2_UTILS_MATH_ELEMENTWISE_H_ +#define CAFFE2_UTILS_MATH_ELEMENTWISE_H_ + +#include "caffe2/core/common.h" +#include "caffe2/core/types.h" + +namespace caffe2 { +namespace math { + +template <typename T, class Context, StorageOrder kOrder> +CAFFE2_API void AffineChannel( + const int N, + const int C, + const int HxW, + const T* X, + const T* scale, + const T* bias, + T* Y, + Context* context); + +} // namespace math +} // namespace caffe2 + +#endif // CAFFE2_UTILS_MATH_ELEMENTWISE_H_ diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc index dd782391b7..d59e0045e3 100644 --- a/caffe2/utils/math_cpu.cc +++ b/caffe2/utils/math_cpu.cc @@ -785,11 +785,11 @@ DELEGATE_CBRT_FUNCTION(float) DELEGATE_CBRT_FUNCTION(double) #undef DELEGATE_CBRT_FUNCTION -#define DELEGATE_ERF_FUNCTION(T) \ - template <> \ - C10_EXPORT void Erf<T, CPUContext>( \ - const int N, const T* X, T* Y, CPUContext*) { \ - std::transform(X, X + N, Y, [](const T x) { return erf(x); }); \ +#define DELEGATE_ERF_FUNCTION(T) \ + template <> \ + C10_EXPORT void Erf<T, CPUContext>( \ + const int N, const T* X, T* Y, CPUContext*) { \ + std::transform(X, X + N, Y, [](const T x) { return erf(x); }); \ } DELEGATE_ERF_FUNCTION(float) DELEGATE_ERF_FUNCTION(double) @@ -4141,51 +4141,6 @@ CAFFE2_SPECIALIZED_TRANSPOSE(std::uint8_t) CAFFE2_SPECIALIZED_TRANSPOSE(std::uint16_t) #undef CAFFE2_SPECIALIZED_TRANSPOSE -#define CAFFE2_SPECIALIZED_AFFINE_CHANNEL(T) \ - template <> \ - void AffineChannel<T, CPUContext, StorageOrder::NCHW>( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - const T* scale, \ - const T* bias, \ - T* Y, \ - CPUContext* /* context */) { \ - ConstEigenVectorArrayMap<T> scale_arr(scale, C); \ - ConstEigenVectorArrayMap<T> bias_arr(bias, C); \ - const int stride = C * HxW; \ - const T* X_ptr = X; \ - T* Y_ptr = Y; \ - for (int i = 0; i < N; ++i) { \ - EigenArrayMap<T>(Y_ptr, HxW, C) = \ - (ConstEigenArrayMap<T>(X_ptr, HxW, C).rowwise() * \ - scale_arr.transpose()) \ - .rowwise() + \ - bias_arr.transpose(); \ - X_ptr += stride; \ - Y_ptr += stride; \ - } \ - } \ - template <> \ - void AffineChannel<T, CPUContext, StorageOrder::NHWC>( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - const T* scale, \ - const T* bias, \ - T* Y, \ - CPUContext* /* context */) { \ - EigenArrayMap<T>(Y, C, N * HxW) = \ - (ConstEigenArrayMap<T>(X, C, N * HxW).colwise() * \ - ConstEigenVectorArrayMap<T>(scale, C)) \ - .colwise() + \ - ConstEigenVectorArrayMap<T>(bias, C); \ - } -CAFFE2_SPECIALIZED_AFFINE_CHANNEL(float) -#undef CAFFE2_SPECIALIZED_AFFINE_CHANNEL - #define CAFFE2_SPECIALIZED_NCHW2NHWC(T) \ template <> \ C10_EXPORT void NCHW2NHWC<T, CPUContext>( \ diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu index 820dfe490b..0229969e74 100644 --- a/caffe2/utils/math_gpu.cu +++ b/caffe2/utils/math_gpu.cu @@ -764,9 +764,9 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>( context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); #ifdef __HIP_PLATFORM_HCC__ // rocblas doesn't support cublasSgemmEx type API yet. - // It has more general rocblas_gemm_ex API which is more close to cublasGemmEx - // rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C, whereas - // cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C + // It has more general rocblas_gemm_ex API which is more close to + // cublasGemmEx rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C, + // whereas cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C ROCBLAS_ENFORCE(rocblas_gemm_ex( context->rocblashandle(), cu_trans_B, @@ -1171,8 +1171,8 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>( CUBLAS_ENFORCE(cublasSetPointerMode( context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); #ifdef __HIP_PLATFORM_HCC__ - // D[i*stride_d] = alpha*op(A[i*stride_a])*op(B[i*stride_b]) + beta*C[i*stride_c], - // for i in [0,batch_count-1] + // D[i*stride_d] = alpha*op(A[i*stride_a])*op(B[i*stride_b]) + + // beta*C[i*stride_c], for i in [0,batch_count-1] ROCBLAS_ENFORCE(rocblas_gemm_strided_batched_ex( context->rocblashandle(), cu_trans_B, @@ -1560,9 +1560,9 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>( context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); #ifdef __HIP_PLATFORM_HCC__ // rocblas doesn't support cublasSgemmEx type API yet. - // It has more general rocblas_gemm_ex API which is more close to cublasGemmEx - // rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C, whereas - // cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C + // It has more general rocblas_gemm_ex API which is more close to + // cublasGemmEx rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C, + // whereas cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C ROCBLAS_ENFORCE(rocblas_gemm_ex( context->rocblashandle(), cu_trans_A, @@ -4219,51 +4219,6 @@ CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(int) CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(int64_t) #undef CAFFE2_SPECIALIZED_CUDA_TRANSPOSE -namespace { - -template <typename T, StorageOrder kOrder> -__global__ void AffineChannelCUDAKernel( - const int size, - const int C, - const int HxW, - const T* X, - const T* scale, - const T* bias, - T* Y) { - CUDA_1D_KERNEL_LOOP(i, size) { - const int c = kOrder == StorageOrder::NCHW ? i / HxW % C : i % C; -#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) - Y[i] = __ldg(scale + c) * __ldg(X + i) + __ldg(bias + c); -#else - Y[i] = scale[c] * X[i] + bias[c]; -#endif - } -} - -} // namespace - -#define CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(T, kOrder) \ - template <> \ - CAFFE2_CUDA_EXPORT void AffineChannel<T, CUDAContext, kOrder>( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - const T* scale, \ - const T* bias, \ - T* Y, \ - CUDAContext* context) { \ - const int size = N * C * HxW; \ - AffineChannelCUDAKernel<T, kOrder> \ - <<<CAFFE_GET_BLOCKS(size), \ - CAFFE_CUDA_NUM_THREADS, \ - 0, \ - context->cuda_stream()>>>(size, C, HxW, X, scale, bias, Y); \ - } -CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(float, StorageOrder::NCHW) -CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(float, StorageOrder::NHWC) -#undef CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL - #define CAFFE2_SPECIALIZED_CUDA_NCHW2NHWC(T) \ template <> \ CAFFE2_CUDA_EXPORT void NCHW2NHWC<T, CUDAContext>( \ diff --git a/caffe2/utils/math_utils.h b/caffe2/utils/math_utils.h index 02e327e6ac..bd53eb1ec9 100644 --- a/caffe2/utils/math_utils.h +++ b/caffe2/utils/math_utils.h @@ -11,6 +11,7 @@ namespace caffe2 { namespace math { + namespace utils { MATH_UTILS_DECL bool Not(const bool x) { @@ -135,6 +136,34 @@ CAFFE2_API void ComputeTransposedStrides( int* strides); } // namespace utils + +// Calculates ceil(a / b). User must be careful to ensure that there +// is no overflow or underflow in the calculation. +template <typename T> +constexpr T DivUp(const T a, const T b) { + return (a + b - T(1)) / b; +} + +// Rounds a up to the next highest multiple of b. User must be careful +// to ensure that there is no overflow or underflow in the calculation +// of divUp. +template <typename T> +constexpr T RoundUp(const T a, const T b) { + return DivUp<T>(a, b) * b; +} + +// Returns log2(n) for a positive integer type +template <typename T> +constexpr int IntegerLog2(T n, int p = 0) { + return (n <= 1) ? p : IntegerLog2(n / 2, p + 1); +} + +// Returns the next highest power-of-2 for an integer type +template <typename T> +constexpr T IntegerNextHighestPowerOf2(T v) { + return (IntegerIsPowerOf2(v) ? T(2) * v : (T(1) << (IntegerLog2(v) + 1))); +} + } // namespace math } // namespace caffe2 |