summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorrohithkrn <rohith.nallamaddi@gmail.com>2018-12-10 17:25:46 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-12-10 17:54:06 -0800
commit7e2b074219fad6d2b09b379423e83b2295b29df2 (patch)
tree0f47bb2e53291946e2ce96397d21efe7498faa69
parent92f3616f3695ca0ec79e4d583b086cefbcef8aed (diff)
downloadpytorch-7e2b074219fad6d2b09b379423e83b2295b29df2.tar.gz
pytorch-7e2b074219fad6d2b09b379423e83b2295b29df2.tar.bz2
pytorch-7e2b074219fad6d2b09b379423e83b2295b29df2.zip
Integrate rocBLAS fp16 api into Caffe2 (#14882)
Summary: This PR integrates rocBLAS half and mixed precision APIs in to Caffe2. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14882 Differential Revision: D13407840 Pulled By: bddppq fbshipit-source-id: 75cb0d74da066776fa66575f1d255e879d36121e
-rw-r--r--caffe2/core/common_gpu.h9
-rw-r--r--caffe2/operators/fully_connected_op_gpu.cc2
-rw-r--r--caffe2/python/operator_test/fc_operator_test.py11
-rw-r--r--caffe2/python/operator_test/matmul_op_test.py6
-rw-r--r--caffe2/python/operator_test/momentum_sgd_test.py10
-rw-r--r--caffe2/sgd/fp16_momentum_sgd_op.cu2
-rw-r--r--caffe2/utils/math_gpu.cu173
-rw-r--r--cmake/Dependencies.cmake1
-rw-r--r--tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py2
9 files changed, 159 insertions, 57 deletions
diff --git a/caffe2/core/common_gpu.h b/caffe2/core/common_gpu.h
index 1af674e8a3..db87887fc8 100644
--- a/caffe2/core/common_gpu.h
+++ b/caffe2/core/common_gpu.h
@@ -69,7 +69,7 @@
// CAFFE_HAS_CUDA_FP16 manually.
#ifndef CAFFE_HAS_CUDA_FP16
-#if CUDA_VERSION >= 7050
+#if CUDA_VERSION >= 7050 || defined(__HIP_PLATFORM_HCC__)
#define CAFFE_HAS_CUDA_FP16
#endif // CUDA_VERSION >= 7050
#endif // CAFFE_HAS_CUDA_FP16
@@ -78,6 +78,13 @@
#include <cuda_fp16.h>
#endif
+// cuda major revision number below which fp16 compute is not supoorted
+#ifndef __HIP_PLATFORM_HCC__
+constexpr int kFp16CUDADevicePropMajor = 6;
+#else
+constexpr int kFp16CUDADevicePropMajor = 3;
+#endif
+
// Re-enable strict aliasing diagnostic if it was disabled.
#if CUDA_VERSION >= 9000
#ifdef __GNUC__
diff --git a/caffe2/operators/fully_connected_op_gpu.cc b/caffe2/operators/fully_connected_op_gpu.cc
index 3f82283f48..4762692ee2 100644
--- a/caffe2/operators/fully_connected_op_gpu.cc
+++ b/caffe2/operators/fully_connected_op_gpu.cc
@@ -6,8 +6,6 @@ namespace caffe2 {
namespace {
-constexpr int kFp16CUDADevicePropMajor = 6;
-
template <class FullyConnectedOp>
bool RunFullyConnectedOpOnCUDADevice(
const bool float16_compute,
diff --git a/caffe2/python/operator_test/fc_operator_test.py b/caffe2/python/operator_test/fc_operator_test.py
index d42e00cda2..466453c9aa 100644
--- a/caffe2/python/operator_test/fc_operator_test.py
+++ b/caffe2/python/operator_test/fc_operator_test.py
@@ -16,9 +16,9 @@ import unittest
class TestFcOperator(serial.SerializedTestCase):
def _run_test(self, n, m, k, transposed, multi_dim, dtype, engine, gc, dc):
if dtype == np.float16:
- # fp16 only supported with CUDA
- assume(gc.device_type == caffe2_pb2.CUDA)
- dc = [d for d in dc if d.device_type == caffe2_pb2.CUDA]
+ # fp16 only supported with CUDA/HIP
+ assume(core.IsGPUDeviceType(gc.device_type))
+ dc = [d for d in dc if core.IsGPUDeviceType(d.device_type)]
if engine == 'TENSORCORE':
# TensorCore only makes sense with CUDA
@@ -54,18 +54,21 @@ class TestFcOperator(serial.SerializedTestCase):
engine=engine,
)
- if dtype == np.float16 and gc.device_type == caffe2_pb2.CUDA:
+ if dtype == np.float16 and core.IsGPUDeviceType(gc.device_type):
a = caffe2_pb2.Argument()
a.i = 1
a.name = "float16_compute"
op.arg.extend([a])
# Check against numpy reference
+ # ReferenceChecks is flaky on rocm with threshold of 1e-4 for fp16. Relaxing to 1e-3.
+ threshold = 1e-3 if (gc.device_type == caffe2_pb2.HIP and dtype == np.float16) else 1e-4
self.assertReferenceChecks(
device_option=gc,
op=op,
inputs=[X, W, b],
reference=fc_tranposed_op if transposed else fc_op,
+ threshold=threshold
)
# Check over multiple devices
self.assertDeviceChecks(dc, op, [X, W, b], [0])
diff --git a/caffe2/python/operator_test/matmul_op_test.py b/caffe2/python/operator_test/matmul_op_test.py
index 1872a129e5..64e0e51051 100644
--- a/caffe2/python/operator_test/matmul_op_test.py
+++ b/caffe2/python/operator_test/matmul_op_test.py
@@ -140,9 +140,9 @@ class TestBatchMatMul(serial.SerializedTestCase):
)
def test_batch_matmul(self, C, M, K, N, trans_a, trans_b, dtype, gc, dc):
if dtype == np.float16:
- # fp16 is only supported with CUDA
- assume(gc.device_type == caffe2_pb2.CUDA)
- dc = [d for d in dc if d.device_type == caffe2_pb2.CUDA]
+ # fp16 is only supported with CUDA/HIP
+ assume(core.IsGPUDeviceType(gc.device_type))
+ dc = [d for d in dc if core.IsGPUDeviceType(d.device_type)]
batch_dims = np.random.randint(
low=1,
diff --git a/caffe2/python/operator_test/momentum_sgd_test.py b/caffe2/python/operator_test/momentum_sgd_test.py
index 39e358f30d..27dcb78c14 100644
--- a/caffe2/python/operator_test/momentum_sgd_test.py
+++ b/caffe2/python/operator_test/momentum_sgd_test.py
@@ -7,8 +7,7 @@ from caffe2.python import core, workspace
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.serialized_test.serialized_test_util as serial
-import hypothesis
-from hypothesis import given
+from hypothesis import given, assume
import hypothesis.strategies as st
import numpy as np
import unittest
@@ -95,7 +94,7 @@ class TestMomentumSGD(serial.SerializedTestCase):
)
# Verify that the generated indices are unique
- hypothesis.assume(
+ assume(
np.array_equal(
np.unique(indices.flatten()),
np.sort(indices.flatten())))
@@ -139,9 +138,10 @@ class TestMomentumSGD(serial.SerializedTestCase):
[grad, m, lr, w, indices],
sparse)
- @given(n=st.integers(4, 8), nesterov=st.booleans(), **hu.gcs_gpu_only)
- @unittest.skipIf(not workspace.has_gpu_support, "No gpu support.")
+ @unittest.skipIf(not workspace.has_gpu_support and not workspace.has_hip_support, "No gpu support.")
+ @given(n=st.integers(4, 8), nesterov=st.booleans(), **hu.gcs)
def test_fp16momentum_sgd(self, n, nesterov, gc, dc):
+ assume(core.IsGPUDeviceType(gc.device_type))
gpuvers = workspace.GetDeviceProperties(0)["major"]
if gpuvers < 6:
print("No FP16 support because major version {} < 6".format(gpuvers))
diff --git a/caffe2/sgd/fp16_momentum_sgd_op.cu b/caffe2/sgd/fp16_momentum_sgd_op.cu
index 4da36da220..b7ac0a7b76 100644
--- a/caffe2/sgd/fp16_momentum_sgd_op.cu
+++ b/caffe2/sgd/fp16_momentum_sgd_op.cu
@@ -198,7 +198,7 @@ void fp16_momentum_sgd_update<CUDAContext>(
at::Half* param,
CUDAContext* context) {
const cudaDeviceProp& prop = GetDeviceProperty(0);
- if (prop.major >= 6) {
+ if (prop.major >= kFp16CUDADevicePropMajor) {
if (!fp32_update) {
FP16MomentumSGDKernel<<<
CAFFE_GET_BLOCKS(N / 2),
diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu
index 12abf4289f..dc7cb22efc 100644
--- a/caffe2/utils/math_gpu.cu
+++ b/caffe2/utils/math_gpu.cu
@@ -35,6 +35,12 @@
#define FIXED_DIVISOR_DIV_MOD(d, n, q, r) (d.DivMod(n, q, r))
#endif // __HIP_PLATFORM_HCC__
+#ifdef __HIP_PLATFORM_HCC__
+using CUBLAS_HALF_TYPE = rocblas_half;
+#else // __HIP_PLATFORM_HCC
+using CUBLAS_HALF_TYPE = __half;
+#endif // __HIP_PLATFORM_HCC
+
#include "caffe2/utils/math_utils.h"
#if THRUST_VERSION >= 100800
@@ -743,9 +749,6 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
at::Half* C,
CUDAContext* context,
TensorProto::DataType math_type) {
-#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
- CAFFE_THROW("HIP currently does not support FP16 yet.");
-#else
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
const int lda = (trans_A == CblasNoTrans) ? K : M;
@@ -757,6 +760,39 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
if (math_type == TensorProto_DataType_FLOAT) {
CUBLAS_ENFORCE(cublasSetPointerMode(
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
+#ifdef __HIP_PLATFORM_HCC__
+ // rocblas doesn't support cublasSgemmEx type API yet.
+ // It has more general rocblas_gemm_ex API which is more close to cublasGemmEx
+ // rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C, whereas
+ // cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C
+ ROCBLAS_ENFORCE(rocblas_gemm_ex(
+ context->rocblashandle(),
+ cu_trans_B,
+ cu_trans_A,
+ N,
+ M,
+ K,
+ &alpha,
+ B,
+ rocblas_datatype_f16_r,
+ ldb,
+ A,
+ rocblas_datatype_f16_r,
+ lda,
+ &beta,
+ C,
+ rocblas_datatype_f16_r,
+ N,
+ C, // D
+ rocblas_datatype_f16_r, // D type
+ N, // ldd
+ rocblas_datatype_f32_r, // compute type
+ rocblas_gemm_algo_standard, // rocblas_gemm_algo
+ 0, // solution index, reserved for future use
+ 0, // flags, reserved for future use
+ NULL, // size of workspace
+ NULL)); // workspace
+#else
CUBLAS_ENFORCE(cublasSgemmEx(
context->cublas_handle(),
cu_trans_B,
@@ -775,6 +811,7 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
C,
CUDA_R_16F,
N));
+#endif // __HIP_PLATFORM_HCC__
} else if (math_type == TensorProto_DataType_FLOAT16) {
// convert alpha, beta from float -> __half
const __half alpha_fp16 = at::Half(alpha);
@@ -789,19 +826,18 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
N,
M,
K,
- &alpha_fp16,
- (const __half*)B,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(&alpha_fp16),
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(B),
ldb,
- (const __half*)A,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(A),
lda,
- &beta_fp16,
- (__half*)C,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(&beta_fp16),
+ reinterpret_cast<CUBLAS_HALF_TYPE*>(C),
N));
} else {
// fail
CAFFE_THROW("Unsupported math type");
}
-#endif
}
template <>
@@ -968,9 +1004,6 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
at::Half** C,
CUDAContext* context,
TensorProto::DataType math_type) {
-#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
- CAFFE_THROW("HIP currently does not support FP16 yet.");
-#else
#if __CUDACC_VER_MAJOR__ < 9
// loop over matrices in the batch
for (int i = 0; i < batch_size; ++i) {
@@ -1083,7 +1116,6 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
CAFFE_THROW("Unsupported math type");
}
#endif
-#endif
}
template <>
@@ -1104,10 +1136,7 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
const int C_stride,
CUDAContext* context,
TensorProto::DataType math_type) {
-#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
- CAFFE_THROW("HIP currently does not support FP16 yet.");
-#else
-#if __CUDACC_VER_MAJOR__ < 8
+#if __CUDACC_VER_MAJOR__ < 8 && !defined(__HIP_PLATFORM_HCC__)
// loop over matrices in the batch
for (int i = 0; i < batch_size; ++i) {
Gemm<at::Half, CUDAContext>(
@@ -1127,7 +1156,7 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
const cublasOperation_t cu_trans_B =
(trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
if (math_type == TensorProto_DataType_FLOAT) {
-#if CUDA_VERSION < 9010
+#if CUDA_VERSION < 9010 && !defined(__HIP_PLATFORM_HCC__)
// loop over matrices in the batch
for (int i = 0; i < batch_size; ++i) {
Gemm<at::Half, CUDAContext>(
@@ -1139,6 +1168,42 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
#else
CUBLAS_ENFORCE(cublasSetPointerMode(
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
+#ifdef __HIP_PLATFORM_HCC__
+ // D[i*stride_d] = alpha*op(A[i*stride_a])*op(B[i*stride_b]) + beta*C[i*stride_c],
+ // for i in [0,batch_count-1]
+ ROCBLAS_ENFORCE(rocblas_gemm_strided_batched_ex(
+ context->rocblashandle(),
+ cu_trans_B,
+ cu_trans_A,
+ N,
+ M,
+ K,
+ &alpha,
+ B,
+ rocblas_datatype_f16_r,
+ ldb,
+ B_stride,
+ A,
+ rocblas_datatype_f16_r,
+ lda,
+ A_stride,
+ &beta,
+ C,
+ rocblas_datatype_f16_r,
+ ldc,
+ C_stride,
+ C, // D
+ rocblas_datatype_f16_r, // D type
+ ldc, // ldd
+ C_stride, // D stride
+ batch_size,
+ rocblas_datatype_f32_r, // compute type
+ rocblas_gemm_algo_standard, // rocblas_gemm_algo
+ 0, // solution index, reserved for future use
+ 0, // flags, reserved for future use
+ NULL, // size of workspace
+ NULL)); // workspace
+#else
CUBLAS_ENFORCE(cublasGemmStridedBatchedEx(
context->cublas_handle(),
cu_trans_B,
@@ -1163,6 +1228,7 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
batch_size,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+#endif // __HIP_PLATFORM_HCC__
#endif
} else if (math_type == TensorProto_DataType_FLOAT16) {
// Convert alpha, beta from float -> __half
@@ -1177,15 +1243,15 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
N,
M,
K,
- &alpha_fp16,
- (const __half*)B,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(&alpha_fp16),
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(B),
ldb,
B_stride,
- (const __half*)A,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(A),
lda,
A_stride,
- &beta_fp16,
- (__half*)C,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(&beta_fp16),
+ reinterpret_cast<CUBLAS_HALF_TYPE*>(C),
ldc,
C_stride,
batch_size));
@@ -1193,7 +1259,6 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
CAFFE_THROW("Unsupported math type");
}
#endif
-#endif
}
#if CUDA_VERSION >= 9000
@@ -1479,9 +1544,6 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
at::Half* y,
CUDAContext* context,
TensorProto::DataType math_type) {
-#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
- CAFFE_THROW("HIP currently does not support FP16 yet.");
-#else
const cublasOperation_t cu_trans_A =
(trans_A == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
@@ -1494,6 +1556,39 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
if (math_type == TensorProto_DataType_FLOAT) {
CUBLAS_ENFORCE(cublasSetPointerMode(
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
+#ifdef __HIP_PLATFORM_HCC__
+ // rocblas doesn't support cublasSgemmEx type API yet.
+ // It has more general rocblas_gemm_ex API which is more close to cublasGemmEx
+ // rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C, whereas
+ // cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C
+ ROCBLAS_ENFORCE(rocblas_gemm_ex(
+ context->rocblashandle(),
+ cu_trans_A,
+ rocblas_operation_none,
+ m,
+ 1,
+ k,
+ &alpha,
+ A,
+ rocblas_datatype_f16_r,
+ lda,
+ x,
+ rocblas_datatype_f16_r,
+ k,
+ &beta,
+ y,
+ rocblas_datatype_f16_r,
+ ldc,
+ y, // D
+ rocblas_datatype_f16_r, // D type
+ ldc, // ldd
+ rocblas_datatype_f32_r, // compute type
+ rocblas_gemm_algo_standard, // rocblas_gemm_algo
+ 0, // solution index, reserved for future use
+ 0, // flags, reserved for future use
+ NULL, // size of workspace
+ NULL)); // workspace
+#else
CUBLAS_ENFORCE(cublasSgemmEx(
context->cublas_handle(),
cu_trans_A,
@@ -1512,6 +1607,7 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
y,
CUDA_R_16F,
ldc));
+#endif // __HIP_PLATFORM_HCC__
} else if (math_type == TensorProto_DataType_FLOAT16) {
const __half alpha_fp16 = at::Half(alpha);
const __half beta_fp16 = at::Half(beta);
@@ -1524,19 +1620,18 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
m,
1,
k,
- &alpha_fp16,
- (const __half*)A,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(&alpha_fp16),
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(A),
lda,
- (const __half*)x,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(x),
k,
- &beta_fp16,
- (__half*)y,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(&beta_fp16),
+ reinterpret_cast<CUBLAS_HALF_TYPE*>(y),
ldc));
} else {
// fail
CAFFE_THROW("Unsupported math type");
}
-#endif
}
namespace {
@@ -1727,8 +1822,8 @@ CAFFE2_CUDA_EXPORT void Dot<at::Half, CUDAContext>(
const at::Half* b,
at::Half* y,
CUDAContext* context) {
-#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
- CAFFE_THROW("HIP currently does not support FP16 yet.");
+#if defined(__HIP_PLATFORM_HCC__)
+ CAFFE_THROW("HIP currently does not support FP16 completely yet.");
#else
// execute with 32-bit math
CUBLAS_ENFORCE(cublasSetPointerMode(
@@ -2358,8 +2453,8 @@ CAFFE2_CUDA_EXPORT void Axpy<at::Half, CUDAContext>(
const at::Half* X,
at::Half* Y,
CUDAContext* context) {
-#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
- CAFFE_THROW("HIP currently does not support FP16 yet.");
+#if defined(__HIP_PLATFORM_HCC__)
+ CAFFE_THROW("HIP currently does not support FP16 completely yet.");
#else
CUBLAS_ENFORCE(
cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
@@ -2397,8 +2492,8 @@ CAFFE2_CUDA_EXPORT void Axpy<at::Half, CUDAContext>(
const at::Half* X,
at::Half* Y,
CUDAContext* context) {
-#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
- CAFFE_THROW("HIP currently does not support FP16 yet.");
+#if defined(__HIP_PLATFORM_HCC__)
+ CAFFE_THROW("HIP currently does not support FP16 completely yet.");
#else
CUBLAS_ENFORCE(cublasSetPointerMode(
context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE));
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index 07cab9630a..1b84bf1d76 100644
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -746,7 +746,6 @@ if(USE_ROCM)
list(APPEND HIP_CXX_FLAGS -Wno-unused-command-line-argument)
list(APPEND HIP_CXX_FLAGS -Wno-duplicate-decl-specifier)
list(APPEND HIP_CXX_FLAGS -DCAFFE2_USE_MIOPEN)
- list(APPEND HIP_CXX_FLAGS -DROCBLAS_FP16=0)
set(HIP_HCC_FLAGS ${HIP_CXX_FLAGS})
# Ask hcc to generate device code during compilation so we can use
diff --git a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py
index c530c50dc3..22aa9721f8 100644
--- a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py
+++ b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py
@@ -1729,7 +1729,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict([
("cublasCgemmStridedBatched", ("rocblas_cgemm_strided_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)),
("cublasCgemm3mStridedBatched", ("rocblas_cgemm_3m_strided_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)),
("cublasZgemmStridedBatched", ("rocblas_zgemm_strided_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)),
- ("cublasHgemmStridedBatched", ("rocblas_hgemm_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)),
+ ("cublasHgemmStridedBatched", ("rocblas_hgemm_strided_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)),
("cublasSgemm", ("rocblas_sgemm", CONV_MATH_FUNC, API_BLAS)),
("cublasDgemm", ("rocblas_dgemm", CONV_MATH_FUNC, API_BLAS)),
("cublasCgemm", ("rocblas_cgemm", CONV_MATH_FUNC, API_BLAS)),