#ifndef CAFFE2_OPERATORS_MINMAX_OPS_H_ #define CAFFE2_OPERATORS_MINMAX_OPS_H_ #include "caffe2/core/common_omp.h" #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/core/types.h" #include "caffe2/utils/math.h" namespace caffe2 { template class MaxMinOpBase : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; USE_SIMPLE_CTOR_DTOR(MaxMinOpBase) bool RunOnDevice() override { auto& input0 = Input(0); auto* output = Output(0); output->ResizeLike(input0); output->CopyFrom(input0, /* async */ true); if (InputSize() == 1) { return true; } // Dimension checking for (int i = 1; i < InputSize(); ++i) { CAFFE_ENFORCE_EQ( output->sizes(), Input(i).sizes(), "Description: Input #", i, ", input dimension:", Input(i).sizes(), " should match output dimension: ", output->sizes()); } return this->Compute(); } virtual bool Compute() = 0; }; template class MaxOp : public MaxMinOpBase { public: USE_OPERATOR_CONTEXT_FUNCTIONS; MaxOp(const OperatorDef& operator_def, Workspace* ws) : MaxMinOpBase(operator_def, ws) {} virtual ~MaxOp() noexcept {} bool Compute() override; }; template class SelectGradientOpBase : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; USE_SIMPLE_CTOR_DTOR(SelectGradientOpBase) bool RunOnDevice() override; }; template class MaxGradientOp : public SelectGradientOpBase { public: MaxGradientOp(const OperatorDef& operator_def, Workspace* ws) : SelectGradientOpBase(operator_def, ws) {} virtual ~MaxGradientOp() noexcept {} }; template class MinOp : public MaxMinOpBase { public: USE_OPERATOR_CONTEXT_FUNCTIONS; MinOp(const OperatorDef& operator_def, Workspace* ws) : MaxMinOpBase(operator_def, ws) {} virtual ~MinOp() noexcept {} bool Compute() override; }; template class MinGradientOp : public SelectGradientOpBase { public: MinGradientOp(const OperatorDef& operator_def, Workspace* ws) : SelectGradientOpBase(operator_def, ws) {} virtual ~MinGradientOp() noexcept {} }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_MINMAX_OPS_H_