#pragma once #include "caffe2/core/operator.h" namespace caffe2 { template struct GFtrlParams { explicit GFtrlParams(OperatorBase* op) : alphaInv(1.0 / op->GetSingleArgument("alpha", 0.005f)), beta(op->GetSingleArgument("beta", 1.0f)), lambda1(op->GetSingleArgument("lambda1", 0.001f)), lambda2(op->GetSingleArgument("lambda2", 0.001f)) {} T alphaInv; T beta; T lambda1; T lambda2; }; template class GFtrlOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; GFtrlOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), params_(this) { CAFFE_ENFORCE( !HasArgument("alpha") || ALPHA >= InputSize(), "Cannot specify alpha by both input and argument"); } bool RunOnDevice() override; protected: GFtrlParams params_; INPUT_TAGS(VAR, N_Z, GRAD, ALPHA); OUTPUT_TAGS(OUTPUT_VAR, OUTPUT_N_Z); }; } // namespace caffe2