#ifndef CAFFE2_OPERATORS_MATMUL_OP_H_ #define CAFFE2_OPERATORS_MATMUL_OP_H_ #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" namespace caffe2 { template class MatMulOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; MatMulOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), axis_a_(OperatorBase::GetSingleArgument("axis_a", 1)), axis_b_(OperatorBase::GetSingleArgument("axis_b", 1)), trans_a_(OperatorBase::GetSingleArgument("trans_a", 0)), trans_b_(OperatorBase::GetSingleArgument("trans_b", 0)) {} ~MatMulOp() {} bool RunOnDevice() override { const auto& A = Input(0); const auto& B = Input(1); auto* Y = Output(0); const auto canonical_axis_a = A.canonical_axis_index(axis_a_); const auto canonical_axis_b = B.canonical_axis_index(axis_b_); int A_dim0 = A.size_to_dim(canonical_axis_a); int A_dim1 = A.size_from_dim(canonical_axis_a); int B_dim0 = B.size_to_dim(canonical_axis_b); int B_dim1 = B.size_from_dim(canonical_axis_b); int a_dim0, a_dim1, b_dim0, b_dim1; if (trans_a_) { a_dim0 = A_dim1; a_dim1 = A_dim0; } else { a_dim0 = A_dim0; a_dim1 = A_dim1; } if (trans_b_) { b_dim0 = B_dim1; b_dim1 = B_dim0; } else { b_dim0 = B_dim0; b_dim1 = B_dim1; } auto dimErrorString = [&]() { return MakeString( "Dimension mismatch: ", trans_a_ ? "trans(A): " : "A: ", a_dim0, " ", a_dim1, trans_b_ ? ", trans(B): " : ", B: ", b_dim0, " ", b_dim1); }; // Error checking CAFFE_ENFORCE(a_dim1 == b_dim0, dimErrorString()); Y_shape_cache_[0] = a_dim0; Y_shape_cache_[1] = b_dim1; Y->Resize(Y_shape_cache_); CAFFE_ENFORCE(a_dim0 * b_dim1 == Y->size(), dimErrorString()); // Y = A * B math::Gemm( trans_a_ ? CblasTrans : CblasNoTrans, trans_b_ ? CblasTrans : CblasNoTrans, a_dim0, b_dim1, a_dim1, 1, A.template data(), B.template data(), 0, Y->template mutable_data(), &context_); if (InputSize() == 3) { // In gradient op, resize to input Y->ResizeLike(Input(2)); } return true; } protected: // A local vector to cache the output shape so we don't need to recreate // a vector object every time we run Run(). vector Y_shape_cache_{0, 0}; int axis_a_{1}; int axis_b_{1}; bool trans_a_; bool trans_b_; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_MATMUL_OP_H_