diff options
author | Xiaomeng Yang <yangxm@fb.com> | 2019-02-07 18:19:46 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-07 18:38:26 -0800 |
commit | 2db847b3a7edc48652e144e7c9d7aa0bbed66aaa (patch) | |
tree | 3d95b2f625ca878f038092dea671a2bffed088c6 /caffe2 | |
parent | 22477c6a7fc327cdf913751ab31b788ec710caa9 (diff) | |
download | pytorch-2db847b3a7edc48652e144e7c9d7aa0bbed66aaa.tar.gz pytorch-2db847b3a7edc48652e144e7c9d7aa0bbed66aaa.tar.bz2 pytorch-2db847b3a7edc48652e144e7c9d7aa0bbed66aaa.zip |
Separate elementwise level2 math functions (#16753)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16753
Separate elementwise level2 math functions
i-am-not-moving-c2-to-c10
Reviewed By: houseroad
Differential Revision: D13954928
fbshipit-source-id: 1ca7a5d3da96e32510f502e5e4e79168854bee67
Diffstat (limited to 'caffe2')
25 files changed, 1038 insertions, 737 deletions
diff --git a/caffe2/operators/conv_transpose_op_mobile_impl.h b/caffe2/operators/conv_transpose_op_mobile_impl.h index f586453d5c..6869f5178d 100644 --- a/caffe2/operators/conv_transpose_op_mobile_impl.h +++ b/caffe2/operators/conv_transpose_op_mobile_impl.h @@ -18,7 +18,7 @@ #include "caffe2/utils/eigen_utils.h" #include "caffe2/utils/fixed_divisor.h" #include "caffe2/utils/math.h" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math/utils.h" C10_DECLARE_bool(caffe2_force_shared_col_buffer); diff --git a/caffe2/operators/layer_norm_op.cu b/caffe2/operators/layer_norm_op.cu index c3465c1d2d..440783c6eb 100644 --- a/caffe2/operators/layer_norm_op.cu +++ b/caffe2/operators/layer_norm_op.cu @@ -4,7 +4,7 @@ #include "caffe2/core/context_gpu.h" #include "caffe2/utils/math.h" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math/utils.h" namespace caffe2 { diff --git a/caffe2/operators/minmax_gradient_ops.cc b/caffe2/operators/minmax_gradient_ops.cc index 9c7df22d55..d288eb0946 100644 --- a/caffe2/operators/minmax_gradient_ops.cc +++ b/caffe2/operators/minmax_gradient_ops.cc @@ -1,66 +1,66 @@ #include "caffe2/operators/minmax_ops.h" -#include "caffe2/utils/eigen_utils.h" -namespace caffe2 { +#include <string> +#include <vector> -REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp<float, CPUContext>); -REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp<float, CPUContext>); +#include "caffe2/utils/eigen_utils.h" -OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX); -OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX); +namespace caffe2 { template <typename T, class Context> bool SelectGradientOpBase<T, Context>::RunOnDevice() { - auto& output = Input(0); - auto& grad_output = Input(1); - const int kInputStartOffset = 2; - - const T* data = output.template data<T>(); - ConstEigenArrayMap<T> output_array( - output.template data<T>(), 1, output.numel()); - ConstEigenArrayMap<T> grad_out_array( - grad_output.template data<T>(), 1, grad_output.numel()); - + const auto& Y = Input(0); + const auto& dY = Input(1); + const int N = Y.numel(); + ConstEigenVectorArrayMap<T> Y_arr(Y.template data<T>(), N); + ConstEigenVectorArrayMap<T> dY_arr(dY.template data<T>(), N); for (int i = 0; i < OutputSize(); i++) { - auto& input = Input(i + kInputStartOffset); - ConstEigenArrayMap<T> input_array( - input.template data<T>(), 1, input.numel()); - - auto* grad_input = Output(i, input.sizes(), at::dtype<T>()); - EigenArrayMap<T> grad_in_array( - grad_input->template mutable_data<T>(), 1, grad_input->numel()); - grad_in_array = grad_out_array * - input_array.cwiseEqual(output_array).template cast<T>(); + const auto& Xi = Input(i + 2); + auto* dXi = Output(i, Xi.sizes(), at::dtype<T>()); + ConstEigenVectorArrayMap<T> Xi_arr(Xi.template data<T>(), N); + EigenVectorArrayMap<T> dXi_arr(dXi->template mutable_data<T>(), N); + dXi_arr = (Xi_arr == Y_arr).template cast<T>() * dY_arr; } return true; } +REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp<float, CPUContext>); +REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp<float, CPUContext>); + +OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX); +OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX); + +namespace { + class GetMaxGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; - vector<OperatorDef> GetGradientDefs() override { - auto gradInputs = vector<string>(); - auto inputs = vector<string>{O(0), GO(0)}; - for (int i = 0; i < def_.input_size(); i++) { - gradInputs.push_back(GI(i)); + std::vector<OperatorDef> GetGradientDefs() override { + std::vector<std::string> inputs = {O(0), GO(0)}; + std::vector<std::string> grad_inputs; + for (int i = 0; i < def_.input_size(); ++i) { inputs.push_back(I(i)); + grad_inputs.push_back(GI(i)); } - return SingleGradientDef("MaxGradient", "", inputs, gradInputs); + return SingleGradientDef("MaxGradient", "", inputs, grad_inputs); } }; -REGISTER_GRADIENT(Max, GetMaxGradient); class GetMinGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; vector<OperatorDef> GetGradientDefs() override { - auto gradInputs = vector<string>(); - auto inputs = vector<string>{O(0), GO(0)}; - for (int i = 0; i < def_.input_size(); i++) { - gradInputs.push_back(GI(i)); + std::vector<std::string> inputs = {O(0), GO(0)}; + std::vector<std::string> grad_inputs; + for (int i = 0; i < def_.input_size(); ++i) { inputs.push_back(I(i)); + grad_inputs.push_back(GI(i)); } - return SingleGradientDef("MinGradient", "", inputs, gradInputs); + return SingleGradientDef("MinGradient", "", inputs, grad_inputs); } }; + +} // namespace + +REGISTER_GRADIENT(Max, GetMaxGradient); REGISTER_GRADIENT(Min, GetMinGradient); } // namespace caffe2 diff --git a/caffe2/operators/minmax_ops.cc b/caffe2/operators/minmax_ops.cc index 1a105770d2..3dd2487335 100644 --- a/caffe2/operators/minmax_ops.cc +++ b/caffe2/operators/minmax_ops.cc @@ -1,10 +1,9 @@ #include "caffe2/operators/minmax_ops.h" -#include "caffe2/utils/eigen_utils.h" namespace caffe2 { -REGISTER_CPU_OPERATOR(Max, MaxOp<float, CPUContext>); REGISTER_CPU_OPERATOR(Min, MinOp<float, CPUContext>); +REGISTER_CPU_OPERATOR(Max, MaxOp<float, CPUContext>); OPERATOR_SCHEMA(Max) .NumInputs(1, INT_MAX) @@ -155,34 +154,4 @@ Min: "Contains the minimum valued element at each location.") .InheritOnnxSchema(); -template <typename T, class Context> -bool MaxOp<T, Context>::Compute() { - auto& input0 = Input(0); - const int N = input0.numel(); - T* output_data = Output(0)->template mutable_data<T>(); - - for (int i = 1; i < InputSize(); i++) { - auto input_data = Input(i).template data<T>(); - EigenVectorMap<T> output_vec(output_data, N); - output_vec = output_vec.cwiseMax(ConstEigenVectorMap<T>(input_data, N)); - } - - return true; -} - -template <typename T, class Context> -bool MinOp<T, Context>::Compute() { - auto& input0 = Input(0); - const int N = input0.numel(); - T* output_data = Output(0)->template mutable_data<T>(); - - for (int i = 1; i < InputSize(); i++) { - auto input_data = Input(i).template data<T>(); - EigenVectorMap<T> output_vec(output_data, N); - output_vec = output_vec.cwiseMin(ConstEigenVectorMap<T>(input_data, N)); - } - - return true; -} - } // namespace caffe2 diff --git a/caffe2/operators/minmax_ops.cu b/caffe2/operators/minmax_ops.cu new file mode 100644 index 0000000000..a853b700d1 --- /dev/null +++ b/caffe2/operators/minmax_ops.cu @@ -0,0 +1,56 @@ +#include "caffe2/operators/minmax_ops.h" + +#include "caffe2/core/context_gpu.h" +#include "caffe2/utils/math.h" + +namespace caffe2 { + +namespace { + +template <typename T> +__global__ void SelectGradientCUDAKernel( + const int N, + const T* dY, + const T* X, + const T* Y, + T* dX) { + const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (i < N) { +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + dX[i] = __ldg(X + i) == __ldg(Y + i) ? __ldg(dY + i) : T(0); +#else + dX[i] = X[i] == Y[i] ? dY[i] : T(0); +#endif + } +} + +} // namespace + +template <> +bool SelectGradientOpBase<float, CUDAContext>::RunOnDevice() { + const auto& Y = Input(0); + const auto& dY = Input(1); + const int N = Y.numel(); + const int M = math::DivUp(N, CAFFE_CUDA_NUM_THREADS); + const float* dY_data = dY.data<float>(); + const float* Y_data = Y.data<float>(); + for (int i = 0; i < OutputSize(); i++) { + const auto& Xi = Input(i + 2); + auto* dXi = Output(i, Xi.sizes(), at::dtype<float>()); + const float* Xi_data = Xi.data<float>(); + float* dXi_data = dXi->mutable_data<float>(); + if (N > 0) { + SelectGradientCUDAKernel<float> + <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>( + N, dY_data, Xi_data, Y_data, dXi_data); + } + } + return true; +} + +REGISTER_CUDA_OPERATOR(Min, MinOp<float, CUDAContext>); +REGISTER_CUDA_OPERATOR(MinGradient, MinGradientOp<float, CUDAContext>); +REGISTER_CUDA_OPERATOR(Max, MaxOp<float, CUDAContext>); +REGISTER_CUDA_OPERATOR(MaxGradient, MaxGradientOp<float, CUDAContext>); + +} // namespace caffe2 diff --git a/caffe2/operators/minmax_ops.h b/caffe2/operators/minmax_ops.h index db02e0dcb0..5668460682 100644 --- a/caffe2/operators/minmax_ops.h +++ b/caffe2/operators/minmax_ops.h @@ -1,7 +1,6 @@ #ifndef CAFFE2_OPERATORS_MINMAX_OPS_H_ #define CAFFE2_OPERATORS_MINMAX_OPS_H_ -#include "caffe2/core/common_omp.h" #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" @@ -11,49 +10,99 @@ namespace caffe2 { template <typename T, class Context> -class MaxMinOpBase : public Operator<Context> { +class MaxOp final : public Operator<Context> { public: USE_OPERATOR_CONTEXT_FUNCTIONS; - USE_SIMPLE_CTOR_DTOR(MaxMinOpBase) - bool RunOnDevice() override { - auto& input0 = Input(0); - auto* output = Output(0); - - output->ResizeLike(input0); - output->CopyFrom(input0, /* async */ true); + USE_SIMPLE_CTOR_DTOR(MaxOp) + bool RunOnDevice() override { + const auto& X0 = Input(0); + auto* Y = Output(0); + Y->ResizeLike(X0); + const T* X0_data = X0.template data<T>(); + T* Y_data = Y->template mutable_data<T>(); + const int N = X0.numel(); if (InputSize() == 1) { + if (Y != &X0) { + context_.template CopySameDevice<T>(N, X0_data, Y_data); + } return true; } - - // Dimension checking - for (int i = 1; i < InputSize(); ++i) { + const auto& X1 = Input(1); + CAFFE_ENFORCE_EQ( + X0.sizes(), + Y->sizes(), + "Description: Input #1, input dimension:", + X1.sizes(), + " should match output dimension: ", + Y->sizes()); + const T* X1_data = X1.template data<T>(); + math::Max<T, Context>(N, X0_data, X1_data, Y_data, &context_); + for (int i = 2; i < InputSize(); ++i) { + const auto& Xi = Input(i); CAFFE_ENFORCE_EQ( - output->sizes(), - Input(i).sizes(), + Xi.sizes(), + Y->sizes(), "Description: Input #", i, ", input dimension:", Input(i).sizes(), " should match output dimension: ", - output->sizes()); + Y->sizes()); + const T* Xi_data = Xi.template data<T>(); + math::Max<T, Context>(N, Y_data, Xi_data, Y_data, &context_); } - - return this->Compute(); + return true; } - - virtual bool Compute() = 0; }; template <typename T, class Context> -class MaxOp : public MaxMinOpBase<T, Context> { +class MinOp final : public Operator<Context> { public: USE_OPERATOR_CONTEXT_FUNCTIONS; - MaxOp(const OperatorDef& operator_def, Workspace* ws) - : MaxMinOpBase<T, Context>(operator_def, ws) {} - virtual ~MaxOp() noexcept {} - bool Compute() override; + + USE_SIMPLE_CTOR_DTOR(MinOp) + + bool RunOnDevice() override { + const auto& X0 = Input(0); + auto* Y = Output(0); + Y->ResizeLike(X0); + const T* X0_data = X0.template data<T>(); + T* Y_data = Y->template mutable_data<T>(); + const int N = X0.numel(); + if (InputSize() == 1) { + if (Y != &X0) { + context_.template CopySameDevice<T>(N, X0_data, Y_data); + } + return true; + } + const auto& X1 = Input(1); + CAFFE_ENFORCE_EQ( + X0.sizes(), + Y->sizes(), + "Description: Input #1, input dimension:", + X1.sizes(), + " should match output dimension: ", + Y->sizes()); + const T* X1_data = X1.template data<T>(); + math::Min<T, Context>(N, X0_data, X1_data, Y_data, &context_); + for (int i = 2; i < InputSize(); ++i) { + const auto& Xi = Input(i); + CAFFE_ENFORCE_EQ( + Xi.sizes(), + Y->sizes(), + "Description: Input #", + i, + ", input dimension:", + Input(i).sizes(), + " should match output dimension: ", + Y->sizes()); + const T* Xi_data = Xi.template data<T>(); + math::Min<T, Context>(N, Y_data, Xi_data, Y_data, &context_); + } + return true; + } }; template <typename T, class Context> @@ -66,29 +115,21 @@ class SelectGradientOpBase : public Operator<Context> { }; template <typename T, class Context> -class MaxGradientOp : public SelectGradientOpBase<T, Context> { +class MaxGradientOp final : public SelectGradientOpBase<T, Context> { public: MaxGradientOp(const OperatorDef& operator_def, Workspace* ws) : SelectGradientOpBase<T, Context>(operator_def, ws) {} - virtual ~MaxGradientOp() noexcept {} -}; -template <typename T, class Context> -class MinOp : public MaxMinOpBase<T, Context> { - public: - USE_OPERATOR_CONTEXT_FUNCTIONS; - MinOp(const OperatorDef& operator_def, Workspace* ws) - : MaxMinOpBase<T, Context>(operator_def, ws) {} - virtual ~MinOp() noexcept {} - bool Compute() override; + ~MaxGradientOp() = default; }; template <typename T, class Context> -class MinGradientOp : public SelectGradientOpBase<T, Context> { +class MinGradientOp final : public SelectGradientOpBase<T, Context> { public: MinGradientOp(const OperatorDef& operator_def, Workspace* ws) : SelectGradientOpBase<T, Context>(operator_def, ws) {} - virtual ~MinGradientOp() noexcept {} + + ~MinGradientOp() = default; }; } // namespace caffe2 diff --git a/caffe2/operators/rsqrt_op.cu b/caffe2/operators/rsqrt_op.cu index 378d131a6f..eae8dc0f44 100644 --- a/caffe2/operators/rsqrt_op.cu +++ b/caffe2/operators/rsqrt_op.cu @@ -4,7 +4,7 @@ #include <functional> #include "caffe2/core/context_gpu.h" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math.h" namespace caffe2 { diff --git a/caffe2/operators/utility_ops.cu b/caffe2/operators/utility_ops.cu index b7d17ba64f..f767a9ae11 100644 --- a/caffe2/operators/utility_ops.cu +++ b/caffe2/operators/utility_ops.cu @@ -1,8 +1,4 @@ -#include "caffe2/core/context_gpu.h" -#include "caffe2/operators/flatten_op.h" -#include "caffe2/operators/minmax_ops.h" #include "caffe2/operators/utility_ops.h" -#include "caffe2/utils/math.h" #include <thrust/device_vector.h> #include <thrust/sequence.h> @@ -10,6 +6,10 @@ #include <thrust/system/cuda/execution_policy.h> #include <thrust/unique.h> +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/flatten_op.h" +#include "caffe2/utils/math.h" + namespace caffe2 { template <> @@ -137,102 +137,6 @@ bool NanCheckOp<CUDAContext>::RunOnDevice() { REGISTER_CUDA_OPERATOR(NanCheck, NanCheckOp<CUDAContext>); -__global__ void -ElwiseMaxKernel(const float* X, const float* Y, float* maxout, const int N) { - CUDA_1D_KERNEL_LOOP(i, N) { - maxout[i] = fmaxf(X[i], Y[i]); - } -} - -template <> -bool MaxOp<float, CUDAContext>::Compute() { - float* output_data = Output(0)->template mutable_data<float>(); - const int N = Input(0).numel(); - - // Run pairwise-maxes - for (int i = 1; i < InputSize(); ++i) { - ElwiseMaxKernel<<< - CAFFE_GET_BLOCKS(N), - CAFFE_CUDA_NUM_THREADS, - 0, - context_.cuda_stream()>>>( - (i == 0 ? Input(0).data<float>() : Output(0)->data<float>()), - Input(i).data<float>(), - output_data, - N); - } - - return true; -} - -REGISTER_CUDA_OPERATOR(Max, MaxOp<float, CUDAContext>); -REGISTER_CUDA_OPERATOR(MaxGradient, MaxGradientOp<float, CUDAContext>); - -__global__ void -ElwiseMinKernel(const float* X, const float* Y, float* minout, const int N) { - CUDA_1D_KERNEL_LOOP(i, N) { - minout[i] = fminf(X[i], Y[i]); - } -} - -template <> -bool MinOp<float, CUDAContext>::Compute() { - float* output_data = Output(0)->template mutable_data<float>(); - const int N = Input(0).numel(); - - // Run pairwise-mines - for (int i = 1; i < InputSize(); ++i) { - ElwiseMinKernel<<< - CAFFE_GET_BLOCKS(N), - CAFFE_CUDA_NUM_THREADS, - 0, - context_.cuda_stream()>>>( - (i == 0 ? Input(0).data<float>() : Output(0)->data<float>()), - Input(i).data<float>(), - output_data, - N); - } - - return true; -} - -REGISTER_CUDA_OPERATOR(Min, MinOp<float, CUDAContext>); -REGISTER_CUDA_OPERATOR(MinGradient, MinGradientOp<float, CUDAContext>); - -template <typename T> -__global__ void -MaxMinGradKernel(int N, const T* mx, const T* x, const T* go, T* gi) { - CUDA_1D_KERNEL_LOOP(i, N) { - gi[i] = go[i] * (mx[i] == x[i]); - } -} - -template <> -bool SelectGradientOpBase<float, CUDAContext>::RunOnDevice() { - auto& output = Input(0); - auto& grad_output = Input(1); - const int kInputStartOffset = 2; - - const float* data = output.data<float>(); - - for (int i = 0; i < OutputSize(); i++) { - auto& input = Input(i + kInputStartOffset); - - auto* grad_input = Output(i, input.sizes(), at::dtype<float>()); - MaxMinGradKernel<<< - CAFFE_GET_BLOCKS(input.numel()), - CAFFE_CUDA_NUM_THREADS, - 0, - context_.cuda_stream()>>>( - input.numel(), - output.data<float>(), - input.data<float>(), - grad_output.data<float>(), - grad_input->template mutable_data<float>()); - } - return true; -} - /** * @brief Update slices of Y in-place with a batch of weighted X's. * Y[idx] = alpha[b] * X[b][i] + Y[idx] diff --git a/caffe2/quantization/server/im2col_dnnlowp.h b/caffe2/quantization/server/im2col_dnnlowp.h index fdbee29556..92f7b272ac 100644 --- a/caffe2/quantization/server/im2col_dnnlowp.h +++ b/caffe2/quantization/server/im2col_dnnlowp.h @@ -6,7 +6,7 @@ #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math/utils.h" namespace caffe2 { diff --git a/caffe2/utils/CMakeLists.txt b/caffe2/utils/CMakeLists.txt index 9208a69641..eaa10f8865 100644 --- a/caffe2/utils/CMakeLists.txt +++ b/caffe2/utils/CMakeLists.txt @@ -1,10 +1,11 @@ list(APPEND Caffe2_CPU_SRCS utils/bench_utils.cc utils/cpuid.cc + utils/math/broadcast.cc utils/math/elementwise.cc utils/math/reduce.cc + utils/math/utils.cc utils/math_cpu.cc - utils/math_utils.cc utils/murmur_hash3.cc utils/proto_convert.cc utils/proto_utils.cc @@ -26,12 +27,14 @@ if (NOT MSVC) endif() set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} + utils/math/broadcast.cu utils/math/elementwise.cu utils/math/reduce.cu utils/math_gpu.cu ) set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} + utils/math/hip/broadcast.hip utils/math/hip/elementwise.hip utils/math/hip/reduce.hip utils/hip/math_gpu.hip diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h index f870c3d0f8..2ea960176a 100644 --- a/caffe2/utils/math.h +++ b/caffe2/utils/math.h @@ -15,9 +15,10 @@ extern "C" { #include "caffe2/core/common.h" #include "caffe2/core/types.h" +#include "caffe2/utils/math/broadcast.h" #include "caffe2/utils/math/elementwise.h" #include "caffe2/utils/math/reduce.h" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math/utils.h" namespace caffe2 { @@ -31,9 +32,6 @@ class CAFFE2_API DefaultEngine {}; namespace math { #define C10_DECLARE_COMPARE_OP(Comp) \ - template <typename T, class Context> \ - void Comp(const int N, const T* A, const T* B, bool* C, Context* context); \ - \ template <typename T, class Context, bool kBroadcast1st = false> \ void Rowwise##Comp( \ const int rows, \ @@ -72,37 +70,34 @@ C10_DECLARE_COMPARE_OP(GE) #undef C10_DECLARE_COMPARE_OP -#define C10_DECLARE_BINARY_OP(Func) \ - template <typename T, class Context> \ - void Func(const int N, const T* A, const T* B, T* C, Context* context); \ - \ - template <typename T, class Context, bool kBroadcast1st = false> \ - void Rowwise##Func( \ - const int rows, \ - const int cols, \ - const T* A, \ - const T* B, \ - T* C, \ - Context* context); \ - \ - template <typename T, class Context, bool kBroadcast1st = false> \ - void Colwise##Func( \ - const int rows, \ - const int cols, \ - const T* A, \ - const T* B, \ - T* C, \ - Context* context); \ - \ - template <typename T, class Context> \ - void Func( \ - const int A_ndim, \ - const int* A_dims, \ - const int B_ndim, \ - const int* B_dims, \ - const T* A, \ - const T* B, \ - T* C, \ +#define C10_DECLARE_BINARY_OP(Func) \ + template <typename T, class Context, bool kBroadcast1st = false> \ + void Rowwise##Func( \ + const int rows, \ + const int cols, \ + const T* A, \ + const T* B, \ + T* C, \ + Context* context); \ + \ + template <typename T, class Context, bool kBroadcast1st = false> \ + void Colwise##Func( \ + const int rows, \ + const int cols, \ + const T* A, \ + const T* B, \ + T* C, \ + Context* context); \ + \ + template <typename T, class Context> \ + void Func( \ + const int A_ndim, \ + const int* A_dims, \ + const int B_ndim, \ + const int* B_dims, \ + const T* A, \ + const T* B, \ + T* C, \ Context* context); C10_DECLARE_BINARY_OP(Add) @@ -238,11 +233,6 @@ template <typename T, class Context> CAFFE2_API void ColwiseMax(const int N, const int D, const T* x, T* y, Context* context); -// Elemwise maximum of vector x and vector y. z[i] = max(x[i], y[i]) -template <typename T, class Context> -CAFFE2_API void -ElemwiseMax(const int N, const T* x, const T* y, T* z, Context* context); - // Elemwise maximum of vector x and scalar alpha. y[i] = max(x[i], alpha) template <typename T, class Context> CAFFE2_API void @@ -621,7 +611,6 @@ CAFFE2_API void NHWC2NCHW( T* Y, Context* context); - } // namespace math } // namespace caffe2 diff --git a/caffe2/utils/math/broadcast.cc b/caffe2/utils/math/broadcast.cc new file mode 100644 index 0000000000..f880bb7eeb --- /dev/null +++ b/caffe2/utils/math/broadcast.cc @@ -0,0 +1,55 @@ +#include "caffe2/utils/math/broadcast.h" + +#include "caffe2/core/context.h" +#include "caffe2/utils/eigen_utils.h" + +namespace caffe2 { +namespace math { + +#define CAFFE2_SPECIALIZED_AFFINE_CHANNEL(T) \ + template <> \ + C10_EXPORT void AffineChannel<T, CPUContext, StorageOrder::NCHW>( \ + const int N, \ + const int C, \ + const int HxW, \ + const T* X, \ + const T* scale, \ + const T* bias, \ + T* Y, \ + CPUContext* /* context */) { \ + ConstEigenVectorArrayMap<T> scale_arr(scale, C); \ + ConstEigenVectorArrayMap<T> bias_arr(bias, C); \ + const int stride = C * HxW; \ + const T* X_ptr = X; \ + T* Y_ptr = Y; \ + for (int i = 0; i < N; ++i) { \ + EigenArrayMap<T>(Y_ptr, HxW, C) = \ + (ConstEigenArrayMap<T>(X_ptr, HxW, C).rowwise() * \ + scale_arr.transpose()) \ + .rowwise() + \ + bias_arr.transpose(); \ + X_ptr += stride; \ + Y_ptr += stride; \ + } \ + } \ + template <> \ + C10_EXPORT void AffineChannel<T, CPUContext, StorageOrder::NHWC>( \ + const int N, \ + const int C, \ + const int HxW, \ + const T* X, \ + const T* scale, \ + const T* bias, \ + T* Y, \ + CPUContext* /* context */) { \ + EigenArrayMap<T>(Y, C, N * HxW) = \ + (ConstEigenArrayMap<T>(X, C, N * HxW).colwise() * \ + ConstEigenVectorArrayMap<T>(scale, C)) \ + .colwise() + \ + ConstEigenVectorArrayMap<T>(bias, C); \ + } +CAFFE2_SPECIALIZED_AFFINE_CHANNEL(float) +#undef CAFFE2_SPECIALIZED_AFFINE_CHANNEL + +} // namespace math +} // namespace caffe2 diff --git a/caffe2/utils/math/broadcast.cu b/caffe2/utils/math/broadcast.cu new file mode 100644 index 0000000000..97f7cb500f --- /dev/null +++ b/caffe2/utils/math/broadcast.cu @@ -0,0 +1,108 @@ +#include "caffe2/utils/math/broadcast.h" + +#include "caffe2/core/context_gpu.h" +#include "caffe2/utils/math/utils.h" + +namespace caffe2 { +namespace math { + +namespace { + +template <typename T> +__global__ void AffineChannelNCHWCUDAKernel( + const int C, + const int M, + const int HxW, + const T* X, + const T* scale, + const T* bias, + T* Y); + +template <> +__global__ void AffineChannelNCHWCUDAKernel<float>( + const int C, + const int M, + const int HxW, + const float* X, + const float* scale, + const float* bias, + float* Y) { + const int nc = blockIdx.x / M; + const int c = nc % C; + const int w = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (w < HxW) { + const int index = nc * HxW + w; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + Y[index] = fmaf(__ldg(X + index), __ldg(scale + c), __ldg(bias + c)); +#else + Y[index] = fmaf(X[index], scale[c], bias[c]); +#endif + } +} + +template <typename T> +__global__ void AffineChannelNHWCCUDAKernel( + const int C, + const T* X, + const T* scale, + const T* bias, + T* Y); + +template <> +__global__ void AffineChannelNHWCCUDAKernel<float>( + const int C, + const float* X, + const float* scale, + const float* bias, + float* Y) { + const int c = blockIdx.y * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (c < C) { + const int index = blockIdx.x * C + c; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + Y[index] = fmaf(__ldg(X + index), __ldg(scale + c), __ldg(bias + c)); +#else + Y[index] = fmaf(X[index], scale[c], bias[c]); +#endif + } +} + +} // namespace + +#define CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(T) \ + template <> \ + CAFFE2_CUDA_EXPORT void AffineChannel<T, CUDAContext, StorageOrder::NCHW>( \ + const int N, \ + const int C, \ + const int HxW, \ + const T* X, \ + const T* scale, \ + const T* bias, \ + T* Y, \ + CUDAContext* context) { \ + const int M = DivUp(HxW, CAFFE_CUDA_NUM_THREADS); \ + AffineChannelNCHWCUDAKernel<T> \ + <<<N * C * M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>( \ + C, M, HxW, X, scale, bias, Y); \ + } \ + template <> \ + CAFFE2_CUDA_EXPORT void AffineChannel<T, CUDAContext, StorageOrder::NHWC>( \ + const int N, \ + const int C, \ + const int HxW, \ + const T* X, \ + const T* scale, \ + const T* bias, \ + T* Y, \ + CUDAContext* context) { \ + const int M = DivUp(C, CAFFE_CUDA_NUM_THREADS); \ + AffineChannelNHWCCUDAKernel<T> \ + <<<dim3(N* HxW, M), \ + CAFFE_CUDA_NUM_THREADS, \ + 0, \ + context->cuda_stream()>>>(C, X, scale, bias, Y); \ + } +CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(float) +#undef CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL + +} // namespace math +} // namespace caffe2 diff --git a/caffe2/utils/math/broadcast.h b/caffe2/utils/math/broadcast.h new file mode 100644 index 0000000000..67e37d1bd9 --- /dev/null +++ b/caffe2/utils/math/broadcast.h @@ -0,0 +1,24 @@ +#ifndef CAFFE2_UTILS_MATH_BROADCAST_H_ +#define CAFFE2_UTILS_MATH_BROADCAST_H_ + +#include "caffe2/core/common.h" +#include "caffe2/core/types.h" + +namespace caffe2 { +namespace math { + +template <typename T, class Context, StorageOrder kOrder> +CAFFE2_API void AffineChannel( + const int N, + const int C, + const int HxW, + const T* X, + const T* scale, + const T* bias, + T* Y, + Context* context); + +} // namespace math +} // namespace caffe2 + +#endif // CAFFE2_UTILS_MATH_BROADCAST_H_ diff --git a/caffe2/utils/math/elementwise.cc b/caffe2/utils/math/elementwise.cc index ba6fb96338..f392829ccc 100644 --- a/caffe2/utils/math/elementwise.cc +++ b/caffe2/utils/math/elementwise.cc @@ -1,5 +1,8 @@ #include "caffe2/utils/math/elementwise.h" +#include <algorithm> +#include <functional> + #ifdef CAFFE2_USE_MKL #include <mkl.h> #endif // CAFFE2_USE_MKL @@ -12,7 +15,7 @@ namespace math { //////////////////////////////////////////////////////////////////////////////// // MKL VML alternatives. -// Depending on whether we are using MKL, we will delegate the Caffe math +// Depending on whether we are using MKL, we will delegate the Caffe2 math // functions that are VML-related to either the VML call or the Eigen // implementation. If you are setting the flags (such as AVX) right for your CPU // architecture, usually Eigen will deliver a throughput as fast as the VML @@ -90,6 +93,22 @@ DELEGATE_POWX_FUNCTION(float, vsPowx) DELEGATE_POWX_FUNCTION(double, vdPowx) #undef DELEGATE_POWX_FUNCTION +#define DELEGATE_SIMPLE_BINARY_FUNCTION(T, Func, MKLFunc) \ + template <> \ + C10_EXPORT void Func<T, CPUContext>( \ + const int N, const T* A, const T* B, T* C, CPUContext* /* context */) { \ + MKLFunc(N, A, B, C); \ + } +DELEGATE_SIMPLE_BINARY_FUNCTION(float, Add, vsAdd) +DELEGATE_SIMPLE_BINARY_FUNCTION(double, Add, vdAdd) +DELEGATE_SIMPLE_BINARY_FUNCTION(float, Sub, vsSub) +DELEGATE_SIMPLE_BINARY_FUNCTION(double, Sub, vdSub) +DELEGATE_SIMPLE_BINARY_FUNCTION(float, Mul, vsMul) +DELEGATE_SIMPLE_BINARY_FUNCTION(double, Mul, vdMul) +DELEGATE_SIMPLE_BINARY_FUNCTION(float, Div, vsDiv) +DELEGATE_SIMPLE_BINARY_FUNCTION(double, Div, vdDiv) +#undef DELEGATE_SIMPLE_BINARY_FUNCTION + #else // CAFFE2_USE_MKL #define DELEGATE_SIMPLE_UNARY_FUNCTION(T, Func, EigenFunc) \ @@ -127,68 +146,85 @@ DELEGATE_SIMPLE_UNARY_FUNCTION(float, Inv, inverse) DELEGATE_SIMPLE_UNARY_FUNCTION(double, Inv, inverse) #undef DELEGATE_SIMPLE_UNARY_FUNCTION -#define DELEGATE_SINH_FUNCTION(T) \ +#define CAFFE2_SPECIALIZED_SINH(T) \ template <> \ C10_EXPORT void Sinh<T, CPUContext>( \ const int N, const T* X, T* Y, CPUContext* /* context */) { \ ConstEigenVectorArrayMap<T> X_arr(X, N); \ EigenVectorArrayMap<T>(Y, N) = (X_arr.exp() - (-X_arr).exp()) / T(2); \ } -DELEGATE_SINH_FUNCTION(float) -DELEGATE_SINH_FUNCTION(double) -#undef DELEGATE_SINH_FUNCTION +CAFFE2_SPECIALIZED_SINH(float) +CAFFE2_SPECIALIZED_SINH(double) +#undef CAFFE2_SPECIALIZED_SINH -#define DELEGATE_COSH_FUNCTION(T) \ +#define CAFFE2_SPECIALIZED_COSH(T) \ template <> \ C10_EXPORT void Cosh<T, CPUContext>( \ const int N, const T* X, T* Y, CPUContext* /* context */) { \ ConstEigenVectorArrayMap<T> X_arr(X, N); \ EigenVectorArrayMap<T>(Y, N) = (X_arr.exp() + (-X_arr).exp()) / T(2); \ } -DELEGATE_COSH_FUNCTION(float) -DELEGATE_COSH_FUNCTION(double) -#undef DELEGATE_COSH_FUNCTION +CAFFE2_SPECIALIZED_COSH(float) +CAFFE2_SPECIALIZED_COSH(double) +#undef CAFFE2_SPECIALIZED_COSH -#define DELEGATE_SINCOS_FUNCTION(T) \ +#define CAFFE2_SPECIALIZED_SINCOS(T) \ template <> \ C10_EXPORT void SinCos<T, CPUContext>( \ const int N, const T* X, T* S, T* C, CPUContext* /* context */) { \ EigenVectorArrayMap<T>(S, N) = ConstEigenVectorArrayMap<T>(X, N).sin(); \ EigenVectorArrayMap<T>(C, N) = ConstEigenVectorArrayMap<T>(X, N).cos(); \ } -DELEGATE_SINCOS_FUNCTION(float) -DELEGATE_SINCOS_FUNCTION(double) -#undef DELEGATE_SINCOS_FUNCTION +CAFFE2_SPECIALIZED_SINCOS(float) +CAFFE2_SPECIALIZED_SINCOS(double) +#undef CAFFE2_SPECIALIZED_SINCOS -#define DELEGATE_POWX_FUNCTION(T) \ +#define CAFFE2_SPECIALIZED_POWX(T) \ template <> \ C10_EXPORT void Powx<T, CPUContext>( \ const int N, const T* A, const T b, T* Y, CPUContext* /* context */) { \ EigenVectorArrayMap<T>(Y, N) = ConstEigenVectorArrayMap<T>(A, N).pow(b); \ } -DELEGATE_POWX_FUNCTION(float) -DELEGATE_POWX_FUNCTION(double) -#undef DELEGATE_POWX_FUNCTION +CAFFE2_SPECIALIZED_POWX(float) +CAFFE2_SPECIALIZED_POWX(double) +#undef CAFFE2_SPECIALIZED_POWX -#define DELEGATE_CBRT_FUNCTION(T) \ +#define CAFFE2_SPECIALIZED_CBRT(T) \ template <> \ C10_EXPORT void Cbrt<T, CPUContext>( \ const int N, const T* X, T* Y, CPUContext* /* context */) { \ std::transform(X, X + N, Y, [](const T x) { return cbrt(x); }); \ } -DELEGATE_CBRT_FUNCTION(float) -DELEGATE_CBRT_FUNCTION(double) -#undef DELEGATE_CBRT_FUNCTION +CAFFE2_SPECIALIZED_CBRT(float) +CAFFE2_SPECIALIZED_CBRT(double) +#undef CAFFE2_SPECIALIZED_CBRT -#define DELEGATE_ERF_FUNCTION(T) \ +#define CAFFE2_SPECIALIZED_ERF(T) \ template <> \ C10_EXPORT void Erf<T, CPUContext>( \ const int N, const T* X, T* Y, CPUContext* /* context */) { \ std::transform(X, X + N, Y, [](const T x) { return erf(x); }); \ } -DELEGATE_ERF_FUNCTION(float) -DELEGATE_ERF_FUNCTION(double) -#undef DELEGATE_ERF_FUNCTION +CAFFE2_SPECIALIZED_ERF(float) +CAFFE2_SPECIALIZED_ERF(double) +#undef CAFFE2_SPECIALIZED_ERF + +#define DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(T, Func, EigenOp) \ + template <> \ + C10_EXPORT void Func<T, CPUContext>( \ + const int N, const T* A, const T* B, T* C, CPUContext* /* context */) { \ + EigenVectorMap<T>(C, N) = ConstEigenVectorArrayMap<T>(A, N) \ + EigenOp ConstEigenVectorArrayMap<T>(B, N); \ + } +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(float, Add, +) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(double, Add, +) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(float, Sub, -) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(double, Sub, -) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(float, Mul, *) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(double, Mul, *) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(float, Div, /) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(double, Div, /) +#undef DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR #endif // CAFFE2_USE_MKL @@ -202,74 +238,155 @@ DELEGATE_ERF_FUNCTION(double) // Eigen's Tanh implementation is faster than MKL, so use Eigen here. DELEGATE_SIMPLE_UNARY_FUNCTION(float, Tanh, tanh) DELEGATE_SIMPLE_UNARY_FUNCTION(double, Tanh, tanh) -DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sign, sign) -DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sign, sign) DELEGATE_SIMPLE_UNARY_FUNCTION(std::int32_t, Sign, sign) DELEGATE_SIMPLE_UNARY_FUNCTION(std::int64_t, Sign, sign) +DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sign, sign) +DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sign, sign) DELEGATE_SIMPLE_UNARY_FUNCTION(std::int32_t, Abs, abs) DELEGATE_SIMPLE_UNARY_FUNCTION(std::int64_t, Abs, abs) -DELEGATE_SIMPLE_UNARY_FUNCTION(float, Cube, cube) -DELEGATE_SIMPLE_UNARY_FUNCTION(double, Cube, cube) DELEGATE_SIMPLE_UNARY_FUNCTION(std::int32_t, Cube, cube) DELEGATE_SIMPLE_UNARY_FUNCTION(std::int64_t, Cube, cube) +DELEGATE_SIMPLE_UNARY_FUNCTION(float, Cube, cube) +DELEGATE_SIMPLE_UNARY_FUNCTION(double, Cube, cube) #undef DELEGATE_SIMPLE_UNARY_FUNCTION -#define DELEGATE_NEG_FUNCTION(T) \ +#define CAFFE2_SPECIALIZED_NEG(T) \ template <> \ C10_EXPORT void Neg<T, CPUContext>( \ const int N, const T* X, T* Y, CPUContext* /* context */) { \ EigenVectorArrayMap<T>(Y, N) = -ConstEigenVectorArrayMap<T>(X, N); \ } -DELEGATE_NEG_FUNCTION(float) -DELEGATE_NEG_FUNCTION(double) -DELEGATE_NEG_FUNCTION(std::int32_t) -DELEGATE_NEG_FUNCTION(std::int64_t) -#undef DELEGATE_NEG_FUNCTION +CAFFE2_SPECIALIZED_NEG(std::int32_t) +CAFFE2_SPECIALIZED_NEG(std::int64_t) +CAFFE2_SPECIALIZED_NEG(float) +CAFFE2_SPECIALIZED_NEG(double) +#undef CAFFE2_SPECIALIZED_NEG + +#define DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(T, Func, EigenOp) \ + template <> \ + C10_EXPORT void Func<T, CPUContext>( \ + const int N, const T* A, const T* B, T* C, CPUContext* /* context */) { \ + EigenVectorMap<T>(C, N) = ConstEigenVectorArrayMap<T>(A, N) \ + EigenOp ConstEigenVectorArrayMap<T>(B, N); \ + } +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, Add, +) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, Add, +) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, Sub, -) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, Sub, -) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, Mul, *) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, Mul, *) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, Div, /) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, Div, /) +#undef DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR -#define CAFFE2_SPECIALIZED_AFFINE_CHANNEL(T) \ - template <> \ - void AffineChannel<T, CPUContext, StorageOrder::NCHW>( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - const T* scale, \ - const T* bias, \ - T* Y, \ - CPUContext* /* context */) { \ - ConstEigenVectorArrayMap<T> scale_arr(scale, C); \ - ConstEigenVectorArrayMap<T> bias_arr(bias, C); \ - const int stride = C * HxW; \ - const T* X_ptr = X; \ - T* Y_ptr = Y; \ - for (int i = 0; i < N; ++i) { \ - EigenArrayMap<T>(Y_ptr, HxW, C) = \ - (ConstEigenArrayMap<T>(X_ptr, HxW, C).rowwise() * \ - scale_arr.transpose()) \ - .rowwise() + \ - bias_arr.transpose(); \ - X_ptr += stride; \ - Y_ptr += stride; \ - } \ - } \ - template <> \ - void AffineChannel<T, CPUContext, StorageOrder::NHWC>( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - const T* scale, \ - const T* bias, \ - T* Y, \ - CPUContext* /* context */) { \ - EigenArrayMap<T>(Y, C, N * HxW) = \ - (ConstEigenArrayMap<T>(X, C, N * HxW).colwise() * \ - ConstEigenVectorArrayMap<T>(scale, C)) \ - .colwise() + \ - ConstEigenVectorArrayMap<T>(bias, C); \ +#define DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_FUNCTION(T, Func, EigenFunc) \ + template <> \ + C10_EXPORT void Func<T, CPUContext>( \ + const int N, const T* A, const T* B, T* C, CPUContext* /* context */) { \ + EigenVectorMap<T>(C, N) = ConstEigenVectorArrayMap<T>(A, N).EigenFunc( \ + ConstEigenVectorArrayMap<T>(B, N)); \ + } +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_FUNCTION(float, Min, min) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_FUNCTION(double, Min, min) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_FUNCTION(float, Max, max) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_FUNCTION(double, Max, max) +#undef DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_FUNCTION + +#define DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION(T, Func, StdFunc) \ + template <> \ + C10_EXPORT void Func<T, CPUContext>( \ + const int N, const T* A, const T* B, T* C, CPUContext* /* context */) { \ + std::transform(A, A + N, B, C, StdFunc); \ + } +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + bool, + And, + std::logical_and<bool>()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + bool, + Or, + std::logical_or<bool>()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION(bool, Xor, std::bit_xor<bool>()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + bool, + BitwiseAnd, + std::bit_and<bool>()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + std::int32_t, + BitwiseAnd, + std::bit_and<std::int32_t>()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + std::int64_t, + BitwiseAnd, + std::bit_and<std::int64_t>()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + bool, + BitwiseOr, + std::bit_or<bool>()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + std::int32_t, + BitwiseOr, + std::bit_or<std::int32_t>()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + std::int64_t, + BitwiseOr, + std::bit_or<std::int64_t>()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + bool, + BitwiseXor, + std::bit_xor<bool>()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + std::int32_t, + BitwiseXor, + std::bit_xor<std::int32_t>()) +DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION( + std::int64_t, + BitwiseXor, + std::bit_xor<std::int64_t>()) +#undef DELEGATE_SIMPLE_BINARY_FUNCTION_BY_STD_FUNCTION + +#define DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(T, Func, EigenOp) \ + template <> \ + C10_EXPORT void Func<T, CPUContext>( \ + const int N, \ + const T* A, \ + const T* B, \ + bool* C, \ + CPUContext* /* context */) { \ + EigenVectorArrayMap<bool>(C, N) = ConstEigenVectorArrayMap<T>(A, N) \ + EigenOp ConstEigenVectorArrayMap<T>(B, N); \ } -CAFFE2_SPECIALIZED_AFFINE_CHANNEL(float) -#undef CAFFE2_SPECIALIZED_AFFINE_CHANNEL +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(bool, EQ, ==) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, EQ, ==) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, EQ, ==) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(float, EQ, ==) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(double, EQ, ==) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(bool, NE, !=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, NE, !=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, NE, !=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(float, NE, !=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(double, NE, !=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(bool, LT, <) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, LT, <) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, LT, <) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(float, LT, <) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(double, LT, <) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(bool, LE, <=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, LE, <=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, LE, <=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(float, LE, <=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(double, LE, <=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(bool, GT, >) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, GT, >) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, GT, >) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(float, GT, >) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(double, GT, >) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(bool, GE, >=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int32_t, GE, >=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(std::int64_t, GE, >=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(float, GE, >=) +DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR(double, GE, >=) +#undef DELEGATE_SIMPLE_COMPARE_FUNCTION_BY_EIGEN_OPERATOR } // namespace math } // namespace caffe2 diff --git a/caffe2/utils/math/elementwise.cu b/caffe2/utils/math/elementwise.cu index cc7613cda9..b7605deece 100644 --- a/caffe2/utils/math/elementwise.cu +++ b/caffe2/utils/math/elementwise.cu @@ -1,13 +1,77 @@ #include "caffe2/utils/math/elementwise.h" +#include <thrust/functional.h> + #include "caffe2/core/context_gpu.h" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/conversions.h" +#include "caffe2/utils/math/half_utils.h" +#include "caffe2/utils/math/utils.h" namespace caffe2 { namespace math { namespace { +#define DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(T, Func, DeviceFunc) \ + __global__ void Func##CUDAKernel(const int N, const T* X, T* Y) { \ + const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; \ + if (i < N) { \ + Y[i] = DeviceFunc(X[i]); \ + } \ + } +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Exp, expf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Log, logf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Cos, cosf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Acos, acosf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Sin, sinf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Asin, asinf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Tan, tanf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Atan, atanf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Sinh, sinhf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Cosh, coshf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Tanh, tanhf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Abs, fabsf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Sqr, utils::Square<float>) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Sqrt, sqrtf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Rsqrt, rsqrtf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Cbrt, cbrtf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Erf, erff) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(double, Erf, erf) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION( + std::int32_t, + Cube, + utils::Cube<std::int32_t>) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION( + std::int64_t, + Cube, + utils::Cube<std::int64_t>) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Cube, utils::Cube<float>) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(double, Cube, utils::Cube<double>) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(bool, Not, utils::Not<bool>) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION( + std::int32_t, + Neg, + utils::Negate<std::int32_t>) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION( + std::int64_t, + Neg, + utils::Negate<std::int64_t>) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Neg, utils::Negate<float>) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(double, Neg, utils::Negate<double>) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION( + std::int32_t, + Sign, + utils::Sign<std::int32_t>) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION( + std::int64_t, + Sign, + utils::Sign<std::int64_t>) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Sign, utils::Sign<float>) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(double, Sign, utils::Sign<double>) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(float, Inv, utils::Inv<float>) +DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION(double, Inv, utils::Inv<double>) +#undef DELEGATE_SIMPLE_CUDA_UNARY_KERNEL_FUNCTION + template <typename T> __global__ void SinCosCUDAKernel(const int N, const T* X, T* S, T* C) { const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; @@ -20,190 +84,276 @@ __global__ void SinCosCUDAKernel(const int N, const T* X, T* S, T* C) { } } -template <typename T> -__global__ void AffineChannelNCHWCUDAKernel( - const int C, - const int M, - const int HxW, - const T* X, - const T* scale, - const T* bias, - T* Y); - -template <> -__global__ void AffineChannelNCHWCUDAKernel<float>( - const int C, - const int M, - const int HxW, - const float* X, - const float* scale, - const float* bias, - float* Y) { - const int nc = blockIdx.x / M; - const int c = nc % C; - const int w = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - if (w < HxW) { - const int index = nc * HxW + w; -#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) - Y[index] = fmaf(__ldg(X + index), __ldg(scale + c), __ldg(bias + c)); -#else - Y[index] = fmaf(X[index], scale[c], bias[c]); -#endif +template <typename T, class Func> +__global__ void SimpleBinaryCUDAKernel( + const int N, + const Func func, + const T* A, + const T* B, + T* C) { + const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (i < N) { + C[i] = func(A[i], B[i]); } } -template <typename T> -__global__ void AffineChannelNHWCCUDAKernel( - const int C, - const T* X, - const T* scale, - const T* bias, - T* Y); - -template <> -__global__ void AffineChannelNHWCCUDAKernel<float>( - const int C, - const float* X, - const float* scale, - const float* bias, - float* Y) { - const int c = blockIdx.y * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - if (c < C) { - const int index = blockIdx.x * C + c; -#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) - Y[index] = fmaf(__ldg(X + index), __ldg(scale + c), __ldg(bias + c)); -#else - Y[index] = fmaf(X[index], scale[c], bias[c]); -#endif +template <typename T, class Comp> +__global__ void SimpleCompareCUDAKernel( + const int N, + const Comp comp, + const T* A, + const T* B, + bool* C) { + const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (i < N) { + C[i] = comp(A[i], B[i]); } } } // namespace -#define DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(T, Func, KernelFunc) \ - __global__ void Func##CUDAKernel(const int N, const T* X, T* Y) { \ - const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; \ - if (i < N) { \ - Y[i] = KernelFunc(X[i]); \ - } \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Func<T, CUDAContext>( \ - const int N, const T* X, T* Y, CUDAContext* context) { \ - if (N > 0) { \ - const int K = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - Func##CUDAKernel<<< \ - K, \ - CAFFE_CUDA_NUM_THREADS, \ - 0, \ - context->cuda_stream()>>>(N, X, Y); \ - } \ +#define DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(T, Func) \ + template <> \ + CAFFE2_CUDA_EXPORT void Func<T, CUDAContext>( \ + const int N, const T* X, T* Y, CUDAContext* context) { \ + if (N > 0) { \ + const int M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ + Func##CUDAKernel<<< \ + M, \ + CAFFE_CUDA_NUM_THREADS, \ + 0, \ + context->cuda_stream()>>>(N, X, Y); \ + } \ } -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Exp, expf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Log, logf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cos, cosf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Acos, acosf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sin, sinf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Asin, asinf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tan, tanf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Atan, atanf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sinh, sinhf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cosh, coshf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tanh, tanhf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Abs, fabsf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqr, utils::Square<float>) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqrt, sqrtf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Rsqrt, rsqrtf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cbrt, cbrtf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Erf, erff) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Erf, erf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cube, utils::Cube<float>) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Cube, utils::Cube<double>) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Exp) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Log) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cos) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Acos) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sin) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Asin) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tan) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Atan) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sinh) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cosh) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tanh) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Abs) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqr) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqrt) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Rsqrt) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cbrt) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Erf) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(double, Erf) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cube) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(double, Cube) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(std::int32_t, Cube) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(std::int64_t, Cube) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(bool, Not) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Neg) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(double, Neg) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(std::int32_t, Neg) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(std::int64_t, Neg) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sign) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(double, Sign) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(std::int32_t, Sign) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(std::int64_t, Sign) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(float, Inv) +DEFINE_SIMPLE_CUDA_UNARY_FUNCTION(double, Inv) +#undef DEFINE_SIMPLE_CUDA_UNARY_FUNCTION + +#define CAFFE2_SPECIALIZED_CUDA_SINCOS(T) \ + template <> \ + CAFFE2_CUDA_EXPORT void SinCos<T, CUDAContext>( \ + const int N, const T* X, T* S, T* C, CUDAContext* context) { \ + if (N > 0) { \ + const int K = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ + SinCosCUDAKernel<T> \ + <<<K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>( \ + N, X, S, C); \ + } \ + } +CAFFE2_SPECIALIZED_CUDA_SINCOS(float) +CAFFE2_SPECIALIZED_CUDA_SINCOS(double) +#undef CAFFE2_SPECIALIZED_CUDA_SINCOS + +#define DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(T, Func, DeviceFunc) \ + template <> \ + CAFFE2_CUDA_EXPORT void Func<T, CUDAContext>( \ + const int N, const T* A, const T* B, T* C, CUDAContext* context) { \ + if (N > 0) { \ + const int M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ + SimpleBinaryCUDAKernel<<< \ + M, \ + CAFFE_CUDA_NUM_THREADS, \ + 0, \ + context->cuda_stream()>>>(N, DeviceFunc, A, B, C); \ + } \ + } +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( std::int32_t, - Cube, - utils::Cube<std::int32_t>) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( + Add, + thrust::plus<std::int32_t>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( std::int64_t, - Cube, - utils::Cube<std::int64_t>) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(bool, Not, utils::Not) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Neg, utils::Negate<float>) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Neg, utils::Negate<double>) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( + Add, + thrust::plus<std::int64_t>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Add, thrust::plus<float>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Add, thrust::plus<double>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(at::Half, Add, utils::HalfAddFunctor()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( std::int32_t, - Neg, - utils::Negate<std::int32_t>) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( + Sub, + thrust::minus<std::int32_t>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( std::int64_t, - Neg, - utils::Negate<std::int64_t>) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sign, utils::Sign<float>) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Sign, utils::Sign<double>) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( + Sub, + thrust::minus<std::int64_t>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Sub, thrust::minus<float>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Sub, thrust::minus<double>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(at::Half, Sub, utils::HalfSubFunctor()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( std::int32_t, - Sign, - utils::Sign<std::int32_t>) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( + Mul, + thrust::multiplies<std::int32_t>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( std::int64_t, - Sign, - utils::Sign<std::int64_t>) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Inv, utils::Inv<float>) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Inv, utils::Inv<double>) -#undef DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION - -#define CAFFE2_SPECIALIZED_CUDA_SINCOS(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void SinCos<T, CUDAContext>( \ - const int N, const T* X, T* S, T* C, CUDAContext* context) { \ - if (N > 0) { \ - const int K = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - SinCosCUDAKernel<<< \ - K, \ - CAFFE_CUDA_NUM_THREADS, \ - 0, \ - context->cuda_stream()>>>(N, X, S, C); \ - } \ - } -CAFFE2_SPECIALIZED_CUDA_SINCOS(float) -CAFFE2_SPECIALIZED_CUDA_SINCOS(double) -#undef CAFFE2_SPECIALIZED_CUDA_SINCOS + Mul, + thrust::multiplies<std::int64_t>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Mul, thrust::multiplies<float>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Mul, thrust::multiplies<double>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(at::Half, Mul, utils::HalfMulFunctor()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( + std::int32_t, + Div, + thrust::divides<std::int32_t>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( + std::int64_t, + Div, + thrust::divides<std::int64_t>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Div, thrust::divides<float>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Div, thrust::divides<double>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(at::Half, Div, utils::HalfDivFunctor()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Min, thrust::minimum<float>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Min, thrust::minimum<double>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Max, thrust::maximum<float>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Max, thrust::maximum<double>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, And, thrust::logical_and<bool>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, Or, thrust::logical_or<bool>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, Xor, thrust::bit_xor<bool>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, BitwiseAnd, thrust::bit_and<bool>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( + std::int32_t, + BitwiseAnd, + thrust::bit_and<std::int32_t>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( + std::int64_t, + BitwiseAnd, + thrust::bit_and<std::int64_t>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, BitwiseOr, thrust::bit_or<bool>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( + std::int32_t, + BitwiseOr, + thrust::bit_or<std::int32_t>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( + std::int64_t, + BitwiseOr, + thrust::bit_or<std::int64_t>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, BitwiseXor, thrust::bit_xor<bool>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( + std::int32_t, + BitwiseXor, + thrust::bit_xor<std::int32_t>()) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( + std::int64_t, + BitwiseXor, + thrust::bit_xor<std::int64_t>()) +#undef DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION -#define CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void AffineChannel<T, CUDAContext, StorageOrder::NCHW>( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - const T* scale, \ - const T* bias, \ - T* Y, \ - CUDAContext* context) { \ - const int M = DivUp(HxW, CAFFE_CUDA_NUM_THREADS); \ - AffineChannelNCHWCUDAKernel<T> \ - <<<N * C * M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>( \ - C, M, HxW, X, scale, bias, Y); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void AffineChannel<T, CUDAContext, StorageOrder::NHWC>( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - const T* scale, \ - const T* bias, \ - T* Y, \ - CUDAContext* context) { \ - const int M = DivUp(C, CAFFE_CUDA_NUM_THREADS); \ - AffineChannelNHWCCUDAKernel<T> \ - <<<dim3(N* HxW, M), \ - CAFFE_CUDA_NUM_THREADS, \ - 0, \ - context->cuda_stream()>>>(C, X, scale, bias, Y); \ +#define DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(T, Func, DeviceComp) \ + template <> \ + CAFFE2_CUDA_EXPORT void Func<T, CUDAContext>( \ + const int N, const T* A, const T* B, bool* C, CUDAContext* context) { \ + if (N > 0) { \ + const int M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ + SimpleCompareCUDAKernel<<< \ + M, \ + CAFFE_CUDA_NUM_THREADS, \ + 0, \ + context->cuda_stream()>>>(N, DeviceComp, A, B, C); \ + } \ } -CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(float) -#undef CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, EQ, thrust::equal_to<bool>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int32_t, + EQ, + thrust::equal_to<std::int32_t>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int64_t, + EQ, + thrust::equal_to<std::int64_t>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, EQ, thrust::equal_to<float>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(double, EQ, thrust::equal_to<double>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, NE, thrust::not_equal_to<bool>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int32_t, + NE, + thrust::not_equal_to<std::int32_t>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int64_t, + NE, + thrust::not_equal_to<std::int64_t>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, NE, thrust::not_equal_to<float>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + double, + NE, + thrust::not_equal_to<double>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, LT, thrust::less<bool>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int32_t, + LT, + thrust::less<std::int32_t>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int64_t, + LT, + thrust::less<std::int64_t>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, LT, thrust::less<float>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(double, LT, thrust::less<double>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, LE, thrust::less_equal<bool>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int32_t, + LE, + thrust::less_equal<std::int32_t>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int64_t, + LE, + thrust::less_equal<std::int64_t>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, LE, thrust::less_equal<float>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(double, LE, thrust::less_equal<double>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, GT, thrust::greater<bool>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int32_t, + GT, + thrust::greater<std::int32_t>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int64_t, + GT, + thrust::greater<std::int64_t>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, GT, thrust::greater<float>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(double, GT, thrust::greater<double>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, GE, thrust::greater_equal<bool>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int32_t, + GE, + thrust::greater_equal<std::int32_t>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + std::int64_t, + GE, + thrust::greater_equal<std::int64_t>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, GE, thrust::greater_equal<float>()) +DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( + double, + GE, + thrust::greater_equal<double>()) +#undef DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION } // namespace math } // namespace caffe2 diff --git a/caffe2/utils/math/elementwise.h b/caffe2/utils/math/elementwise.h index 63890de498..b11d22b955 100644 --- a/caffe2/utils/math/elementwise.h +++ b/caffe2/utils/math/elementwise.h @@ -8,64 +8,97 @@ namespace caffe2 { namespace math { template <typename T, class Context> -void Exp(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Exp(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Log(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Log(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Sin(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Sin(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Asin(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Asin(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Cos(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Cos(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Acos(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Acos(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Tan(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Tan(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Atan(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Atan(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Sinh(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Sinh(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Cosh(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Cosh(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void SinCos(const int N, const T* X, T* S, T* C, Context* context); +CAFFE2_API void SinCos(int N, const T* X, T* S, T* C, Context* context); template <typename T, class Context> -void Tanh(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Tanh(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Abs(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Abs(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Sqr(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Sqr(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Sqrt(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Sqrt(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Rsqrt(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Rsqrt(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Cube(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Cube(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Cbrt(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Cbrt(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Neg(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Neg(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Sign(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Sign(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Not(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Not(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Powx(const int N, const T* A, const T b, T* Y, Context* context); +CAFFE2_API void Powx(int N, const T* A, const T b, T* Y, Context* context); template <typename T, class Context> -void Inv(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Inv(int N, const T* X, T* Y, Context* context); template <typename T, class Context> -void Erf(const int N, const T* X, T* Y, Context* context); +CAFFE2_API void Erf(int N, const T* X, T* Y, Context* context); -template <typename T, class Context, StorageOrder kOrder> -CAFFE2_API void AffineChannel( - const int N, - const int C, - const int HxW, - const T* X, - const T* scale, - const T* bias, - T* Y, - Context* context); +template <typename T, class Context> +CAFFE2_API void Add(int N, const T* A, const T* B, T* C, Context* context); +template <typename T, class Context> +CAFFE2_API void Sub(int N, const T* A, const T* B, T* C, Context* context); +template <typename T, class Context> +CAFFE2_API void Mul(int N, const T* A, const T* B, T* C, Context* context); +template <typename T, class Context> +CAFFE2_API void Div(int N, const T* A, const T* B, T* C, Context* context); + +template <typename T, class Context> +CAFFE2_API void Min(int N, const T* A, const T* B, T* C, Context* context); +template <typename T, class Context> +CAFFE2_API void Max(int N, const T* A, const T* B, T* C, Context* context); + +template <typename T, class Context> +CAFFE2_API void And(int N, const T* A, const T* B, T* C, Context* context); +template <typename T, class Context> +CAFFE2_API void Or(int N, const T* A, const T* B, T* C, Context* context); +template <typename T, class Context> +CAFFE2_API void Xor(int N, const T* A, const T* B, T* C, Context* context); + +template <typename T, class Context> +CAFFE2_API void +BitwiseAnd(int N, const T* A, const T* B, T* C, Context* context); +template <typename T, class Context> +CAFFE2_API void +BitwiseOr(int N, const T* A, const T* B, T* C, Context* context); +template <typename T, class Context> +CAFFE2_API void +BitwiseXor(int N, const T* A, const T* B, T* C, Context* context); + +template <typename T, class Context> +CAFFE2_API void EQ(int N, const T* A, const T* B, bool* C, Context* context); +template <typename T, class Context> +CAFFE2_API void NE(int N, const T* A, const T* B, bool* C, Context* context); +template <typename T, class Context> +CAFFE2_API void LT(int N, const T* A, const T* B, bool* C, Context* context); +template <typename T, class Context> +CAFFE2_API void LE(int N, const T* A, const T* B, bool* C, Context* context); +template <typename T, class Context> +CAFFE2_API void GT(int N, const T* A, const T* B, bool* C, Context* context); +template <typename T, class Context> +CAFFE2_API void GE(int N, const T* A, const T* B, bool* C, Context* context); } // namespace math } // namespace caffe2 diff --git a/caffe2/utils/math/half_utils.h b/caffe2/utils/math/half_utils.h new file mode 100644 index 0000000000..ac841d165a --- /dev/null +++ b/caffe2/utils/math/half_utils.h @@ -0,0 +1,49 @@ +#ifndef CAFFE2_UTILS_MATH_HALF_UTILS_H_ +#define CAFFE2_UTILS_MATH_HALF_UTILS_H_ + +#include "caffe2/core/common.h" +#include "caffe2/core/types.h" +#include "caffe2/utils/conversions.h" +#include "caffe2/utils/math/utils.h" + +namespace caffe2 { +namespace math { +namespace utils { + +struct HalfAddFunctor { + MATH_UTILS_DECL at::Half operator()(const at::Half a, const at::Half b) + const { + return convert::To<float, at::Half>( + convert::To<at::Half, float>(a) + convert::To<at::Half, float>(b)); + } +}; + +struct HalfSubFunctor { + MATH_UTILS_DECL at::Half operator()(const at::Half a, const at::Half b) + const { + return convert::To<float, at::Half>( + convert::To<at::Half, float>(a) - convert::To<at::Half, float>(b)); + } +}; + +struct HalfMulFunctor { + MATH_UTILS_DECL at::Half operator()(const at::Half a, const at::Half b) + const { + return convert::To<float, at::Half>( + convert::To<at::Half, float>(a) * convert::To<at::Half, float>(b)); + } +}; + +struct HalfDivFunctor { + MATH_UTILS_DECL at::Half operator()(const at::Half a, const at::Half b) + const { + return convert::To<float, at::Half>( + convert::To<at::Half, float>(a) / convert::To<at::Half, float>(b)); + } +}; + +} // namespace utils +} // namespace math +} // namespace caffe2 + +#endif // CAFFE2_UTILS_MATH_HALF_UTILS_H_ diff --git a/caffe2/utils/math/reduce.cc b/caffe2/utils/math/reduce.cc index c654834eee..4bcbb1cda3 100644 --- a/caffe2/utils/math/reduce.cc +++ b/caffe2/utils/math/reduce.cc @@ -8,7 +8,7 @@ #include "caffe2/core/context.h" #include "caffe2/utils/eigen_utils.h" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math/utils.h" namespace caffe2 { namespace math { diff --git a/caffe2/utils/math/reduce.cu b/caffe2/utils/math/reduce.cu index 31a653930e..18291d35f9 100644 --- a/caffe2/utils/math/reduce.cu +++ b/caffe2/utils/math/reduce.cu @@ -11,7 +11,7 @@ #include "caffe2/core/context_gpu.h" #include "caffe2/utils/fixed_divisor.h" #include "caffe2/utils/math/reduce.cuh" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math/utils.h" namespace caffe2 { namespace math { diff --git a/caffe2/utils/math_utils.cc b/caffe2/utils/math/utils.cc index f4f7da99f3..3b75cedaaa 100644 --- a/caffe2/utils/math_utils.cc +++ b/caffe2/utils/math/utils.cc @@ -1,4 +1,4 @@ -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math/utils.h" #include <algorithm> #include <functional> diff --git a/caffe2/utils/math_utils.h b/caffe2/utils/math/utils.h index b3fdb14884..b704adb188 100644 --- a/caffe2/utils/math_utils.h +++ b/caffe2/utils/math/utils.h @@ -3,9 +3,8 @@ #include "caffe2/core/common.h" -// See Note [hip-clang differences to hcc] - -#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) || defined(__HIP__) +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) || \ + defined(__HIP__) #define MATH_UTILS_DECL inline __host__ __device__ #else #define MATH_UTILS_DECL inline @@ -16,7 +15,8 @@ namespace math { namespace utils { -MATH_UTILS_DECL bool Not(const bool x) { +template <typename T> +MATH_UTILS_DECL T Not(const T x) { return !x; } diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc index 0adb9239a3..3bbc1ac4ed 100644 --- a/caffe2/utils/math_cpu.cc +++ b/caffe2/utils/math_cpu.cc @@ -623,66 +623,6 @@ C10_EXPORT void GemmStridedBatched<float, CPUContext>( } //////////////////////////////////////////////////////////////////////////////// -// MKL VML alternatives. -// Depending on whether we are using MKL, we will delegate the Caffe math -// functions that are VML-related to either the VML call or the Eigen -// implementation. If you are setting the flags (such as AVX) right for your CPU -// architecture, usually Eigen will deliver a throughput as fast as the VML -// functions. -//////////////////////////////////////////////////////////////////////////////// -#ifdef CAFFE2_USE_MKL - -#define DELEGATE_SIMPLE_BINARY_FUNCTION(T, Func, FuncImpl) \ - template <> \ - C10_EXPORT void Func<T, CPUContext>( \ - const int N, const T* A, const T* B, T* C, CPUContext*) { \ - FuncImpl(N, A, B, C); \ - } -DELEGATE_SIMPLE_BINARY_FUNCTION(float, Add, vsAdd) -DELEGATE_SIMPLE_BINARY_FUNCTION(double, Add, vdAdd) -DELEGATE_SIMPLE_BINARY_FUNCTION(float, Sub, vsSub) -DELEGATE_SIMPLE_BINARY_FUNCTION(double, Sub, vdSub) -DELEGATE_SIMPLE_BINARY_FUNCTION(float, Mul, vsMul) -DELEGATE_SIMPLE_BINARY_FUNCTION(double, Mul, vdMul) -DELEGATE_SIMPLE_BINARY_FUNCTION(float, Div, vsDiv) -DELEGATE_SIMPLE_BINARY_FUNCTION(double, Div, vdDiv) -#undef DELEGATE_SIMPLE_BINARY_FUNCTION - -#endif // CAFFE2_USE_MKL - -#define EIGEN_SIMPLE_BINARY_FUNCTION(T, Func, expr) \ - template <> \ - C10_EXPORT void Func<T, CPUContext>( \ - const int N, const T* A, const T* B, T* C, CPUContext*) { \ - EigenVectorMap<T>(C, N) = ConstEigenVectorArrayMap<T>(A, N) \ - expr ConstEigenVectorArrayMap<T>(B, N); \ - } - -#ifdef CAFFE2_USE_MKL - -#define DEFINE_SIMPLE_BINARY_FUNCTION(Func, expr) \ - EIGEN_SIMPLE_BINARY_FUNCTION(std::int32_t, Func, expr) \ - EIGEN_SIMPLE_BINARY_FUNCTION(std::int64_t, Func, expr) - -#else - -#define DEFINE_SIMPLE_BINARY_FUNCTION(Func, expr) \ - EIGEN_SIMPLE_BINARY_FUNCTION(float, Func, expr) \ - EIGEN_SIMPLE_BINARY_FUNCTION(double, Func, expr) \ - EIGEN_SIMPLE_BINARY_FUNCTION(std::int32_t, Func, expr) \ - EIGEN_SIMPLE_BINARY_FUNCTION(std::int64_t, Func, expr) - -#endif - -DEFINE_SIMPLE_BINARY_FUNCTION(Add, +) -DEFINE_SIMPLE_BINARY_FUNCTION(Sub, -) -DEFINE_SIMPLE_BINARY_FUNCTION(Mul, *) -DEFINE_SIMPLE_BINARY_FUNCTION(Div, /) - -#undef DEFINE_SIMPLE_BINARY_FUNCTION -#undef EIGEN_SIMPLE_BINARY_FUNCTION - -//////////////////////////////////////////////////////////////////////////////// // Common math functions being used in Caffe that do not have a BLAS or MKL // equivalent. For all these functions, we will simply implement them either via // Eigen or via custom code. @@ -1332,17 +1272,6 @@ CAFFE2_SPECIALIZED_ROWWISEMAX(float) CAFFE2_SPECIALIZED_COLWISEMAX(float) #undef CAFFE2_SPECIALIZED_COLWISEMAX -#define CAFFE2_SPECIALIZED_ELEMWISEMAX(T) \ - template <> \ - C10_EXPORT void ElemwiseMax<T, CPUContext>( \ - const int N, const T* x, const T* y, T* z, CPUContext* /*context*/) { \ - std::transform(x, x + N, y, z, [](const T& x_i, const T& y_i) { \ - return std::max(x_i, y_i); \ - }); \ - } -CAFFE2_SPECIALIZED_ELEMWISEMAX(float) -#undef CAFFE2_SPECIALIZED_ELEMWISEMAX - #define CAFFE2_SPECIALIZED_MAXIMUM(T) \ template <> \ C10_EXPORT void Maximum<T, CPUContext>( \ @@ -1609,46 +1538,6 @@ C10_EXPORT void BroadcastBinaryOpImpl( } // namespace -#define DELEGATE_1D_BINARY_FUNCTION(TIn, TOut, Func, Op) \ - template <> \ - C10_EXPORT void Func<TIn, CPUContext>( \ - const int N, const TIn* A, const TIn* B, TOut* C, CPUContext*) { \ - std::transform(A, A + N, B, C, Op<TIn>()); \ - } - -#define DEFINE_1D_COMPARE_FUNCTION(Func, Op) \ - DELEGATE_1D_BINARY_FUNCTION(float, bool, Func, Op) \ - DELEGATE_1D_BINARY_FUNCTION(double, bool, Func, Op) \ - DELEGATE_1D_BINARY_FUNCTION(std::int32_t, bool, Func, Op) \ - DELEGATE_1D_BINARY_FUNCTION(std::int64_t, bool, Func, Op) \ - DELEGATE_1D_BINARY_FUNCTION(bool, bool, Func, Op) - -DEFINE_1D_COMPARE_FUNCTION(EQ, std::equal_to) -DEFINE_1D_COMPARE_FUNCTION(NE, std::not_equal_to) -DEFINE_1D_COMPARE_FUNCTION(LT, std::less) -DEFINE_1D_COMPARE_FUNCTION(LE, std::less_equal) -DEFINE_1D_COMPARE_FUNCTION(GT, std::greater) -DEFINE_1D_COMPARE_FUNCTION(GE, std::greater_equal) - -#undef DEFINE_1D_COMPARE_FUNCTION - -DELEGATE_1D_BINARY_FUNCTION(bool, bool, And, std::logical_and) -DELEGATE_1D_BINARY_FUNCTION(bool, bool, Or, std::logical_or) -DELEGATE_1D_BINARY_FUNCTION(bool, bool, Xor, std::bit_xor) - -#define DEFINE_1D_BITWISE_BINARY_FUNCTION(Func, op) \ - DELEGATE_1D_BINARY_FUNCTION(bool, bool, Func, op) \ - DELEGATE_1D_BINARY_FUNCTION(std::int32_t, std::int32_t, Func, op) \ - DELEGATE_1D_BINARY_FUNCTION(std::int64_t, std::int64_t, Func, op) - -DEFINE_1D_BITWISE_BINARY_FUNCTION(BitwiseAnd, std::bit_and) -DEFINE_1D_BITWISE_BINARY_FUNCTION(BitwiseOr, std::bit_or) -DEFINE_1D_BITWISE_BINARY_FUNCTION(BitwiseXor, std::bit_xor) - -#undef DEFINE_1D_BITWISE_BINARY_FUNCTION - -#undef DELEGATE_1D_BINARY_FUNCTION - #define DELEGATE_2D_BROADCAST_BINARY_FUNCTION(TIn, TOut, Func, Op) \ template <> \ C10_EXPORT void Rowwise##Func<TIn, CPUContext, true>( \ diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu index e93e6924a9..f8ccda7053 100644 --- a/caffe2/utils/math_gpu.cu +++ b/caffe2/utils/math_gpu.cu @@ -41,7 +41,7 @@ using CUBLAS_HALF_TYPE = rocblas_half; using CUBLAS_HALF_TYPE = __half; #endif // __HIP_PLATFORM_HCC -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math/utils.h" #if THRUST_VERSION >= 100800 #define THRUST_SUPPORTS_PER_THREAD @@ -302,74 +302,6 @@ CAFFE2_CUDA_EXPORT void BroadcastBinaryOp( } // namespace -#define DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(TIn, TOut, Func, Op) \ - template <> \ - CAFFE2_CUDA_EXPORT void Func<TIn, CUDAContext>( \ - const int N, \ - const TIn* A, \ - const TIn* B, \ - TOut* C, \ - CUDAContext* context) { \ - SimpleBinaryOpCUDAKernel<TIn, TOut, Op<TIn>> \ - <<<CAFFE_GET_BLOCKS(N), \ - CAFFE_CUDA_NUM_THREADS, \ - 0, \ - context->cuda_stream()>>>(N, Op<TIn>(), A, B, C); \ - } - -#define DEFINE_SIMPLE_CUDA_COMPARE_FUNCTION(Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(std::int32_t, bool, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(std::int64_t, bool, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, bool, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, bool, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, bool, Func, Op) - -DEFINE_SIMPLE_CUDA_COMPARE_FUNCTION(EQ, thrust::equal_to) -DEFINE_SIMPLE_CUDA_COMPARE_FUNCTION(NE, thrust::not_equal_to) -DEFINE_SIMPLE_CUDA_COMPARE_FUNCTION(LT, thrust::less) -DEFINE_SIMPLE_CUDA_COMPARE_FUNCTION(LE, thrust::less_equal) -DEFINE_SIMPLE_CUDA_COMPARE_FUNCTION(GT, thrust::greater) -DEFINE_SIMPLE_CUDA_COMPARE_FUNCTION(GE, thrust::greater_equal) - -#undef DEFINE_SIMPLE_CUDA_COMPARE_FUNCTION - -#define DEFINE_SIMPLE_CUDA_BINARY_FUNCTION(Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(std::int32_t, std::int32_t, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(std::int64_t, std::int64_t, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, float, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, double, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(at::Half, at::Half, Func, Op) - -DEFINE_SIMPLE_CUDA_BINARY_FUNCTION(Add, AddFunctor) -DEFINE_SIMPLE_CUDA_BINARY_FUNCTION(Sub, SubFunctor) -DEFINE_SIMPLE_CUDA_BINARY_FUNCTION(Mul, MulFunctor) -DEFINE_SIMPLE_CUDA_BINARY_FUNCTION(Div, DivFunctor) - -#undef DEFINE_SIMPLE_CUDA_BINARY_FUNCTION - -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, bool, And, thrust::logical_and) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, bool, Or, thrust::logical_or) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, bool, Xor, thrust::bit_xor) - -#define DEFINE_SIMPLE_CUDA_BITWISE_BINARY_FUNCTION(Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, bool, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(std::int32_t, std::int32_t, Func, Op) \ - DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(std::int64_t, std::int64_t, Func, Op) - -DEFINE_SIMPLE_CUDA_BITWISE_BINARY_FUNCTION(BitwiseAnd, thrust::bit_and) -DEFINE_SIMPLE_CUDA_BITWISE_BINARY_FUNCTION(BitwiseOr, thrust::bit_or) -DEFINE_SIMPLE_CUDA_BITWISE_BINARY_FUNCTION(BitwiseXor, thrust::bit_xor) - -#undef DEFINE_SIMPLE_CUDA_BITWISE_BINARY_FUNCTION - -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - float, - float, - ElemwiseMax, - thrust::maximum); - -#undef DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION - #define DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(TIn, TOut, Func, Op) \ template <> \ CAFFE2_CUDA_EXPORT void Rowwise##Func<TIn, CUDAContext, true>( \ diff --git a/caffe2/utils/math_gpu_test.cc b/caffe2/utils/math_gpu_test.cc index 80d7ff0291..99e9f69e33 100644 --- a/caffe2/utils/math_gpu_test.cc +++ b/caffe2/utils/math_gpu_test.cc @@ -204,24 +204,6 @@ TEST(MathUtilGPUTest, testReduceMax) { [](int /*i*/) { return 17.0f; }); } -TEST(MathUtilGPUTest, testElemwiseMax) { - executeGpuBinaryOpTest( - 13, - 13, - 13, - [](int i) { return 2.0f - i; }, - [](int i) { return i - 6.0f; }, - [](int N0, - int /*N1*/, - const float* src0, - const float* src1, - float* dst, - CUDAContext* context) { - math::ElemwiseMax<float, CUDAContext>(N0, src0, src1, dst, context); - }, - [](int i) { return std::max(2.0f - i, i - 6.0f); }); -} - TEST(MathUtilGPUTest, testCopyVector) { executeGpuBinaryOpTest( 6, |