diff options
author | Xiaomeng Yang <yangxm@fb.com> | 2019-02-07 18:19:46 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-07 18:38:26 -0800 |
commit | 2db847b3a7edc48652e144e7c9d7aa0bbed66aaa (patch) | |
tree | 3d95b2f625ca878f038092dea671a2bffed088c6 /caffe2/operators | |
parent | 22477c6a7fc327cdf913751ab31b788ec710caa9 (diff) | |
download | pytorch-2db847b3a7edc48652e144e7c9d7aa0bbed66aaa.tar.gz pytorch-2db847b3a7edc48652e144e7c9d7aa0bbed66aaa.tar.bz2 pytorch-2db847b3a7edc48652e144e7c9d7aa0bbed66aaa.zip |
Separate elementwise level2 math functions (#16753)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16753
Separate elementwise level2 math functions
i-am-not-moving-c2-to-c10
Reviewed By: houseroad
Differential Revision: D13954928
fbshipit-source-id: 1ca7a5d3da96e32510f502e5e4e79168854bee67
Diffstat (limited to 'caffe2/operators')
-rw-r--r-- | caffe2/operators/conv_transpose_op_mobile_impl.h | 2 | ||||
-rw-r--r-- | caffe2/operators/layer_norm_op.cu | 2 | ||||
-rw-r--r-- | caffe2/operators/minmax_gradient_ops.cc | 74 | ||||
-rw-r--r-- | caffe2/operators/minmax_ops.cc | 33 | ||||
-rw-r--r-- | caffe2/operators/minmax_ops.cu | 56 | ||||
-rw-r--r-- | caffe2/operators/minmax_ops.h | 115 | ||||
-rw-r--r-- | caffe2/operators/rsqrt_op.cu | 2 | ||||
-rw-r--r-- | caffe2/operators/utility_ops.cu | 104 |
8 files changed, 179 insertions, 209 deletions
diff --git a/caffe2/operators/conv_transpose_op_mobile_impl.h b/caffe2/operators/conv_transpose_op_mobile_impl.h index f586453d5c..6869f5178d 100644 --- a/caffe2/operators/conv_transpose_op_mobile_impl.h +++ b/caffe2/operators/conv_transpose_op_mobile_impl.h @@ -18,7 +18,7 @@ #include "caffe2/utils/eigen_utils.h" #include "caffe2/utils/fixed_divisor.h" #include "caffe2/utils/math.h" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math/utils.h" C10_DECLARE_bool(caffe2_force_shared_col_buffer); diff --git a/caffe2/operators/layer_norm_op.cu b/caffe2/operators/layer_norm_op.cu index c3465c1d2d..440783c6eb 100644 --- a/caffe2/operators/layer_norm_op.cu +++ b/caffe2/operators/layer_norm_op.cu @@ -4,7 +4,7 @@ #include "caffe2/core/context_gpu.h" #include "caffe2/utils/math.h" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math/utils.h" namespace caffe2 { diff --git a/caffe2/operators/minmax_gradient_ops.cc b/caffe2/operators/minmax_gradient_ops.cc index 9c7df22d55..d288eb0946 100644 --- a/caffe2/operators/minmax_gradient_ops.cc +++ b/caffe2/operators/minmax_gradient_ops.cc @@ -1,66 +1,66 @@ #include "caffe2/operators/minmax_ops.h" -#include "caffe2/utils/eigen_utils.h" -namespace caffe2 { +#include <string> +#include <vector> -REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp<float, CPUContext>); -REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp<float, CPUContext>); +#include "caffe2/utils/eigen_utils.h" -OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX); -OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX); +namespace caffe2 { template <typename T, class Context> bool SelectGradientOpBase<T, Context>::RunOnDevice() { - auto& output = Input(0); - auto& grad_output = Input(1); - const int kInputStartOffset = 2; - - const T* data = output.template data<T>(); - ConstEigenArrayMap<T> output_array( - output.template data<T>(), 1, output.numel()); - ConstEigenArrayMap<T> grad_out_array( - grad_output.template data<T>(), 1, grad_output.numel()); - + const auto& Y = Input(0); + const auto& dY = Input(1); + const int N = Y.numel(); + ConstEigenVectorArrayMap<T> Y_arr(Y.template data<T>(), N); + ConstEigenVectorArrayMap<T> dY_arr(dY.template data<T>(), N); for (int i = 0; i < OutputSize(); i++) { - auto& input = Input(i + kInputStartOffset); - ConstEigenArrayMap<T> input_array( - input.template data<T>(), 1, input.numel()); - - auto* grad_input = Output(i, input.sizes(), at::dtype<T>()); - EigenArrayMap<T> grad_in_array( - grad_input->template mutable_data<T>(), 1, grad_input->numel()); - grad_in_array = grad_out_array * - input_array.cwiseEqual(output_array).template cast<T>(); + const auto& Xi = Input(i + 2); + auto* dXi = Output(i, Xi.sizes(), at::dtype<T>()); + ConstEigenVectorArrayMap<T> Xi_arr(Xi.template data<T>(), N); + EigenVectorArrayMap<T> dXi_arr(dXi->template mutable_data<T>(), N); + dXi_arr = (Xi_arr == Y_arr).template cast<T>() * dY_arr; } return true; } +REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp<float, CPUContext>); +REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp<float, CPUContext>); + +OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX); +OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX); + +namespace { + class GetMaxGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; - vector<OperatorDef> GetGradientDefs() override { - auto gradInputs = vector<string>(); - auto inputs = vector<string>{O(0), GO(0)}; - for (int i = 0; i < def_.input_size(); i++) { - gradInputs.push_back(GI(i)); + std::vector<OperatorDef> GetGradientDefs() override { + std::vector<std::string> inputs = {O(0), GO(0)}; + std::vector<std::string> grad_inputs; + for (int i = 0; i < def_.input_size(); ++i) { inputs.push_back(I(i)); + grad_inputs.push_back(GI(i)); } - return SingleGradientDef("MaxGradient", "", inputs, gradInputs); + return SingleGradientDef("MaxGradient", "", inputs, grad_inputs); } }; -REGISTER_GRADIENT(Max, GetMaxGradient); class GetMinGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; vector<OperatorDef> GetGradientDefs() override { - auto gradInputs = vector<string>(); - auto inputs = vector<string>{O(0), GO(0)}; - for (int i = 0; i < def_.input_size(); i++) { - gradInputs.push_back(GI(i)); + std::vector<std::string> inputs = {O(0), GO(0)}; + std::vector<std::string> grad_inputs; + for (int i = 0; i < def_.input_size(); ++i) { inputs.push_back(I(i)); + grad_inputs.push_back(GI(i)); } - return SingleGradientDef("MinGradient", "", inputs, gradInputs); + return SingleGradientDef("MinGradient", "", inputs, grad_inputs); } }; + +} // namespace + +REGISTER_GRADIENT(Max, GetMaxGradient); REGISTER_GRADIENT(Min, GetMinGradient); } // namespace caffe2 diff --git a/caffe2/operators/minmax_ops.cc b/caffe2/operators/minmax_ops.cc index 1a105770d2..3dd2487335 100644 --- a/caffe2/operators/minmax_ops.cc +++ b/caffe2/operators/minmax_ops.cc @@ -1,10 +1,9 @@ #include "caffe2/operators/minmax_ops.h" -#include "caffe2/utils/eigen_utils.h" namespace caffe2 { -REGISTER_CPU_OPERATOR(Max, MaxOp<float, CPUContext>); REGISTER_CPU_OPERATOR(Min, MinOp<float, CPUContext>); +REGISTER_CPU_OPERATOR(Max, MaxOp<float, CPUContext>); OPERATOR_SCHEMA(Max) .NumInputs(1, INT_MAX) @@ -155,34 +154,4 @@ Min: "Contains the minimum valued element at each location.") .InheritOnnxSchema(); -template <typename T, class Context> -bool MaxOp<T, Context>::Compute() { - auto& input0 = Input(0); - const int N = input0.numel(); - T* output_data = Output(0)->template mutable_data<T>(); - - for (int i = 1; i < InputSize(); i++) { - auto input_data = Input(i).template data<T>(); - EigenVectorMap<T> output_vec(output_data, N); - output_vec = output_vec.cwiseMax(ConstEigenVectorMap<T>(input_data, N)); - } - - return true; -} - -template <typename T, class Context> -bool MinOp<T, Context>::Compute() { - auto& input0 = Input(0); - const int N = input0.numel(); - T* output_data = Output(0)->template mutable_data<T>(); - - for (int i = 1; i < InputSize(); i++) { - auto input_data = Input(i).template data<T>(); - EigenVectorMap<T> output_vec(output_data, N); - output_vec = output_vec.cwiseMin(ConstEigenVectorMap<T>(input_data, N)); - } - - return true; -} - } // namespace caffe2 diff --git a/caffe2/operators/minmax_ops.cu b/caffe2/operators/minmax_ops.cu new file mode 100644 index 0000000000..a853b700d1 --- /dev/null +++ b/caffe2/operators/minmax_ops.cu @@ -0,0 +1,56 @@ +#include "caffe2/operators/minmax_ops.h" + +#include "caffe2/core/context_gpu.h" +#include "caffe2/utils/math.h" + +namespace caffe2 { + +namespace { + +template <typename T> +__global__ void SelectGradientCUDAKernel( + const int N, + const T* dY, + const T* X, + const T* Y, + T* dX) { + const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (i < N) { +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + dX[i] = __ldg(X + i) == __ldg(Y + i) ? __ldg(dY + i) : T(0); +#else + dX[i] = X[i] == Y[i] ? dY[i] : T(0); +#endif + } +} + +} // namespace + +template <> +bool SelectGradientOpBase<float, CUDAContext>::RunOnDevice() { + const auto& Y = Input(0); + const auto& dY = Input(1); + const int N = Y.numel(); + const int M = math::DivUp(N, CAFFE_CUDA_NUM_THREADS); + const float* dY_data = dY.data<float>(); + const float* Y_data = Y.data<float>(); + for (int i = 0; i < OutputSize(); i++) { + const auto& Xi = Input(i + 2); + auto* dXi = Output(i, Xi.sizes(), at::dtype<float>()); + const float* Xi_data = Xi.data<float>(); + float* dXi_data = dXi->mutable_data<float>(); + if (N > 0) { + SelectGradientCUDAKernel<float> + <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>( + N, dY_data, Xi_data, Y_data, dXi_data); + } + } + return true; +} + +REGISTER_CUDA_OPERATOR(Min, MinOp<float, CUDAContext>); +REGISTER_CUDA_OPERATOR(MinGradient, MinGradientOp<float, CUDAContext>); +REGISTER_CUDA_OPERATOR(Max, MaxOp<float, CUDAContext>); +REGISTER_CUDA_OPERATOR(MaxGradient, MaxGradientOp<float, CUDAContext>); + +} // namespace caffe2 diff --git a/caffe2/operators/minmax_ops.h b/caffe2/operators/minmax_ops.h index db02e0dcb0..5668460682 100644 --- a/caffe2/operators/minmax_ops.h +++ b/caffe2/operators/minmax_ops.h @@ -1,7 +1,6 @@ #ifndef CAFFE2_OPERATORS_MINMAX_OPS_H_ #define CAFFE2_OPERATORS_MINMAX_OPS_H_ -#include "caffe2/core/common_omp.h" #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" @@ -11,49 +10,99 @@ namespace caffe2 { template <typename T, class Context> -class MaxMinOpBase : public Operator<Context> { +class MaxOp final : public Operator<Context> { public: USE_OPERATOR_CONTEXT_FUNCTIONS; - USE_SIMPLE_CTOR_DTOR(MaxMinOpBase) - bool RunOnDevice() override { - auto& input0 = Input(0); - auto* output = Output(0); - - output->ResizeLike(input0); - output->CopyFrom(input0, /* async */ true); + USE_SIMPLE_CTOR_DTOR(MaxOp) + bool RunOnDevice() override { + const auto& X0 = Input(0); + auto* Y = Output(0); + Y->ResizeLike(X0); + const T* X0_data = X0.template data<T>(); + T* Y_data = Y->template mutable_data<T>(); + const int N = X0.numel(); if (InputSize() == 1) { + if (Y != &X0) { + context_.template CopySameDevice<T>(N, X0_data, Y_data); + } return true; } - - // Dimension checking - for (int i = 1; i < InputSize(); ++i) { + const auto& X1 = Input(1); + CAFFE_ENFORCE_EQ( + X0.sizes(), + Y->sizes(), + "Description: Input #1, input dimension:", + X1.sizes(), + " should match output dimension: ", + Y->sizes()); + const T* X1_data = X1.template data<T>(); + math::Max<T, Context>(N, X0_data, X1_data, Y_data, &context_); + for (int i = 2; i < InputSize(); ++i) { + const auto& Xi = Input(i); CAFFE_ENFORCE_EQ( - output->sizes(), - Input(i).sizes(), + Xi.sizes(), + Y->sizes(), "Description: Input #", i, ", input dimension:", Input(i).sizes(), " should match output dimension: ", - output->sizes()); + Y->sizes()); + const T* Xi_data = Xi.template data<T>(); + math::Max<T, Context>(N, Y_data, Xi_data, Y_data, &context_); } - - return this->Compute(); + return true; } - - virtual bool Compute() = 0; }; template <typename T, class Context> -class MaxOp : public MaxMinOpBase<T, Context> { +class MinOp final : public Operator<Context> { public: USE_OPERATOR_CONTEXT_FUNCTIONS; - MaxOp(const OperatorDef& operator_def, Workspace* ws) - : MaxMinOpBase<T, Context>(operator_def, ws) {} - virtual ~MaxOp() noexcept {} - bool Compute() override; + + USE_SIMPLE_CTOR_DTOR(MinOp) + + bool RunOnDevice() override { + const auto& X0 = Input(0); + auto* Y = Output(0); + Y->ResizeLike(X0); + const T* X0_data = X0.template data<T>(); + T* Y_data = Y->template mutable_data<T>(); + const int N = X0.numel(); + if (InputSize() == 1) { + if (Y != &X0) { + context_.template CopySameDevice<T>(N, X0_data, Y_data); + } + return true; + } + const auto& X1 = Input(1); + CAFFE_ENFORCE_EQ( + X0.sizes(), + Y->sizes(), + "Description: Input #1, input dimension:", + X1.sizes(), + " should match output dimension: ", + Y->sizes()); + const T* X1_data = X1.template data<T>(); + math::Min<T, Context>(N, X0_data, X1_data, Y_data, &context_); + for (int i = 2; i < InputSize(); ++i) { + const auto& Xi = Input(i); + CAFFE_ENFORCE_EQ( + Xi.sizes(), + Y->sizes(), + "Description: Input #", + i, + ", input dimension:", + Input(i).sizes(), + " should match output dimension: ", + Y->sizes()); + const T* Xi_data = Xi.template data<T>(); + math::Min<T, Context>(N, Y_data, Xi_data, Y_data, &context_); + } + return true; + } }; template <typename T, class Context> @@ -66,29 +115,21 @@ class SelectGradientOpBase : public Operator<Context> { }; template <typename T, class Context> -class MaxGradientOp : public SelectGradientOpBase<T, Context> { +class MaxGradientOp final : public SelectGradientOpBase<T, Context> { public: MaxGradientOp(const OperatorDef& operator_def, Workspace* ws) : SelectGradientOpBase<T, Context>(operator_def, ws) {} - virtual ~MaxGradientOp() noexcept {} -}; -template <typename T, class Context> -class MinOp : public MaxMinOpBase<T, Context> { - public: - USE_OPERATOR_CONTEXT_FUNCTIONS; - MinOp(const OperatorDef& operator_def, Workspace* ws) - : MaxMinOpBase<T, Context>(operator_def, ws) {} - virtual ~MinOp() noexcept {} - bool Compute() override; + ~MaxGradientOp() = default; }; template <typename T, class Context> -class MinGradientOp : public SelectGradientOpBase<T, Context> { +class MinGradientOp final : public SelectGradientOpBase<T, Context> { public: MinGradientOp(const OperatorDef& operator_def, Workspace* ws) : SelectGradientOpBase<T, Context>(operator_def, ws) {} - virtual ~MinGradientOp() noexcept {} + + ~MinGradientOp() = default; }; } // namespace caffe2 diff --git a/caffe2/operators/rsqrt_op.cu b/caffe2/operators/rsqrt_op.cu index 378d131a6f..eae8dc0f44 100644 --- a/caffe2/operators/rsqrt_op.cu +++ b/caffe2/operators/rsqrt_op.cu @@ -4,7 +4,7 @@ #include <functional> #include "caffe2/core/context_gpu.h" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math.h" namespace caffe2 { diff --git a/caffe2/operators/utility_ops.cu b/caffe2/operators/utility_ops.cu index b7d17ba64f..f767a9ae11 100644 --- a/caffe2/operators/utility_ops.cu +++ b/caffe2/operators/utility_ops.cu @@ -1,8 +1,4 @@ -#include "caffe2/core/context_gpu.h" -#include "caffe2/operators/flatten_op.h" -#include "caffe2/operators/minmax_ops.h" #include "caffe2/operators/utility_ops.h" -#include "caffe2/utils/math.h" #include <thrust/device_vector.h> #include <thrust/sequence.h> @@ -10,6 +6,10 @@ #include <thrust/system/cuda/execution_policy.h> #include <thrust/unique.h> +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/flatten_op.h" +#include "caffe2/utils/math.h" + namespace caffe2 { template <> @@ -137,102 +137,6 @@ bool NanCheckOp<CUDAContext>::RunOnDevice() { REGISTER_CUDA_OPERATOR(NanCheck, NanCheckOp<CUDAContext>); -__global__ void -ElwiseMaxKernel(const float* X, const float* Y, float* maxout, const int N) { - CUDA_1D_KERNEL_LOOP(i, N) { - maxout[i] = fmaxf(X[i], Y[i]); - } -} - -template <> -bool MaxOp<float, CUDAContext>::Compute() { - float* output_data = Output(0)->template mutable_data<float>(); - const int N = Input(0).numel(); - - // Run pairwise-maxes - for (int i = 1; i < InputSize(); ++i) { - ElwiseMaxKernel<<< - CAFFE_GET_BLOCKS(N), - CAFFE_CUDA_NUM_THREADS, - 0, - context_.cuda_stream()>>>( - (i == 0 ? Input(0).data<float>() : Output(0)->data<float>()), - Input(i).data<float>(), - output_data, - N); - } - - return true; -} - -REGISTER_CUDA_OPERATOR(Max, MaxOp<float, CUDAContext>); -REGISTER_CUDA_OPERATOR(MaxGradient, MaxGradientOp<float, CUDAContext>); - -__global__ void -ElwiseMinKernel(const float* X, const float* Y, float* minout, const int N) { - CUDA_1D_KERNEL_LOOP(i, N) { - minout[i] = fminf(X[i], Y[i]); - } -} - -template <> -bool MinOp<float, CUDAContext>::Compute() { - float* output_data = Output(0)->template mutable_data<float>(); - const int N = Input(0).numel(); - - // Run pairwise-mines - for (int i = 1; i < InputSize(); ++i) { - ElwiseMinKernel<<< - CAFFE_GET_BLOCKS(N), - CAFFE_CUDA_NUM_THREADS, - 0, - context_.cuda_stream()>>>( - (i == 0 ? Input(0).data<float>() : Output(0)->data<float>()), - Input(i).data<float>(), - output_data, - N); - } - - return true; -} - -REGISTER_CUDA_OPERATOR(Min, MinOp<float, CUDAContext>); -REGISTER_CUDA_OPERATOR(MinGradient, MinGradientOp<float, CUDAContext>); - -template <typename T> -__global__ void -MaxMinGradKernel(int N, const T* mx, const T* x, const T* go, T* gi) { - CUDA_1D_KERNEL_LOOP(i, N) { - gi[i] = go[i] * (mx[i] == x[i]); - } -} - -template <> -bool SelectGradientOpBase<float, CUDAContext>::RunOnDevice() { - auto& output = Input(0); - auto& grad_output = Input(1); - const int kInputStartOffset = 2; - - const float* data = output.data<float>(); - - for (int i = 0; i < OutputSize(); i++) { - auto& input = Input(i + kInputStartOffset); - - auto* grad_input = Output(i, input.sizes(), at::dtype<float>()); - MaxMinGradKernel<<< - CAFFE_GET_BLOCKS(input.numel()), - CAFFE_CUDA_NUM_THREADS, - 0, - context_.cuda_stream()>>>( - input.numel(), - output.data<float>(), - input.data<float>(), - grad_output.data<float>(), - grad_input->template mutable_data<float>()); - } - return true; -} - /** * @brief Update slices of Y in-place with a batch of weighted X's. * Y[idx] = alpha[b] * X[b][i] + Y[idx] |