#include "caffe2/operators/rsqrt_op.h" #include "caffe2/utils/eigen_utils.h" #include #include #include namespace caffe2 { template <> template bool RsqrtGradientFunctor::Forward( const std::vector& dY_dims, const std::vector& /* Y_dims */, const T* dY, const T* Y, T* dX, CPUContext* /* context */) const { const int size = std::accumulate( dY_dims.cbegin(), dY_dims.cend(), 1, std::multiplies()); EigenVectorMap(dX, size) = ConstEigenVectorMap(dY, size).array() * ConstEigenVectorMap(Y, size).array().cube() * static_cast(-0.5); return true; } REGISTER_CPU_OPERATOR( Rsqrt, UnaryElementwiseOp< TensorTypes, CPUContext, RsqrtFunctor>); REGISTER_CPU_OPERATOR( RsqrtGradient, BinaryElementwiseOp< TensorTypes, CPUContext, RsqrtGradientFunctor>); OPERATOR_SCHEMA(Rsqrt) .NumInputs(1) .NumOutputs(1) .AllowInplace({{0, 0}}) .IdenticalTypeAndShape() .SetDoc("Computes the element-wise rsqrt of the input.") .Input(0, "X", "ND input tensor") .Output(0, "Y", "ND output tensor"); OPERATOR_SCHEMA(RsqrtGradient) .NumInputs(2) .NumOutputs(1) .AllowInplace({{0, 0}}); namespace { class GetRsqrtGradient final : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; std::vector GetGradientDefs() override { return SingleGradientDef( "RsqrtGradient", "", std::vector{GO(0), O(0)}, std::vector{GI(0)}); } }; } // namespace REGISTER_GRADIENT(Rsqrt, GetRsqrtGradient); } // namespace caffe2