#ifndef CAFFE2_OPERATORS_DROPOUT_OP_H_ #define CAFFE2_OPERATORS_DROPOUT_OP_H_ #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" namespace caffe2 { template class DropoutOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; DropoutOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), ratio_(this->template GetSingleArgument("ratio", 0.5)), is_test_( this->template GetSingleArgument(OpSchema::Arg_IsTest, 0)) { CAFFE_ENFORCE_GE(ratio_, 0); CAFFE_ENFORCE_LT(ratio_, 1); } bool RunOnDevice() override; protected: float ratio_; bool is_test_; // Input: X; Output: Y, mask. }; template class DropoutGradientOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; DropoutGradientOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), ratio_(this->template GetSingleArgument("ratio", 0.5)), is_test_( this->template GetSingleArgument(OpSchema::Arg_IsTest, 0)) { CAFFE_ENFORCE_GE(ratio_, 0); CAFFE_ENFORCE_LT(ratio_, 1); } bool RunOnDevice() override; protected: float ratio_; bool is_test_; // Input: dY, mask; Output: dX }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_DROPOUT_OP_H_