From 2db847b3a7edc48652e144e7c9d7aa0bbed66aaa Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Thu, 7 Feb 2019 18:19:46 -0800 Subject: Separate elementwise level2 math functions (#16753) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16753 Separate elementwise level2 math functions i-am-not-moving-c2-to-c10 Reviewed By: houseroad Differential Revision: D13954928 fbshipit-source-id: 1ca7a5d3da96e32510f502e5e4e79168854bee67 --- caffe2/operators/conv_transpose_op_mobile_impl.h | 2 +- caffe2/operators/layer_norm_op.cu | 2 +- caffe2/operators/minmax_gradient_ops.cc | 74 ++-- caffe2/operators/minmax_ops.cc | 33 +- caffe2/operators/minmax_ops.cu | 56 +++ caffe2/operators/minmax_ops.h | 115 ++++-- caffe2/operators/rsqrt_op.cu | 2 +- caffe2/operators/utility_ops.cu | 104 +---- caffe2/quantization/server/im2col_dnnlowp.h | 2 +- caffe2/utils/CMakeLists.txt | 5 +- caffe2/utils/math.h | 71 ++-- caffe2/utils/math/broadcast.cc | 55 +++ caffe2/utils/math/broadcast.cu | 108 +++++ caffe2/utils/math/broadcast.h | 24 ++ caffe2/utils/math/elementwise.cc | 273 +++++++++---- caffe2/utils/math/elementwise.cu | 486 +++++++++++++++-------- caffe2/utils/math/elementwise.h | 101 +++-- caffe2/utils/math/half_utils.h | 49 +++ caffe2/utils/math/reduce.cc | 2 +- caffe2/utils/math/reduce.cu | 2 +- caffe2/utils/math/utils.cc | 347 ++++++++++++++++ caffe2/utils/math/utils.h | 178 +++++++++ caffe2/utils/math_cpu.cc | 111 ------ caffe2/utils/math_gpu.cu | 70 +--- caffe2/utils/math_gpu_test.cc | 18 - caffe2/utils/math_utils.cc | 347 ---------------- caffe2/utils/math_utils.h | 178 --------- 27 files changed, 1558 insertions(+), 1257 deletions(-) create mode 100644 caffe2/operators/minmax_ops.cu create mode 100644 caffe2/utils/math/broadcast.cc create mode 100644 caffe2/utils/math/broadcast.cu create mode 100644 caffe2/utils/math/broadcast.h create mode 100644 caffe2/utils/math/half_utils.h create mode 100644 caffe2/utils/math/utils.cc create mode 100644 caffe2/utils/math/utils.h delete mode 100644 caffe2/utils/math_utils.cc delete mode 100644 caffe2/utils/math_utils.h (limited to 'caffe2') diff --git a/caffe2/operators/conv_transpose_op_mobile_impl.h b/caffe2/operators/conv_transpose_op_mobile_impl.h index f586453d5c..6869f5178d 100644 --- a/caffe2/operators/conv_transpose_op_mobile_impl.h +++ b/caffe2/operators/conv_transpose_op_mobile_impl.h @@ -18,7 +18,7 @@ #include "caffe2/utils/eigen_utils.h" #include "caffe2/utils/fixed_divisor.h" #include "caffe2/utils/math.h" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math/utils.h" C10_DECLARE_bool(caffe2_force_shared_col_buffer); diff --git a/caffe2/operators/layer_norm_op.cu b/caffe2/operators/layer_norm_op.cu index c3465c1d2d..440783c6eb 100644 --- a/caffe2/operators/layer_norm_op.cu +++ b/caffe2/operators/layer_norm_op.cu @@ -4,7 +4,7 @@ #include "caffe2/core/context_gpu.h" #include "caffe2/utils/math.h" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math/utils.h" namespace caffe2 { diff --git a/caffe2/operators/minmax_gradient_ops.cc b/caffe2/operators/minmax_gradient_ops.cc index 9c7df22d55..d288eb0946 100644 --- a/caffe2/operators/minmax_gradient_ops.cc +++ b/caffe2/operators/minmax_gradient_ops.cc @@ -1,66 +1,66 @@ #include "caffe2/operators/minmax_ops.h" -#include "caffe2/utils/eigen_utils.h" -namespace caffe2 { +#include +#include -REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp); -REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp); +#include "caffe2/utils/eigen_utils.h" -OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX); -OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX); +namespace caffe2 { template bool SelectGradientOpBase::RunOnDevice() { - auto& output = Input(0); - auto& grad_output = Input(1); - const int kInputStartOffset = 2; - - const T* data = output.template data(); - ConstEigenArrayMap output_array( - output.template data(), 1, output.numel()); - ConstEigenArrayMap grad_out_array( - grad_output.template data(), 1, grad_output.numel()); - + const auto& Y = Input(0); + const auto& dY = Input(1); + const int N = Y.numel(); + ConstEigenVectorArrayMap Y_arr(Y.template data(), N); + ConstEigenVectorArrayMap dY_arr(dY.template data(), N); for (int i = 0; i < OutputSize(); i++) { - auto& input = Input(i + kInputStartOffset); - ConstEigenArrayMap input_array( - input.template data(), 1, input.numel()); - - auto* grad_input = Output(i, input.sizes(), at::dtype()); - EigenArrayMap grad_in_array( - grad_input->template mutable_data(), 1, grad_input->numel()); - grad_in_array = grad_out_array * - input_array.cwiseEqual(output_array).template cast(); + const auto& Xi = Input(i + 2); + auto* dXi = Output(i, Xi.sizes(), at::dtype()); + ConstEigenVectorArrayMap Xi_arr(Xi.template data(), N); + EigenVectorArrayMap dXi_arr(dXi->template mutable_data(), N); + dXi_arr = (Xi_arr == Y_arr).template cast() * dY_arr; } return true; } +REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp); +REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp); + +OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX); +OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX); + +namespace { + class GetMaxGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; - vector GetGradientDefs() override { - auto gradInputs = vector(); - auto inputs = vector{O(0), GO(0)}; - for (int i = 0; i < def_.input_size(); i++) { - gradInputs.push_back(GI(i)); + std::vector GetGradientDefs() override { + std::vector inputs = {O(0), GO(0)}; + std::vector grad_inputs; + for (int i = 0; i < def_.input_size(); ++i) { inputs.push_back(I(i)); + grad_inputs.push_back(GI(i)); } - return SingleGradientDef("MaxGradient", "", inputs, gradInputs); + return SingleGradientDef("MaxGradient", "", inputs, grad_inputs); } }; -REGISTER_GRADIENT(Max, GetMaxGradient); class GetMinGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; vector GetGradientDefs() override { - auto gradInputs = vector(); - auto inputs = vector{O(0), GO(0)}; - for (int i = 0; i < def_.input_size(); i++) { - gradInputs.push_back(GI(i)); + std::vector inputs = {O(0), GO(0)}; + std::vector grad_inputs; + for (int i = 0; i < def_.input_size(); ++i) { inputs.push_back(I(i)); + grad_inputs.push_back(GI(i)); } - return SingleGradientDef("MinGradient", "", inputs, gradInputs); + return SingleGradientDef("MinGradient", "", inputs, grad_inputs); } }; + +} // namespace + +REGISTER_GRADIENT(Max, GetMaxGradient); REGISTER_GRADIENT(Min, GetMinGradient); } // namespace caffe2 diff --git a/caffe2/operators/minmax_ops.cc b/caffe2/operators/minmax_ops.cc index 1a105770d2..3dd2487335 100644 --- a/caffe2/operators/minmax_ops.cc +++ b/caffe2/operators/minmax_ops.cc @@ -1,10 +1,9 @@ #include "caffe2/operators/minmax_ops.h" -#include "caffe2/utils/eigen_utils.h" namespace caffe2 { -REGISTER_CPU_OPERATOR(Max, MaxOp); REGISTER_CPU_OPERATOR(Min, MinOp); +REGISTER_CPU_OPERATOR(Max, MaxOp); OPERATOR_SCHEMA(Max) .NumInputs(1, INT_MAX) @@ -155,34 +154,4 @@ Min: "Contains the minimum valued element at each location.") .InheritOnnxSchema(); -template -bool MaxOp::Compute() { - auto& input0 = Input(0); - const int N = input0.numel(); - T* output_data = Output(0)->template mutable_data(); - - for (int i = 1; i < InputSize(); i++) { - auto input_data = Input(i).template data(); - EigenVectorMap output_vec(output_data, N); - output_vec = output_vec.cwiseMax(ConstEigenVectorMap(input_data, N)); - } - - return true; -} - -template -bool MinOp::Compute() { - auto& input0 = Input(0); - const int N = input0.numel(); - T* output_data = Output(0)->template mutable_data(); - - for (int i = 1; i < InputSize(); i++) { - auto input_data = Input(i).template data(); - EigenVectorMap output_vec(output_data, N); - output_vec = output_vec.cwiseMin(ConstEigenVectorMap(input_data, N)); - } - - return true; -} - } // namespace caffe2 diff --git a/caffe2/operators/minmax_ops.cu b/caffe2/operators/minmax_ops.cu new file mode 100644 index 0000000000..a853b700d1 --- /dev/null +++ b/caffe2/operators/minmax_ops.cu @@ -0,0 +1,56 @@ +#include "caffe2/operators/minmax_ops.h" + +#include "caffe2/core/context_gpu.h" +#include "caffe2/utils/math.h" + +namespace caffe2 { + +namespace { + +template +__global__ void SelectGradientCUDAKernel( + const int N, + const T* dY, + const T* X, + const T* Y, + T* dX) { + const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (i < N) { +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + dX[i] = __ldg(X + i) == __ldg(Y + i) ? __ldg(dY + i) : T(0); +#else + dX[i] = X[i] == Y[i] ? dY[i] : T(0); +#endif + } +} + +} // namespace + +template <> +bool SelectGradientOpBase::RunOnDevice() { + const auto& Y = Input(0); + const auto& dY = Input(1); + const int N = Y.numel(); + const int M = math::DivUp(N, CAFFE_CUDA_NUM_THREADS); + const float* dY_data = dY.data(); + const float* Y_data = Y.data(); + for (int i = 0; i < OutputSize(); i++) { + const auto& Xi = Input(i + 2); + auto* dXi = Output(i, Xi.sizes(), at::dtype()); + const float* Xi_data = Xi.data(); + float* dXi_data = dXi->mutable_data(); + if (N > 0) { + SelectGradientCUDAKernel + <<>>( + N, dY_data, Xi_data, Y_data, dXi_data); + } + } + return true; +} + +REGISTER_CUDA_OPERATOR(Min, MinOp); +REGISTER_CUDA_OPERATOR(MinGradient, MinGradientOp); +REGISTER_CUDA_OPERATOR(Max, MaxOp); +REGISTER_CUDA_OPERATOR(MaxGradient, MaxGradientOp); + +} // namespace caffe2 diff --git a/caffe2/operators/minmax_ops.h b/caffe2/operators/minmax_ops.h index db02e0dcb0..5668460682 100644 --- a/caffe2/operators/minmax_ops.h +++ b/caffe2/operators/minmax_ops.h @@ -1,7 +1,6 @@ #ifndef CAFFE2_OPERATORS_MINMAX_OPS_H_ #define CAFFE2_OPERATORS_MINMAX_OPS_H_ -#include "caffe2/core/common_omp.h" #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" @@ -11,49 +10,99 @@ namespace caffe2 { template -class MaxMinOpBase : public Operator { +class MaxOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; - USE_SIMPLE_CTOR_DTOR(MaxMinOpBase) - bool RunOnDevice() override { - auto& input0 = Input(0); - auto* output = Output(0); - - output->ResizeLike(input0); - output->CopyFrom(input0, /* async */ true); + USE_SIMPLE_CTOR_DTOR(MaxOp) + bool RunOnDevice() override { + const auto& X0 = Input(0); + auto* Y = Output(0); + Y->ResizeLike(X0); + const T* X0_data = X0.template data(); + T* Y_data = Y->template mutable_data(); + const int N = X0.numel(); if (InputSize() == 1) { + if (Y != &X0) { + context_.template CopySameDevice(N, X0_data, Y_data); + } return true; } - - // Dimension checking - for (int i = 1; i < InputSize(); ++i) { + const auto& X1 = Input(1); + CAFFE_ENFORCE_EQ( + X0.sizes(), + Y->sizes(), + "Description: Input #1, input dimension:", + X1.sizes(), + " should match output dimension: ", + Y->sizes()); + const T* X1_data = X1.template data(); + math::Max(N, X0_data, X1_data, Y_data, &context_); + for (int i = 2; i < InputSize(); ++i) { + const auto& Xi = Input(i); CAFFE_ENFORCE_EQ( - output->sizes(), - Input(i).sizes(), + Xi.sizes(), + Y->sizes(), "Description: Input #", i, ", input dimension:", Input(i).sizes(), " should match output dimension: ", - output->sizes()); + Y->sizes()); + const T* Xi_data = Xi.template data(); + math::Max(N, Y_data, Xi_data, Y_data, &context_); } - - return this->Compute(); + return true; } - - virtual bool Compute() = 0; }; template -class MaxOp : public MaxMinOpBase { +class MinOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; - MaxOp(const OperatorDef& operator_def, Workspace* ws) - : MaxMinOpBase(operator_def, ws) {} - virtual ~MaxOp() noexcept {} - bool Compute() override; + + USE_SIMPLE_CTOR_DTOR(MinOp) + + bool RunOnDevice() override { + const auto& X0 = Input(0); + auto* Y = Output(0); + Y->ResizeLike(X0); + const T* X0_data = X0.template data(); + T* Y_data = Y->template mutable_data(); + const int N = X0.numel(); + if (InputSize() == 1) { + if (Y != &X0) { + context_.template CopySameDevice(N, X0_data, Y_data); + } + return true; + } + const auto& X1 = Input(1); + CAFFE_ENFORCE_EQ( + X0.sizes(), + Y->sizes(), + "Description: Input #1, input dimension:", + X1.sizes(), + " should match output dimension: ", + Y->sizes()); + const T* X1_data = X1.template data(); + math::Min(N, X0_data, X1_data, Y_data, &context_); + for (int i = 2; i < InputSize(); ++i) { + const auto& Xi = Input(i); + CAFFE_ENFORCE_EQ( + Xi.sizes(), + Y->sizes(), + "Description: Input #", + i, + ", input dimension:", + Input(i).sizes(), + " should match output dimension: ", + Y->sizes()); + const T* Xi_data = Xi.template data(); + math::Min(N, Y_data, Xi_data, Y_data, &context_); + } + return true; + } }; template @@ -66,29 +115,21 @@ class SelectGradientOpBase : public Operator { }; template -class MaxGradientOp : public SelectGradientOpBase { +class MaxGradientOp final : public SelectGradientOpBase { public: MaxGradientOp(const OperatorDef& operator_def, Workspace* ws) : SelectGradientOpBase(operator_def, ws) {} - virtual ~MaxGradientOp() noexcept {} -}; -template -class MinOp : public MaxMinOpBase { - public: - USE_OPERATOR_CONTEXT_FUNCTIONS; - MinOp(const OperatorDef& operator_def, Workspace* ws) - : MaxMinOpBase(operator_def, ws) {} - virtual ~MinOp() noexcept {} - bool Compute() override; + ~MaxGradientOp() = default; }; template -class MinGradientOp : public SelectGradientOpBase { +class MinGradientOp final : public SelectGradientOpBase { public: MinGradientOp(const OperatorDef& operator_def, Workspace* ws) : SelectGradientOpBase(operator_def, ws) {} - virtual ~MinGradientOp() noexcept {} + + ~MinGradientOp() = default; }; } // namespace caffe2 diff --git a/caffe2/operators/rsqrt_op.cu b/caffe2/operators/rsqrt_op.cu index 378d131a6f..eae8dc0f44 100644 --- a/caffe2/operators/rsqrt_op.cu +++ b/caffe2/operators/rsqrt_op.cu @@ -4,7 +4,7 @@ #include #include "caffe2/core/context_gpu.h" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math.h" namespace caffe2 { diff --git a/caffe2/operators/utility_ops.cu b/caffe2/operators/utility_ops.cu index b7d17ba64f..f767a9ae11 100644 --- a/caffe2/operators/utility_ops.cu +++ b/caffe2/operators/utility_ops.cu @@ -1,8 +1,4 @@ -#include "caffe2/core/context_gpu.h" -#include "caffe2/operators/flatten_op.h" -#include "caffe2/operators/minmax_ops.h" #include "caffe2/operators/utility_ops.h" -#include "caffe2/utils/math.h" #include #include @@ -10,6 +6,10 @@ #include #include +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/flatten_op.h" +#include "caffe2/utils/math.h" + namespace caffe2 { template <> @@ -137,102 +137,6 @@ bool NanCheckOp::RunOnDevice() { REGISTER_CUDA_OPERATOR(NanCheck, NanCheckOp); -__global__ void -ElwiseMaxKernel(const float* X, const float* Y, float* maxout, const int N) { - CUDA_1D_KERNEL_LOOP(i, N) { - maxout[i] = fmaxf(X[i], Y[i]); - } -} - -template <> -bool MaxOp::Compute() { - float* output_data = Output(0)->template mutable_data(); - const int N = Input(0).numel(); - - // Run pairwise-maxes - for (int i = 1; i < InputSize(); ++i) { - ElwiseMaxKernel<<< - CAFFE_GET_BLOCKS(N), - CAFFE_CUDA_NUM_THREADS, - 0, - context_.cuda_stream()>>>( - (i == 0 ? Input(0).data() : Output(0)->data()), - Input(i).data(), - output_data, - N); - } - - return true; -} - -REGISTER_CUDA_OPERATOR(Max, MaxOp); -REGISTER_CUDA_OPERATOR(MaxGradient, MaxGradientOp); - -__global__ void -ElwiseMinKernel(const float* X, const float* Y, float* minout, const int N) { - CUDA_1D_KERNEL_LOOP(i, N) { - minout[i] = fminf(X[i], Y[i]); - } -} - -template <> -bool MinOp::Compute() { - float* output_data = Output(0)->template mutable_data(); - const int N = Input(0).numel(); - - // Run pairwise-mines - for (int i = 1; i < InputSize(); ++i) { - ElwiseMinKernel<<< - CAFFE_GET_BLOCKS(N), - CAFFE_CUDA_NUM_THREADS, - 0, - context_.cuda_stream()>>>( - (i == 0 ? Input(0).data() : Output(0)->data()), - Input(i).data(), - output_data, - N); - } - - return true; -} - -REGISTER_CUDA_OPERATOR(Min, MinOp); -REGISTER_CUDA_OPERATOR(MinGradient, MinGradientOp); - -template -__global__ void -MaxMinGradKernel(int N, const T* mx, const T* x, const T* go, T* gi) { - CUDA_1D_KERNEL_LOOP(i, N) { - gi[i] = go[i] * (mx[i] == x[i]); - } -} - -template <> -bool SelectGradientOpBase::RunOnDevice() { - auto& output = Input(0); - auto& grad_output = Input(1); - const int kInputStartOffset = 2; - - const float* data = output.data(); - - for (int i = 0; i < OutputSize(); i++) { - auto& input = Input(i + kInputStartOffset); - - auto* grad_input = Output(i, input.sizes(), at::dtype()); - MaxMinGradKernel<<< - CAFFE_GET_BLOCKS(input.numel()), - CAFFE_CUDA_NUM_THREADS, - 0, - context_.cuda_stream()>>>( - input.numel(), - output.data(), - input.data(), - grad_output.data(), - grad_input->template mutable_data()); - } - return true; -} - /** * @brief Update slices of Y in-place with a batch of weighted X's. * Y[idx] = alpha[b] * X[b][i] + Y[idx] diff --git a/caffe2/quantization/server/im2col_dnnlowp.h b/caffe2/quantization/server/im2col_dnnlowp.h index fdbee29556..92f7b272ac 100644 --- a/caffe2/quantization/server/im2col_dnnlowp.h +++ b/caffe2/quantization/server/im2col_dnnlowp.h @@ -6,7 +6,7 @@ #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math/utils.h" namespace caffe2 { diff --git a/caffe2/utils/CMakeLists.txt b/caffe2/utils/CMakeLists.txt index 9208a69641..eaa10f8865 100644 --- a/caffe2/utils/CMakeLists.txt +++ b/caffe2/utils/CMakeLists.txt @@ -1,10 +1,11 @@ list(APPEND Caffe2_CPU_SRCS utils/bench_utils.cc utils/cpuid.cc + utils/math/broadcast.cc utils/math/elementwise.cc utils/math/reduce.cc + utils/math/utils.cc utils/math_cpu.cc - utils/math_utils.cc utils/murmur_hash3.cc utils/proto_convert.cc utils/proto_utils.cc @@ -26,12 +27,14 @@ if (NOT MSVC) endif() set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} + utils/math/broadcast.cu utils/math/elementwise.cu utils/math/reduce.cu utils/math_gpu.cu ) set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} + utils/math/hip/broadcast.hip utils/math/hip/elementwise.hip utils/math/hip/reduce.hip utils/hip/math_gpu.hip diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h index f870c3d0f8..2ea960176a 100644 --- a/caffe2/utils/math.h +++ b/caffe2/utils/math.h @@ -15,9 +15,10 @@ extern "C" { #include "caffe2/core/common.h" #include "caffe2/core/types.h" +#include "caffe2/utils/math/broadcast.h" #include "caffe2/utils/math/elementwise.h" #include "caffe2/utils/math/reduce.h" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math/utils.h" namespace caffe2 { @@ -31,9 +32,6 @@ class CAFFE2_API DefaultEngine {}; namespace math { #define C10_DECLARE_COMPARE_OP(Comp) \ - template \ - void Comp(const int N, const T* A, const T* B, bool* C, Context* context); \ - \ template \ void Rowwise##Comp( \ const int rows, \ @@ -72,37 +70,34 @@ C10_DECLARE_COMPARE_OP(GE) #undef C10_DECLARE_COMPARE_OP -#define C10_DECLARE_BINARY_OP(Func) \ - template \ - void Func(const int N, const T* A, const T* B, T* C, Context* context); \ - \ - template \ - void Rowwise##Func( \ - const int rows, \ - const int cols, \ - const T* A, \ - const T* B, \ - T* C, \ - Context* context); \ - \ - template \ - void Colwise##Func( \ - const int rows, \ - const int cols, \ - const T* A, \ - const T* B, \ - T* C, \ - Context* context); \ - \ - template \ - void Func( \ - const int A_ndim, \ - const int* A_dims, \ - const int B_ndim, \ - const int* B_dims, \ - const T* A, \ - const T* B, \ - T* C, \ +#define C10_DECLARE_BINARY_OP(Func) \ + template \ + void Rowwise##Func( \ + const int rows, \ + const int cols, \ + const T* A, \ + const T* B, \ + T* C, \ + Context* context); \ + \ + template \ + void Colwise##Func( \ + const int rows, \ + const int cols, \ + const T* A, \ + const T* B, \ + T* C, \ + Context* context); \ + \ + template \ + void Func( \ + const int A_ndim, \ + const int* A_dims, \ + const int B_ndim, \ + const int* B_dims, \ + const T* A, \ + const T* B, \ + T* C, \ Context* context); C10_DECLARE_BINARY_OP(Add) @@ -238,11 +233,6 @@ template CAFFE2_API void ColwiseMax(const int N, const int D, const T* x, T* y, Context* context); -// Elemwise maximum of vector x and vector y. z[i] = max(x[i], y[i]) -template -CAFFE2_API void -ElemwiseMax(const int N, const T* x, const T* y, T* z, Context* context); - // Elemwise maximum of vector x and scalar alpha. y[i] = max(x[i], alpha) template CAFFE2_API void @@ -621,7 +611,6 @@ CAFFE2_API void NHWC2NCHW( T* Y, Context* context); - } // namespace math } // namespace caffe2 diff --git a/caffe2/utils/math/broadcast.cc b/caffe2/utils/math/broadcast.cc new file mode 100644 index 0000000000..f880bb7eeb --- /dev/null +++ b/caffe2/utils/math/broadcast.cc @@ -0,0 +1,55 @@ +#include "caffe2/utils/math/broadcast.h" + +#include "caffe2/core/context.h" +#include "caffe2/utils/eigen_utils.h" + +namespace caffe2 { +namespace math { + +#define CAFFE2_SPECIALIZED_AFFINE_CHANNEL(T) \ + template <> \ + C10_EXPORT void AffineChannel( \ + const int N, \ + const int C, \ + const int HxW, \ + const T* X, \ + const T* scale, \ + const T* bias, \ + T* Y, \ + CPUContext* /* context */) { \ + ConstEigenVectorArrayMap scale_arr(scale, C); \ + ConstEigenVectorArrayMap bias_arr(bias, C); \ + const int stride = C * HxW; \ + const T* X_ptr = X; \ + T* Y_ptr = Y; \ + for (int i = 0; i < N; ++i) { \ + EigenArrayMap(Y_ptr, HxW, C) = \ + (ConstEigenArrayMap(X_ptr, HxW, C).rowwise() * \ + scale_arr.transpose()) \ + .rowwise() + \ + bias_arr.transpose(); \ + X_ptr += stride; \ + Y_ptr += stride; \ + } \ + } \ + template <> \ + C10_EXPORT void AffineChannel( \ + const int N, \ + const int C, \ + const int HxW, \ + const T* X, \ + const T* scale, \ + const T* bias, \ + T* Y, \ + CPUContext* /* context */) { \ + EigenArrayMap(Y, C, N * HxW) = \ + (ConstEigenArrayMap(X, C, N * HxW).colwise() * \ + ConstEigenVectorArrayMap(scale, C)) \ + .colwise() + \ + ConstEigenVectorArrayMap(bias, C); \ + } +CAFFE2_SPECIALIZED_AFFINE_CHANNEL(float) +#undef CAFFE2_SPECIALIZED_AFFINE_CHANNEL + +} // namespace math +} // namespace caffe2 diff --git a/caffe2/utils/math/broadcast.cu b/caffe2/utils/math/broadcast.cu new file mode 100644 index 0000000000..97f7cb500f --- /dev/null +++ b/caffe2/utils/math/broadcast.cu @@ -0,0 +1,108 @@ +#include "caffe2/utils/math/broadcast.h" + +#include "caffe2/core/context_gpu.h" +#include "caffe2/utils/math/utils.h" + +namespace caffe2 { +namespace math { + +namespace { + +template +__global__ void AffineChannelNCHWCUDAKernel( + const int C, + const int M, + const int HxW, + const T* X, + const T* scale, + const T* bias, + T* Y); + +template <> +__global__ void AffineChannelNCHWCUDAKernel( + const int C, + const int M, + const int HxW, + const float* X, + const float* scale, + const float* bias, + float* Y) { + const int nc = blockIdx.x / M; + const int c = nc % C; + const int w = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (w < HxW) { + const int index = nc * HxW + w; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + Y[index] = fmaf(__ldg(X + index), __ldg(scale + c), __ldg(bias + c)); +#else + Y[index] = fmaf(X[index], scale[c], bias[c]); +#endif + } +} + +template +__global__ void AffineChannelNHWCCUDAKernel( + const int C, + const T* X, + const T* scale, + const T* bias, + T* Y); + +template <> +__global__ void AffineChannelNHWCCUDAKernel( + const int C, + const float* X, + const float* scale, + const float* bias, + float* Y) { + const int c = blockIdx.y * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (c < C) { + const int index = blockIdx.x * C + c; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + Y[index] = fmaf(__ldg(X + index), __ldg(scale + c), __ldg(bias + c)); +#else + Y[index] = fmaf(X[index], scale[c], bias[c]); +#endif + } +} + +} // namespace + +#define CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(T) \ + template <> \ + CAFFE2_CUDA_EXPORT void AffineChannel( \ + const int N, \ + const int C, \ + const int HxW, \ + const T* X, \ + const T* scale, \ + const T* bias, \ + T* Y, \ + CUDAContext* context) { \ + const int M = DivUp(HxW, CAFFE_CUDA_NUM_THREADS); \ + AffineChannelNCHWCUDAKernel \ + <<cuda_stream()>>>( \ + C, M, HxW, X, scale, bias, Y); \ + } \ + template <> \ + CAFFE2_CUDA_EXPORT void AffineChannel( \ + const int N, \ + const int C, \ + const int HxW, \ + const T* X, \ + const T* scale, \ + const T* bias, \ + T* Y, \ + CUDAContext* context) { \ + const int M = DivUp(C, CAFFE_CUDA_NUM_THREADS); \ + AffineChannelNHWCCUDAKernel \ + <<cuda_stream()>>>(C, X, scale, bias, Y); \ + } +CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(float) +#undef CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL + +} // namespace math +} // namespace caffe2 diff --git a/caffe2/utils/math/broadcast.h b/caffe2/utils/math/broadcast.h new file mode 100644 index 0000000000..67e37d1bd9 --- /dev/null +++ b/caffe2/utils/math/broadcast.h @@ -0,0 +1,24 @@ +#ifndef CAFFE2_UTILS_MATH_BROADCAST_H_ +#define CAFFE2_UTILS_MATH_BROADCAST_H_ + +#include "caffe2/core/common.h" +#include "caffe2/core/types.h" + +namespace caffe2 { +namespace math { + +template +CAFFE2_API void AffineChannel( + const int N, + const int C, + const int HxW, + const T* X, + const T* scale, + const T* bias, + T* Y, + Context* context); + +} // namespace math +} // namespace caffe2 + +#endif // CAFFE2_UTILS_MATH_BROADCAST_H_ diff --git a/caffe2/utils/math/elementwise.cc b/caffe2/utils/math/elementwise.cc index ba6fb96338..f392829ccc 100644 --- a/caffe2/utils/math/elementwise.cc +++ b/caffe2/utils/math/elementwise.cc @@ -1,5 +1,8 @@ #include "caffe2/utils/math/elementwise.h" +#include +#include + #ifdef CAFFE2_USE_MKL #include #endif // CAFFE2_USE_MKL @@ -12,7 +15,7 @@ namespace math { //////////////////////////////////////////////////////////////////////////////// // MKL VML alternatives. -// Depending on whether we are using MKL, we will delegate the Caffe math +// Depending on whether we are using MKL, we will delegate the Caffe2 math // functions that are VML-related to either the VML call or the Eigen // implementation. If you are setting the flags (such as AVX) right for your CPU // architecture, usually Eigen will deliver a throughput as fast as the VML @@ -90,6 +93,22 @@ DELEGATE_POWX_FUNCTION(float, vsPowx) DELEGATE_POWX_FUNCTION(double, vdPowx) #undef DELEGATE_POWX_FUNCTION +#define DELEGATE_SIMPLE_BINARY_FUNCTION(T, Func, MKLFunc) \ + template <> \ + C10_EXPORT void Func( \ + const int N, const T* A, const T* B, T* C, CPUContext* /* context */) { \ + MKLFunc(N, A, B, C); \ + } +DELEGATE_SIMPLE_BINARY_FUNCTION(float, Add, vsAdd) +DELEGATE_SIMPLE_BINARY_FUNCTION(double, Add, vdAdd) +DELEGATE_SIMPLE_BINARY_FUNCTION(float, Sub, vsSub) +DELEGATE_SIMPLE_BINARY_FUNCTION(double, Sub, vdSub) +DELEGATE_SIMPLE_BINARY_FUNCTION(float, Mul, vsMul) +DELEGATE_SIMPLE_BINARY_FUNCTION(double, Mul, vdMul) +DELEGATE_SIMPLE_BINARY_FUNCTION(float, Div, vsDiv) +DELEGATE_SIMPLE_BINARY_FUNCTION(double, Div, vdDiv) +#undef DELEGATE_SIMPLE_BINARY_FUNCTION + #else // CAFFE2_USE_MKL #define DELEGATE_SIMPLE_UNARY_FUNCTION(T, Func, EigenFunc) \ @@ -127,68 +146,85 @@ DELEGATE_SIMPLE_UNARY_FUNCTION(float, Inv, inverse) DELEGATE_SIMPLE_UNARY_FUNCTION(double, Inv, inverse) #undef DELEGATE_SIMPLE_UNARY_FUNCTION -#define DELEGATE_SINH_FUNCTION(T) \ +#define CAFFE2_SPECIALIZED_SINH(T) \ template <> \ C10_EXPORT void Sinh( \ const int N, const T* X, T* Y, CPUContext* /* context */) { \ ConstEigenVectorArrayMap X_arr(X, N); \ EigenVectorArrayMap(Y, N) = (X_arr.exp() - (-X_arr).exp()) / T(2); \ } -DELEGATE_SINH_FUNCTION(float) -DELEGATE_SINH_FUNCTION(double) -#undef DELEGATE_SINH_FUNCTION +CAFFE2_SPECIALIZED_SINH(float) +CAFFE2_SPECIALIZED_SINH(double) +#undef CAFFE2_SPECIALIZED_SINH -#define DELEGATE_COSH_FUNCTION(T) \ +#define CAFFE2_SPECIALIZED_COSH(T) \ template <> \ C10_EXPORT void Cosh( \ const int N, const T* X, T* Y, CPUContext* /* context */) { \ ConstEigenVectorArrayMap X_arr(X, N); \ EigenVectorArrayMap(Y, N) = (X_arr.exp() + (-X_arr).exp()) / T(2); \ } -DELEGATE_COSH_FUNCTION(float) -DELEGATE_COSH_FUNCTION(double) -#undef DELEGATE_COSH_FUNCTION +CAFFE2_SPECIALIZED_COSH(float) +CAFFE2_SPECIALIZED_COSH(double) +#undef CAFFE2_SPECIALIZED_COSH -#define DELEGATE_SINCOS_FUNCTION(T) \ +#define CAFFE2_SPECIALIZED_SINCOS(T) \ template <> \ C10_EXPORT void SinCos( \ const int N, const T* X, T* S, T* C, CPUContext* /* context */) { \ EigenVectorArrayMap(S, N) = ConstEigenVectorArrayMap(X, N).sin(); \ EigenVectorArrayMap(C, N) = ConstEigenVectorArrayMap(X, N).cos(); \ } -DELEGATE_SINCOS_FUNCTION(float) -DELEGATE_SINCOS_FUNCTION(double) -#undef DELEGATE_SINCOS_FUNCTION +CAFFE2_SPECIALIZED_SINCOS(float) +CAFFE2_SPECIALIZED_SINCOS(double) +#undef CAFFE2_SPECIALIZED_SINCOS -#define DELEGATE_POWX_FUNCTION(T) \ +#define CAFFE2_SPECIALIZED_POWX(T) \ template <> \ C10_EXPORT void Powx( \ const int N, const T* A, const T b, T* Y, CPUContext* /* context */) { \ EigenVectorArrayMap(Y, N) = ConstEigenVectorArrayMap(A, N).pow(b); \ } -DELEGATE_POWX_FUNCTION(float) -DELEGATE_POWX_FUNCTION(double) -#undef DELEGATE_POWX_FUNCTION +CAFFE2_SPECIALIZED_POWX(float) +CAFFE2_SPECIALIZED_POWX(double) +#undef CAFFE2_SPECIALIZED_POWX -#define DELEGATE_CBRT_FUNCTION(T) \ +#define CAFFE2_SPECIALIZED_CBRT(T) \ template <> \ C10_EXPORT void Cbrt( \ const int N, const T* X, T* Y, CPUContext* /* context */) { \ std::transform(X, X + N, Y, [](const T x) { return cbrt(x); }); \ } -DELEGATE_CBRT_FUNCTION(float) -DELEGATE_CBRT_FUNCTION(double) -#undef DELEGATE_CBRT_FUNCTION +CAFFE2_SPECIALIZED_CBRT(float) +CAFFE2_SPECIALIZED_CBRT(double) +#undef CAFFE2_SPECIALIZED_CBRT -#define DELEGATE_ERF_FUNCTION(T) \ +#define CAFFE2_SPECIALIZED_ERF(T) \ template <> \ C10_EXPORT void Erf( \ const int N, const T* X, T* Y, CPUContext* /* context */) { \ std::transform(X, X + N, Y, [](const T x) { return erf(x); }); \ } -DELEGATE_ERF_FUNCTION(float) -DELEGATE_ERF_FUNCTION(double) -#undef DELEGATE_ERF_FUNCTION +CAFFE2_SPECIALIZED_ERF(float) +CAFFE2_SPECIALIZED_ERF(double) +#undef CAFFE2_SPECIALIZED_ERF + +#define DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(T, Func, EigenOp) \ + template <> \ + C10_EXPORT void Func( \ + const int N, const T* A, const T* B, T* C, CPUContext* /* context */) { \ + EigenVectorMap(C, N) = ConstEigenVectorArrayMap(A, N) \ + EigenOp ConstEigenVectorArrayMap(B, N); \ + } +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(float, Add, +) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(double, Add, +) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(float, Sub, -) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(double, Sub, -) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(float, Mul, *) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(double, Mul, *) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(float, Div, /) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(double, Div, /) +#undef DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR #endif // CAFFE2_USE_MKL @@ -202,74 +238,155 @@ DELEGATE_ERF_FUNCTION(double) // Eigen's Tanh implementation is faster than MKL, so use Eigen here. DELEGATE_SIMPLE_UNARY_FUNCTION(float, Tanh, tanh) DELEGATE_SIMPLE_UNARY_FUNCTION(double, Tanh, tanh) -DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sign, sign) -DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sign, sign) DELEGATE_SIMPLE_UNARY_FUNCTION(std::int32_t, Sign, sign) DELEGATE_SIMPLE_UNARY_FUNCTION(std::int64_t, Sign, sign) +DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sign, sign) +DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sign, sign) DELEGATE_SIMPLE_UNARY_FUNCTION(std::int32_t, Abs, abs) DELEGATE_SIMPLE_UNARY_FUNCTION(std::int64_t, Abs, abs) -DELEGATE_SIMPLE_UNARY_FUNCTION(float, Cube, cube) -DELEGATE_SIMPLE_UNARY_FUNCTION(double, Cube, cube) DELEGATE_SIMPLE_UNARY_FUNCTION(std::int32_t, Cube, cube) DELEGATE_SIMPLE_UNARY_FUNCTION(std::int64_t, Cube, cube) +DELEGATE_SIMPLE_UNARY_FUNCTION(float, Cube, cube) +DELEGATE_SIMPLE_UNARY_FUNCTION(double, Cube, cube) #undef DELEGATE_SIMPLE_UNARY_FUNCTION -#define DELEGATE_NEG_FUNCTION(T) \ +#define CAFFE2_SPECIALIZED_NEG(T) \ template <> \ C10_EXPORT void Neg( \ const int N, const T* X, T* Y, CPUContext* /* context */) { \ EigenVectorArrayMap(Y, N) = -ConstEigenVectorArrayMap(X, N); \ } -DELEGATE_NEG_FUNCTION(float) -DELEGATE_NEG_FUNCTION(double) -DELEGATE_NEG_FUNCTION(std::int32_t) -DELEGATE_NEG_FUNCTION(std::int64_t) -#undef DELEGATE_NEG_FUNCTION +CAFFE2_SPECIALIZED_NEG(std::int32_t) +CAFFE2_SPECIALIZED_NEG(std::int64_t) +CAFFE2_SPECIALIZED_NEG(float) +CAFFE2_SPECIALIZED_NEG(double) +#undef CAFFE2_SPECIALIZED_NEG + +#define DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(T, Func, EigenOp) \ + template <> \ + C10_EXPORT void Func( \ + const int N, const T* A, const T* B, T* C, CPUContext* /* context */) { \ + EigenVectorMap(C, N) = ConstEigenVectorArrayMap(A, N) \ + EigenOp ConstEigenVectorArrayMap(B, N); \ + } +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, Add, +) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, Add, +) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, Sub, -) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, Sub, -) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, Mul, *) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, Mul, *) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, Div, /) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, Div, /) +#undef DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR -#define CAFFE2_SPECIALIZED_AFFINE_CHANNEL(T) \ - template <> \ - void AffineChannel( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - const T* scale, \ - const T* bias, \ - T* Y, \ - CPUContext* /* context */) { \ - ConstEigenVectorArrayMap scale_arr(scale, C); \ - ConstEigenVectorArrayMap bias_arr(bias, C); \ - const int stride = C * HxW; \ - const T* X_ptr = X; \ - T* Y_ptr = Y; \ - for (int i = 0; i < N; ++i) { \ - EigenArrayMap(Y_ptr, HxW, C) = \ - (ConstEigenArrayMap(X_ptr, HxW, C).rowwise() * \ - scale_arr.transpose()) \ - .rowwise() + \ - bias_arr.transpose(); \ - X_ptr += stride; \ - Y_ptr += stride; \ - } \ - } \ - template <> \ - void AffineChannel( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - const T* scale, \ - const T* bias, \ - T* Y, \ - CPUContext* /* context */) { \ - EigenArrayMap(Y, C, N * HxW) = \ - (ConstEigenArrayMap(X, C, N * HxW).colwise() * \ - ConstEigenVectorArrayMap(scale, C)) \ - .colwise() + \ - ConstEigenVectorArrayMap(bias, C); \ +#define DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_FUNCTION(T, Func, EigenFunc) \ + template <> \ + C10_EXPORT void Func( \ + const int N, const T* A, const T* B, T* C, CPUContext* /* context */) { \ + EigenVectorMap(C, N) = ConstEigenVectorArrayMap(A, N).EigenFunc( \ + ConstEigenVectorArrayMap(B, N)); \ + } +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_FUNCTION(float, Min, min) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_FUNCTION(double, Min, min) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_FUNCTION(float, Max, max) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_FUNCTION(double, Max, max) +#undef DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_FUNCTION + +#define DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION(T, Func, StdFunc) \ + template <> \ + C10_EXPORT void Func( \ + const int N, const T* A, const T* B, T* C, CPUContext* /* context */) { \ + std::transform(A, A + N, B, C, StdFunc); \ + } +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + bool, + And, + std::logical_and()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + bool, + Or, + std::logical_or()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION(bool, Xor, std::bit_xor()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + bool, + BitwiseAnd, + std::bit_and()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + std::int32_t, + BitwiseAnd, + std::bit_and()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + std::int64_t, + BitwiseAnd, + std::bit_and()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + bool, + BitwiseOr, + std::bit_or()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + std::int32_t, + BitwiseOr, + std::bit_or()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + std::int64_t, + BitwiseOr, + std::bit_or()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + bool, + BitwiseXor, + std::bit_xor()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + std::int32_t, + BitwiseXor, + std::bit_xor()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + std::int64_t, + BitwiseXor, + std::bit_xor()) +#undef DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION + +#define DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(T, Func, EigenOp) \ + template <> \ + C10_EXPORT void Func( \ + const int N, \ + const T* A, \ + const T* B, \ + bool* C, \ + CPUContext* /* context */) { \ + EigenVectorArrayMap(C, N) = ConstEigenVectorArrayMap(A, N) \ + EigenOp ConstEigenVectorArrayMap(B, N); \ } -CAFFE2_SPECIALIZED_AFFINE_CHANNEL(float) -#undef CAFFE2_SPECIALIZED_AFFINE_CHANNEL +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(bool, EQ, ==) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, EQ, ==) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, EQ, ==) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(float, EQ, ==) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(double, EQ, ==) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(bool, NE, !=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, NE, !=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, NE, !=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(float, NE, !=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(double, NE, !=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(bool, LT, <) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, LT, <) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, LT, <) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(float, LT, <) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(double, LT, <) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(bool, LE, <=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, LE, <=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, LE, <=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(float, LE, <=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(double, LE, <=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(bool, GT, >) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, GT, >) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, GT, >) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(float, GT, >) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(double, GT, >) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(bool, GE, >=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, GE, >=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, GE, >=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(float, GE, >=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(double, GE, >=) +#undef DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR } // namespace math } // namespace caffe2 diff --git a/caffe2/utils/math/elementwise.cu b/caffe2/utils/math/elementwise.cu index cc7613cda9..b7605deece 100644 --- a/caffe2/utils/math/elementwise.cu +++ b/caffe2/utils/math/elementwise.cu @@ -1,13 +1,77 @@ #include "caffe2/utils/math/elementwise.h" +#include + #include "caffe2/core/context_gpu.h" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/conversions.h" +#include "caffe2/utils/math/half_utils.h" +#include "caffe2/utils/math/utils.h" namespace caffe2 { namespace math { namespace { +#define DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(T, Func, DeviceFunc) \ + __global__ void Func##CUDAKernel(const int N, const T* X, T* Y) { \ + const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; \ + if (i < N) { \ + Y[i] = DeviceFunc(X[i]); \ + } \ + } +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Exp, expf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Log, logf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Cos, cosf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Acos, acosf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Sin, sinf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Asin, asinf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Tan, tanf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Atan, atanf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Sinh, sinhf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Cosh, coshf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Tanh, tanhf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Abs, fabsf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Sqr, utils::Square) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Sqrt, sqrtf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Rsqrt, rsqrtf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Cbrt, cbrtf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Erf, erff) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(double, Erf, erf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION( + std::int32_t, + Cube, + utils::Cube) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION( + std::int64_t, + Cube, + utils::Cube) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Cube, utils::Cube) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(double, Cube, utils::Cube) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(bool, Not, utils::Not) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION( + std::int32_t, + Neg, + utils::Negate) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION( + std::int64_t, + Neg, + utils::Negate) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Neg, utils::Negate) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(double, Neg, utils::Negate) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION( + std::int32_t, + Sign, + utils::Sign) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION( + std::int64_t, + Sign, + utils::Sign) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Sign, utils::Sign) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(double, Sign, utils::Sign) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Inv, utils::Inv) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(double, Inv, utils::Inv) +#undef DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION + template __global__ void SinCosCUDAKernel(const int N, const T* X, T* S, T* C) { const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; @@ -20,190 +84,276 @@ __global__ void SinCosCUDAKernel(const int N, const T* X, T* S, T* C) { } } -template -__global__ void AffineChannelNCHWCUDAKernel( - const int C, - const int M, - const int HxW, - const T* X, - const T* scale, - const T* bias, - T* Y); - -template <> -__global__ void AffineChannelNCHWCUDAKernel( - const int C, - const int M, - const int HxW, - const float* X, - const float* scale, - const float* bias, - float* Y) { - const int nc = blockIdx.x / M; - const int c = nc % C; - const int w = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - if (w < HxW) { - const int index = nc * HxW + w; -#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) - Y[index] = fmaf(__ldg(X + index), __ldg(scale + c), __ldg(bias + c)); -#else - Y[index] = fmaf(X[index], scale[c], bias[c]); -#endif +template +__global__ void SimpleBinaryCUDAKernel( + const int N, + const Func func, + const T* A, + const T* B, + T* C) { + const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (i < N) { + C[i] = func(A[i], B[i]); } } -template -__global__ void AffineChannelNHWCCUDAKernel( - const int C, - const T* X, - const T* scale, - const T* bias, - T* Y); - -template <> -__global__ void AffineChannelNHWCCUDAKernel( - const int C, - const float* X, - const float* scale, - const float* bias, - float* Y) { - const int c = blockIdx.y * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - if (c < C) { - const int index = blockIdx.x * C + c; -#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) - Y[index] = fmaf(__ldg(X + index), __ldg(scale + c), __ldg(bias + c)); -#else - Y[index] = fmaf(X[index], scale[c], bias[c]); -#endif +template +__global__ void SimpleCompareCUDAKernel( + const int N, + const Comp comp, + const T* A, + const T* B, + bool* C) { + const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (i < N) { + C[i] = comp(A[i], B[i]); } } } // namespace -#define DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(T, Func, KernelFunc) \ - __global__ void Func##CUDAKernel(const int N, const T* X, T* Y) { \ - const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; \ - if (i < N) { \ - Y[i] = KernelFunc(X[i]); \ - } \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Func( \ - const int N, const T* X, T* Y, CUDAContext* context) { \ - if (N > 0) { \ - const int K = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - Func##CUDAKernel<<< \ - K, \ - CAFFE_CUDA_NUM_THREADS, \ - 0, \ - context->cuda_stream()>>>(N, X, Y); \ - } \ +#define DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(T, Func) \ + template <> \ + CAFFE2_CUDA_EXPORT void Func( \ + const int N, const T* X, T* Y, CUDAContext* context) { \ + if (N > 0) { \ + const int M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ + Func##CUDAKernel<<< \ + M, \ + CAFFE_CUDA_NUM_THREADS, \ + 0, \ + context->cuda_stream()>>>(N, X, Y); \ + } \ } -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Exp, expf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Log, logf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cos, cosf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Acos, acosf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sin, sinf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Asin, asinf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tan, tanf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Atan, atanf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sinh, sinhf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cosh, coshf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tanh, tanhf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Abs, fabsf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqr, utils::Square) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqrt, sqrtf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Rsqrt, rsqrtf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cbrt, cbrtf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Erf, erff) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Erf, erf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cube, utils::Cube) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Cube, utils::Cube) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Exp) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Log) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cos) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Acos) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sin) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Asin) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tan) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Atan) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sinh) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cosh) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tanh) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Abs) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqr) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqrt) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Rsqrt) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cbrt) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Erf) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(double, Erf) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cube) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(double, Cube) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(std::int32_t, Cube) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(std::int64_t, Cube) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(bool, Not) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Neg) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(double, Neg) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(std::int32_t, Neg) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(std::int64_t, Neg) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sign) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(double, Sign) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(std::int32_t, Sign) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(std::int64_t, Sign) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Inv) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(double, Inv) +#undef DEFINE_SIMPLE_CUDA_UNARY_FUNCTION + +#define CAFFE2_SPECIALIZED_CUDA_SINCOS(T) \ + template <> \ + CAFFE2_CUDA_EXPORT void SinCos( \ + const int N, const T* X, T* S, T* C, CUDAContext* context) { \ + if (N > 0) { \ + const int K = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ + SinCosCUDAKernel \ + <<cuda_stream()>>>( \ + N, X, S, C); \ + } \ + } +CAFFE2_SPECIALIZED_CUDA_SINCOS(float) +CAFFE2_SPECIALIZED_CUDA_SINCOS(double) +#undef CAFFE2_SPECIALIZED_CUDA_SINCOS + +#define DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(T, Func, DeviceFunc) \ + template <> \ + CAFFE2_CUDA_EXPORT void Func( \ + const int N, const T* A, const T* B, T* C, CUDAContext* context) { \ + if (N > 0) { \ + const int M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ + SimpleBinaryCUDAKernel<<< \ + M, \ + CAFFE_CUDA_NUM_THREADS, \ + 0, \ + context->cuda_stream()>>>(N, DeviceFunc, A, B, C); \ + } \ + } +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( std::int32_t, - Cube, - utils::Cube) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( + Add, + thrust::plus()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( std::int64_t, - Cube, - utils::Cube) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(bool, Not, utils::Not) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Neg, utils::Negate) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Neg, utils::Negate) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( + Add, + thrust::plus()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Add, thrust::plus()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Add, thrust::plus()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(at::Half, Add, utils::HalfAddFunctor()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( std::int32_t, - Neg, - utils::Negate) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( + Sub, + thrust::minus()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( std::int64_t, - Neg, - utils::Negate) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sign, utils::Sign) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Sign, utils::Sign) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( + Sub, + thrust::minus()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Sub, thrust::minus()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Sub, thrust::minus()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(at::Half, Sub, utils::HalfSubFunctor()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( std::int32_t, - Sign, - utils::Sign) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( + Mul, + thrust::multiplies()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( std::int64_t, - Sign, - utils::Sign) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Inv, utils::Inv) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Inv, utils::Inv) -#undef DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION - -#define CAFFE2_SPECIALIZED_CUDA_SINCOS(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void SinCos( \ - const int N, const T* X, T* S, T* C, CUDAContext* context) { \ - if (N > 0) { \ - const int K = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - SinCosCUDAKernel<<< \ - K, \ - CAFFE_CUDA_NUM_THREADS, \ - 0, \ - context->cuda_stream()>>>(N, X, S, C); \ - } \ - } -CAFFE2_SPECIALIZED_CUDA_SINCOS(float) -CAFFE2_SPECIALIZED_CUDA_SINCOS(double) -#undef CAFFE2_SPECIALIZED_CUDA_SINCOS + Mul, + thrust::multiplies()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Mul, thrust::multiplies()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Mul, thrust::multiplies()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(at::Half, Mul, utils::HalfMulFunctor()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( + std::int32_t, + Div, + thrust::divides()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( + std::int64_t, + Div, + thrust::divides()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Div, thrust::divides()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Div, thrust::divides()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(at::Half, Div, utils::HalfDivFunctor()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Min, thrust::minimum()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Min, thrust::minimum()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Max, thrust::maximum()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Max, thrust::maximum()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, And, thrust::logical_and()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, Or, thrust::logical_or()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, Xor, thrust::bit_xor()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, BitwiseAnd, thrust::bit_and()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( + std::int32_t, + BitwiseAnd, + thrust::bit_and()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( + std::int64_t, + BitwiseAnd, + thrust::bit_and()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, BitwiseOr, thrust::bit_or()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( + std::int32_t, + BitwiseOr, + thrust::bit_or()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( + std::int64_t, + BitwiseOr, + thrust::bit_or()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, BitwiseXor, thrust::bit_xor()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( + std::int32_t, + BitwiseXor, + thrust::bit_xor()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( + std::int64_t, + BitwiseXor, + thrust::bit_xor()) +#undef DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION -#define CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void AffineChannel( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - const T* scale, \ - const T* bias, \ - T* Y, \ - CUDAContext* context) { \ - const int M = DivUp(HxW, CAFFE_CUDA_NUM_THREADS); \ - AffineChannelNCHWCUDAKernel \ - <<cuda_stream()>>>( \ - C, M, HxW, X, scale, bias, Y); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void AffineChannel( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - const T* scale, \ - const T* bias, \ - T* Y, \ - CUDAContext* context) { \ - const int M = DivUp(C, CAFFE_CUDA_NUM_THREADS); \ - AffineChannelNHWCCUDAKernel \ - <<cuda_stream()>>>(C, X, scale, bias, Y); \ +#define DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(T, Func, DeviceComp) \ + template <> \ + CAFFE2_CUDA_EXPORT void Func( \ + const int N, const T* A, const T* B, bool* C, CUDAContext* context) { \ + if (N > 0) { \ + const int M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ + SimpleCompareCUDAKernel<<< \ + M, \ + CAFFE_CUDA_NUM_THREADS, \ + 0, \ + context->cuda_stream()>>>(N, DeviceComp, A, B, C); \ + } \ } -CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(float) -#undef CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, EQ, thrust::equal_to()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int32_t, + EQ, + thrust::equal_to()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int64_t, + EQ, + thrust::equal_to()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, EQ, thrust::equal_to()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(double, EQ, thrust::equal_to()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, NE, thrust::not_equal_to()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int32_t, + NE, + thrust::not_equal_to()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int64_t, + NE, + thrust::not_equal_to()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, NE, thrust::not_equal_to()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + double, + NE, + thrust::not_equal_to()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, LT, thrust::less()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int32_t, + LT, + thrust::less()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int64_t, + LT, + thrust::less()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, LT, thrust::less()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(double, LT, thrust::less()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, LE, thrust::less_equal()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int32_t, + LE, + thrust::less_equal()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int64_t, + LE, + thrust::less_equal()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, LE, thrust::less_equal()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(double, LE, thrust::less_equal()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, GT, thrust::greater()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int32_t, + GT, + thrust::greater()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int64_t, + GT, + thrust::greater()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, GT, thrust::greater()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(double, GT, thrust::greater()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, GE, thrust::greater_equal()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int32_t, + GE, + thrust::greater_equal()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int64_t, + GE, + thrust::greater_equal()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, GE, thrust::greater_equal()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + double, + GE, + thrust::greater_equal()) +#undef DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION } // namespace math } // namespace caffe2 diff --git a/caffe2/utils/math/elementwise.h b/caffe2/utils/math/elementwise.h index 63890de498..b11d22b955 100644 --- a/caffe2/utils/math/elementwise.h +++ b/caffe2/utils/math/elementwise.h @@ -8,64 +8,97 @@ namespace caffe2 { namespace math { template -void Exp(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Exp(int N, const T* X, T* Y, Context* context); template -void Log(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Log(int N, const T* X, T* Y, Context* context); template -void Sin(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Sin(int N, const T* X, T* Y, Context* context); template -void Asin(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Asin(int N, const T* X, T* Y, Context* context); template -void Cos(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Cos(int N, const T* X, T* Y, Context* context); template -void Acos(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Acos(int N, const T* X, T* Y, Context* context); template -void Tan(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Tan(int N, const T* X, T* Y, Context* context); template -void Atan(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Atan(int N, const T* X, T* Y, Context* context); template -void Sinh(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Sinh(int N, const T* X, T* Y, Context* context); template -void Cosh(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Cosh(int N, const T* X, T* Y, Context* context); template -void SinCos(const int N, const T* X, T* S, T* C, Context* context); +CAFFE2_API void SinCos(int N, const T* X, T* S, T* C, Context* context); template -void Tanh(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Tanh(int N, const T* X, T* Y, Context* context); template -void Abs(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Abs(int N, const T* X, T* Y, Context* context); template -void Sqr(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Sqr(int N, const T* X, T* Y, Context* context); template -void Sqrt(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Sqrt(int N, const T* X, T* Y, Context* context); template -void Rsqrt(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Rsqrt(int N, const T* X, T* Y, Context* context); template -void Cube(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Cube(int N, const T* X, T* Y, Context* context); template -void Cbrt(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Cbrt(int N, const T* X, T* Y, Context* context); template -void Neg(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Neg(int N, const T* X, T* Y, Context* context); template -void Sign(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Sign(int N, const T* X, T* Y, Context* context); template -void Not(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Not(int N, const T* X, T* Y, Context* context); template -void Powx(const int N, const T* A, const T b, T* Y, Context* context); +CAFFE2_API void Powx(int N, const T* A, const T b, T* Y, Context* context); template -void Inv(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Inv(int N, const T* X, T* Y, Context* context); template -void Erf(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Erf(int N, const T* X, T* Y, Context* context); -template -CAFFE2_API void AffineChannel( - const int N, - const int C, - const int HxW, - const T* X, - const T* scale, - const T* bias, - T* Y, - Context* context); +template +CAFFE2_API void Add(int N, const T* A, const T* B, T* C, Context* context); +template +CAFFE2_API void Sub(int N, const T* A, const T* B, T* C, Context* context); +template +CAFFE2_API void Mul(int N, const T* A, const T* B, T* C, Context* context); +template +CAFFE2_API void Div(int N, const T* A, const T* B, T* C, Context* context); + +template +CAFFE2_API void Min(int N, const T* A, const T* B, T* C, Context* context); +template +CAFFE2_API void Max(int N, const T* A, const T* B, T* C, Context* context); + +template +CAFFE2_API void And(int N, const T* A, const T* B, T* C, Context* context); +template +CAFFE2_API void Or(int N, const T* A, const T* B, T* C, Context* context); +template +CAFFE2_API void Xor(int N, const T* A, const T* B, T* C, Context* context); + +template +CAFFE2_API void +BitwiseAnd(int N, const T* A, const T* B, T* C, Context* context); +template +CAFFE2_API void +BitwiseOr(int N, const T* A, const T* B, T* C, Context* context); +template +CAFFE2_API void +BitwiseXor(int N, const T* A, const T* B, T* C, Context* context); + +template +CAFFE2_API void EQ(int N, const T* A, const T* B, bool* C, Context* context); +template +CAFFE2_API void NE(int N, const T* A, const T* B, bool* C, Context* context); +template +CAFFE2_API void LT(int N, const T* A, const T* B, bool* C, Context* context); +template +CAFFE2_API void LE(int N, const T* A, const T* B, bool* C, Context* context); +template +CAFFE2_API void GT(int N, const T* A, const T* B, bool* C, Context* context); +template +CAFFE2_API void GE(int N, const T* A, const T* B, bool* C, Context* context); } // namespace math } // namespace caffe2 diff --git a/caffe2/utils/math/half_utils.h b/caffe2/utils/math/half_utils.h new file mode 100644 index 0000000000..ac841d165a --- /dev/null +++ b/caffe2/utils/math/half_utils.h @@ -0,0 +1,49 @@ +#ifndef CAFFE2_UTILS_MATH_HALF_UTILS_H_ +#define CAFFE2_UTILS_MATH_HALF_UTILS_H_ + +#include "caffe2/core/common.h" +#include "caffe2/core/types.h" +#include "caffe2/utils/conversions.h" +#include "caffe2/utils/math/utils.h" + +namespace caffe2 { +namespace math { +namespace utils { + +struct HalfAddFunctor { + MATH_UTILS_DECL at::Half operator()(const at::Half a, const at::Half b) + const { + return convert::To( + convert::To(a) + convert::To(b)); + } +}; + +struct HalfSubFunctor { + MATH_UTILS_DECL at::Half operator()(const at::Half a, const at::Half b) + const { + return convert::To( + convert::To(a) - convert::To(b)); + } +}; + +struct HalfMulFunctor { + MATH_UTILS_DECL at::Half operator()(const at::Half a, const at::Half b) + const { + return convert::To( + convert::To(a) * convert::To(b)); + } +}; + +struct HalfDivFunctor { + MATH_UTILS_DECL at::Half operator()(const at::Half a, const at::Half b) + const { + return convert::To( + convert::To(a) / convert::To(b)); + } +}; + +} // namespace utils +} // namespace math +} // namespace caffe2 + +#endif // CAFFE2_UTILS_MATH_HALF_UTILS_H_ diff --git a/caffe2/utils/math/reduce.cc b/caffe2/utils/math/reduce.cc index c654834eee..4bcbb1cda3 100644 --- a/caffe2/utils/math/reduce.cc +++ b/caffe2/utils/math/reduce.cc @@ -8,7 +8,7 @@ #include "caffe2/core/context.h" #include "caffe2/utils/eigen_utils.h" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math/utils.h" namespace caffe2 { namespace math { diff --git a/caffe2/utils/math/reduce.cu b/caffe2/utils/math/reduce.cu index 31a653930e..18291d35f9 100644 --- a/caffe2/utils/math/reduce.cu +++ b/caffe2/utils/math/reduce.cu @@ -11,7 +11,7 @@ #include "caffe2/core/context_gpu.h" #include "caffe2/utils/fixed_divisor.h" #include "caffe2/utils/math/reduce.cuh" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math/utils.h" namespace caffe2 { namespace math { diff --git a/caffe2/utils/math/utils.cc b/caffe2/utils/math/utils.cc new file mode 100644 index 0000000000..3b75cedaaa --- /dev/null +++ b/caffe2/utils/math/utils.cc @@ -0,0 +1,347 @@ +#include "caffe2/utils/math/utils.h" + +#include +#include +#include +#include + +#include "caffe2/core/logging.h" + +namespace caffe2 { +namespace math { +namespace utils { + +void IncreaseIndexInDims(const int n, const int* dims, int* index) { + for (int i = n - 1; i >= 0; --i) { + ++index[i]; + if (index[i] >= dims[i]) { + index[i] -= dims[i]; + } else { + break; + } + } +} + +int GetIndexFromDims(const int n, const int* dims, const int* index) { + int sum = 0; + for (int i = 0; i < n; ++i) { + if (dims[i] > 1) { + sum = sum * dims[i] + index[i]; + } + } + return sum; +} + +bool IsIdentityPermutation(const int n, const int* perm) { + for (int i = 0; i < n; ++i) { + if (perm[i] != i) { + return false; + } + } + return true; +} + +bool CheckReduceDims(const int ndim, const int* X_dims, const int* Y_dims) { + for (int i = 0; i < ndim; ++i) { + if (X_dims[i] != Y_dims[i] && Y_dims[i] != 1) { + return false; + } + } + return true; +} + +bool IsRowwiseReduce( + const int ndim, + const int* A_dims, + const int* B_dims, + int* rows, + int* cols) { + *cols = 1; + int pivot = ndim - 1; + for (; pivot >= 0 && B_dims[pivot] == 1; --pivot) { + *cols *= A_dims[pivot]; + } + *rows = 1; + for (int i = pivot; i >= 0; --i) { + if (A_dims[i] != B_dims[i]) { + return false; + } + *rows *= A_dims[i]; + } + return true; +} + +bool IsColwiseReduce( + const int ndim, + const int* A_dims, + const int* B_dims, + int* rows, + int* cols) { + *rows = 1; + int pivot = 0; + for (; pivot < ndim && B_dims[pivot] == 1; ++pivot) { + *rows *= A_dims[pivot]; + } + *cols = 1; + for (int i = pivot; i < ndim; ++i) { + if (A_dims[i] != B_dims[i]) { + return false; + } + *cols *= A_dims[i]; + } + return true; +} + +bool IsBothEndsReduce( + const int ndim, + const int* A_dims, + const int* B_dims, + int* pre, + int* mid, + int* nxt) { + *nxt = 1; + int r = ndim - 1; + for (; r >= 0 && B_dims[r] == 1; --r) { + *nxt *= A_dims[r]; + } + *pre = 1; + int l = 0; + for (; l <= r && B_dims[l] == 1; ++l) { + *pre *= A_dims[l]; + } + *mid = 1; + for (int i = l; i <= r; ++i) { + if (A_dims[i] != B_dims[i]) { + return false; + } + *mid *= A_dims[i]; + } + return true; +} + +void ComputeBroadcastBinaryOpDims( + const int A_ndim, + const int* A_dims, + const int B_ndim, + const int* B_dims, + int* A_broadcast_dims, + int* B_broadcast_dims, + int* C_broadcast_dims) { + const int ndim = std::max(A_ndim, B_ndim); + std::fill(A_broadcast_dims, A_broadcast_dims + ndim - A_ndim, 1); + std::fill(B_broadcast_dims, B_broadcast_dims + ndim - B_ndim, 1); + std::copy(A_dims, A_dims + A_ndim, A_broadcast_dims + ndim - A_ndim); + std::copy(B_dims, B_dims + B_ndim, B_broadcast_dims + ndim - B_ndim); + for (int i = 0; i < ndim; ++i) { + CAFFE_ENFORCE( + A_broadcast_dims[i] == B_broadcast_dims[i] || + A_broadcast_dims[i] <= 1 || B_broadcast_dims[i] <= 1); + if (A_broadcast_dims[i] == 0 || B_broadcast_dims[i] == 0) { + C_broadcast_dims[i] = 0; + } else { + C_broadcast_dims[i] = std::max(A_broadcast_dims[i], B_broadcast_dims[i]); + } + } +} + +bool IsRowwiseBroadcastBinaryOp( + const int ndim, + const int* A_dims, + const int* B_dims, + int* rows, + int* cols, + bool* broadcast_1st) { + if (ndim == 0) { + return false; + } + int A_pivot = 0; + for (; A_pivot < ndim && A_dims[A_pivot] == 1; ++A_pivot) + ; + int B_pivot = 0; + for (; B_pivot < ndim && B_dims[B_pivot] == 1; ++B_pivot) + ; + if (A_pivot == B_pivot) { + return false; + } + const int pivot = std::max(A_pivot, B_pivot); + if (A_pivot > B_pivot) { + *rows = std::accumulate( + B_dims + B_pivot, B_dims + pivot, 1, std::multiplies()); + *broadcast_1st = true; + } else { + *rows = std::accumulate( + A_dims + A_pivot, A_dims + pivot, 1, std::multiplies()); + *broadcast_1st = false; + } + *cols = 1; + for (int i = pivot; i < ndim; ++i) { + if (A_dims[i] != B_dims[i]) { + return false; + } + *cols *= A_dims[i]; + } + return true; +} + +bool IsColwiseBroadcastBinaryOp( + const int ndim, + const int* A_dims, + const int* B_dims, + int* rows, + int* cols, + bool* broadcast_1st) { + if (ndim == 0) { + return false; + } + int A_pivot = ndim - 1; + for (; A_pivot >= 0 && A_dims[A_pivot] == 1; --A_pivot) + ; + int B_pivot = ndim - 1; + for (; B_pivot >= 0 && B_dims[B_pivot] == 1; --B_pivot) + ; + if (A_pivot == B_pivot) { + return false; + } + ++A_pivot; + ++B_pivot; + const int pivot = std::min(A_pivot, B_pivot); + if (A_pivot < B_pivot) { + *cols = std::accumulate( + B_dims + pivot, B_dims + B_pivot, 1, std::multiplies()); + *broadcast_1st = true; + } else { + *cols = std::accumulate( + A_dims + pivot, A_dims + A_pivot, 1, std::multiplies()); + *broadcast_1st = false; + } + *rows = 1; + for (int i = 0; i < pivot; ++i) { + if (A_dims[i] != B_dims[i]) { + return false; + } + *rows *= A_dims[i]; + } + return true; +} + +bool IsBothEndsBroadcastBinaryOp( + const int ndim, + const int* A_dims, + const int* B_dims, + int* pre, + int* mid, + int* nxt, + bool* broadcast_1st) { + if (ndim == 0) { + return false; + } + int A_pre = 0; + for (; A_pre < ndim && A_dims[A_pre] == 1; ++A_pre) + ; + int B_pre = 0; + for (; B_pre < ndim && B_dims[B_pre] == 1; ++B_pre) + ; + int A_nxt = ndim - 1; + for (; A_nxt >= 0 && A_dims[A_nxt] == 1; --A_nxt) + ; + int B_nxt = ndim - 1; + for (; B_nxt >= 0 && B_dims[B_nxt] == 1; --B_nxt) + ; + ++A_nxt; + ++B_nxt; + if (A_pre == B_pre || A_nxt == B_nxt) { + return false; + } + if (A_pre > B_pre && A_nxt < B_nxt) { + *pre = std::accumulate( + B_dims + B_pre, B_dims + A_pre, 1, std::multiplies()); + *nxt = std::accumulate( + B_dims + A_nxt, B_dims + B_nxt, 1, std::multiplies()); + *broadcast_1st = true; + } else if (A_pre < B_pre && A_nxt > B_nxt) { + *pre = std::accumulate( + A_dims + A_pre, A_dims + B_pre, 1, std::multiplies()); + *nxt = std::accumulate( + A_dims + B_nxt, A_dims + A_nxt, 1, std::multiplies()); + *broadcast_1st = false; + } else { + return false; + } + const int l = std::max(A_pre, B_pre); + const int r = std::min(A_nxt, B_nxt); + *mid = 1; + for (int i = l; i < r; ++i) { + if (A_dims[i] != B_dims[i]) { + return false; + } + *mid *= A_dims[i]; + } + return true; +} + +bool IsBatchTranspose2D(const int ndim, const int* axes) { + if (ndim < 2) { + return false; + } + for (int i = 0; i < ndim - 2; ++i) { + if (axes[i] != i) { + return false; + } + } + return axes[ndim - 2] == ndim - 1 && axes[ndim - 1] == ndim - 2; +} + +void ComputeTransposeAxesForReduceOp( + const int num_dims, + const int num_reduce_axes, + const int* reduce_axes, + int* transpose_axes) { + const int d = num_dims - num_reduce_axes; + std::copy_n(reduce_axes, num_reduce_axes, transpose_axes + d); + std::sort(transpose_axes + d, transpose_axes + num_dims); + int p = 0; + int q = d; + for (int i = 0; i < num_dims; ++i) { + if (q < num_dims && i == transpose_axes[q]) { + ++q; + } else { + transpose_axes[p++] = i; + } + } +} + +void ComputeTransposeAxesForReduceOp( + const int ndim, + const int* dims, + int* axes) { + const int d = ndim - std::count(dims, dims + ndim, 1); + int p = 0; + int q = d; + for (int i = 0; i < ndim; ++i) { + if (dims[i] == 1) { + axes[q++] = i; + } else { + axes[p++] = i; + } + } +} + +void ComputeTransposedStrides( + const int ndim, + const int* dims, + const int* axes, + int* strides) { + std::vector buff(ndim); + int cur_stride = 1; + for (int i = ndim - 1; i >= 0; --i) { + buff[i] = cur_stride; + cur_stride *= dims[i]; + } + for (int i = 0; i < ndim; ++i) { + strides[i] = buff[axes[i]]; + } +} + +} // namespace utils +} // namespace math +} // namespace caffe2 diff --git a/caffe2/utils/math/utils.h b/caffe2/utils/math/utils.h new file mode 100644 index 0000000000..b704adb188 --- /dev/null +++ b/caffe2/utils/math/utils.h @@ -0,0 +1,178 @@ +#ifndef CAFFE2_UTILS_MATH_UTILS_H_ +#define CAFFE2_UTILS_MATH_UTILS_H_ + +#include "caffe2/core/common.h" + +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) || \ + defined(__HIP__) +#define MATH_UTILS_DECL inline __host__ __device__ +#else +#define MATH_UTILS_DECL inline +#endif + +namespace caffe2 { +namespace math { + +namespace utils { + +template +MATH_UTILS_DECL T Not(const T x) { + return !x; +} + +template +MATH_UTILS_DECL T Sign(const T x) { + return x > 0 ? T(1) : (x < 0 ? T(-1) : T(0)); +} + +template +MATH_UTILS_DECL T Negate(const T x) { + return -x; +} + +template +MATH_UTILS_DECL T Inv(const T x) { + return T(1) / x; +} + +template +MATH_UTILS_DECL T Square(const T x) { + return x * x; +} + +template +MATH_UTILS_DECL T Cube(const T x) { + return x * x * x; +} + +// Function uses casting from int to unsigned to compare if value of +// parameter a is greater or equal to zero and lower than value of +// parameter b. The b parameter is of type signed and is always +// positive, +// therefore its value is always lower than 0x800... where casting +// negative value of a parameter converts it to value higher than +// 0x800... +// The casting allows to use one condition instead of two. +MATH_UTILS_DECL bool IsAGeZeroAndALtB(const int a, const int b) { + return static_cast(a) < static_cast(b); +} + +// Increase the index digits by one based on dims. +CAFFE2_API void IncreaseIndexInDims(const int n, const int* dims, int* index); + +// Get index value from dims and index digits. +CAFFE2_API int GetIndexFromDims(const int n, const int* dims, const int* index); + +// Checks if the input permutation is an identity permutation; +CAFFE2_API bool IsIdentityPermutation(const int n, const int* perm); + +CAFFE2_API bool +CheckReduceDims(const int ndim, const int* X_dims, const int* Y_dims); + +CAFFE2_API bool IsRowwiseReduce( + const int ndim, + const int* X_dims, + const int* Y_dims, + int* rows, + int* cols); + +CAFFE2_API bool IsColwiseReduce( + const int ndim, + const int* X_dims, + const int* Y_dims, + int* rows, + int* cols); + +CAFFE2_API bool IsBothEndsReduce( + const int ndim, + const int* X_dims, + const int* Y_dims, + int* pre, + int* mid, + int* nxt); + +// Computest the broadcast binary operation dims. +CAFFE2_API void ComputeBroadcastBinaryOpDims( + const int A_ndim, + const int* A_dims, + const int B_ndim, + const int* B_dims, + int* A_broadcast_dims, + int* B_broadcast_dims, + int* C_broadcast_dims); + +CAFFE2_API bool IsRowwiseBroadcastBinaryOp( + const int ndim, + const int* A_dims, + const int* B_dims, + int* rows, + int* cols, + bool* broadcast_1st); + +CAFFE2_API bool IsColwiseBroadcastBinaryOp( + const int ndim, + const int* A_dims, + const int* B_dims, + int* rows, + int* cols, + bool* broadcast_1st); + +CAFFE2_API bool IsBothEndsBroadcastBinaryOp( + const int ndim, + const int* A_dims, + const int* B_dims, + int* pre, + int* mid, + int* nxt, + bool* broadcast_1st); + +CAFFE2_API bool IsBatchTranspose2D(const int ndim, const int* axes); + +CAFFE2_API void ComputeTransposeAxesForReduceOp( + const int num_dims, + const int num_reduce_axes, + const int* reduce_axes, + int* transpose_axes); + +CAFFE2_API void +ComputeTransposeAxesForReduceOp(const int ndim, const int* dims, int* axes); + +CAFFE2_API void ComputeTransposedStrides( + const int ndim, + const int* dims, + const int* axes, + int* strides); + +} // namespace utils + +// Calculates ceil(a / b). User must be careful to ensure that there +// is no overflow or underflow in the calculation. +template +constexpr T DivUp(const T a, const T b) { + return (a + b - T(1)) / b; +} + +// Rounds a up to the next highest multiple of b. User must be careful +// to ensure that there is no overflow or underflow in the calculation +// of divUp. +template +constexpr T RoundUp(const T a, const T b) { + return DivUp(a, b) * b; +} + +// Returns log2(n) for a positive integer type +template +constexpr int IntegerLog2(T n, int p = 0) { + return (n <= 1) ? p : IntegerLog2(n / 2, p + 1); +} + +// Returns the next highest power-of-2 for an integer type +template +constexpr T IntegerNextHighestPowerOf2(T v) { + return (IntegerIsPowerOf2(v) ? T(2) * v : (T(1) << (IntegerLog2(v) + 1))); +} + +} // namespace math +} // namespace caffe2 + +#endif // CAFFE2_UTILS_MATH_UTILS_H_ diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc index 0adb9239a3..3bbc1ac4ed 100644 --- a/caffe2/utils/math_cpu.cc +++ b/caffe2/utils/math_cpu.cc @@ -622,66 +622,6 @@ C10_EXPORT void GemmStridedBatched( #endif } -//////////////////////////////////////////////////////////////////////////////// -// MKL VML alternatives. -// Depending on whether we are using MKL, we will delegate the Caffe math -// functions that are VML-related to either the VML call or the Eigen -// implementation. If you are setting the flags (such as AVX) right for your CPU -// architecture, usually Eigen will deliver a throughput as fast as the VML -// functions. -//////////////////////////////////////////////////////////////////////////////// -#ifdef CAFFE2_USE_MKL - -#define DELEGATE_SIMPLE_BINARY_FUNCTION(T, Func, FuncImpl) \ - template <> \ - C10_EXPORT void Func( \ - const int N, const T* A, const T* B, T* C, CPUContext*) { \ - FuncImpl(N, A, B, C); \ - } -DELEGATE_SIMPLE_BINARY_FUNCTION(float, Add, vsAdd) -DELEGATE_SIMPLE_BINARY_FUNCTION(double, Add, vdAdd) -DELEGATE_SIMPLE_BINARY_FUNCTION(float, Sub, vsSub) -DELEGATE_SIMPLE_BINARY_FUNCTION(double, Sub, vdSub) -DELEGATE_SIMPLE_BINARY_FUNCTION(float, Mul, vsMul) -DELEGATE_SIMPLE_BINARY_FUNCTION(double, Mul, vdMul) -DELEGATE_SIMPLE_BINARY_FUNCTION(float, Div, vsDiv) -DELEGATE_SIMPLE_BINARY_FUNCTION(double, Div, vdDiv) -#undef DELEGATE_SIMPLE_BINARY_FUNCTION - -#endif // CAFFE2_USE_MKL - -#define EIGEN_SIMPLE_BINARY_FUNCTION(T, Func, expr) \ - template <> \ - C10_EXPORT void Func( \ - const int N, const T* A, const T* B, T* C, CPUContext*) { \ - EigenVectorMap(C, N) = ConstEigenVectorArrayMap(A, N) \ - expr ConstEigenVectorArrayMap(B, N); \ - } - -#ifdef CAFFE2_USE_MKL - -#define DEFINE_SIMPLE_BINARY_FUNCTION(Func, expr) \ - EIGEN_SIMPLE_BINARY_FUNCTION(std::int32_t, Func, expr) \ - EIGEN_SIMPLE_BINARY_FUNCTION(std::int64_t, Func, expr) - -#else - -#define DEFINE_SIMPLE_BINARY_FUNCTION(Func, expr) \ - EIGEN_SIMPLE_BINARY_FUNCTION(float, Func, expr) \ - EIGEN_SIMPLE_BINARY_FUNCTION(double, Func, expr) \ - EIGEN_SIMPLE_BINARY_FUNCTION(std::int32_t, Func, expr) \ - EIGEN_SIMPLE_BINARY_FUNCTION(std::int64_t, Func, expr) - -#endif - -DEFINE_SIMPLE_BINARY_FUNCTION(Add, +) -DEFINE_SIMPLE_BINARY_FUNCTION(Sub, -) -DEFINE_SIMPLE_BINARY_FUNCTION(Mul, *) -DEFINE_SIMPLE_BINARY_FUNCTION(Div, /) - -#undef DEFINE_SIMPLE_BINARY_FUNCTION -#undef EIGEN_SIMPLE_BINARY_FUNCTION - //////////////////////////////////////////////////////////////////////////////// // Common math functions being used in Caffe that do not have a BLAS or MKL // equivalent. For all these functions, we will simply implement them either via @@ -1332,17 +1272,6 @@ CAFFE2_SPECIALIZED_ROWWISEMAX(float) CAFFE2_SPECIALIZED_COLWISEMAX(float) #undef CAFFE2_SPECIALIZED_COLWISEMAX -#define CAFFE2_SPECIALIZED_ELEMWISEMAX(T) \ - template <> \ - C10_EXPORT void ElemwiseMax( \ - const int N, const T* x, const T* y, T* z, CPUContext* /*context*/) { \ - std::transform(x, x + N, y, z, [](const T& x_i, const T& y_i) { \ - return std::max(x_i, y_i); \ - }); \ - } -CAFFE2_SPECIALIZED_ELEMWISEMAX(float) -#undef CAFFE2_SPECIALIZED_ELEMWISEMAX - #define CAFFE2_SPECIALIZED_MAXIMUM(T) \ template <> \ C10_EXPORT void Maximum( \ @@ -1609,46 +1538,6 @@ C10_EXPORT void BroadcastBinaryOpImpl( } // namespace -#define DELEGATE_1D_BINARY_FUNCTION(TIn, TOut, Func, Op) \ - template <> \ - C10_EXPORT void Func( \ - const int N, const TIn* A, const TIn* B, TOut* C, CPUContext*) { \ - std::transform(A, A + N, B, C, Op()); \ - } - -#define DEFINE_1D_COMPARE_FUNCTION(Func, Op) \ - DELEGATE_1D_BINARY_FUNCTION(float, bool, Func, Op) \ - DELEGATE_1D_BINARY_FUNCTION(double, bool, Func, Op) \ - DELEGATE_1D_BINARY_FUNCTION(std::int32_t, bool, Func, Op) \ - DELEGATE_1D_BINARY_FUNCTION(std::int64_t, bool, Func, Op) \ - DELEGATE_1D_BINARY_FUNCTION(bool, bool, Func, Op) - -DEFINE_1D_COMPARE_FUNCTION(EQ, std::equal_to) -DEFINE_1D_COMPARE_FUNCTION(NE, std::not_equal_to) -DEFINE_1D_COMPARE_FUNCTION(LT, std::less) -DEFINE_1D_COMPARE_FUNCTION(LE, std::less_equal) -DEFINE_1D_COMPARE_FUNCTION(GT, std::greater) -DEFINE_1D_COMPARE_FUNCTION(GE, std::greater_equal) - -#undef DEFINE_1D_COMPARE_FUNCTION - -DELEGATE_1D_BINARY_FUNCTION(bool, bool, And, std::logical_and) -DELEGATE_1D_BINARY_FUNCTION(bool, bool, Or, std::logical_or) -DELEGATE_1D_BINARY_FUNCTION(bool, bool, Xor, std::bit_xor) - -#define DEFINE_1D_BITWISE_BINARY_FUNCTION(Func, op) \ - DELEGATE_1D_BINARY_FUNCTION(bool, bool, Func, op) \ - DELEGATE_1D_BINARY_FUNCTION(std::int32_t, std::int32_t, Func, op) \ - DELEGATE_1D_BINARY_FUNCTION(std::int64_t, std::int64_t, Func, op) - -DEFINE_1D_BITWISE_BINARY_FUNCTION(BitwiseAnd, std::bit_and) -DEFINE_1D_BITWISE_BINARY_FUNCTION(BitwiseOr, std::bit_or) -DEFINE_1D_BITWISE_BINARY_FUNCTION(BitwiseXor, std::bit_xor) - -#undef DEFINE_1D_BITWISE_BINARY_FUNCTION - -#undef DELEGATE_1D_BINARY_FUNCTION - #define DELEGATE_2D_BROADCAST_BINARY_FUNCTION(TIn, TOut, Func, Op) \ template <> \ C10_EXPORT void Rowwise##Func( \ diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu index e93e6924a9..f8ccda7053 100644 --- a/caffe2/utils/math_gpu.cu +++ b/caffe2/utils/math_gpu.cu @@ -41,7 +41,7 @@ using CUBLAS_HALF_TYPE = rocblas_half; using CUBLAS_HALF_TYPE = __half; #endif // __HIP_PLATFORM_HCC -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math/utils.h" #if THRUST_VERSION >= 100800 #define THRUST_SUPPORTS_PER_THREAD @@ -302,74 +302,6 @@ CAFFE2_CUDA_EXPORT void BroadcastBinaryOp( } // namespace -#define DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(TIn, TOut, Func, Op) \ - template <> \ - CAFFE2_CUDA_EXPORT void Func( \ - const int N, \ - const TIn* A, \ - const TIn* B, \ - TOut* C, \ - CUDAContext* context) { \ - SimpleBinaryOpCUDAKernel> \ - <<cuda_stream()>>>(N, Op(), A, B, C); \ - } - -#define DEFINE_SIMPLE_CUDA_COMPARE_FUNCTION(Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(std::int32_t, bool, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(std::int64_t, bool, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, bool, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, bool, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, bool, Func, Op) - -DEFINE_SIMPLE_CUDA_COMPARE_FUNCTION(EQ, thrust::equal_to) -DEFINE_SIMPLE_CUDA_COMPARE_FUNCTION(NE, thrust::not_equal_to) -DEFINE_SIMPLE_CUDA_COMPARE_FUNCTION(LT, thrust::less) -DEFINE_SIMPLE_CUDA_COMPARE_FUNCTION(LE, thrust::less_equal) -DEFINE_SIMPLE_CUDA_COMPARE_FUNCTION(GT, thrust::greater) -DEFINE_SIMPLE_CUDA_COMPARE_FUNCTION(GE, thrust::greater_equal) - -#undef DEFINE_SIMPLE_CUDA_COMPARE_FUNCTION - -#define DEFINE_SIMPLE_CUDA_BINARY_FUNCTION(Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(std::int32_t, std::int32_t, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(std::int64_t, std::int64_t, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, float, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, double, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(at::Half, at::Half, Func, Op) - -DEFINE_SIMPLE_CUDA_BINARY_FUNCTION(Add, AddFunctor) -DEFINE_SIMPLE_CUDA_BINARY_FUNCTION(Sub, SubFunctor) -DEFINE_SIMPLE_CUDA_BINARY_FUNCTION(Mul, MulFunctor) -DEFINE_SIMPLE_CUDA_BINARY_FUNCTION(Div, DivFunctor) - -#undef DEFINE_SIMPLE_CUDA_BINARY_FUNCTION - -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, bool, And, thrust::logical_and) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, bool, Or, thrust::logical_or) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, bool, Xor, thrust::bit_xor) - -#define DEFINE_SIMPLE_CUDA_BITWISE_BINARY_FUNCTION(Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, bool, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(std::int32_t, std::int32_t, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(std::int64_t, std::int64_t, Func, Op) - -DEFINE_SIMPLE_CUDA_BITWISE_BINARY_FUNCTION(BitwiseAnd, thrust::bit_and) -DEFINE_SIMPLE_CUDA_BITWISE_BINARY_FUNCTION(BitwiseOr, thrust::bit_or) -DEFINE_SIMPLE_CUDA_BITWISE_BINARY_FUNCTION(BitwiseXor, thrust::bit_xor) - -#undef DEFINE_SIMPLE_CUDA_BITWISE_BINARY_FUNCTION - -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - float, - float, - ElemwiseMax, - thrust::maximum); - -#undef DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION - #define DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(TIn, TOut, Func, Op) \ template <> \ CAFFE2_CUDA_EXPORT void Rowwise##Func( \ diff --git a/caffe2/utils/math_gpu_test.cc b/caffe2/utils/math_gpu_test.cc index 80d7ff0291..99e9f69e33 100644 --- a/caffe2/utils/math_gpu_test.cc +++ b/caffe2/utils/math_gpu_test.cc @@ -204,24 +204,6 @@ TEST(MathUtilGPUTest, testReduceMax) { [](int /*i*/) { return 17.0f; }); } -TEST(MathUtilGPUTest, testElemwiseMax) { - executeGpuBinaryOpTest( - 13, - 13, - 13, - [](int i) { return 2.0f - i; }, - [](int i) { return i - 6.0f; }, - [](int N0, - int /*N1*/, - const float* src0, - const float* src1, - float* dst, - CUDAContext* context) { - math::ElemwiseMax(N0, src0, src1, dst, context); - }, - [](int i) { return std::max(2.0f - i, i - 6.0f); }); -} - TEST(MathUtilGPUTest, testCopyVector) { executeGpuBinaryOpTest( 6, diff --git a/caffe2/utils/math_utils.cc b/caffe2/utils/math_utils.cc deleted file mode 100644 index f4f7da99f3..0000000000 --- a/caffe2/utils/math_utils.cc +++ /dev/null @@ -1,347 +0,0 @@ -#include "caffe2/utils/math_utils.h" - -#include -#include -#include -#include - -#include "caffe2/core/logging.h" - -namespace caffe2 { -namespace math { -namespace utils { - -void IncreaseIndexInDims(const int n, const int* dims, int* index) { - for (int i = n - 1; i >= 0; --i) { - ++index[i]; - if (index[i] >= dims[i]) { - index[i] -= dims[i]; - } else { - break; - } - } -} - -int GetIndexFromDims(const int n, const int* dims, const int* index) { - int sum = 0; - for (int i = 0; i < n; ++i) { - if (dims[i] > 1) { - sum = sum * dims[i] + index[i]; - } - } - return sum; -} - -bool IsIdentityPermutation(const int n, const int* perm) { - for (int i = 0; i < n; ++i) { - if (perm[i] != i) { - return false; - } - } - return true; -} - -bool CheckReduceDims(const int ndim, const int* X_dims, const int* Y_dims) { - for (int i = 0; i < ndim; ++i) { - if (X_dims[i] != Y_dims[i] && Y_dims[i] != 1) { - return false; - } - } - return true; -} - -bool IsRowwiseReduce( - const int ndim, - const int* A_dims, - const int* B_dims, - int* rows, - int* cols) { - *cols = 1; - int pivot = ndim - 1; - for (; pivot >= 0 && B_dims[pivot] == 1; --pivot) { - *cols *= A_dims[pivot]; - } - *rows = 1; - for (int i = pivot; i >= 0; --i) { - if (A_dims[i] != B_dims[i]) { - return false; - } - *rows *= A_dims[i]; - } - return true; -} - -bool IsColwiseReduce( - const int ndim, - const int* A_dims, - const int* B_dims, - int* rows, - int* cols) { - *rows = 1; - int pivot = 0; - for (; pivot < ndim && B_dims[pivot] == 1; ++pivot) { - *rows *= A_dims[pivot]; - } - *cols = 1; - for (int i = pivot; i < ndim; ++i) { - if (A_dims[i] != B_dims[i]) { - return false; - } - *cols *= A_dims[i]; - } - return true; -} - -bool IsBothEndsReduce( - const int ndim, - const int* A_dims, - const int* B_dims, - int* pre, - int* mid, - int* nxt) { - *nxt = 1; - int r = ndim - 1; - for (; r >= 0 && B_dims[r] == 1; --r) { - *nxt *= A_dims[r]; - } - *pre = 1; - int l = 0; - for (; l <= r && B_dims[l] == 1; ++l) { - *pre *= A_dims[l]; - } - *mid = 1; - for (int i = l; i <= r; ++i) { - if (A_dims[i] != B_dims[i]) { - return false; - } - *mid *= A_dims[i]; - } - return true; -} - -void ComputeBroadcastBinaryOpDims( - const int A_ndim, - const int* A_dims, - const int B_ndim, - const int* B_dims, - int* A_broadcast_dims, - int* B_broadcast_dims, - int* C_broadcast_dims) { - const int ndim = std::max(A_ndim, B_ndim); - std::fill(A_broadcast_dims, A_broadcast_dims + ndim - A_ndim, 1); - std::fill(B_broadcast_dims, B_broadcast_dims + ndim - B_ndim, 1); - std::copy(A_dims, A_dims + A_ndim, A_broadcast_dims + ndim - A_ndim); - std::copy(B_dims, B_dims + B_ndim, B_broadcast_dims + ndim - B_ndim); - for (int i = 0; i < ndim; ++i) { - CAFFE_ENFORCE( - A_broadcast_dims[i] == B_broadcast_dims[i] || - A_broadcast_dims[i] <= 1 || B_broadcast_dims[i] <= 1); - if (A_broadcast_dims[i] == 0 || B_broadcast_dims[i] == 0) { - C_broadcast_dims[i] = 0; - } else { - C_broadcast_dims[i] = std::max(A_broadcast_dims[i], B_broadcast_dims[i]); - } - } -} - -bool IsRowwiseBroadcastBinaryOp( - const int ndim, - const int* A_dims, - const int* B_dims, - int* rows, - int* cols, - bool* broadcast_1st) { - if (ndim == 0) { - return false; - } - int A_pivot = 0; - for (; A_pivot < ndim && A_dims[A_pivot] == 1; ++A_pivot) - ; - int B_pivot = 0; - for (; B_pivot < ndim && B_dims[B_pivot] == 1; ++B_pivot) - ; - if (A_pivot == B_pivot) { - return false; - } - const int pivot = std::max(A_pivot, B_pivot); - if (A_pivot > B_pivot) { - *rows = std::accumulate( - B_dims + B_pivot, B_dims + pivot, 1, std::multiplies()); - *broadcast_1st = true; - } else { - *rows = std::accumulate( - A_dims + A_pivot, A_dims + pivot, 1, std::multiplies()); - *broadcast_1st = false; - } - *cols = 1; - for (int i = pivot; i < ndim; ++i) { - if (A_dims[i] != B_dims[i]) { - return false; - } - *cols *= A_dims[i]; - } - return true; -} - -bool IsColwiseBroadcastBinaryOp( - const int ndim, - const int* A_dims, - const int* B_dims, - int* rows, - int* cols, - bool* broadcast_1st) { - if (ndim == 0) { - return false; - } - int A_pivot = ndim - 1; - for (; A_pivot >= 0 && A_dims[A_pivot] == 1; --A_pivot) - ; - int B_pivot = ndim - 1; - for (; B_pivot >= 0 && B_dims[B_pivot] == 1; --B_pivot) - ; - if (A_pivot == B_pivot) { - return false; - } - ++A_pivot; - ++B_pivot; - const int pivot = std::min(A_pivot, B_pivot); - if (A_pivot < B_pivot) { - *cols = std::accumulate( - B_dims + pivot, B_dims + B_pivot, 1, std::multiplies()); - *broadcast_1st = true; - } else { - *cols = std::accumulate( - A_dims + pivot, A_dims + A_pivot, 1, std::multiplies()); - *broadcast_1st = false; - } - *rows = 1; - for (int i = 0; i < pivot; ++i) { - if (A_dims[i] != B_dims[i]) { - return false; - } - *rows *= A_dims[i]; - } - return true; -} - -bool IsBothEndsBroadcastBinaryOp( - const int ndim, - const int* A_dims, - const int* B_dims, - int* pre, - int* mid, - int* nxt, - bool* broadcast_1st) { - if (ndim == 0) { - return false; - } - int A_pre = 0; - for (; A_pre < ndim && A_dims[A_pre] == 1; ++A_pre) - ; - int B_pre = 0; - for (; B_pre < ndim && B_dims[B_pre] == 1; ++B_pre) - ; - int A_nxt = ndim - 1; - for (; A_nxt >= 0 && A_dims[A_nxt] == 1; --A_nxt) - ; - int B_nxt = ndim - 1; - for (; B_nxt >= 0 && B_dims[B_nxt] == 1; --B_nxt) - ; - ++A_nxt; - ++B_nxt; - if (A_pre == B_pre || A_nxt == B_nxt) { - return false; - } - if (A_pre > B_pre && A_nxt < B_nxt) { - *pre = std::accumulate( - B_dims + B_pre, B_dims + A_pre, 1, std::multiplies()); - *nxt = std::accumulate( - B_dims + A_nxt, B_dims + B_nxt, 1, std::multiplies()); - *broadcast_1st = true; - } else if (A_pre < B_pre && A_nxt > B_nxt) { - *pre = std::accumulate( - A_dims + A_pre, A_dims + B_pre, 1, std::multiplies()); - *nxt = std::accumulate( - A_dims + B_nxt, A_dims + A_nxt, 1, std::multiplies()); - *broadcast_1st = false; - } else { - return false; - } - const int l = std::max(A_pre, B_pre); - const int r = std::min(A_nxt, B_nxt); - *mid = 1; - for (int i = l; i < r; ++i) { - if (A_dims[i] != B_dims[i]) { - return false; - } - *mid *= A_dims[i]; - } - return true; -} - -bool IsBatchTranspose2D(const int ndim, const int* axes) { - if (ndim < 2) { - return false; - } - for (int i = 0; i < ndim - 2; ++i) { - if (axes[i] != i) { - return false; - } - } - return axes[ndim - 2] == ndim - 1 && axes[ndim - 1] == ndim - 2; -} - -void ComputeTransposeAxesForReduceOp( - const int num_dims, - const int num_reduce_axes, - const int* reduce_axes, - int* transpose_axes) { - const int d = num_dims - num_reduce_axes; - std::copy_n(reduce_axes, num_reduce_axes, transpose_axes + d); - std::sort(transpose_axes + d, transpose_axes + num_dims); - int p = 0; - int q = d; - for (int i = 0; i < num_dims; ++i) { - if (q < num_dims && i == transpose_axes[q]) { - ++q; - } else { - transpose_axes[p++] = i; - } - } -} - -void ComputeTransposeAxesForReduceOp( - const int ndim, - const int* dims, - int* axes) { - const int d = ndim - std::count(dims, dims + ndim, 1); - int p = 0; - int q = d; - for (int i = 0; i < ndim; ++i) { - if (dims[i] == 1) { - axes[q++] = i; - } else { - axes[p++] = i; - } - } -} - -void ComputeTransposedStrides( - const int ndim, - const int* dims, - const int* axes, - int* strides) { - std::vector buff(ndim); - int cur_stride = 1; - for (int i = ndim - 1; i >= 0; --i) { - buff[i] = cur_stride; - cur_stride *= dims[i]; - } - for (int i = 0; i < ndim; ++i) { - strides[i] = buff[axes[i]]; - } -} - -} // namespace utils -} // namespace math -} // namespace caffe2 diff --git a/caffe2/utils/math_utils.h b/caffe2/utils/math_utils.h deleted file mode 100644 index b3fdb14884..0000000000 --- a/caffe2/utils/math_utils.h +++ /dev/null @@ -1,178 +0,0 @@ -#ifndef CAFFE2_UTILS_MATH_UTILS_H_ -#define CAFFE2_UTILS_MATH_UTILS_H_ - -#include "caffe2/core/common.h" - -// See Note [hip-clang differences to hcc] - -#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) || defined(__HIP__) -#define MATH_UTILS_DECL inline __host__ __device__ -#else -#define MATH_UTILS_DECL inline -#endif - -namespace caffe2 { -namespace math { - -namespace utils { - -MATH_UTILS_DECL bool Not(const bool x) { - return !x; -} - -template -MATH_UTILS_DECL T Sign(const T x) { - return x > 0 ? T(1) : (x < 0 ? T(-1) : T(0)); -} - -template -MATH_UTILS_DECL T Negate(const T x) { - return -x; -} - -template -MATH_UTILS_DECL T Inv(const T x) { - return T(1) / x; -} - -template -MATH_UTILS_DECL T Square(const T x) { - return x * x; -} - -template -MATH_UTILS_DECL T Cube(const T x) { - return x * x * x; -} - -// Function uses casting from int to unsigned to compare if value of -// parameter a is greater or equal to zero and lower than value of -// parameter b. The b parameter is of type signed and is always -// positive, -// therefore its value is always lower than 0x800... where casting -// negative value of a parameter converts it to value higher than -// 0x800... -// The casting allows to use one condition instead of two. -MATH_UTILS_DECL bool IsAGeZeroAndALtB(const int a, const int b) { - return static_cast(a) < static_cast(b); -} - -// Increase the index digits by one based on dims. -CAFFE2_API void IncreaseIndexInDims(const int n, const int* dims, int* index); - -// Get index value from dims and index digits. -CAFFE2_API int GetIndexFromDims(const int n, const int* dims, const int* index); - -// Checks if the input permutation is an identity permutation; -CAFFE2_API bool IsIdentityPermutation(const int n, const int* perm); - -CAFFE2_API bool -CheckReduceDims(const int ndim, const int* X_dims, const int* Y_dims); - -CAFFE2_API bool IsRowwiseReduce( - const int ndim, - const int* X_dims, - const int* Y_dims, - int* rows, - int* cols); - -CAFFE2_API bool IsColwiseReduce( - const int ndim, - const int* X_dims, - const int* Y_dims, - int* rows, - int* cols); - -CAFFE2_API bool IsBothEndsReduce( - const int ndim, - const int* X_dims, - const int* Y_dims, - int* pre, - int* mid, - int* nxt); - -// Computest the broadcast binary operation dims. -CAFFE2_API void ComputeBroadcastBinaryOpDims( - const int A_ndim, - const int* A_dims, - const int B_ndim, - const int* B_dims, - int* A_broadcast_dims, - int* B_broadcast_dims, - int* C_broadcast_dims); - -CAFFE2_API bool IsRowwiseBroadcastBinaryOp( - const int ndim, - const int* A_dims, - const int* B_dims, - int* rows, - int* cols, - bool* broadcast_1st); - -CAFFE2_API bool IsColwiseBroadcastBinaryOp( - const int ndim, - const int* A_dims, - const int* B_dims, - int* rows, - int* cols, - bool* broadcast_1st); - -CAFFE2_API bool IsBothEndsBroadcastBinaryOp( - const int ndim, - const int* A_dims, - const int* B_dims, - int* pre, - int* mid, - int* nxt, - bool* broadcast_1st); - -CAFFE2_API bool IsBatchTranspose2D(const int ndim, const int* axes); - -CAFFE2_API void ComputeTransposeAxesForReduceOp( - const int num_dims, - const int num_reduce_axes, - const int* reduce_axes, - int* transpose_axes); - -CAFFE2_API void -ComputeTransposeAxesForReduceOp(const int ndim, const int* dims, int* axes); - -CAFFE2_API void ComputeTransposedStrides( - const int ndim, - const int* dims, - const int* axes, - int* strides); - -} // namespace utils - -// Calculates ceil(a / b). User must be careful to ensure that there -// is no overflow or underflow in the calculation. -template -constexpr T DivUp(const T a, const T b) { - return (a + b - T(1)) / b; -} - -// Rounds a up to the next highest multiple of b. User must be careful -// to ensure that there is no overflow or underflow in the calculation -// of divUp. -template -constexpr T RoundUp(const T a, const T b) { - return DivUp(a, b) * b; -} - -// Returns log2(n) for a positive integer type -template -constexpr int IntegerLog2(T n, int p = 0) { - return (n <= 1) ? p : IntegerLog2(n / 2, p + 1); -} - -// Returns the next highest power-of-2 for an integer type -template -constexpr T IntegerNextHighestPowerOf2(T v) { - return (IntegerIsPowerOf2(v) ? T(2) * v : (T(1) << (IntegerLog2(v) + 1))); -} - -} // namespace math -} // namespace caffe2 - -#endif // CAFFE2_UTILS_MATH_UTILS_H_ -- cgit v1.2.3