#ifndef CAFFE2_OPERATORS_PACK_SEGMENTS_H_ #define CAFFE2_OPERATORS_PACK_SEGMENTS_H_ #include #include #include #include #include #include "caffe2/core/operator.h" #include "caffe2/core/tensor.h" #include "caffe2/utils/math.h" namespace caffe2 { template class PackSegmentsOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; USE_DISPATCH_HELPER; template explicit PackSegmentsOp(Args&&... args) : Operator(std::forward(args)...), max_length_(this->template GetSingleArgument("max_length", -1)), pad_minf_(this->template GetSingleArgument("pad_minf", false)), return_presence_mask_(this->template GetSingleArgument( "return_presence_mask", false)) { if (pad_minf_) { padding_ = -1.0 * std::numeric_limits::infinity(); } else { padding_ = 0; } } bool RunOnDevice() { return DispatchHelper>::call(this, Input(LENGTHS)); } template bool DoRunWithType(); template bool DoRunWithType2(); INPUT_TAGS(LENGTHS, DATA); private: int64_t max_length_; bool pad_minf_; float padding_; bool return_presence_mask_; // Scratch space required by the CUDA version Tensor dev_buffer_{Context::GetDeviceType()}; Tensor dev_lengths_prefix_sum_{Context::GetDeviceType()}; Tensor dev_max_length_{Context::GetDeviceType()}; Tensor host_max_length_{CPU}; }; template class UnpackSegmentsOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; USE_DISPATCH_HELPER; template explicit UnpackSegmentsOp(Args&&... args) : Operator(std::forward(args)...), max_length_(this->template GetSingleArgument("max_length", -1)) {} bool RunOnDevice() override { return DispatchHelper>::call(this, Input(LENGTHS)); } template bool DoRunWithType(); template bool DoRunWithType2(); INPUT_TAGS(LENGTHS, DATA); private: int64_t max_length_; Tensor dev_buffer_{Context::GetDeviceType()}; Tensor dev_lengths_prefix_sum_{Context::GetDeviceType()}; Tensor dev_max_length_{Context::GetDeviceType()}; Tensor dev_num_cell_{Context::GetDeviceType()}; Tensor host_max_length_{CPU}; Tensor host_num_cell_{CPU}; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_PACK_SEGMENTS_H_