summaryrefslogtreecommitdiff
path: root/caffe2/operators
diff options
context:
space:
mode:
authorXiaomeng Yang <yangxm@fb.com>2019-02-07 18:19:46 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-07 18:38:26 -0800
commit2db847b3a7edc48652e144e7c9d7aa0bbed66aaa (patch)
tree3d95b2f625ca878f038092dea671a2bffed088c6 /caffe2/operators
parent22477c6a7fc327cdf913751ab31b788ec710caa9 (diff)
downloadpytorch-2db847b3a7edc48652e144e7c9d7aa0bbed66aaa.tar.gz
pytorch-2db847b3a7edc48652e144e7c9d7aa0bbed66aaa.tar.bz2
pytorch-2db847b3a7edc48652e144e7c9d7aa0bbed66aaa.zip
Separate elementwise level2 math functions (#16753)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16753 Separate elementwise level2 math functions i-am-not-moving-c2-to-c10 Reviewed By: houseroad Differential Revision: D13954928 fbshipit-source-id: 1ca7a5d3da96e32510f502e5e4e79168854bee67
Diffstat (limited to 'caffe2/operators')
-rw-r--r--caffe2/operators/conv_transpose_op_mobile_impl.h2
-rw-r--r--caffe2/operators/layer_norm_op.cu2
-rw-r--r--caffe2/operators/minmax_gradient_ops.cc74
-rw-r--r--caffe2/operators/minmax_ops.cc33
-rw-r--r--caffe2/operators/minmax_ops.cu56
-rw-r--r--caffe2/operators/minmax_ops.h115
-rw-r--r--caffe2/operators/rsqrt_op.cu2
-rw-r--r--caffe2/operators/utility_ops.cu104
8 files changed, 179 insertions, 209 deletions
diff --git a/caffe2/operators/conv_transpose_op_mobile_impl.h b/caffe2/operators/conv_transpose_op_mobile_impl.h
index f586453d5c..6869f5178d 100644
--- a/caffe2/operators/conv_transpose_op_mobile_impl.h
+++ b/caffe2/operators/conv_transpose_op_mobile_impl.h
@@ -18,7 +18,7 @@
#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/fixed_divisor.h"
#include "caffe2/utils/math.h"
-#include "caffe2/utils/math_utils.h"
+#include "caffe2/utils/math/utils.h"
C10_DECLARE_bool(caffe2_force_shared_col_buffer);
diff --git a/caffe2/operators/layer_norm_op.cu b/caffe2/operators/layer_norm_op.cu
index c3465c1d2d..440783c6eb 100644
--- a/caffe2/operators/layer_norm_op.cu
+++ b/caffe2/operators/layer_norm_op.cu
@@ -4,7 +4,7 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/utils/math.h"
-#include "caffe2/utils/math_utils.h"
+#include "caffe2/utils/math/utils.h"
namespace caffe2 {
diff --git a/caffe2/operators/minmax_gradient_ops.cc b/caffe2/operators/minmax_gradient_ops.cc
index 9c7df22d55..d288eb0946 100644
--- a/caffe2/operators/minmax_gradient_ops.cc
+++ b/caffe2/operators/minmax_gradient_ops.cc
@@ -1,66 +1,66 @@
#include "caffe2/operators/minmax_ops.h"
-#include "caffe2/utils/eigen_utils.h"
-namespace caffe2 {
+#include <string>
+#include <vector>
-REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp<float, CPUContext>);
-REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp<float, CPUContext>);
+#include "caffe2/utils/eigen_utils.h"
-OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
-OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
+namespace caffe2 {
template <typename T, class Context>
bool SelectGradientOpBase<T, Context>::RunOnDevice() {
- auto& output = Input(0);
- auto& grad_output = Input(1);
- const int kInputStartOffset = 2;
-
- const T* data = output.template data<T>();
- ConstEigenArrayMap<T> output_array(
- output.template data<T>(), 1, output.numel());
- ConstEigenArrayMap<T> grad_out_array(
- grad_output.template data<T>(), 1, grad_output.numel());
-
+ const auto& Y = Input(0);
+ const auto& dY = Input(1);
+ const int N = Y.numel();
+ ConstEigenVectorArrayMap<T> Y_arr(Y.template data<T>(), N);
+ ConstEigenVectorArrayMap<T> dY_arr(dY.template data<T>(), N);
for (int i = 0; i < OutputSize(); i++) {
- auto& input = Input(i + kInputStartOffset);
- ConstEigenArrayMap<T> input_array(
- input.template data<T>(), 1, input.numel());
-
- auto* grad_input = Output(i, input.sizes(), at::dtype<T>());
- EigenArrayMap<T> grad_in_array(
- grad_input->template mutable_data<T>(), 1, grad_input->numel());
- grad_in_array = grad_out_array *
- input_array.cwiseEqual(output_array).template cast<T>();
+ const auto& Xi = Input(i + 2);
+ auto* dXi = Output(i, Xi.sizes(), at::dtype<T>());
+ ConstEigenVectorArrayMap<T> Xi_arr(Xi.template data<T>(), N);
+ EigenVectorArrayMap<T> dXi_arr(dXi->template mutable_data<T>(), N);
+ dXi_arr = (Xi_arr == Y_arr).template cast<T>() * dY_arr;
}
return true;
}
+REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp<float, CPUContext>);
+
+OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
+OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
+
+namespace {
+
class GetMaxGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
- vector<OperatorDef> GetGradientDefs() override {
- auto gradInputs = vector<string>();
- auto inputs = vector<string>{O(0), GO(0)};
- for (int i = 0; i < def_.input_size(); i++) {
- gradInputs.push_back(GI(i));
+ std::vector<OperatorDef> GetGradientDefs() override {
+ std::vector<std::string> inputs = {O(0), GO(0)};
+ std::vector<std::string> grad_inputs;
+ for (int i = 0; i < def_.input_size(); ++i) {
inputs.push_back(I(i));
+ grad_inputs.push_back(GI(i));
}
- return SingleGradientDef("MaxGradient", "", inputs, gradInputs);
+ return SingleGradientDef("MaxGradient", "", inputs, grad_inputs);
}
};
-REGISTER_GRADIENT(Max, GetMaxGradient);
class GetMinGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
- auto gradInputs = vector<string>();
- auto inputs = vector<string>{O(0), GO(0)};
- for (int i = 0; i < def_.input_size(); i++) {
- gradInputs.push_back(GI(i));
+ std::vector<std::string> inputs = {O(0), GO(0)};
+ std::vector<std::string> grad_inputs;
+ for (int i = 0; i < def_.input_size(); ++i) {
inputs.push_back(I(i));
+ grad_inputs.push_back(GI(i));
}
- return SingleGradientDef("MinGradient", "", inputs, gradInputs);
+ return SingleGradientDef("MinGradient", "", inputs, grad_inputs);
}
};
+
+} // namespace
+
+REGISTER_GRADIENT(Max, GetMaxGradient);
REGISTER_GRADIENT(Min, GetMinGradient);
} // namespace caffe2
diff --git a/caffe2/operators/minmax_ops.cc b/caffe2/operators/minmax_ops.cc
index 1a105770d2..3dd2487335 100644
--- a/caffe2/operators/minmax_ops.cc
+++ b/caffe2/operators/minmax_ops.cc
@@ -1,10 +1,9 @@
#include "caffe2/operators/minmax_ops.h"
-#include "caffe2/utils/eigen_utils.h"
namespace caffe2 {
-REGISTER_CPU_OPERATOR(Max, MaxOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(Min, MinOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(Max, MaxOp<float, CPUContext>);
OPERATOR_SCHEMA(Max)
.NumInputs(1, INT_MAX)
@@ -155,34 +154,4 @@ Min:
"Contains the minimum valued element at each location.")
.InheritOnnxSchema();
-template <typename T, class Context>
-bool MaxOp<T, Context>::Compute() {
- auto& input0 = Input(0);
- const int N = input0.numel();
- T* output_data = Output(0)->template mutable_data<T>();
-
- for (int i = 1; i < InputSize(); i++) {
- auto input_data = Input(i).template data<T>();
- EigenVectorMap<T> output_vec(output_data, N);
- output_vec = output_vec.cwiseMax(ConstEigenVectorMap<T>(input_data, N));
- }
-
- return true;
-}
-
-template <typename T, class Context>
-bool MinOp<T, Context>::Compute() {
- auto& input0 = Input(0);
- const int N = input0.numel();
- T* output_data = Output(0)->template mutable_data<T>();
-
- for (int i = 1; i < InputSize(); i++) {
- auto input_data = Input(i).template data<T>();
- EigenVectorMap<T> output_vec(output_data, N);
- output_vec = output_vec.cwiseMin(ConstEigenVectorMap<T>(input_data, N));
- }
-
- return true;
-}
-
} // namespace caffe2
diff --git a/caffe2/operators/minmax_ops.cu b/caffe2/operators/minmax_ops.cu
new file mode 100644
index 0000000000..a853b700d1
--- /dev/null
+++ b/caffe2/operators/minmax_ops.cu
@@ -0,0 +1,56 @@
+#include "caffe2/operators/minmax_ops.h"
+
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+
+namespace {
+
+template <typename T>
+__global__ void SelectGradientCUDAKernel(
+ const int N,
+ const T* dY,
+ const T* X,
+ const T* Y,
+ T* dX) {
+ const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+ if (i < N) {
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
+ dX[i] = __ldg(X + i) == __ldg(Y + i) ? __ldg(dY + i) : T(0);
+#else
+ dX[i] = X[i] == Y[i] ? dY[i] : T(0);
+#endif
+ }
+}
+
+} // namespace
+
+template <>
+bool SelectGradientOpBase<float, CUDAContext>::RunOnDevice() {
+ const auto& Y = Input(0);
+ const auto& dY = Input(1);
+ const int N = Y.numel();
+ const int M = math::DivUp(N, CAFFE_CUDA_NUM_THREADS);
+ const float* dY_data = dY.data<float>();
+ const float* Y_data = Y.data<float>();
+ for (int i = 0; i < OutputSize(); i++) {
+ const auto& Xi = Input(i + 2);
+ auto* dXi = Output(i, Xi.sizes(), at::dtype<float>());
+ const float* Xi_data = Xi.data<float>();
+ float* dXi_data = dXi->mutable_data<float>();
+ if (N > 0) {
+ SelectGradientCUDAKernel<float>
+ <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+ N, dY_data, Xi_data, Y_data, dXi_data);
+ }
+ }
+ return true;
+}
+
+REGISTER_CUDA_OPERATOR(Min, MinOp<float, CUDAContext>);
+REGISTER_CUDA_OPERATOR(MinGradient, MinGradientOp<float, CUDAContext>);
+REGISTER_CUDA_OPERATOR(Max, MaxOp<float, CUDAContext>);
+REGISTER_CUDA_OPERATOR(MaxGradient, MaxGradientOp<float, CUDAContext>);
+
+} // namespace caffe2
diff --git a/caffe2/operators/minmax_ops.h b/caffe2/operators/minmax_ops.h
index db02e0dcb0..5668460682 100644
--- a/caffe2/operators/minmax_ops.h
+++ b/caffe2/operators/minmax_ops.h
@@ -1,7 +1,6 @@
#ifndef CAFFE2_OPERATORS_MINMAX_OPS_H_
#define CAFFE2_OPERATORS_MINMAX_OPS_H_
-#include "caffe2/core/common_omp.h"
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
@@ -11,49 +10,99 @@
namespace caffe2 {
template <typename T, class Context>
-class MaxMinOpBase : public Operator<Context> {
+class MaxOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
- USE_SIMPLE_CTOR_DTOR(MaxMinOpBase)
- bool RunOnDevice() override {
- auto& input0 = Input(0);
- auto* output = Output(0);
-
- output->ResizeLike(input0);
- output->CopyFrom(input0, /* async */ true);
+ USE_SIMPLE_CTOR_DTOR(MaxOp)
+ bool RunOnDevice() override {
+ const auto& X0 = Input(0);
+ auto* Y = Output(0);
+ Y->ResizeLike(X0);
+ const T* X0_data = X0.template data<T>();
+ T* Y_data = Y->template mutable_data<T>();
+ const int N = X0.numel();
if (InputSize() == 1) {
+ if (Y != &X0) {
+ context_.template CopySameDevice<T>(N, X0_data, Y_data);
+ }
return true;
}
-
- // Dimension checking
- for (int i = 1; i < InputSize(); ++i) {
+ const auto& X1 = Input(1);
+ CAFFE_ENFORCE_EQ(
+ X0.sizes(),
+ Y->sizes(),
+ "Description: Input #1, input dimension:",
+ X1.sizes(),
+ " should match output dimension: ",
+ Y->sizes());
+ const T* X1_data = X1.template data<T>();
+ math::Max<T, Context>(N, X0_data, X1_data, Y_data, &context_);
+ for (int i = 2; i < InputSize(); ++i) {
+ const auto& Xi = Input(i);
CAFFE_ENFORCE_EQ(
- output->sizes(),
- Input(i).sizes(),
+ Xi.sizes(),
+ Y->sizes(),
"Description: Input #",
i,
", input dimension:",
Input(i).sizes(),
" should match output dimension: ",
- output->sizes());
+ Y->sizes());
+ const T* Xi_data = Xi.template data<T>();
+ math::Max<T, Context>(N, Y_data, Xi_data, Y_data, &context_);
}
-
- return this->Compute();
+ return true;
}
-
- virtual bool Compute() = 0;
};
template <typename T, class Context>
-class MaxOp : public MaxMinOpBase<T, Context> {
+class MinOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
- MaxOp(const OperatorDef& operator_def, Workspace* ws)
- : MaxMinOpBase<T, Context>(operator_def, ws) {}
- virtual ~MaxOp() noexcept {}
- bool Compute() override;
+
+ USE_SIMPLE_CTOR_DTOR(MinOp)
+
+ bool RunOnDevice() override {
+ const auto& X0 = Input(0);
+ auto* Y = Output(0);
+ Y->ResizeLike(X0);
+ const T* X0_data = X0.template data<T>();
+ T* Y_data = Y->template mutable_data<T>();
+ const int N = X0.numel();
+ if (InputSize() == 1) {
+ if (Y != &X0) {
+ context_.template CopySameDevice<T>(N, X0_data, Y_data);
+ }
+ return true;
+ }
+ const auto& X1 = Input(1);
+ CAFFE_ENFORCE_EQ(
+ X0.sizes(),
+ Y->sizes(),
+ "Description: Input #1, input dimension:",
+ X1.sizes(),
+ " should match output dimension: ",
+ Y->sizes());
+ const T* X1_data = X1.template data<T>();
+ math::Min<T, Context>(N, X0_data, X1_data, Y_data, &context_);
+ for (int i = 2; i < InputSize(); ++i) {
+ const auto& Xi = Input(i);
+ CAFFE_ENFORCE_EQ(
+ Xi.sizes(),
+ Y->sizes(),
+ "Description: Input #",
+ i,
+ ", input dimension:",
+ Input(i).sizes(),
+ " should match output dimension: ",
+ Y->sizes());
+ const T* Xi_data = Xi.template data<T>();
+ math::Min<T, Context>(N, Y_data, Xi_data, Y_data, &context_);
+ }
+ return true;
+ }
};
template <typename T, class Context>
@@ -66,29 +115,21 @@ class SelectGradientOpBase : public Operator<Context> {
};
template <typename T, class Context>
-class MaxGradientOp : public SelectGradientOpBase<T, Context> {
+class MaxGradientOp final : public SelectGradientOpBase<T, Context> {
public:
MaxGradientOp(const OperatorDef& operator_def, Workspace* ws)
: SelectGradientOpBase<T, Context>(operator_def, ws) {}
- virtual ~MaxGradientOp() noexcept {}
-};
-template <typename T, class Context>
-class MinOp : public MaxMinOpBase<T, Context> {
- public:
- USE_OPERATOR_CONTEXT_FUNCTIONS;
- MinOp(const OperatorDef& operator_def, Workspace* ws)
- : MaxMinOpBase<T, Context>(operator_def, ws) {}
- virtual ~MinOp() noexcept {}
- bool Compute() override;
+ ~MaxGradientOp() = default;
};
template <typename T, class Context>
-class MinGradientOp : public SelectGradientOpBase<T, Context> {
+class MinGradientOp final : public SelectGradientOpBase<T, Context> {
public:
MinGradientOp(const OperatorDef& operator_def, Workspace* ws)
: SelectGradientOpBase<T, Context>(operator_def, ws) {}
- virtual ~MinGradientOp() noexcept {}
+
+ ~MinGradientOp() = default;
};
} // namespace caffe2
diff --git a/caffe2/operators/rsqrt_op.cu b/caffe2/operators/rsqrt_op.cu
index 378d131a6f..eae8dc0f44 100644
--- a/caffe2/operators/rsqrt_op.cu
+++ b/caffe2/operators/rsqrt_op.cu
@@ -4,7 +4,7 @@
#include <functional>
#include "caffe2/core/context_gpu.h"
-#include "caffe2/utils/math_utils.h"
+#include "caffe2/utils/math.h"
namespace caffe2 {
diff --git a/caffe2/operators/utility_ops.cu b/caffe2/operators/utility_ops.cu
index b7d17ba64f..f767a9ae11 100644
--- a/caffe2/operators/utility_ops.cu
+++ b/caffe2/operators/utility_ops.cu
@@ -1,8 +1,4 @@
-#include "caffe2/core/context_gpu.h"
-#include "caffe2/operators/flatten_op.h"
-#include "caffe2/operators/minmax_ops.h"
#include "caffe2/operators/utility_ops.h"
-#include "caffe2/utils/math.h"
#include <thrust/device_vector.h>
#include <thrust/sequence.h>
@@ -10,6 +6,10 @@
#include <thrust/system/cuda/execution_policy.h>
#include <thrust/unique.h>
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/operators/flatten_op.h"
+#include "caffe2/utils/math.h"
+
namespace caffe2 {
template <>
@@ -137,102 +137,6 @@ bool NanCheckOp<CUDAContext>::RunOnDevice() {
REGISTER_CUDA_OPERATOR(NanCheck, NanCheckOp<CUDAContext>);
-__global__ void
-ElwiseMaxKernel(const float* X, const float* Y, float* maxout, const int N) {
- CUDA_1D_KERNEL_LOOP(i, N) {
- maxout[i] = fmaxf(X[i], Y[i]);
- }
-}
-
-template <>
-bool MaxOp<float, CUDAContext>::Compute() {
- float* output_data = Output(0)->template mutable_data<float>();
- const int N = Input(0).numel();
-
- // Run pairwise-maxes
- for (int i = 1; i < InputSize(); ++i) {
- ElwiseMaxKernel<<<
- CAFFE_GET_BLOCKS(N),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- (i == 0 ? Input(0).data<float>() : Output(0)->data<float>()),
- Input(i).data<float>(),
- output_data,
- N);
- }
-
- return true;
-}
-
-REGISTER_CUDA_OPERATOR(Max, MaxOp<float, CUDAContext>);
-REGISTER_CUDA_OPERATOR(MaxGradient, MaxGradientOp<float, CUDAContext>);
-
-__global__ void
-ElwiseMinKernel(const float* X, const float* Y, float* minout, const int N) {
- CUDA_1D_KERNEL_LOOP(i, N) {
- minout[i] = fminf(X[i], Y[i]);
- }
-}
-
-template <>
-bool MinOp<float, CUDAContext>::Compute() {
- float* output_data = Output(0)->template mutable_data<float>();
- const int N = Input(0).numel();
-
- // Run pairwise-mines
- for (int i = 1; i < InputSize(); ++i) {
- ElwiseMinKernel<<<
- CAFFE_GET_BLOCKS(N),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- (i == 0 ? Input(0).data<float>() : Output(0)->data<float>()),
- Input(i).data<float>(),
- output_data,
- N);
- }
-
- return true;
-}
-
-REGISTER_CUDA_OPERATOR(Min, MinOp<float, CUDAContext>);
-REGISTER_CUDA_OPERATOR(MinGradient, MinGradientOp<float, CUDAContext>);
-
-template <typename T>
-__global__ void
-MaxMinGradKernel(int N, const T* mx, const T* x, const T* go, T* gi) {
- CUDA_1D_KERNEL_LOOP(i, N) {
- gi[i] = go[i] * (mx[i] == x[i]);
- }
-}
-
-template <>
-bool SelectGradientOpBase<float, CUDAContext>::RunOnDevice() {
- auto& output = Input(0);
- auto& grad_output = Input(1);
- const int kInputStartOffset = 2;
-
- const float* data = output.data<float>();
-
- for (int i = 0; i < OutputSize(); i++) {
- auto& input = Input(i + kInputStartOffset);
-
- auto* grad_input = Output(i, input.sizes(), at::dtype<float>());
- MaxMinGradKernel<<<
- CAFFE_GET_BLOCKS(input.numel()),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- input.numel(),
- output.data<float>(),
- input.data<float>(),
- grad_output.data<float>(),
- grad_input->template mutable_data<float>());
- }
- return true;
-}
-
/**
* @brief Update slices of Y in-place with a batch of weighted X's.
* Y[idx] = alpha[b] * X[b][i] + Y[idx]