#ifndef CAFFE2_OPERATORS_EXPAND_OP_H_ #define CAFFE2_OPERATORS_EXPAND_OP_H_ #include #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/core/types.h" #include "caffe2/utils/math.h" namespace caffe2 { template class ExpandOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit ExpandOp(Args&&... args) : Operator(std::forward(args)...) {} bool RunOnDevice() override { return DispatchHelper::call(this, Input(0)); } template bool DoRunWithType() { const auto& X = Input(0); const auto& Y_shape_tensor = Input(1); std::vector shape_dims(Y_shape_tensor.numel()); context_.template CopyToCPU( Y_shape_tensor.numel(), Y_shape_tensor.template data(), shape_dims.data()); const int ndim = shape_dims.size(); const std::vector X_dims(X.sizes().cbegin(), X.sizes().cend()); std::vector Y_dims; Y_dims.reserve(std::max(ndim, X.dim())); // ndim, X.ndim() might equal to 0 for (int i = ndim - 1, j = X.dim() - 1; i >= 0 || j >= 0; --i, --j) { const int shape_x = (j >= 0 ? X_dims[j] : 1); // In PyTorch expand treats -1 as a special value to indicate // preserving the size of that dimension. const int shape_y = ((i >= 0 && shape_dims[i] > 0) ? shape_dims[i] : 1); CAFFE_ENFORCE( shape_x == 1 || shape_y == 1 || shape_x == shape_y, "Dimensions format invalid."); Y_dims.push_back(std::max(shape_x, shape_y)); } std::reverse(Y_dims.begin(), Y_dims.end()); // TODO: remove when the function in math are changed to use vector std::vector Y_dims_int64; std::copy(Y_dims.begin(), Y_dims.end(), std::back_inserter(Y_dims_int64)); auto* Y = Output(0, Y_dims_int64, at::dtype()); math::Broadcast( X_dims.size(), X_dims.data(), Y_dims.size(), Y_dims.data(), T(1), X.template data(), Y->template mutable_data(), &context_); return true; } }; template class ExpandGradientOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit ExpandGradientOp(Args&&... args) : Operator(std::forward(args)...) {} bool RunOnDevice() override { return DispatchHelper::call(this, Input(0)); } template bool DoRunWithType() { const auto& dY = Input(0); const auto& X = Input(1); const int ndim = dY.dim(); const std::vector dX_dims(X.sizes().cbegin(), X.sizes().cend()); const std::vector dY_dims(dY.sizes().cbegin(), dY.sizes().cend()); auto* dX = Output(0, X.sizes(), at::dtype()); std::vector axes; const int offset = ndim - X.dim(); for (int i = 0; i < ndim; i++) { if (i < offset || dX_dims[i - offset] == 1) { axes.push_back(i); } } std::vector X_dims = dY_dims; for (const int axis : axes) { X_dims[axis] = 1; } math::ReduceSum( dY_dims.size(), dY_dims.data(), X_dims.data(), T(1), dY.template data(), dX->template mutable_data(), &context_); return true; } }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_REDUCE_OPS_H_