diff options
author | Xiaomeng Yang <yangxm@fb.com> | 2019-04-23 15:24:03 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-23 15:34:59 -0700 |
commit | fb9fc42a0c00a04aaf5574ae4b932dc90221d147 (patch) | |
tree | 084016897ca45de9ef73d378cd16b31113329a90 | |
parent | 176bdc0722951c42ca83aa4d9e1e49762e7df039 (diff) | |
download | pytorch-fb9fc42a0c00a04aaf5574ae4b932dc90221d147.tar.gz pytorch-fb9fc42a0c00a04aaf5574ae4b932dc90221d147.tar.bz2 pytorch-fb9fc42a0c00a04aaf5574ae4b932dc90221d147.zip |
optimize BatchMatmulOp (#18612)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18612
optimize BatchMatmulOp
Reviewed By: houseroad
Differential Revision: D14681665
fbshipit-source-id: cf5ea4909ace58fd44fe6fa634531102ac84e851
-rw-r--r-- | caffe2/operators/batch_matmul_op.cc | 4 | ||||
-rw-r--r-- | caffe2/operators/batch_matmul_op.cu | 1 | ||||
-rw-r--r-- | caffe2/operators/batch_matmul_op.h | 479 | ||||
-rw-r--r-- | caffe2/utils/math/utils.cc | 74 | ||||
-rw-r--r-- | caffe2/utils/math/utils.h | 17 | ||||
-rw-r--r-- | caffe2/utils/math_gpu.cu | 364 |
6 files changed, 533 insertions, 406 deletions
diff --git a/caffe2/operators/batch_matmul_op.cc b/caffe2/operators/batch_matmul_op.cc index 80fc6029d0..62ad5ad571 100644 --- a/caffe2/operators/batch_matmul_op.cc +++ b/caffe2/operators/batch_matmul_op.cc @@ -1,4 +1,5 @@ #include "caffe2/operators/batch_matmul_op.h" + #include "caffe2/core/operator_schema.h" namespace caffe2 { @@ -28,7 +29,8 @@ vector<TensorShape> TensorInferenceForBatchMatMul( b_dim1 = in[1].dims(ndim - 1); } - auto output_dims = vector<int64_t>{in[0].dims().begin(), in[0].dims().end()}; + auto output_dims = + vector<int64_t>{in[0].dims().begin(), in[0].dims().end()}; output_dims[ndim - 2] = a_dim0; output_dims[ndim - 1] = b_dim1; diff --git a/caffe2/operators/batch_matmul_op.cu b/caffe2/operators/batch_matmul_op.cu index 801e4a6d04..b0ce9a31a1 100644 --- a/caffe2/operators/batch_matmul_op.cu +++ b/caffe2/operators/batch_matmul_op.cu @@ -22,6 +22,7 @@ REGISTER_CUDA_OPERATOR_WITH_ENGINE( BatchMatMul, TENSORCORE, BatchMatMulOp<CUDAContext, TensorCoreEngine>); + #endif } // namespace caffe2 diff --git a/caffe2/operators/batch_matmul_op.h b/caffe2/operators/batch_matmul_op.h index 662229a4a5..fa5800aec1 100644 --- a/caffe2/operators/batch_matmul_op.h +++ b/caffe2/operators/batch_matmul_op.h @@ -1,7 +1,11 @@ -#ifndef CAFFE2_OPERATORS_MATMUL_OP_H_ -#define CAFFE2_OPERATORS_MATMUL_OP_H_ +#ifndef CAFFE2_OPERATORS_BATCH_MATMUL_OP_H_ +#define CAFFE2_OPERATORS_BATCH_MATMUL_OP_H_ -#include <sstream> +#include <algorithm> +#include <functional> +#include <numeric> +#include <string> +#include <vector> #include "caffe2/core/context.h" #include "caffe2/core/operator.h" @@ -13,14 +17,13 @@ template <class Context, class Engine = DefaultEngine> class BatchMatMulOp final : public Operator<Context> { public: USE_OPERATOR_CONTEXT_FUNCTIONS; + template <class... Args> explicit BatchMatMulOp(Args&&... args) : Operator<Context>(std::forward<Args>(args)...), - trans_a_(this->template GetSingleArgument<int>("trans_a", 0)), - trans_b_(this->template GetSingleArgument<int>("trans_b", 0)), - broadcast_(this->template GetSingleArgument<int>("broadcast", 0)) {} - - ~BatchMatMulOp() {} + OP_SINGLE_ARG(bool, "trans_a", trans_a_, false), + OP_SINGLE_ARG(bool, "trans_b", trans_b_, false), + OP_SINGLE_ARG(bool, "broadcast", broadcast_, false) {} bool RunOnDevice() override { return DispatchHelper<TensorTypes<float>>::call(this, Input(0)); @@ -30,253 +33,265 @@ class BatchMatMulOp final : public Operator<Context> { bool DoRunWithType() { const auto& A = Input(0); const auto& B = Input(1); + const int A_ndim = A.dim(); + const int B_ndim = B.dim(); + const std::vector<std::int64_t> A_dims = A.sizes().vec(); + const std::vector<std::int64_t> B_dims = B.sizes().vec(); + const T* A_data = A.template data<T>(); + const T* B_data = B.template data<T>(); - auto ndims_A = A.dim(); - auto dims_A = A.sizes().vec(); - auto ndims_B = B.dim(); - auto dims_B = B.sizes().vec(); - - auto noBroadcastErrorMsg = [](size_t dim1, size_t dim2) { - std::stringstream ss; - ss << "Inputs with dimensions A = "; - ss << dim1; - ss << " and B = "; - ss << dim2; - ss << " is not supported with broadcast=0. Did you forget to set the " - "broadcast flag?"; - return ss.str(); - }; - - // These should all be false if we're not broadcasting. - bool dimMismatch = ndims_A != ndims_B; - bool dimsLessThan1D = ndims_A < 2; - CAFFE_ENFORCE( - broadcast_ || (!dimMismatch && !dimsLessThan1D), - noBroadcastErrorMsg(ndims_A, ndims_B)); - - auto* data_A = A.template data<T>(); - auto* data_B = B.template data<T>(); - - auto dimMismatchErrorString = [](size_t dimnum1, - size_t dim1, - size_t dimnum2, - size_t dim2, - bool trans_a, - bool trans_b) { - std::stringstream ss; - ss << "Expected dimension "; - ss << dimnum1; - ss << " of tensor A with value "; - ss << dim1; - ss << " to match dimension "; - ss << dimnum2; - ss << " of tensor B with value "; - ss << dim2; - ss << ". trans_a = "; - ss << trans_a; - ss << " trans_b = "; - ss << trans_b; - return ss.str(); - }; - - if (ndims_A == 1 && ndims_B == 1) { - // vector-vector - CAFFE_ENFORCE_EQ( - dims_A[0], - dims_B[0], - "Vector-vector product requires each of the vectors to " - "be the same size."); + if (A_ndim == 1 && B_ndim == 1) { + CAFFE_ENFORCE_EQ(A.numel(), B.numel()); auto* Y = Output(0, {1}, at::dtype<T>()); - math::Dot<T, Context>( - dims_A[0], data_A, data_B, Y->template mutable_data<T>(), &context_); - } else { - bool A_broadcasted = false, B_broadcasted = false; - if (ndims_A == 1) { - dims_A.insert(dims_A.begin(), 1); - ndims_A = 2; - A_broadcasted = true; - } - if (ndims_B == 1) { - dims_B.push_back(1); - ndims_B = 2; - B_broadcasted = true; - } - // matrix-matrix with batches - // [B1..., M, K] * [B2..., K, N] -> [B..., M, N] - // In the event that A or B are one-dimensional, the trailing or leading - // 1 is not added to the output tensor's size. - - // First step: partition the tensors into inner and outer blocks. - // Ignoring the last two dimensions of A and B, ensure that one of the - // tensors' dimensions is a suffix of the other. For example, - // [4, x, x] is a suffix of [2, 3, 4, x, x]. In this example, the - // dimensions of size 2 and 3 will be broadcasted, so we partition into - // 2*3=6 individual instances of batched GEMM with A and B \in [4, x, x]. - size_t num_inner_dims = std::min(ndims_A, ndims_B); - for (size_t i = 2; i < num_inner_dims; ++i) { - auto first_r_itr = dims_A.rbegin(); - auto second_r_itr = dims_B.rbegin(); - CAFFE_ENFORCE_EQ( - *(first_r_itr + i), - *(second_r_itr + i), - dimMismatchErrorString( - ndims_A - i - 1, - *(first_r_itr + i), - ndims_B - i - 1, - *(second_r_itr + i), - trans_a_, - trans_b_)); - } - size_t num_outer_dims = std::max(ndims_A, ndims_B) - num_inner_dims; - - // Standard M, N, and K parameters respecting GEMM API and transpose - // flags - size_t M, N, K, K_dim; - if (trans_a_) { - M = dims_A[ndims_A - 1]; - K = dims_A[ndims_A - 2]; - K_dim = ndims_A - 2; - } else { - M = dims_A[ndims_A - 2]; - K = dims_A[ndims_A - 1]; - K_dim = ndims_A - 1; - } + T* Y_data = Y->template mutable_data<T>(); + math::Dot<T, Context>(A.numel(), A_data, B_data, Y_data, &context_); + return true; + } + if (A_ndim == 1) { + const int N = A.numel(); if (trans_b_) { - N = dims_B[ndims_B - 2]; - CAFFE_ENFORCE_EQ( - K, - dims_B[ndims_B - 1], - dimMismatchErrorString( - K_dim, - K, - ndims_B - 1, - dims_B[ndims_B - 1], - trans_a_, - trans_b_)); + CAFFE_ENFORCE_EQ(B_dims[B_ndim - 1], N); } else { - N = dims_B[ndims_B - 1]; - CAFFE_ENFORCE_EQ( - K, - dims_B[ndims_B - 2], - dimMismatchErrorString( - K_dim, - K, - ndims_B - 2, - dims_B[ndims_B - 2], - trans_a_, - trans_b_)); + CAFFE_ENFORCE_EQ(B_dims[B_ndim - 2], N); } - - // Calculate output tensor shapes [B..., (M), (N)] - // Batch dimensions will be broadcasted out to those of the longer tensor - // A or B. Either M or N are optional if A or B, respectively are 1-D. - std::vector<int64_t> new_dims; - if (ndims_A >= ndims_B) { - new_dims.assign(dims_A.begin(), dims_A.end() - 2); + std::vector<std::int64_t> Y_dims(B_ndim - 1); + if (trans_b_) { + std::copy_n(B_dims.cbegin(), B_ndim - 1, Y_dims.begin()); } else { - new_dims.assign(dims_B.begin(), dims_B.end() - 2); + std::copy_n(B_dims.cbegin(), B_ndim - 2, Y_dims.begin()); + Y_dims.back() = B_dims.back(); } - if (!A_broadcasted) { - new_dims.push_back(M); + auto* Y = Output(0, Y_dims, at::dtype<T>()); + T* Y_data = Y->template mutable_data<T>(); + if (trans_b_) { + const int M = B.numel() / N; + math::Gemv<T, Context, Engine>( + CblasNoTrans, M, N, 1.0f, B_data, A_data, 0.0f, Y_data, &context_); } else { - new_dims.push_back(1); + const int M = B_dims[B_ndim - 1]; + const int batch_size = B.numel() / (M * N); + if (batch_size == 1) { + math::Gemv<T, Context, Engine>( + CblasTrans, N, M, 1.0f, B_data, A_data, 0.0f, Y_data, &context_); + } else { + math::GemmStridedBatched<T, Context, Engine>( + CblasTrans, + CblasNoTrans, + batch_size, + M, + 1, + N, + 1.0f, + B_data, + M * N, + A_data, + 0, + 0.0f, + Y_data, + M, + &context_); + } } - if (!B_broadcasted) { - new_dims.push_back(N); + return true; + } + if (B_ndim == 1) { + const int N = B.numel(); + if (trans_a_) { + CAFFE_ENFORCE_EQ(A_dims[A_ndim - 2], N); } else { - new_dims.push_back(1); + CAFFE_ENFORCE_EQ(A_dims[A_ndim - 1], N); } - - // Calculate strides. Continuing our example above, - // [4, M, K] * [2, 3, 4, K, N] = [2, 3, 4, M, N] - // We calculate this as follows: - // 1) Treat the outer batch dimensions as flattened, i.e. view the B - // tensor here as [6, 4, K, N] and Y as [6, 4, M, N]. The same rea- - // soning is analogous for the case where # dims A >= # dims B. - // 2) Perform this operation: - // for i in range(6): - // Y[i, :, :, :] = BatchMatMul(A, B[i, :, :, :]) - size_t A_stride = 1; // How far to increment A pointer each itr - size_t B_stride = 1; // How far to increment B pointer each itr - size_t Y_stride = 1; // How far to increment Y pointer each itr - // How many "inner batches" we have. That is, the product of sizes for - // the slices excluding M, K, and N, for their respective matrices. - size_t num_sub_batches = 1; - if (ndims_A >= ndims_B) { - auto first_r_itr = dims_A.rbegin(); - auto output_r_itr = new_dims.rbegin(); - for (size_t i = 0; i < num_inner_dims; ++i) { - A_stride *= *(first_r_itr + i); - Y_stride *= *(output_r_itr + i); - if (i >= 2) { - num_sub_batches *= *(first_r_itr + i); - } + const std::vector<std::int64_t> Y_dims( + A_dims.cbegin(), A_dims.cbegin() + A_ndim - 1); + auto* Y = Output(0, Y_dims, at::dtype<T>()); + T* Y_data = Y->template mutable_data<T>(); + if (trans_a_) { + const int M = A_dims[A_ndim - 1]; + const int batch_size = A.numel() / (M * N); + if (batch_size == 1) { + math::Gemv<T, Context, Engine>( + CblasTrans, N, M, 1.0f, A_data, B_data, 0.0f, Y_data, &context_); + } else { + math::GemmStridedBatched<T, Context, Engine>( + CblasTrans, + CblasNoTrans, + batch_size, + M, + 1, + N, + 1.0f, + A_data, + M * N, + B_data, + 0, + 0.0f, + Y_data, + M, + &context_); } - B_stride = 0; } else { - A_stride = 0; - auto second_r_itr = dims_B.rbegin(); - auto output_r_itr = new_dims.rbegin(); - for (size_t i = 0; i < num_inner_dims; ++i) { - B_stride *= *(second_r_itr + i); - Y_stride *= *(output_r_itr + i); - if (i >= 2) { - num_sub_batches *= *(second_r_itr + i); - } - } - } - - size_t num_outer_batches = 1; - for (size_t i = 0; i < num_outer_dims; ++i) { - num_outer_batches *= new_dims[i]; - } - - // Mutually exclusive since otherwise we would've taken the vector-vector - // path above - if (A_broadcasted) { - new_dims.erase(new_dims.end() - 2); - } else if (B_broadcasted) { - new_dims.erase(new_dims.end() - 1); + const int M = A.numel() / N; + math::Gemv<T, Context, Engine>( + CblasNoTrans, M, N, 1.0f, A_data, B_data, 0.0f, Y_data, &context_); } + return true; + } - // Allocate output tensor - auto* Y = Output(0, new_dims, at::dtype<T>()); - auto* Y_data = Y->template mutable_data<T>(); + const int M = trans_a_ ? A_dims[A_ndim - 1] : A_dims[A_ndim - 2]; + const int K = trans_a_ ? A_dims[A_ndim - 2] : A_dims[A_ndim - 1]; + if (trans_b_) { + CAFFE_ENFORCE_EQ(B_dims[B_ndim - 1], K); + } else { + CAFFE_ENFORCE_EQ(B_dims[B_ndim - 2], K); + } + const int N = trans_b_ ? B_dims[B_ndim - 2] : B_dims[B_ndim - 1]; + const int ndim = std::max(A_ndim, B_ndim); + std::vector<std::int64_t> A_broadcast_dims(ndim); + std::vector<std::int64_t> B_broadcast_dims(ndim); + std::vector<std::int64_t> Y_broadcast_dims(ndim); + math::utils::ComputeBroadcastBinaryOpDims( + A_ndim - 2, + A_dims.data(), + B_ndim - 2, + B_dims.data(), + A_broadcast_dims.data(), + B_broadcast_dims.data(), + Y_broadcast_dims.data()); + Y_broadcast_dims[ndim - 2] = M; + Y_broadcast_dims[ndim - 1] = N; + auto* Y = Output(0, Y_broadcast_dims, at::dtype<T>()); + T* Y_data = Y->template mutable_data<T>(); - // Zero batch dimension indicates no elements - if (num_sub_batches == 0 || num_outer_batches == 0) { - return true; - } + const int batch_dim = ndim - 2; + const bool is_broadcast_dims = !std::equal( + A_broadcast_dims.cbegin(), + A_broadcast_dims.cbegin() + batch_dim, + B_broadcast_dims.cbegin()); + if (is_broadcast_dims) { + CAFFE_ENFORCE(broadcast_); + } - // TODO(T23893772): doing this in a loop is likely going to be slow on GPU - for (size_t p = 0; p < num_outer_batches; ++p) { - math::GemmStridedBatched<T, Context, Engine>( - trans_a_ ? CblasTrans : CblasNoTrans, - trans_b_ ? CblasTrans : CblasNoTrans, - num_sub_batches, - M, - N, - K, - 1.0f, - data_A + p * A_stride, - M * K, - data_B + p * B_stride, - K * N, - 0.0f, - Y_data + p * Y_stride, - M * N, - &context_); + const std::int64_t A_batch_size = std::accumulate( + A_broadcast_dims.cbegin(), + A_broadcast_dims.cbegin() + batch_dim, + 1LL, + std::multiplies<std::int64_t>()); + const std::int64_t B_batch_size = std::accumulate( + B_broadcast_dims.cbegin(), + B_broadcast_dims.cbegin() + batch_dim, + 1LL, + std::multiplies<std::int64_t>()); + const std::int64_t Y_batch_size = std::accumulate( + Y_broadcast_dims.cbegin(), + Y_broadcast_dims.cbegin() + batch_dim, + 1LL, + std::multiplies<std::int64_t>()); + if (Y_batch_size == 0) { + return true; + } + if (A_batch_size == 1 && B_batch_size == 1) { + math::Gemm<T, Context, Engine>( + trans_a_ ? CblasTrans : CblasNoTrans, + trans_b_ ? CblasTrans : CblasNoTrans, + M, + N, + K, + 1.0f, + A_data, + B_data, + 0.0f, + Y_data, + &context_); + } else if (A_batch_size == 1) { + math::GemmStridedBatched<T, Context, Engine>( + trans_a_ ? CblasTrans : CblasNoTrans, + trans_b_ ? CblasTrans : CblasNoTrans, + Y_batch_size, + M, + N, + K, + 1.0f, + A_data, + 0, + B_data, + K * N, + 0.0f, + Y_data, + M * N, + &context_); + } else if (B_batch_size == 1) { + math::GemmStridedBatched<T, Context, Engine>( + trans_a_ ? CblasTrans : CblasNoTrans, + trans_b_ ? CblasTrans : CblasNoTrans, + Y_batch_size, + M, + N, + K, + 1.0f, + A_data, + M * K, + B_data, + 0, + 0.0f, + Y_data, + M * N, + &context_); + } else if (!is_broadcast_dims) { + math::GemmStridedBatched<T, Context, Engine>( + trans_a_ ? CblasTrans : CblasNoTrans, + trans_b_ ? CblasTrans : CblasNoTrans, + Y_batch_size, + M, + N, + K, + 1.0f, + A_data, + M * K, + B_data, + K * N, + 0.0f, + Y_data, + M * N, + &context_); + } else { + std::vector<const T*> A_ptr(Y_batch_size); + std::vector<const T*> B_ptr(Y_batch_size); + std::vector<T*> Y_ptr(Y_batch_size); + std::vector<std::int64_t> index(batch_dim); + for (std::int64_t i = 0; i < Y_batch_size; ++i) { + const std::int64_t A_index = math::utils::GetIndexFromDims( + batch_dim, A_broadcast_dims.data(), index.data()); + const std::int64_t B_index = math::utils::GetIndexFromDims( + batch_dim, B_broadcast_dims.data(), index.data()); + A_ptr[i] = A_data + A_index * M * K; + B_ptr[i] = B_data + B_index * K * N; + Y_ptr[i] = Y_data + i * M * N; + math::utils::IncreaseIndexInDims( + batch_dim, Y_broadcast_dims.data(), index.data()); } + math::GemmBatched<T, Context, Engine>( + trans_a_ ? CblasTrans : CblasNoTrans, + trans_b_ ? CblasTrans : CblasNoTrans, + Y_batch_size, + M, + N, + K, + 1.0f, + A_ptr.data(), + B_ptr.data(), + 0.0f, + Y_ptr.data(), + &context_); } return true; } - protected: - bool trans_a_; - bool trans_b_; - bool broadcast_; + private: + const bool trans_a_; + const bool trans_b_; + const bool broadcast_; }; } // namespace caffe2 -#endif /* CAFFE2_OPERATORS_MATMUL_OP_H_ */ +#endif // CAFFE2_OPERATORS_BATCH_MATMUL_OP_H_ diff --git a/caffe2/utils/math/utils.cc b/caffe2/utils/math/utils.cc index fdbb479e4b..86573b30e5 100644 --- a/caffe2/utils/math/utils.cc +++ b/caffe2/utils/math/utils.cc @@ -28,15 +28,21 @@ CAFFE2_SPECIALIZED_INCREASE_INDEX_IN_DIMS(std::int32_t) CAFFE2_SPECIALIZED_INCREASE_INDEX_IN_DIMS(std::int64_t) #undef CAFFE2_SPECIALIZED_INCREASE_INDEX_IN_DIMS -int GetIndexFromDims(const int n, const int* dims, const int* index) { - int sum = 0; - for (int i = 0; i < n; ++i) { - if (dims[i] > 1) { - sum = sum * dims[i] + index[i]; - } +#define CAFFE2_SPECIALIZED_GET_INDEX_FROM_DIMS(TIndex) \ + template <> \ + C10_EXPORT TIndex GetIndexFromDims( \ + const int n, const TIndex* dims, const TIndex* index) { \ + TIndex sum = 0; \ + for (int i = 0; i < n; ++i) { \ + if (dims[i] > 1) { \ + sum = sum * dims[i] + index[i]; \ + } \ + } \ + return sum; \ } - return sum; -} +CAFFE2_SPECIALIZED_GET_INDEX_FROM_DIMS(std::int32_t) +CAFFE2_SPECIALIZED_GET_INDEX_FROM_DIMS(std::int64_t) +#undef CAFFE2_SPECIALIZED_GET_INDEX_FROM_DIMS bool IsIdentityPermutation(const int n, const int* perm) { for (int i = 0; i < n; ++i) { @@ -125,30 +131,36 @@ bool IsBothEndsReduce( return true; } -void ComputeBroadcastBinaryOpDims( - const int A_ndim, - const int* A_dims, - const int B_ndim, - const int* B_dims, - int* A_broadcast_dims, - int* B_broadcast_dims, - int* C_broadcast_dims) { - const int ndim = std::max(A_ndim, B_ndim); - std::fill(A_broadcast_dims, A_broadcast_dims + ndim - A_ndim, 1); - std::fill(B_broadcast_dims, B_broadcast_dims + ndim - B_ndim, 1); - std::copy(A_dims, A_dims + A_ndim, A_broadcast_dims + ndim - A_ndim); - std::copy(B_dims, B_dims + B_ndim, B_broadcast_dims + ndim - B_ndim); - for (int i = 0; i < ndim; ++i) { - CAFFE_ENFORCE( - A_broadcast_dims[i] == B_broadcast_dims[i] || - A_broadcast_dims[i] <= 1 || B_broadcast_dims[i] <= 1); - if (A_broadcast_dims[i] == 0 || B_broadcast_dims[i] == 0) { - C_broadcast_dims[i] = 0; - } else { - C_broadcast_dims[i] = std::max(A_broadcast_dims[i], B_broadcast_dims[i]); - } +#define CAFFE2_SPECIALIZED_COMPUTE_BROADCAST_BINARY_OP_DIMS(TIndex) \ + template <> \ + C10_EXPORT void ComputeBroadcastBinaryOpDims( \ + const int A_ndim, \ + const TIndex* A_dims, \ + const int B_ndim, \ + const TIndex* B_dims, \ + TIndex* A_broadcast_dims, \ + TIndex* B_broadcast_dims, \ + TIndex* C_broadcast_dims) { \ + const int ndim = std::max(A_ndim, B_ndim); \ + std::fill(A_broadcast_dims, A_broadcast_dims + ndim - A_ndim, 1); \ + std::fill(B_broadcast_dims, B_broadcast_dims + ndim - B_ndim, 1); \ + std::copy(A_dims, A_dims + A_ndim, A_broadcast_dims + ndim - A_ndim); \ + std::copy(B_dims, B_dims + B_ndim, B_broadcast_dims + ndim - B_ndim); \ + for (int i = 0; i < ndim; ++i) { \ + CAFFE_ENFORCE( \ + A_broadcast_dims[i] == B_broadcast_dims[i] || \ + A_broadcast_dims[i] <= 1 || B_broadcast_dims[i] <= 1); \ + if (A_broadcast_dims[i] == 0 || B_broadcast_dims[i] == 0) { \ + C_broadcast_dims[i] = 0; \ + } else { \ + C_broadcast_dims[i] = \ + std::max(A_broadcast_dims[i], B_broadcast_dims[i]); \ + } \ + } \ } -} +CAFFE2_SPECIALIZED_COMPUTE_BROADCAST_BINARY_OP_DIMS(std::int32_t) +CAFFE2_SPECIALIZED_COMPUTE_BROADCAST_BINARY_OP_DIMS(std::int64_t) +#undef CAFFE2_SPECIALIZED_COMPUTE_BROADCAST_BINARY_OP_DIMS bool IsRowwiseBroadcastBinaryOp( const int ndim, diff --git a/caffe2/utils/math/utils.h b/caffe2/utils/math/utils.h index af239cd2fa..9346d454ab 100644 --- a/caffe2/utils/math/utils.h +++ b/caffe2/utils/math/utils.h @@ -56,7 +56,7 @@ MATH_UTILS_DECL T Cube(const T x) { // 0x800... // The casting allows to use one condition instead of two. MATH_UTILS_DECL bool IsAGeZeroAndALtB(const int a, const int b) { - return static_cast<unsigned int>(a) < static_cast<unsigned>(b); + return static_cast<unsigned int>(a) < static_cast<unsigned int>(b); } // Increase the index digits by one based on dims. @@ -65,7 +65,9 @@ CAFFE2_API void IncreaseIndexInDims(int ndim, const TIndex* dims, TIndex* index); // Get index value from dims and index digits. -CAFFE2_API int GetIndexFromDims(const int n, const int* dims, const int* index); +template <typename TIndex> +CAFFE2_API TIndex +GetIndexFromDims(const int n, const TIndex* dims, const TIndex* index); // Checks if the input permutation is an identity permutation; CAFFE2_API bool IsIdentityPermutation(const int n, const int* perm); @@ -96,14 +98,15 @@ CAFFE2_API bool IsBothEndsReduce( int* nxt); // Computest the broadcast binary operation dims. +template <typename TIndex> CAFFE2_API void ComputeBroadcastBinaryOpDims( const int A_ndim, - const int* A_dims, + const TIndex* A_dims, const int B_ndim, - const int* B_dims, - int* A_broadcast_dims, - int* B_broadcast_dims, - int* C_broadcast_dims); + const TIndex* B_dims, + TIndex* A_broadcast_dims, + TIndex* B_broadcast_dims, + TIndex* C_broadcast_dims); CAFFE2_API bool IsRowwiseBroadcastBinaryOp( const int ndim, diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu index a228527920..836e0b7502 100644 --- a/caffe2/utils/math_gpu.cu +++ b/caffe2/utils/math_gpu.cu @@ -1095,6 +1095,139 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>( #endif } +template <> +CAFFE2_CUDA_EXPORT void Gemv<float, CUDAContext>( + const CBLAS_TRANSPOSE trans_A, + const int M, + const int N, + const float alpha, + const float* A, + const float* x, + const float beta, + float* y, + CUDAContext* context, + TensorProto::DataType math_type) { + const cublasOperation_t cu_trans_A = + (trans_A == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; + CUBLAS_ENFORCE( + cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); + CUBLAS_ENFORCE(cublasSgemv( + context->cublas_handle(), + cu_trans_A, + N, + M, + &alpha, + A, + N, + x, + 1, + &beta, + y, + 1)); +} + +template <> +CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>( + const CBLAS_TRANSPOSE trans_A, + const int M, + const int N, + const float alpha, + const at::Half* A, + const at::Half* x, + const float beta, + at::Half* y, + CUDAContext* context, + TensorProto::DataType math_type) { + const cublasOperation_t cu_trans_A = + (trans_A == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; + + // sort out what we need to call cublasSgemmEx / cublasHgemm + const int m = (cu_trans_A == CUBLAS_OP_N) ? N : M; + const int k = (cu_trans_A == CUBLAS_OP_N) ? M : N; + const int lda = (cu_trans_A == CUBLAS_OP_N) ? m : k; + const int ldc = m; + + if (math_type == TensorProto_DataType_FLOAT) { + CUBLAS_ENFORCE(cublasSetPointerMode( + context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); +#ifdef __HIP_PLATFORM_HCC__ + // rocblas doesn't support cublasSgemmEx type API yet. + // It has more general rocblas_gemm_ex API which is more close to + // cublasGemmEx rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C, + // whereas cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C + ROCBLAS_ENFORCE(rocblas_gemm_ex( + context->rocblashandle(), + cu_trans_A, + rocblas_operation_none, + m, + 1, + k, + &alpha, + A, + rocblas_datatype_f16_r, + lda, + x, + rocblas_datatype_f16_r, + k, + &beta, + y, + rocblas_datatype_f16_r, + ldc, + y, // D + rocblas_datatype_f16_r, // D type + ldc, // ldd + rocblas_datatype_f32_r, // compute type + rocblas_gemm_algo_standard, // rocblas_gemm_algo + 0, // solution index, reserved for future use + 0, // flags, reserved for future use + NULL, // size of workspace + NULL)); // workspace +#else + CUBLAS_ENFORCE(cublasSgemmEx( + context->cublas_handle(), + cu_trans_A, + CUBLAS_OP_N, + m, + 1, + k, + &alpha, + A, + CUDA_R_16F, + lda, + x, + CUDA_R_16F, + k, + &beta, + y, + CUDA_R_16F, + ldc)); +#endif // __HIP_PLATFORM_HCC__ + } else if (math_type == TensorProto_DataType_FLOAT16) { + const __half alpha_fp16 = at::Half(alpha); + const __half beta_fp16 = at::Half(beta); + CUBLAS_ENFORCE(cublasSetPointerMode( + context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); + CUBLAS_ENFORCE(cublasHgemm( + context->cublas_handle(), + cu_trans_A, + CUBLAS_OP_N, + m, + 1, + k, + reinterpret_cast<const CUBLAS_HALF_TYPE*>(&alpha_fp16), + reinterpret_cast<const CUBLAS_HALF_TYPE*>(A), + lda, + reinterpret_cast<const CUBLAS_HALF_TYPE*>(x), + k, + reinterpret_cast<const CUBLAS_HALF_TYPE*>(&beta_fp16), + reinterpret_cast<CUBLAS_HALF_TYPE*>(y), + ldc)); + } else { + // fail + CAFFE_THROW("Unsupported math type"); + } +} + #if CUDA_VERSION >= 9000 // No change, but required. Defer to default CUDA engine @@ -1176,6 +1309,68 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext, TensorCoreEngine>( } template <> +CAFFE2_CUDA_EXPORT void GemmBatched<float, CUDAContext, TensorCoreEngine>( + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, + const int batch_size, + const int M, + const int N, + const int K, + const float alpha, + const float** A, + const float** B, + const float beta, + float** C, + CUDAContext* context, + TensorProto::DataType math_type) { + GemmBatched<float, CUDAContext, DefaultEngine>( + trans_A, + trans_B, + batch_size, + M, + N, + K, + alpha, + A, + B, + beta, + C, + context, + math_type); +} + +template <> +CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext, TensorCoreEngine>( + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, + const int batch_size, + const int M, + const int N, + const int K, + const float alpha, + const at::Half** A, + const at::Half** B, + const float beta, + at::Half** C, + CUDAContext* context, + TensorProto::DataType math_type) { + GemmBatched<at::Half, CUDAContext, DefaultEngine>( + trans_A, + trans_B, + batch_size, + M, + N, + K, + alpha, + A, + B, + beta, + C, + context, + math_type); +} + +template <> CAFFE2_CUDA_EXPORT void GemmStridedBatched<float, CUDAContext, TensorCoreEngine>( const CBLAS_TRANSPOSE trans_A, @@ -1194,7 +1389,7 @@ GemmStridedBatched<float, CUDAContext, TensorCoreEngine>( const int C_stride, CUDAContext* context, TensorProto::DataType math_type) { - return GemmStridedBatched<float, CUDAContext, DefaultEngine>( + GemmStridedBatched<float, CUDAContext, DefaultEngine>( trans_A, trans_B, batch_size, @@ -1232,7 +1427,7 @@ GemmStridedBatched<at::Half, CUDAContext, TensorCoreEngine>( const int C_stride, CUDAContext* context, TensorProto::DataType math_type) { - return GemmStridedBatched<at::Half, CUDAContext, DefaultEngine>( + GemmStridedBatched<at::Half, CUDAContext, DefaultEngine>( trans_A, trans_B, batch_size, @@ -1251,6 +1446,38 @@ GemmStridedBatched<at::Half, CUDAContext, TensorCoreEngine>( math_type); } +template <> +CAFFE2_CUDA_EXPORT void Gemv<float, CUDAContext, TensorCoreEngine>( + const CBLAS_TRANSPOSE trans_A, + const int M, + const int N, + const float alpha, + const float* A, + const float* x, + const float beta, + float* y, + CUDAContext* context, + TensorProto::DataType math_type) { + Gemv<float, CUDAContext, DefaultEngine>( + trans_A, M, N, alpha, A, x, beta, y, context, math_type); +} + +template <> +CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext, TensorCoreEngine>( + const CBLAS_TRANSPOSE trans_A, + const int M, + const int N, + const float alpha, + const at::Half* A, + const at::Half* x, + const float beta, + at::Half* y, + CUDAContext* context, + TensorProto::DataType math_type) { + Gemv<at::Half, CUDAContext, DefaultEngine>( + trans_A, M, N, alpha, A, x, beta, y, context, math_type); +} + #endif // CUDA_VERSION >= 9000 template <> @@ -1294,37 +1521,6 @@ CAFFE2_CUDA_EXPORT void GemmEx<float, CUDAContext>( ldc)); } -template <> -CAFFE2_CUDA_EXPORT void Gemv<float, CUDAContext>( - const CBLAS_TRANSPOSE trans_A, - const int M, - const int N, - const float alpha, - const float* A, - const float* x, - const float beta, - float* y, - CUDAContext* context, - TensorProto::DataType math_type) { - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; - CUBLAS_ENFORCE( - cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasSgemv( - context->cublas_handle(), - cu_trans_A, - N, - M, - &alpha, - A, - N, - x, - 1, - &beta, - y, - 1)); -} - // Batched Add variants namespace { @@ -1366,108 +1562,6 @@ CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(float); CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(at::Half); #undef CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH -template <> -CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>( - const CBLAS_TRANSPOSE trans_A, - const int M, - const int N, - const float alpha, - const at::Half* A, - const at::Half* x, - const float beta, - at::Half* y, - CUDAContext* context, - TensorProto::DataType math_type) { - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; - - // sort out what we need to call cublasSgemmEx / cublasHgemm - const int m = (cu_trans_A == CUBLAS_OP_N) ? N : M; - const int k = (cu_trans_A == CUBLAS_OP_N) ? M : N; - const int lda = (cu_trans_A == CUBLAS_OP_N) ? m : k; - const int ldc = m; - - if (math_type == TensorProto_DataType_FLOAT) { - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); -#ifdef __HIP_PLATFORM_HCC__ - // rocblas doesn't support cublasSgemmEx type API yet. - // It has more general rocblas_gemm_ex API which is more close to - // cublasGemmEx rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C, - // whereas cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C - ROCBLAS_ENFORCE(rocblas_gemm_ex( - context->rocblashandle(), - cu_trans_A, - rocblas_operation_none, - m, - 1, - k, - &alpha, - A, - rocblas_datatype_f16_r, - lda, - x, - rocblas_datatype_f16_r, - k, - &beta, - y, - rocblas_datatype_f16_r, - ldc, - y, // D - rocblas_datatype_f16_r, // D type - ldc, // ldd - rocblas_datatype_f32_r, // compute type - rocblas_gemm_algo_standard, // rocblas_gemm_algo - 0, // solution index, reserved for future use - 0, // flags, reserved for future use - NULL, // size of workspace - NULL)); // workspace -#else - CUBLAS_ENFORCE(cublasSgemmEx( - context->cublas_handle(), - cu_trans_A, - CUBLAS_OP_N, - m, - 1, - k, - &alpha, - A, - CUDA_R_16F, - lda, - x, - CUDA_R_16F, - k, - &beta, - y, - CUDA_R_16F, - ldc)); -#endif // __HIP_PLATFORM_HCC__ - } else if (math_type == TensorProto_DataType_FLOAT16) { - const __half alpha_fp16 = at::Half(alpha); - const __half beta_fp16 = at::Half(beta); - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasHgemm( - context->cublas_handle(), - cu_trans_A, - CUBLAS_OP_N, - m, - 1, - k, - reinterpret_cast<const CUBLAS_HALF_TYPE*>(&alpha_fp16), - reinterpret_cast<const CUBLAS_HALF_TYPE*>(A), - lda, - reinterpret_cast<const CUBLAS_HALF_TYPE*>(x), - k, - reinterpret_cast<const CUBLAS_HALF_TYPE*>(&beta_fp16), - reinterpret_cast<CUBLAS_HALF_TYPE*>(y), - ldc)); - } else { - // fail - CAFFE_THROW("Unsupported math type"); - } -} - namespace { template <typename T> __global__ void |