summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiaomeng Yang <yangxm@fb.com>2019-04-23 15:24:03 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-23 15:34:59 -0700
commitfb9fc42a0c00a04aaf5574ae4b932dc90221d147 (patch)
tree084016897ca45de9ef73d378cd16b31113329a90
parent176bdc0722951c42ca83aa4d9e1e49762e7df039 (diff)
downloadpytorch-fb9fc42a0c00a04aaf5574ae4b932dc90221d147.tar.gz
pytorch-fb9fc42a0c00a04aaf5574ae4b932dc90221d147.tar.bz2
pytorch-fb9fc42a0c00a04aaf5574ae4b932dc90221d147.zip
optimize BatchMatmulOp (#18612)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18612 optimize BatchMatmulOp Reviewed By: houseroad Differential Revision: D14681665 fbshipit-source-id: cf5ea4909ace58fd44fe6fa634531102ac84e851
-rw-r--r--caffe2/operators/batch_matmul_op.cc4
-rw-r--r--caffe2/operators/batch_matmul_op.cu1
-rw-r--r--caffe2/operators/batch_matmul_op.h479
-rw-r--r--caffe2/utils/math/utils.cc74
-rw-r--r--caffe2/utils/math/utils.h17
-rw-r--r--caffe2/utils/math_gpu.cu364
6 files changed, 533 insertions, 406 deletions
diff --git a/caffe2/operators/batch_matmul_op.cc b/caffe2/operators/batch_matmul_op.cc
index 80fc6029d0..62ad5ad571 100644
--- a/caffe2/operators/batch_matmul_op.cc
+++ b/caffe2/operators/batch_matmul_op.cc
@@ -1,4 +1,5 @@
#include "caffe2/operators/batch_matmul_op.h"
+
#include "caffe2/core/operator_schema.h"
namespace caffe2 {
@@ -28,7 +29,8 @@ vector<TensorShape> TensorInferenceForBatchMatMul(
b_dim1 = in[1].dims(ndim - 1);
}
- auto output_dims = vector<int64_t>{in[0].dims().begin(), in[0].dims().end()};
+ auto output_dims =
+ vector<int64_t>{in[0].dims().begin(), in[0].dims().end()};
output_dims[ndim - 2] = a_dim0;
output_dims[ndim - 1] = b_dim1;
diff --git a/caffe2/operators/batch_matmul_op.cu b/caffe2/operators/batch_matmul_op.cu
index 801e4a6d04..b0ce9a31a1 100644
--- a/caffe2/operators/batch_matmul_op.cu
+++ b/caffe2/operators/batch_matmul_op.cu
@@ -22,6 +22,7 @@ REGISTER_CUDA_OPERATOR_WITH_ENGINE(
BatchMatMul,
TENSORCORE,
BatchMatMulOp<CUDAContext, TensorCoreEngine>);
+
#endif
} // namespace caffe2
diff --git a/caffe2/operators/batch_matmul_op.h b/caffe2/operators/batch_matmul_op.h
index 662229a4a5..fa5800aec1 100644
--- a/caffe2/operators/batch_matmul_op.h
+++ b/caffe2/operators/batch_matmul_op.h
@@ -1,7 +1,11 @@
-#ifndef CAFFE2_OPERATORS_MATMUL_OP_H_
-#define CAFFE2_OPERATORS_MATMUL_OP_H_
+#ifndef CAFFE2_OPERATORS_BATCH_MATMUL_OP_H_
+#define CAFFE2_OPERATORS_BATCH_MATMUL_OP_H_
-#include <sstream>
+#include <algorithm>
+#include <functional>
+#include <numeric>
+#include <string>
+#include <vector>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
@@ -13,14 +17,13 @@ template <class Context, class Engine = DefaultEngine>
class BatchMatMulOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
+
template <class... Args>
explicit BatchMatMulOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
- trans_a_(this->template GetSingleArgument<int>("trans_a", 0)),
- trans_b_(this->template GetSingleArgument<int>("trans_b", 0)),
- broadcast_(this->template GetSingleArgument<int>("broadcast", 0)) {}
-
- ~BatchMatMulOp() {}
+ OP_SINGLE_ARG(bool, "trans_a", trans_a_, false),
+ OP_SINGLE_ARG(bool, "trans_b", trans_b_, false),
+ OP_SINGLE_ARG(bool, "broadcast", broadcast_, false) {}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
@@ -30,253 +33,265 @@ class BatchMatMulOp final : public Operator<Context> {
bool DoRunWithType() {
const auto& A = Input(0);
const auto& B = Input(1);
+ const int A_ndim = A.dim();
+ const int B_ndim = B.dim();
+ const std::vector<std::int64_t> A_dims = A.sizes().vec();
+ const std::vector<std::int64_t> B_dims = B.sizes().vec();
+ const T* A_data = A.template data<T>();
+ const T* B_data = B.template data<T>();
- auto ndims_A = A.dim();
- auto dims_A = A.sizes().vec();
- auto ndims_B = B.dim();
- auto dims_B = B.sizes().vec();
-
- auto noBroadcastErrorMsg = [](size_t dim1, size_t dim2) {
- std::stringstream ss;
- ss << "Inputs with dimensions A = ";
- ss << dim1;
- ss << " and B = ";
- ss << dim2;
- ss << " is not supported with broadcast=0. Did you forget to set the "
- "broadcast flag?";
- return ss.str();
- };
-
- // These should all be false if we're not broadcasting.
- bool dimMismatch = ndims_A != ndims_B;
- bool dimsLessThan1D = ndims_A < 2;
- CAFFE_ENFORCE(
- broadcast_ || (!dimMismatch && !dimsLessThan1D),
- noBroadcastErrorMsg(ndims_A, ndims_B));
-
- auto* data_A = A.template data<T>();
- auto* data_B = B.template data<T>();
-
- auto dimMismatchErrorString = [](size_t dimnum1,
- size_t dim1,
- size_t dimnum2,
- size_t dim2,
- bool trans_a,
- bool trans_b) {
- std::stringstream ss;
- ss << "Expected dimension ";
- ss << dimnum1;
- ss << " of tensor A with value ";
- ss << dim1;
- ss << " to match dimension ";
- ss << dimnum2;
- ss << " of tensor B with value ";
- ss << dim2;
- ss << ". trans_a = ";
- ss << trans_a;
- ss << " trans_b = ";
- ss << trans_b;
- return ss.str();
- };
-
- if (ndims_A == 1 && ndims_B == 1) {
- // vector-vector
- CAFFE_ENFORCE_EQ(
- dims_A[0],
- dims_B[0],
- "Vector-vector product requires each of the vectors to "
- "be the same size.");
+ if (A_ndim == 1 && B_ndim == 1) {
+ CAFFE_ENFORCE_EQ(A.numel(), B.numel());
auto* Y = Output(0, {1}, at::dtype<T>());
- math::Dot<T, Context>(
- dims_A[0], data_A, data_B, Y->template mutable_data<T>(), &context_);
- } else {
- bool A_broadcasted = false, B_broadcasted = false;
- if (ndims_A == 1) {
- dims_A.insert(dims_A.begin(), 1);
- ndims_A = 2;
- A_broadcasted = true;
- }
- if (ndims_B == 1) {
- dims_B.push_back(1);
- ndims_B = 2;
- B_broadcasted = true;
- }
- // matrix-matrix with batches
- // [B1..., M, K] * [B2..., K, N] -> [B..., M, N]
- // In the event that A or B are one-dimensional, the trailing or leading
- // 1 is not added to the output tensor's size.
-
- // First step: partition the tensors into inner and outer blocks.
- // Ignoring the last two dimensions of A and B, ensure that one of the
- // tensors' dimensions is a suffix of the other. For example,
- // [4, x, x] is a suffix of [2, 3, 4, x, x]. In this example, the
- // dimensions of size 2 and 3 will be broadcasted, so we partition into
- // 2*3=6 individual instances of batched GEMM with A and B \in [4, x, x].
- size_t num_inner_dims = std::min(ndims_A, ndims_B);
- for (size_t i = 2; i < num_inner_dims; ++i) {
- auto first_r_itr = dims_A.rbegin();
- auto second_r_itr = dims_B.rbegin();
- CAFFE_ENFORCE_EQ(
- *(first_r_itr + i),
- *(second_r_itr + i),
- dimMismatchErrorString(
- ndims_A - i - 1,
- *(first_r_itr + i),
- ndims_B - i - 1,
- *(second_r_itr + i),
- trans_a_,
- trans_b_));
- }
- size_t num_outer_dims = std::max(ndims_A, ndims_B) - num_inner_dims;
-
- // Standard M, N, and K parameters respecting GEMM API and transpose
- // flags
- size_t M, N, K, K_dim;
- if (trans_a_) {
- M = dims_A[ndims_A - 1];
- K = dims_A[ndims_A - 2];
- K_dim = ndims_A - 2;
- } else {
- M = dims_A[ndims_A - 2];
- K = dims_A[ndims_A - 1];
- K_dim = ndims_A - 1;
- }
+ T* Y_data = Y->template mutable_data<T>();
+ math::Dot<T, Context>(A.numel(), A_data, B_data, Y_data, &context_);
+ return true;
+ }
+ if (A_ndim == 1) {
+ const int N = A.numel();
if (trans_b_) {
- N = dims_B[ndims_B - 2];
- CAFFE_ENFORCE_EQ(
- K,
- dims_B[ndims_B - 1],
- dimMismatchErrorString(
- K_dim,
- K,
- ndims_B - 1,
- dims_B[ndims_B - 1],
- trans_a_,
- trans_b_));
+ CAFFE_ENFORCE_EQ(B_dims[B_ndim - 1], N);
} else {
- N = dims_B[ndims_B - 1];
- CAFFE_ENFORCE_EQ(
- K,
- dims_B[ndims_B - 2],
- dimMismatchErrorString(
- K_dim,
- K,
- ndims_B - 2,
- dims_B[ndims_B - 2],
- trans_a_,
- trans_b_));
+ CAFFE_ENFORCE_EQ(B_dims[B_ndim - 2], N);
}
-
- // Calculate output tensor shapes [B..., (M), (N)]
- // Batch dimensions will be broadcasted out to those of the longer tensor
- // A or B. Either M or N are optional if A or B, respectively are 1-D.
- std::vector<int64_t> new_dims;
- if (ndims_A >= ndims_B) {
- new_dims.assign(dims_A.begin(), dims_A.end() - 2);
+ std::vector<std::int64_t> Y_dims(B_ndim - 1);
+ if (trans_b_) {
+ std::copy_n(B_dims.cbegin(), B_ndim - 1, Y_dims.begin());
} else {
- new_dims.assign(dims_B.begin(), dims_B.end() - 2);
+ std::copy_n(B_dims.cbegin(), B_ndim - 2, Y_dims.begin());
+ Y_dims.back() = B_dims.back();
}
- if (!A_broadcasted) {
- new_dims.push_back(M);
+ auto* Y = Output(0, Y_dims, at::dtype<T>());
+ T* Y_data = Y->template mutable_data<T>();
+ if (trans_b_) {
+ const int M = B.numel() / N;
+ math::Gemv<T, Context, Engine>(
+ CblasNoTrans, M, N, 1.0f, B_data, A_data, 0.0f, Y_data, &context_);
} else {
- new_dims.push_back(1);
+ const int M = B_dims[B_ndim - 1];
+ const int batch_size = B.numel() / (M * N);
+ if (batch_size == 1) {
+ math::Gemv<T, Context, Engine>(
+ CblasTrans, N, M, 1.0f, B_data, A_data, 0.0f, Y_data, &context_);
+ } else {
+ math::GemmStridedBatched<T, Context, Engine>(
+ CblasTrans,
+ CblasNoTrans,
+ batch_size,
+ M,
+ 1,
+ N,
+ 1.0f,
+ B_data,
+ M * N,
+ A_data,
+ 0,
+ 0.0f,
+ Y_data,
+ M,
+ &context_);
+ }
}
- if (!B_broadcasted) {
- new_dims.push_back(N);
+ return true;
+ }
+ if (B_ndim == 1) {
+ const int N = B.numel();
+ if (trans_a_) {
+ CAFFE_ENFORCE_EQ(A_dims[A_ndim - 2], N);
} else {
- new_dims.push_back(1);
+ CAFFE_ENFORCE_EQ(A_dims[A_ndim - 1], N);
}
-
- // Calculate strides. Continuing our example above,
- // [4, M, K] * [2, 3, 4, K, N] = [2, 3, 4, M, N]
- // We calculate this as follows:
- // 1) Treat the outer batch dimensions as flattened, i.e. view the B
- // tensor here as [6, 4, K, N] and Y as [6, 4, M, N]. The same rea-
- // soning is analogous for the case where # dims A >= # dims B.
- // 2) Perform this operation:
- // for i in range(6):
- // Y[i, :, :, :] = BatchMatMul(A, B[i, :, :, :])
- size_t A_stride = 1; // How far to increment A pointer each itr
- size_t B_stride = 1; // How far to increment B pointer each itr
- size_t Y_stride = 1; // How far to increment Y pointer each itr
- // How many "inner batches" we have. That is, the product of sizes for
- // the slices excluding M, K, and N, for their respective matrices.
- size_t num_sub_batches = 1;
- if (ndims_A >= ndims_B) {
- auto first_r_itr = dims_A.rbegin();
- auto output_r_itr = new_dims.rbegin();
- for (size_t i = 0; i < num_inner_dims; ++i) {
- A_stride *= *(first_r_itr + i);
- Y_stride *= *(output_r_itr + i);
- if (i >= 2) {
- num_sub_batches *= *(first_r_itr + i);
- }
+ const std::vector<std::int64_t> Y_dims(
+ A_dims.cbegin(), A_dims.cbegin() + A_ndim - 1);
+ auto* Y = Output(0, Y_dims, at::dtype<T>());
+ T* Y_data = Y->template mutable_data<T>();
+ if (trans_a_) {
+ const int M = A_dims[A_ndim - 1];
+ const int batch_size = A.numel() / (M * N);
+ if (batch_size == 1) {
+ math::Gemv<T, Context, Engine>(
+ CblasTrans, N, M, 1.0f, A_data, B_data, 0.0f, Y_data, &context_);
+ } else {
+ math::GemmStridedBatched<T, Context, Engine>(
+ CblasTrans,
+ CblasNoTrans,
+ batch_size,
+ M,
+ 1,
+ N,
+ 1.0f,
+ A_data,
+ M * N,
+ B_data,
+ 0,
+ 0.0f,
+ Y_data,
+ M,
+ &context_);
}
- B_stride = 0;
} else {
- A_stride = 0;
- auto second_r_itr = dims_B.rbegin();
- auto output_r_itr = new_dims.rbegin();
- for (size_t i = 0; i < num_inner_dims; ++i) {
- B_stride *= *(second_r_itr + i);
- Y_stride *= *(output_r_itr + i);
- if (i >= 2) {
- num_sub_batches *= *(second_r_itr + i);
- }
- }
- }
-
- size_t num_outer_batches = 1;
- for (size_t i = 0; i < num_outer_dims; ++i) {
- num_outer_batches *= new_dims[i];
- }
-
- // Mutually exclusive since otherwise we would've taken the vector-vector
- // path above
- if (A_broadcasted) {
- new_dims.erase(new_dims.end() - 2);
- } else if (B_broadcasted) {
- new_dims.erase(new_dims.end() - 1);
+ const int M = A.numel() / N;
+ math::Gemv<T, Context, Engine>(
+ CblasNoTrans, M, N, 1.0f, A_data, B_data, 0.0f, Y_data, &context_);
}
+ return true;
+ }
- // Allocate output tensor
- auto* Y = Output(0, new_dims, at::dtype<T>());
- auto* Y_data = Y->template mutable_data<T>();
+ const int M = trans_a_ ? A_dims[A_ndim - 1] : A_dims[A_ndim - 2];
+ const int K = trans_a_ ? A_dims[A_ndim - 2] : A_dims[A_ndim - 1];
+ if (trans_b_) {
+ CAFFE_ENFORCE_EQ(B_dims[B_ndim - 1], K);
+ } else {
+ CAFFE_ENFORCE_EQ(B_dims[B_ndim - 2], K);
+ }
+ const int N = trans_b_ ? B_dims[B_ndim - 2] : B_dims[B_ndim - 1];
+ const int ndim = std::max(A_ndim, B_ndim);
+ std::vector<std::int64_t> A_broadcast_dims(ndim);
+ std::vector<std::int64_t> B_broadcast_dims(ndim);
+ std::vector<std::int64_t> Y_broadcast_dims(ndim);
+ math::utils::ComputeBroadcastBinaryOpDims(
+ A_ndim - 2,
+ A_dims.data(),
+ B_ndim - 2,
+ B_dims.data(),
+ A_broadcast_dims.data(),
+ B_broadcast_dims.data(),
+ Y_broadcast_dims.data());
+ Y_broadcast_dims[ndim - 2] = M;
+ Y_broadcast_dims[ndim - 1] = N;
+ auto* Y = Output(0, Y_broadcast_dims, at::dtype<T>());
+ T* Y_data = Y->template mutable_data<T>();
- // Zero batch dimension indicates no elements
- if (num_sub_batches == 0 || num_outer_batches == 0) {
- return true;
- }
+ const int batch_dim = ndim - 2;
+ const bool is_broadcast_dims = !std::equal(
+ A_broadcast_dims.cbegin(),
+ A_broadcast_dims.cbegin() + batch_dim,
+ B_broadcast_dims.cbegin());
+ if (is_broadcast_dims) {
+ CAFFE_ENFORCE(broadcast_);
+ }
- // TODO(T23893772): doing this in a loop is likely going to be slow on GPU
- for (size_t p = 0; p < num_outer_batches; ++p) {
- math::GemmStridedBatched<T, Context, Engine>(
- trans_a_ ? CblasTrans : CblasNoTrans,
- trans_b_ ? CblasTrans : CblasNoTrans,
- num_sub_batches,
- M,
- N,
- K,
- 1.0f,
- data_A + p * A_stride,
- M * K,
- data_B + p * B_stride,
- K * N,
- 0.0f,
- Y_data + p * Y_stride,
- M * N,
- &context_);
+ const std::int64_t A_batch_size = std::accumulate(
+ A_broadcast_dims.cbegin(),
+ A_broadcast_dims.cbegin() + batch_dim,
+ 1LL,
+ std::multiplies<std::int64_t>());
+ const std::int64_t B_batch_size = std::accumulate(
+ B_broadcast_dims.cbegin(),
+ B_broadcast_dims.cbegin() + batch_dim,
+ 1LL,
+ std::multiplies<std::int64_t>());
+ const std::int64_t Y_batch_size = std::accumulate(
+ Y_broadcast_dims.cbegin(),
+ Y_broadcast_dims.cbegin() + batch_dim,
+ 1LL,
+ std::multiplies<std::int64_t>());
+ if (Y_batch_size == 0) {
+ return true;
+ }
+ if (A_batch_size == 1 && B_batch_size == 1) {
+ math::Gemm<T, Context, Engine>(
+ trans_a_ ? CblasTrans : CblasNoTrans,
+ trans_b_ ? CblasTrans : CblasNoTrans,
+ M,
+ N,
+ K,
+ 1.0f,
+ A_data,
+ B_data,
+ 0.0f,
+ Y_data,
+ &context_);
+ } else if (A_batch_size == 1) {
+ math::GemmStridedBatched<T, Context, Engine>(
+ trans_a_ ? CblasTrans : CblasNoTrans,
+ trans_b_ ? CblasTrans : CblasNoTrans,
+ Y_batch_size,
+ M,
+ N,
+ K,
+ 1.0f,
+ A_data,
+ 0,
+ B_data,
+ K * N,
+ 0.0f,
+ Y_data,
+ M * N,
+ &context_);
+ } else if (B_batch_size == 1) {
+ math::GemmStridedBatched<T, Context, Engine>(
+ trans_a_ ? CblasTrans : CblasNoTrans,
+ trans_b_ ? CblasTrans : CblasNoTrans,
+ Y_batch_size,
+ M,
+ N,
+ K,
+ 1.0f,
+ A_data,
+ M * K,
+ B_data,
+ 0,
+ 0.0f,
+ Y_data,
+ M * N,
+ &context_);
+ } else if (!is_broadcast_dims) {
+ math::GemmStridedBatched<T, Context, Engine>(
+ trans_a_ ? CblasTrans : CblasNoTrans,
+ trans_b_ ? CblasTrans : CblasNoTrans,
+ Y_batch_size,
+ M,
+ N,
+ K,
+ 1.0f,
+ A_data,
+ M * K,
+ B_data,
+ K * N,
+ 0.0f,
+ Y_data,
+ M * N,
+ &context_);
+ } else {
+ std::vector<const T*> A_ptr(Y_batch_size);
+ std::vector<const T*> B_ptr(Y_batch_size);
+ std::vector<T*> Y_ptr(Y_batch_size);
+ std::vector<std::int64_t> index(batch_dim);
+ for (std::int64_t i = 0; i < Y_batch_size; ++i) {
+ const std::int64_t A_index = math::utils::GetIndexFromDims(
+ batch_dim, A_broadcast_dims.data(), index.data());
+ const std::int64_t B_index = math::utils::GetIndexFromDims(
+ batch_dim, B_broadcast_dims.data(), index.data());
+ A_ptr[i] = A_data + A_index * M * K;
+ B_ptr[i] = B_data + B_index * K * N;
+ Y_ptr[i] = Y_data + i * M * N;
+ math::utils::IncreaseIndexInDims(
+ batch_dim, Y_broadcast_dims.data(), index.data());
}
+ math::GemmBatched<T, Context, Engine>(
+ trans_a_ ? CblasTrans : CblasNoTrans,
+ trans_b_ ? CblasTrans : CblasNoTrans,
+ Y_batch_size,
+ M,
+ N,
+ K,
+ 1.0f,
+ A_ptr.data(),
+ B_ptr.data(),
+ 0.0f,
+ Y_ptr.data(),
+ &context_);
}
return true;
}
- protected:
- bool trans_a_;
- bool trans_b_;
- bool broadcast_;
+ private:
+ const bool trans_a_;
+ const bool trans_b_;
+ const bool broadcast_;
};
} // namespace caffe2
-#endif /* CAFFE2_OPERATORS_MATMUL_OP_H_ */
+#endif // CAFFE2_OPERATORS_BATCH_MATMUL_OP_H_
diff --git a/caffe2/utils/math/utils.cc b/caffe2/utils/math/utils.cc
index fdbb479e4b..86573b30e5 100644
--- a/caffe2/utils/math/utils.cc
+++ b/caffe2/utils/math/utils.cc
@@ -28,15 +28,21 @@ CAFFE2_SPECIALIZED_INCREASE_INDEX_IN_DIMS(std::int32_t)
CAFFE2_SPECIALIZED_INCREASE_INDEX_IN_DIMS(std::int64_t)
#undef CAFFE2_SPECIALIZED_INCREASE_INDEX_IN_DIMS
-int GetIndexFromDims(const int n, const int* dims, const int* index) {
- int sum = 0;
- for (int i = 0; i < n; ++i) {
- if (dims[i] > 1) {
- sum = sum * dims[i] + index[i];
- }
+#define CAFFE2_SPECIALIZED_GET_INDEX_FROM_DIMS(TIndex) \
+ template <> \
+ C10_EXPORT TIndex GetIndexFromDims( \
+ const int n, const TIndex* dims, const TIndex* index) { \
+ TIndex sum = 0; \
+ for (int i = 0; i < n; ++i) { \
+ if (dims[i] > 1) { \
+ sum = sum * dims[i] + index[i]; \
+ } \
+ } \
+ return sum; \
}
- return sum;
-}
+CAFFE2_SPECIALIZED_GET_INDEX_FROM_DIMS(std::int32_t)
+CAFFE2_SPECIALIZED_GET_INDEX_FROM_DIMS(std::int64_t)
+#undef CAFFE2_SPECIALIZED_GET_INDEX_FROM_DIMS
bool IsIdentityPermutation(const int n, const int* perm) {
for (int i = 0; i < n; ++i) {
@@ -125,30 +131,36 @@ bool IsBothEndsReduce(
return true;
}
-void ComputeBroadcastBinaryOpDims(
- const int A_ndim,
- const int* A_dims,
- const int B_ndim,
- const int* B_dims,
- int* A_broadcast_dims,
- int* B_broadcast_dims,
- int* C_broadcast_dims) {
- const int ndim = std::max(A_ndim, B_ndim);
- std::fill(A_broadcast_dims, A_broadcast_dims + ndim - A_ndim, 1);
- std::fill(B_broadcast_dims, B_broadcast_dims + ndim - B_ndim, 1);
- std::copy(A_dims, A_dims + A_ndim, A_broadcast_dims + ndim - A_ndim);
- std::copy(B_dims, B_dims + B_ndim, B_broadcast_dims + ndim - B_ndim);
- for (int i = 0; i < ndim; ++i) {
- CAFFE_ENFORCE(
- A_broadcast_dims[i] == B_broadcast_dims[i] ||
- A_broadcast_dims[i] <= 1 || B_broadcast_dims[i] <= 1);
- if (A_broadcast_dims[i] == 0 || B_broadcast_dims[i] == 0) {
- C_broadcast_dims[i] = 0;
- } else {
- C_broadcast_dims[i] = std::max(A_broadcast_dims[i], B_broadcast_dims[i]);
- }
+#define CAFFE2_SPECIALIZED_COMPUTE_BROADCAST_BINARY_OP_DIMS(TIndex) \
+ template <> \
+ C10_EXPORT void ComputeBroadcastBinaryOpDims( \
+ const int A_ndim, \
+ const TIndex* A_dims, \
+ const int B_ndim, \
+ const TIndex* B_dims, \
+ TIndex* A_broadcast_dims, \
+ TIndex* B_broadcast_dims, \
+ TIndex* C_broadcast_dims) { \
+ const int ndim = std::max(A_ndim, B_ndim); \
+ std::fill(A_broadcast_dims, A_broadcast_dims + ndim - A_ndim, 1); \
+ std::fill(B_broadcast_dims, B_broadcast_dims + ndim - B_ndim, 1); \
+ std::copy(A_dims, A_dims + A_ndim, A_broadcast_dims + ndim - A_ndim); \
+ std::copy(B_dims, B_dims + B_ndim, B_broadcast_dims + ndim - B_ndim); \
+ for (int i = 0; i < ndim; ++i) { \
+ CAFFE_ENFORCE( \
+ A_broadcast_dims[i] == B_broadcast_dims[i] || \
+ A_broadcast_dims[i] <= 1 || B_broadcast_dims[i] <= 1); \
+ if (A_broadcast_dims[i] == 0 || B_broadcast_dims[i] == 0) { \
+ C_broadcast_dims[i] = 0; \
+ } else { \
+ C_broadcast_dims[i] = \
+ std::max(A_broadcast_dims[i], B_broadcast_dims[i]); \
+ } \
+ } \
}
-}
+CAFFE2_SPECIALIZED_COMPUTE_BROADCAST_BINARY_OP_DIMS(std::int32_t)
+CAFFE2_SPECIALIZED_COMPUTE_BROADCAST_BINARY_OP_DIMS(std::int64_t)
+#undef CAFFE2_SPECIALIZED_COMPUTE_BROADCAST_BINARY_OP_DIMS
bool IsRowwiseBroadcastBinaryOp(
const int ndim,
diff --git a/caffe2/utils/math/utils.h b/caffe2/utils/math/utils.h
index af239cd2fa..9346d454ab 100644
--- a/caffe2/utils/math/utils.h
+++ b/caffe2/utils/math/utils.h
@@ -56,7 +56,7 @@ MATH_UTILS_DECL T Cube(const T x) {
// 0x800...
// The casting allows to use one condition instead of two.
MATH_UTILS_DECL bool IsAGeZeroAndALtB(const int a, const int b) {
- return static_cast<unsigned int>(a) < static_cast<unsigned>(b);
+ return static_cast<unsigned int>(a) < static_cast<unsigned int>(b);
}
// Increase the index digits by one based on dims.
@@ -65,7 +65,9 @@ CAFFE2_API void
IncreaseIndexInDims(int ndim, const TIndex* dims, TIndex* index);
// Get index value from dims and index digits.
-CAFFE2_API int GetIndexFromDims(const int n, const int* dims, const int* index);
+template <typename TIndex>
+CAFFE2_API TIndex
+GetIndexFromDims(const int n, const TIndex* dims, const TIndex* index);
// Checks if the input permutation is an identity permutation;
CAFFE2_API bool IsIdentityPermutation(const int n, const int* perm);
@@ -96,14 +98,15 @@ CAFFE2_API bool IsBothEndsReduce(
int* nxt);
// Computest the broadcast binary operation dims.
+template <typename TIndex>
CAFFE2_API void ComputeBroadcastBinaryOpDims(
const int A_ndim,
- const int* A_dims,
+ const TIndex* A_dims,
const int B_ndim,
- const int* B_dims,
- int* A_broadcast_dims,
- int* B_broadcast_dims,
- int* C_broadcast_dims);
+ const TIndex* B_dims,
+ TIndex* A_broadcast_dims,
+ TIndex* B_broadcast_dims,
+ TIndex* C_broadcast_dims);
CAFFE2_API bool IsRowwiseBroadcastBinaryOp(
const int ndim,
diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu
index a228527920..836e0b7502 100644
--- a/caffe2/utils/math_gpu.cu
+++ b/caffe2/utils/math_gpu.cu
@@ -1095,6 +1095,139 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
#endif
}
+template <>
+CAFFE2_CUDA_EXPORT void Gemv<float, CUDAContext>(
+ const CBLAS_TRANSPOSE trans_A,
+ const int M,
+ const int N,
+ const float alpha,
+ const float* A,
+ const float* x,
+ const float beta,
+ float* y,
+ CUDAContext* context,
+ TensorProto::DataType math_type) {
+ const cublasOperation_t cu_trans_A =
+ (trans_A == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
+ CUBLAS_ENFORCE(
+ cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
+ CUBLAS_ENFORCE(cublasSgemv(
+ context->cublas_handle(),
+ cu_trans_A,
+ N,
+ M,
+ &alpha,
+ A,
+ N,
+ x,
+ 1,
+ &beta,
+ y,
+ 1));
+}
+
+template <>
+CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
+ const CBLAS_TRANSPOSE trans_A,
+ const int M,
+ const int N,
+ const float alpha,
+ const at::Half* A,
+ const at::Half* x,
+ const float beta,
+ at::Half* y,
+ CUDAContext* context,
+ TensorProto::DataType math_type) {
+ const cublasOperation_t cu_trans_A =
+ (trans_A == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
+
+ // sort out what we need to call cublasSgemmEx / cublasHgemm
+ const int m = (cu_trans_A == CUBLAS_OP_N) ? N : M;
+ const int k = (cu_trans_A == CUBLAS_OP_N) ? M : N;
+ const int lda = (cu_trans_A == CUBLAS_OP_N) ? m : k;
+ const int ldc = m;
+
+ if (math_type == TensorProto_DataType_FLOAT) {
+ CUBLAS_ENFORCE(cublasSetPointerMode(
+ context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
+#ifdef __HIP_PLATFORM_HCC__
+ // rocblas doesn't support cublasSgemmEx type API yet.
+ // It has more general rocblas_gemm_ex API which is more close to
+ // cublasGemmEx rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C,
+ // whereas cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C
+ ROCBLAS_ENFORCE(rocblas_gemm_ex(
+ context->rocblashandle(),
+ cu_trans_A,
+ rocblas_operation_none,
+ m,
+ 1,
+ k,
+ &alpha,
+ A,
+ rocblas_datatype_f16_r,
+ lda,
+ x,
+ rocblas_datatype_f16_r,
+ k,
+ &beta,
+ y,
+ rocblas_datatype_f16_r,
+ ldc,
+ y, // D
+ rocblas_datatype_f16_r, // D type
+ ldc, // ldd
+ rocblas_datatype_f32_r, // compute type
+ rocblas_gemm_algo_standard, // rocblas_gemm_algo
+ 0, // solution index, reserved for future use
+ 0, // flags, reserved for future use
+ NULL, // size of workspace
+ NULL)); // workspace
+#else
+ CUBLAS_ENFORCE(cublasSgemmEx(
+ context->cublas_handle(),
+ cu_trans_A,
+ CUBLAS_OP_N,
+ m,
+ 1,
+ k,
+ &alpha,
+ A,
+ CUDA_R_16F,
+ lda,
+ x,
+ CUDA_R_16F,
+ k,
+ &beta,
+ y,
+ CUDA_R_16F,
+ ldc));
+#endif // __HIP_PLATFORM_HCC__
+ } else if (math_type == TensorProto_DataType_FLOAT16) {
+ const __half alpha_fp16 = at::Half(alpha);
+ const __half beta_fp16 = at::Half(beta);
+ CUBLAS_ENFORCE(cublasSetPointerMode(
+ context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
+ CUBLAS_ENFORCE(cublasHgemm(
+ context->cublas_handle(),
+ cu_trans_A,
+ CUBLAS_OP_N,
+ m,
+ 1,
+ k,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(&alpha_fp16),
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(A),
+ lda,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(x),
+ k,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(&beta_fp16),
+ reinterpret_cast<CUBLAS_HALF_TYPE*>(y),
+ ldc));
+ } else {
+ // fail
+ CAFFE_THROW("Unsupported math type");
+ }
+}
+
#if CUDA_VERSION >= 9000
// No change, but required. Defer to default CUDA engine
@@ -1176,6 +1309,68 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext, TensorCoreEngine>(
}
template <>
+CAFFE2_CUDA_EXPORT void GemmBatched<float, CUDAContext, TensorCoreEngine>(
+ const CBLAS_TRANSPOSE trans_A,
+ const CBLAS_TRANSPOSE trans_B,
+ const int batch_size,
+ const int M,
+ const int N,
+ const int K,
+ const float alpha,
+ const float** A,
+ const float** B,
+ const float beta,
+ float** C,
+ CUDAContext* context,
+ TensorProto::DataType math_type) {
+ GemmBatched<float, CUDAContext, DefaultEngine>(
+ trans_A,
+ trans_B,
+ batch_size,
+ M,
+ N,
+ K,
+ alpha,
+ A,
+ B,
+ beta,
+ C,
+ context,
+ math_type);
+}
+
+template <>
+CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext, TensorCoreEngine>(
+ const CBLAS_TRANSPOSE trans_A,
+ const CBLAS_TRANSPOSE trans_B,
+ const int batch_size,
+ const int M,
+ const int N,
+ const int K,
+ const float alpha,
+ const at::Half** A,
+ const at::Half** B,
+ const float beta,
+ at::Half** C,
+ CUDAContext* context,
+ TensorProto::DataType math_type) {
+ GemmBatched<at::Half, CUDAContext, DefaultEngine>(
+ trans_A,
+ trans_B,
+ batch_size,
+ M,
+ N,
+ K,
+ alpha,
+ A,
+ B,
+ beta,
+ C,
+ context,
+ math_type);
+}
+
+template <>
CAFFE2_CUDA_EXPORT void
GemmStridedBatched<float, CUDAContext, TensorCoreEngine>(
const CBLAS_TRANSPOSE trans_A,
@@ -1194,7 +1389,7 @@ GemmStridedBatched<float, CUDAContext, TensorCoreEngine>(
const int C_stride,
CUDAContext* context,
TensorProto::DataType math_type) {
- return GemmStridedBatched<float, CUDAContext, DefaultEngine>(
+ GemmStridedBatched<float, CUDAContext, DefaultEngine>(
trans_A,
trans_B,
batch_size,
@@ -1232,7 +1427,7 @@ GemmStridedBatched<at::Half, CUDAContext, TensorCoreEngine>(
const int C_stride,
CUDAContext* context,
TensorProto::DataType math_type) {
- return GemmStridedBatched<at::Half, CUDAContext, DefaultEngine>(
+ GemmStridedBatched<at::Half, CUDAContext, DefaultEngine>(
trans_A,
trans_B,
batch_size,
@@ -1251,6 +1446,38 @@ GemmStridedBatched<at::Half, CUDAContext, TensorCoreEngine>(
math_type);
}
+template <>
+CAFFE2_CUDA_EXPORT void Gemv<float, CUDAContext, TensorCoreEngine>(
+ const CBLAS_TRANSPOSE trans_A,
+ const int M,
+ const int N,
+ const float alpha,
+ const float* A,
+ const float* x,
+ const float beta,
+ float* y,
+ CUDAContext* context,
+ TensorProto::DataType math_type) {
+ Gemv<float, CUDAContext, DefaultEngine>(
+ trans_A, M, N, alpha, A, x, beta, y, context, math_type);
+}
+
+template <>
+CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext, TensorCoreEngine>(
+ const CBLAS_TRANSPOSE trans_A,
+ const int M,
+ const int N,
+ const float alpha,
+ const at::Half* A,
+ const at::Half* x,
+ const float beta,
+ at::Half* y,
+ CUDAContext* context,
+ TensorProto::DataType math_type) {
+ Gemv<at::Half, CUDAContext, DefaultEngine>(
+ trans_A, M, N, alpha, A, x, beta, y, context, math_type);
+}
+
#endif // CUDA_VERSION >= 9000
template <>
@@ -1294,37 +1521,6 @@ CAFFE2_CUDA_EXPORT void GemmEx<float, CUDAContext>(
ldc));
}
-template <>
-CAFFE2_CUDA_EXPORT void Gemv<float, CUDAContext>(
- const CBLAS_TRANSPOSE trans_A,
- const int M,
- const int N,
- const float alpha,
- const float* A,
- const float* x,
- const float beta,
- float* y,
- CUDAContext* context,
- TensorProto::DataType math_type) {
- const cublasOperation_t cu_trans_A =
- (trans_A == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
- CUBLAS_ENFORCE(
- cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
- CUBLAS_ENFORCE(cublasSgemv(
- context->cublas_handle(),
- cu_trans_A,
- N,
- M,
- &alpha,
- A,
- N,
- x,
- 1,
- &beta,
- y,
- 1));
-}
-
// Batched Add variants
namespace {
@@ -1366,108 +1562,6 @@ CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(float);
CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(at::Half);
#undef CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH
-template <>
-CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
- const CBLAS_TRANSPOSE trans_A,
- const int M,
- const int N,
- const float alpha,
- const at::Half* A,
- const at::Half* x,
- const float beta,
- at::Half* y,
- CUDAContext* context,
- TensorProto::DataType math_type) {
- const cublasOperation_t cu_trans_A =
- (trans_A == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
-
- // sort out what we need to call cublasSgemmEx / cublasHgemm
- const int m = (cu_trans_A == CUBLAS_OP_N) ? N : M;
- const int k = (cu_trans_A == CUBLAS_OP_N) ? M : N;
- const int lda = (cu_trans_A == CUBLAS_OP_N) ? m : k;
- const int ldc = m;
-
- if (math_type == TensorProto_DataType_FLOAT) {
- CUBLAS_ENFORCE(cublasSetPointerMode(
- context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
-#ifdef __HIP_PLATFORM_HCC__
- // rocblas doesn't support cublasSgemmEx type API yet.
- // It has more general rocblas_gemm_ex API which is more close to
- // cublasGemmEx rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C,
- // whereas cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C
- ROCBLAS_ENFORCE(rocblas_gemm_ex(
- context->rocblashandle(),
- cu_trans_A,
- rocblas_operation_none,
- m,
- 1,
- k,
- &alpha,
- A,
- rocblas_datatype_f16_r,
- lda,
- x,
- rocblas_datatype_f16_r,
- k,
- &beta,
- y,
- rocblas_datatype_f16_r,
- ldc,
- y, // D
- rocblas_datatype_f16_r, // D type
- ldc, // ldd
- rocblas_datatype_f32_r, // compute type
- rocblas_gemm_algo_standard, // rocblas_gemm_algo
- 0, // solution index, reserved for future use
- 0, // flags, reserved for future use
- NULL, // size of workspace
- NULL)); // workspace
-#else
- CUBLAS_ENFORCE(cublasSgemmEx(
- context->cublas_handle(),
- cu_trans_A,
- CUBLAS_OP_N,
- m,
- 1,
- k,
- &alpha,
- A,
- CUDA_R_16F,
- lda,
- x,
- CUDA_R_16F,
- k,
- &beta,
- y,
- CUDA_R_16F,
- ldc));
-#endif // __HIP_PLATFORM_HCC__
- } else if (math_type == TensorProto_DataType_FLOAT16) {
- const __half alpha_fp16 = at::Half(alpha);
- const __half beta_fp16 = at::Half(beta);
- CUBLAS_ENFORCE(cublasSetPointerMode(
- context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
- CUBLAS_ENFORCE(cublasHgemm(
- context->cublas_handle(),
- cu_trans_A,
- CUBLAS_OP_N,
- m,
- 1,
- k,
- reinterpret_cast<const CUBLAS_HALF_TYPE*>(&alpha_fp16),
- reinterpret_cast<const CUBLAS_HALF_TYPE*>(A),
- lda,
- reinterpret_cast<const CUBLAS_HALF_TYPE*>(x),
- k,
- reinterpret_cast<const CUBLAS_HALF_TYPE*>(&beta_fp16),
- reinterpret_cast<CUBLAS_HALF_TYPE*>(y),
- ldc));
- } else {
- // fail
- CAFFE_THROW("Unsupported math type");
- }
-}
-
namespace {
template <typename T>
__global__ void