summaryrefslogtreecommitdiff
path: root/caffe2/utils
diff options
context:
space:
mode:
authoriotamudelta <dieterich@ogolem.org>2018-10-23 13:42:58 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-10-23 13:46:15 -0700
commit470e7660624c1f60b30df00e3c3f430ce3430482 (patch)
tree0ab25a423cdd4d7d5860abeb8b28b14fd0bd2ae5 /caffe2/utils
parent21285e73da4d15a8e5ffcb1a6b523feefd226755 (diff)
downloadpytorch-470e7660624c1f60b30df00e3c3f430ce3430482.tar.gz
pytorch-470e7660624c1f60b30df00e3c3f430ce3430482.tar.bz2
pytorch-470e7660624c1f60b30df00e3c3f430ce3430482.zip
Fix illegal code in rocblas_handle rocblas_handle() that causes failure w/ gcc as base compiler (#12957)
Summary: The legal function cublasHandle_t cublas_handle() was hipified to the clearly illegal rocblas_handle rocblas_handle(). It should not work and correctly fails with gcc as the host compiler as it induces an ambiguity. Function now hipifies to rocblas_handle rocblashandle() Fixes long standing issue we've observed in PyTorch when base compiler is gcc. For attention: bddppq ezyang Tests on ROCm PyTorch/Caffe2: https://github.com/ROCmSoftwarePlatform/pytorch/pull/284 Pull Request resolved: https://github.com/pytorch/pytorch/pull/12957 Differential Revision: D10501227 Pulled By: bddppq fbshipit-source-id: 568cb80801c0d14c9b1b61e3a7db387a5c21acf4
Diffstat (limited to 'caffe2/utils')
-rw-r--r--caffe2/utils/hip/math_hip.cc28
1 files changed, 14 insertions, 14 deletions
diff --git a/caffe2/utils/hip/math_hip.cc b/caffe2/utils/hip/math_hip.cc
index d499393dbe..bc6ee92f46 100644
--- a/caffe2/utils/hip/math_hip.cc
+++ b/caffe2/utils/hip/math_hip.cc
@@ -759,7 +759,7 @@ void Gemm<float, HIPContext>(
? rocblas_operation_none
: rocblas_operation_transpose;
ROCBLAS_ENFORCE(rocblas_sgemm(
- context->rocblas_handle(),
+ context->rocblashandle(),
cuTransB,
cuTransA,
N,
@@ -803,7 +803,7 @@ void Gemm<at::Half, HIPContext>(
: rocblas_operation_transpose;
if (math_type == TensorProto_DataType_FLOAT) {
ROCBLAS_CHECK(rocblas_sgemmEx(
- context->rocblas_handle(),
+ context->rocblashandle(),
cuTransB,
cuTransA,
N,
@@ -828,7 +828,7 @@ void Gemm<at::Half, HIPContext>(
// call cublasHgemm
ROCBLAS_CHECK(cublasHgemm(
- context->rocblas_handle(),
+ context->rocblashandle(),
cuTransB,
cuTransA,
N,
@@ -933,7 +933,7 @@ void GemmStridedBatched<float, HIPContext>(
? rocblas_operation_none
: rocblas_operation_transpose;
ROCBLAS_ENFORCE(rocblas_sgemm_strided_batched(
- context->rocblas_handle(),
+ context->rocblashandle(),
cuTransB,
cuTransA,
N,
@@ -1004,7 +1004,7 @@ void GemmStridedBatched<at::Half, HIPContext>(
__half alpha_fp16 = at::Half(alpha);
__half beta_fp16 = at::Half(beta);
ROCBLAS_ENFORCE(cublasHgemmStridedBatched(
- context->rocblas_handle(),
+ context->rocblashandle(),
cuTransB,
cuTransA,
N,
@@ -1051,7 +1051,7 @@ void GemmEx<float, HIPContext>(
? rocblas_operation_none
: rocblas_operation_transpose;
ROCBLAS_ENFORCE(rocblas_sgemm(
- context->rocblas_handle(),
+ context->rocblashandle(),
cuTransB,
cuTransA,
N,
@@ -1083,7 +1083,7 @@ void Gemv<float, HIPContext>(
? rocblas_operation_transpose
: rocblas_operation_none;
ROCBLAS_ENFORCE(rocblas_sgemv(
- context->rocblas_handle(),
+ context->rocblashandle(),
cuTransA,
N,
M,
@@ -1170,7 +1170,7 @@ void Gemv<at::Half, HIPContext>(
if (math_type == TensorProto_DataType_FLOAT) {
ROCBLAS_CHECK(cublasSgemmEx(
- context->rocblas_handle(),
+ context->rocblashandle(),
cuTransA,
rocblas_operation_none,
m,
@@ -1192,7 +1192,7 @@ void Gemv<at::Half, HIPContext>(
__half beta_fp16 = at::Half(beta);
ROCBLAS_CHECK(cublasHgemm(
- context->rocblas_handle(),
+ context->rocblashandle(),
cuTransA,
rocblas_operation_none,
m,
@@ -1390,7 +1390,7 @@ void Dot<float, HIPContext>(
HIPContext* context) {
float result;
ROCBLAS_ENFORCE(
- rocblas_sdot(context->rocblas_handle(), n, a, 1, b, 1, &result));
+ rocblas_sdot(context->rocblashandle(), n, a, 1, b, 1, &result));
context->CopyFromCPU<float>(1, &result, y);
}
@@ -1406,7 +1406,7 @@ void Dot<at::Half, HIPContext>(
at::Half result;
// execute with 32-bit math
ROCBLAS_CHECK(cublasDotEx(
- context->rocblas_handle(),
+ context->rocblashandle(),
n,
a,
CUDA_R_16F,
@@ -1879,7 +1879,7 @@ void Axpy<float, HIPContext>(
float* Y,
HIPContext* context) {
ROCBLAS_ENFORCE(
- rocblas_saxpy(context->rocblas_handle(), N, &alpha, X, 1, Y, 1));
+ rocblas_saxpy(context->rocblashandle(), N, &alpha, X, 1, Y, 1));
}
template <>
@@ -1891,7 +1891,7 @@ void Axpy<double, HIPContext>(
HIPContext* context) {
double alpha_d{alpha};
ROCBLAS_ENFORCE(
- rocblas_daxpy(context->rocblas_handle(), N, &alpha_d, X, 1, Y, 1));
+ rocblas_daxpy(context->rocblashandle(), N, &alpha_d, X, 1, Y, 1));
}
template <>
@@ -1904,7 +1904,7 @@ void Axpy<at::Half, HIPContext>(
CAFFE_THROW("Unsupported math type");
#if ROCBLAS_FP16
ROCBLAS_CHECK(cublasAxpyEx(
- context->rocblas_handle(),
+ context->rocblashandle(),
N,
&alpha,
CUDA_R_16F,