#ifndef CAFFE2_OPERATORS_ROW_MUL_H_ #define CAFFE2_OPERATORS_ROW_MUL_H_ #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" namespace caffe2 { // A hacky version of Mul with broadcast // RowMul([mat, w], [output]) template class RowMulOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; USE_SIMPLE_CTOR_DTOR(RowMulOp); bool RunOnDevice() override { auto& mat = Input(0); auto& w = Input(1); auto* output = Output(0, mat.sizes(), at::dtype()); T* output_data = output->template mutable_data(); const T* mat_data = mat.template data(); const T* w_data = w.template data(); // Dimension checking CAFFE_ENFORCE_EQ( w.numel(), mat.dim32(0), "Length of w should be equal to the first dim of mat"); auto block_size = mat.size_from_dim(1); for (int i = 0; i < w.numel(); i++) { size_t offset = i * block_size; for (int j = 0; j < block_size; j++) { output_data[offset + j] = mat_data[offset + j] * w_data[i]; } } return true; } }; // A hacky version template class ReduceTailSumOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; USE_SIMPLE_CTOR_DTOR(ReduceTailSumOp); bool RunOnDevice() override { auto& mat = Input(0); int N = mat.dim32(0); int block_size = mat.size_from_dim(1); auto* output = Output(0, {N}, at::dtype()); T* output_data = output->template mutable_data(); const T* mat_data = mat.template data(); for (int i = 0; i < N; i++) { output_data[i] = 0; size_t offset = i * block_size; for (int j = 0; j < block_size; j++) { output_data[i] += mat_data[offset + j]; } } return true; } }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_ROW_MUL_H_