#ifndef CAFFE2_OPERATORS_FLATTEN_OP_H_ #define CAFFE2_OPERATORS_FLATTEN_OP_H_ #include "caffe2/core/operator.h" namespace caffe2 { template class FlattenOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit FlattenOp(Args&&... args) : Operator(std::forward(args)...), axis_(this->template GetSingleArgument("axis", 1)) {} bool RunOnDevice() override { auto& input = Input(0); auto* output = Output(0); CAFFE_ENFORCE_GE( input.dim(), axis_, "The rank of the tensor must be >= axis."); output->Resize(input.size_to_dim(axis_), input.size_from_dim(axis_)); context_.CopyItemsSameDevice( input.dtype(), input.numel(), input.raw_data(), output->raw_mutable_data(input.dtype())); return true; } private: int axis_; }; inline std::vector TensorInferenceForFlatten( const OperatorDef& def, const std::vector& in) { ArgumentHelper helper(def); const int axis = helper.GetSingleArgument("axis", 1); std::vector out(1); int64_t outer = 1; int64_t inner = 1; std::size_t index = 0; for (auto d : in[0].dims()) { if (index < axis) { outer *= d; } else { inner *= d; } ++index; } out[0].set_data_type(in[0].data_type()); out[0].add_dims(outer); out[0].add_dims(inner); return out; } } // namespace caffe2 #endif // CAFFE2_OPERATORS_FLATTEN_OP_H_