#ifndef CAFFE2_OPERATORS_POOL_OP_H_ #define CAFFE2_OPERATORS_POOL_OP_H_ #include #include "caffe2/core/common_omp.h" #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/operators/conv_pool_op_base.h" namespace caffe2 { template class PoolOp final : public ConvPoolOpBase { public: USE_CONV_POOL_BASE_FUNCTIONS(Context); PoolOp(const OperatorDef& operator_def, Workspace* ws) : ConvPoolOpBase(operator_def, ws), functor_(*this) { const int kernel_size = kernel_.size(); for (int i = 0; i < kernel_size; ++i) { CAFFE_ENFORCE_EQ( dilation_[i], 1, "Pooling op does not support dilation right now."); } if (!global_pooling_) { for (int i = 0; i < kernel_size; ++i) { CAFFE_ENFORCE( pads_[i] < kernel_[i] && pads_[i + kernel_size] < kernel_[i], "Pad should be smaller than kernel."); } } } ~PoolOp() = default; bool RunOnDeviceWithOrderNCHW() override { const auto& X = Input(0); auto* Y = Output(0); const int N = X.dim32(0); const int C = X.dim32(1); ConvPoolOpBase::SetOutputSize(X, Y, C); const T* X_data = X.template data(); T* Y_data = Y->template mutable_data(); if (N == 0) { return true; } if (global_pooling_) { const int HxW = X.numel() / (N * C); return functor_.template GlobalPoolingForward( N, C, HxW, X_data, Y_data, &context_); } const std::vector X_HW_dims = GetDims(X); const std::vector Y_HW_dims = GetDims(*Y); return functor_.template Forward( N, C, X_HW_dims, Y_HW_dims, kernel_, dilation_, stride_, pads_, X.template data(), Y->template mutable_data(), &context_); } bool RunOnDeviceWithOrderNHWC() override { const auto& X = Input(0); auto* Y = Output(0); const int ndim = X.dim(); const int N = X.dim32(0); const int C = X.dim32(ndim - 1); ConvPoolOpBase::SetOutputSize(X, Y, C); const T* X_data = X.template data(); T* Y_data = Y->template mutable_data(); if (N == 0) { return true; } if (global_pooling_) { const int HxW = X.numel() / (N * C); return functor_.template GlobalPoolingForward( N, C, HxW, X_data, Y_data, &context_); } const std::vector X_HW_dims = GetDims(X); const std::vector Y_HW_dims = GetDims(*Y); return functor_.template Forward( N, C, X_HW_dims, Y_HW_dims, kernel_, dilation_, stride_, pads_, X.template data(), Y->template mutable_data(), &context_); } private: const Functor functor_; }; template class PoolGradientOp final : public ConvPoolOpBase { public: USE_CONV_POOL_BASE_FUNCTIONS(Context); PoolGradientOp(const OperatorDef& operator_def, Workspace* ws) : ConvPoolOpBase(operator_def, ws), functor_(*this) {} ~PoolGradientOp() = default; bool RunOnDeviceWithOrderNCHW() override { const auto& X = Input(0); const auto& Y = Input(1); const auto& dY = Input(2); auto* dX = Output(0, X.sizes(), at::dtype()); const int N = X.dim32(0); const int C = X.dim32(1); const std::vector X_HW_dims = GetDims(X); const std::vector Y_HW_dims = GetDims(Y); ConvPoolOpBase::ComputePads(X_HW_dims); const T* dY_data = dY.template data(); const T* X_data = X.template data(); const T* Y_data = Y.template data(); T* dX_data = dX->template mutable_data(); if (N == 0) { return true; } if (global_pooling_) { const int HxW = X.numel() / (N * C); return functor_.template GlobalPoolingBackward( N, C, HxW, dY_data, X_data, Y_data, dX_data, &context_); } return functor_.template Backward( N, C, X_HW_dims, Y_HW_dims, kernel_, dilation_, stride_, pads_, dY_data, X_data, Y_data, dX_data, &context_); } bool RunOnDeviceWithOrderNHWC() override { const auto& X = Input(0); const auto& Y = Input(1); const auto& dY = Input(2); auto* dX = Output(0, X.sizes(), at::dtype()); const int ndim = X.dim(); const int N = X.dim32(0); const int C = X.dim32(ndim - 1); const std::vector X_HW_dims = GetDims(X); const std::vector Y_HW_dims = GetDims(Y); ConvPoolOpBase::ComputePads(X_HW_dims); const T* dY_data = dY.template data(); const T* X_data = X.template data(); const T* Y_data = Y.template data(); T* dX_data = dX->template mutable_data(); if (N == 0) { return true; } if (global_pooling_) { const int HxW = X.numel() / (N * C); return functor_.template GlobalPoolingBackward( N, C, HxW, dY_data, X_data, Y_data, dX_data, &context_); } return functor_.template Backward( N, C, X_HW_dims, Y_HW_dims, kernel_, dilation_, stride_, pads_, dY_data, X_data, Y_data, dX_data, &context_); } private: const Functor functor_; }; template struct AveragePoolFunctor { explicit AveragePoolFunctor(const OperatorBase& op) : count_include_pad( op.template GetSingleArgument("count_include_pad", false)) {} template bool GlobalPoolingForward( int N, int C, int HxW, const T* X, T* Y, Context* context) const; template bool Forward( int N, int C, const std::vector& X_dims, const std::vector& Y_dims, const std::vector& kernel, const std::vector& dilation, const std::vector& stride, const std::vector& pads, const T* X, T* Y, Context* context) const; template bool GlobalPoolingBackward( int N, int C, int HxW, const T* dY, const T* X, const T* Y, T* dX, Context* context) const; template bool Backward( int N, int C, const std::vector& X_dims, const std::vector& Y_dims, const std::vector& kernel, const std::vector& dilation, const std::vector& stride, const std::vector& pads, const T* dY, const T* X, const T* Y, T* dX, Context* context) const; const bool count_include_pad; Tensor ones{Context::GetDeviceType()}; }; template struct MaxPoolFunctor { explicit MaxPoolFunctor(const OperatorBase& /* op */) {} template bool GlobalPoolingForward( int N, int C, int HxW, const T* X, T* Y, Context* context) const; template bool Forward( int N, int C, const std::vector& X_dims, const std::vector& Y_dims, const std::vector& kernel, const std::vector& dilation, const std::vector& stride, const std::vector& pads, const T* X, T* Y, Context* context) const; template bool GlobalPoolingBackward( int N, int C, int HxW, const T* dY, const T* X, const T* Y, T* dX, Context* context) const; template bool Backward( int N, int C, const std::vector& X_dims, const std::vector& Y_dims, const std::vector& kernel, const std::vector& dilation, const std::vector& stride, const std::vector& pads, const T* dY, const T* X, const T* Y, T* dX, Context* context) const; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_POOL_OP_H_