#pragma once #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" namespace caffe2 { template class SwishGradientOp final : public Operator { public: USE_SIMPLE_CTOR_DTOR(SwishGradientOp) USE_OPERATOR_CONTEXT_FUNCTIONS; template bool DoRunWithType(); bool RunOnDevice() override { return DispatchHelper>::call(this, Input(X)); } protected: INPUT_TAGS(X, Y, DY); OUTPUT_TAGS(DX); }; class GetSwishGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; vector GetGradientDefs() override { return SingleGradientDef( "SwishGradient", "", vector{I(0), O(0), GO(0)}, vector{GI(0)}); } }; } // namespace caffe2