#ifndef CAFFE2_OPERATORS_REDUCTION_OPS_H_ #define CAFFE2_OPERATORS_REDUCTION_OPS_H_ #include "caffe2/core/common_omp.h" #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" namespace caffe2 { template class SumElementsOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; explicit SumElementsOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), average_(this->template GetSingleArgument("average", false)) {} explicit SumElementsOp(const OperatorDef& operator_def, Workspace* ws, bool average) : Operator(operator_def, ws), average_(average) {} explicit SumElementsOp(const c10::FunctionSchema& schema, std::vector inputs, std::vector outputs) : Operator(schema, std::move(inputs), std::move(outputs)), average_(this->template GetSingleArgument("average", false)) {} explicit SumElementsOp(const c10::FunctionSchema& schema, std::vector inputs, std::vector outputs, bool average) : Operator(schema, std::move(inputs), std::move(outputs)), average_(average) {} ~SumElementsOp() {} bool RunOnDevice() override { auto& X = Input(0); auto* sum = Output(0, vector(), at::dtype()); T* data = sum->template mutable_data(); math::Sum( X.numel(), X.template data(), data, &context_, &scratch_); if (average_ && X.numel() > 0) { math::Scale( 1, static_cast(1.) / X.numel(), sum->template data(), data, &context_); } return true; } private: bool average_; Tensor scratch_{Context::GetDeviceType()}; }; template class SumElementsIntOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit SumElementsIntOp(Args&&... args) : Operator(std::forward(args)...) {} ~SumElementsIntOp() {} bool RunOnDevice() override { auto& X = Input(0); auto* sum = Output(0, vector(), at::dtype()); T* data = sum->template mutable_data(); math::Sum( X.numel(), X.template data(), data, &context_, &scratch_); return true; } private: Tensor scratch_{Context::GetDeviceType()}; }; template class SumElementsGradientOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; explicit SumElementsGradientOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), average_(this->template GetSingleArgument("average", false)) {} explicit SumElementsGradientOp(const OperatorDef& operator_def, Workspace* ws, bool average) : Operator(operator_def, ws), average_(average) {} explicit SumElementsGradientOp(const c10::FunctionSchema& schema, std::vector inputs, std::vector outputs) : Operator(schema, std::move(inputs), std::move(outputs)), average_(this->template GetSingleArgument("average", false)) {} explicit SumElementsGradientOp(const c10::FunctionSchema& schema, std::vector inputs, std::vector outputs, bool average) : Operator(schema, std::move(inputs), std::move(outputs)), average_(average) {} ~SumElementsGradientOp() {} bool RunOnDevice() override; private: bool average_; }; template class SumSqrElementsOp : public Operator { public: USE_SIMPLE_CTOR_DTOR(SumSqrElementsOp) USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override { return DispatchHelper>::call(this, Input(0)); } template bool DoRunWithType() { bool average = this->template GetSingleArgument("average", false); auto& X = Input(0); auto* sum = Output(0, vector(), at::dtype()); math::SumSqr( X.numel(), X.template data(), sum->template mutable_data(), &context_, &scratch_); if (average && X.numel() > 0) { math::Scale( 1, float(1.) / X.numel(), sum->template data(), sum->template mutable_data(), &context_); } return true; } private: Tensor scratch_{Context::GetDeviceType()}; }; template class MaxReductionOp : public Operator { public: USE_SIMPLE_CTOR_DTOR(MaxReductionOp) USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override { auto& X = Input(0); CAFFE_ENFORCE_EQ(X.dim(), 3); const int batch_size = X.dim32(0); const int M = X.dim32(1); const int N = X.dim32(2); auto* Y = Output(0, {batch_size, ROWWISE ? M : N}, at::dtype()); if (ROWWISE) { math::RowwiseMax( batch_size * M, N, X.template data(), Y->template mutable_data(), &context_); } else { const int input_size = N * M; for (int i = 0; i < batch_size; ++i) { math::ColwiseMax( M, N, X.template data() + i * input_size, Y->template mutable_data() + i * N, &context_); } } return true; } }; template class MaxReductionGradientOp : public Operator { public: USE_SIMPLE_CTOR_DTOR(MaxReductionGradientOp) USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; }; } // namespace caffe2 #endif