diff options
Diffstat (limited to 'driver/others/blas_server_win32.c')
-rw-r--r-- | driver/others/blas_server_win32.c | 69 |
1 files changed, 56 insertions, 13 deletions
diff --git a/driver/others/blas_server_win32.c b/driver/others/blas_server_win32.c index 5ecc4428b..d2cc91757 100644 --- a/driver/others/blas_server_win32.c +++ b/driver/others/blas_server_win32.c @@ -77,7 +77,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ if (!(mode & BLAS_COMPLEX)){ #ifdef EXPRECISION - if (mode & BLAS_XDOUBLE){ + if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ /* REAL / Extended Double */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble *, BLASLONG, xdouble *, BLASLONG, @@ -90,7 +90,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> c, args -> ldc, sb); } else #endif - if (mode & BLAS_DOUBLE){ + if ((mode & BLAS_PREC) == BLAS_DOUBLE){ /* REAL / Double */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double *, BLASLONG, double *, BLASLONG, @@ -101,7 +101,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> a, args -> lda, args -> b, args -> ldb, args -> c, args -> ldc, sb); - } else { + } else if ((mode & BLAS_PREC) == BLAS_SINGLE){ /* REAL / Single */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, @@ -112,10 +112,47 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> a, args -> lda, args -> b, args -> ldb, args -> c, args -> ldc, sb); +#ifdef BUILD_HALF + } else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){ + /* REAL / BFLOAT16 */ + void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16, + bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, + bfloat16 *, BLASLONG, void *) = func; + + afunc(args -> m, args -> n, args -> k, + ((bfloat16 *)args -> alpha)[0], + args -> a, args -> lda, + args -> b, args -> ldb, + args -> c, args -> ldc, sb); + } else if ((mode & BLAS_PREC) == BLAS_STOBF16){ + /* REAL / BLAS_STOBF16 */ + void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, + float *, BLASLONG, bfloat16 *, BLASLONG, + float *, BLASLONG, void *) = func; + + afunc(args -> m, args -> n, args -> k, + ((float *)args -> alpha)[0], + args -> a, args -> lda, + args -> b, args -> ldb, + args -> c, args -> ldc, sb); + } else if ((mode & BLAS_PREC) == BLAS_DTOBF16){ + /* REAL / BLAS_DTOBF16 */ + void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, + double *, BLASLONG, bfloat16 *, BLASLONG, + double *, BLASLONG, void *) = func; + + afunc(args -> m, args -> n, args -> k, + ((double *)args -> alpha)[0], + args -> a, args -> lda, + args -> b, args -> ldb, + args -> c, args -> ldc, sb); +#endif + } else { + /* REAL / Other types in future */ } } else { #ifdef EXPRECISION - if (mode & BLAS_XDOUBLE){ + if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ /* COMPLEX / Extended Double */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, BLASLONG, xdouble *, BLASLONG, @@ -129,7 +166,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> c, args -> ldc, sb); } else #endif - if (mode & BLAS_DOUBLE){ + if ((mode & BLAS_PREC) == BLAS_DOUBLE){ /* COMPLEX / Double */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double, double *, BLASLONG, double *, BLASLONG, @@ -141,7 +178,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> a, args -> lda, args -> b, args -> ldb, args -> c, args -> ldc, sb); - } else { + } else if ((mode & BLAS_PREC) == BLAS_SINGLE) { /* COMPLEX / Single */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float, float *, BLASLONG, float *, BLASLONG, @@ -153,7 +190,9 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> a, args -> lda, args -> b, args -> ldb, args -> c, args -> ldc, sb); - } + } else { + /* COMPLEX / Other types in future */ + } } } @@ -233,32 +272,36 @@ static DWORD WINAPI blas_thread_server(void *arg){ if (sb == NULL) { if (!(queue -> mode & BLAS_COMPLEX)){ #ifdef EXPRECISION - if (queue -> mode & BLAS_XDOUBLE){ + if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * sizeof(xdouble) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); } else #endif - if (queue -> mode & BLAS_DOUBLE){ + if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){ sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); - } else { + } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) { sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); + } else { + /* Other types in future */ } } else { #ifdef EXPRECISION - if (queue -> mode & BLAS_XDOUBLE){ + if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); } else #endif - if (queue -> mode & BLAS_DOUBLE){ + if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){ sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); - } else { + } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) { sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); + } else { + /* Other types in future */ } } queue->sb=sb; |