summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile.tail7
-rw-r--r--cblas.h11
-rw-r--r--cmake/kernel.cmake4
-rw-r--r--common.h3
-rw-r--r--common_interface.h5
-rw-r--r--common_level1.h6
-rw-r--r--common_macro.h6
-rw-r--r--common_param.h7
-rw-r--r--common_sh.h12
-rw-r--r--common_thread.h19
-rw-r--r--common_x86_64.h23
-rw-r--r--driver/others/blas_l1_thread.c74
-rw-r--r--driver/others/blas_server.c93
-rw-r--r--driver/others/blas_server_omp.c71
-rw-r--r--driver/others/blas_server_win32.c69
-rw-r--r--driver/others/dynamic.c44
-rw-r--r--exports/gensymbol4
-rw-r--r--interface/Makefile38
-rw-r--r--interface/bf16dot.c52
-rw-r--r--interface/bf16to.c62
-rw-r--r--interface/tobf16.c61
-rw-r--r--kernel/Makefile.L136
-rw-r--r--kernel/setparam-ref.c4
-rw-r--r--kernel/x86_64/KERNEL12
-rw-r--r--kernel/x86_64/bf16to.c114
-rw-r--r--kernel/x86_64/dtobf16_microk_cooperlake.c104
-rw-r--r--kernel/x86_64/shdot.c115
-rw-r--r--kernel/x86_64/shdot_microk_cooperlake.c159
-rw-r--r--kernel/x86_64/stobf16_microk_cooperlake.c86
-rw-r--r--kernel/x86_64/tobf16.c170
-rw-r--r--openblas_config_template.h3
31 files changed, 1392 insertions, 82 deletions
diff --git a/Makefile.tail b/Makefile.tail
index 39902982b..cfc4a36fc 100644
--- a/Makefile.tail
+++ b/Makefile.tail
@@ -5,13 +5,14 @@ QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
CBLASOBJS_P = $(CBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
ZBLASOBJS_P = $(ZBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
XBLASOBJS_P = $(XBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
+SHEXTOBJS_P = $(SHEXTOBJS:.$(SUFFIX)=.$(PSUFFIX))
COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX))
HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX))
-BLASOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
-BLASOBJS_P = $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P)
+BLASOBJS = $(SHEXTOBJS) $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
+BLASOBJS_P = $(SHEXTOBJS_P) $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P)
ifdef EXPRECISION
BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
@@ -30,6 +31,7 @@ $(QBLASOBJS) $(QBLASOBJS_P) : override CFLAGS += -DXDOUBLE -UCOMPLEX
$(CBLASOBJS) $(CBLASOBJS_P) : override CFLAGS += -UDOUBLE -DCOMPLEX
$(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX
$(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX
+$(SHEXTOBJS) $(SHEXTOBJS_P) : override CFLAGS += -DHALF -UDOUBLE -UCOMPLEX
$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
@@ -38,6 +40,7 @@ $(QBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(CBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(ZBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(XBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
+$(SHEXTOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
libs :: $(BLASOBJS) $(COMMONOBJS)
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
diff --git a/cblas.h b/cblas.h
index 4bc5588d8..21f3958f2 100644
--- a/cblas.h
+++ b/cblas.h
@@ -382,6 +382,17 @@ void cblas_cgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint
void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta,
double *c, OPENBLAS_CONST blasint cldc);
+/*** BFLOAT16 and INT8 extensions ***/
+/* convert float array to BFLOAT16 array by rounding */
+void cblas_shstobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST float *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout);
+/* convert double array to BFLOAT16 array by rounding */
+void cblas_shdtobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST double *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout);
+/* convert BFLOAT16 array to float array */
+void cblas_sbf16tos(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, float *out, OPENBLAS_CONST blasint incout);
+/* convert BFLOAT16 array to double array */
+void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, double *out, OPENBLAS_CONST blasint incout);
+/* dot production of BFLOAT16 input arrays, and output as float */
+float cblas_shdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy);
#ifdef __cplusplus
}
diff --git a/cmake/kernel.cmake b/cmake/kernel.cmake
index 4b505a102..79eeaae6f 100644
--- a/cmake/kernel.cmake
+++ b/cmake/kernel.cmake
@@ -126,12 +126,14 @@ if (BUILD_HALF)
set(SHAXPYKERNEL ../arm/axpy.c)
set(SHAXPBYKERNEL ../arm/axpby.c)
set(SHCOPYKERNEL ../arm/copy.c)
- set(SHDOTKERNEL ../arm/dot.c)
+ set(SHDOTKERNEL ../x86_64/shdot.c)
set(SHROTKERNEL ../arm/rot.c)
set(SHSCALKERNEL ../arm/scal.c)
set(SHNRM2KERNEL ../arm/nrm2.c)
set(SHSUMKERNEL ../arm/sum.c)
set(SHSWAPKERNEL ../arm/swap.c)
+ set(TOBF16KERNEL ../x86_64/tobf16.c)
+ set(BF16TOKERNEL ../x86_64/bf16to.c)
endif ()
endmacro ()
diff --git a/common.h b/common.h
index d6637abe4..adc162536 100644
--- a/common.h
+++ b/common.h
@@ -258,7 +258,8 @@ typedef unsigned long BLASULONG;
#endif
#ifndef BFLOAT16
-typedef unsigned short bfloat16;
+#include <stdint.h>
+typedef uint16_t bfloat16;
#define HALFCONVERSION 1
#endif
diff --git a/common_interface.h b/common_interface.h
index 78f5be6b0..35a957aa1 100644
--- a/common_interface.h
+++ b/common_interface.h
@@ -54,6 +54,11 @@ double BLASFUNC(dsdot) (blasint *, float *, blasint *, float *, blasint *);
double BLASFUNC(ddot) (blasint *, double *, blasint *, double *, blasint *);
xdouble BLASFUNC(qdot) (blasint *, xdouble *, blasint *, xdouble *, blasint *);
+float BLASFUNC(shdot) (blasint *, bfloat16 *, blasint *, bfloat16 *, blasint *);
+void BLASFUNC(shstobf16) (blasint *, float *, blasint *, bfloat16 *, blasint *);
+void BLASFUNC(shdtobf16) (blasint *, double *, blasint *, bfloat16 *, blasint *);
+void BLASFUNC(sbf16tos) (blasint *, bfloat16 *, blasint *, float *, blasint *);
+void BLASFUNC(dbf16tod) (blasint *, bfloat16 *, blasint *, double *, blasint *);
#ifdef RETURN_BY_STRUCT
typedef struct {
diff --git a/common_level1.h b/common_level1.h
index 74cafb6db..88aa275a5 100644
--- a/common_level1.h
+++ b/common_level1.h
@@ -46,6 +46,12 @@ float sdot_k(BLASLONG, float *, BLASLONG, float *, BLASLONG);
double dsdot_k(BLASLONG, float *, BLASLONG, float *, BLASLONG);
double ddot_k(BLASLONG, double *, BLASLONG, double *, BLASLONG);
xdouble qdot_k(BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG);
+float shdot_k(BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
+
+void shstobf16_k(BLASLONG, float *, BLASLONG, bfloat16 *, BLASLONG);
+void shdtobf16_k(BLASLONG, double *, BLASLONG, bfloat16 *, BLASLONG);
+void sbf16tos_k (BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
+void dbf16tod_k (BLASLONG, bfloat16 *, BLASLONG, double *, BLASLONG);
openblas_complex_float cdotc_k (BLASLONG, float *, BLASLONG, float *, BLASLONG);
openblas_complex_float cdotu_k (BLASLONG, float *, BLASLONG, float *, BLASLONG);
diff --git a/common_macro.h b/common_macro.h
index 8fe1f156f..3d6bcd9e8 100644
--- a/common_macro.h
+++ b/common_macro.h
@@ -646,6 +646,11 @@
#elif defined(HALF)
+#define D_TO_BF16_K SHDTOBF16_K
+#define D_BF16_TO_K DBF16TOD_K
+#define S_TO_BF16_K SHSTOBF16_K
+#define S_BF16_TO_K SBF16TOS_K
+
#define AMAX_K SAMAX_K
#define AMIN_K SAMIN_K
#define MAX_K SMAX_K
@@ -657,6 +662,7 @@
#define ASUM_K SASUM_K
#define DOTU_K SDOTU_K
#define DOTC_K SDOTC_K
+#define BF16_DOT_K SHDOT_K
#define AXPYU_K SAXPYU_K
#define AXPYC_K SAXPYC_K
#define AXPBY_K SAXPBY_K
diff --git a/common_param.h b/common_param.h
index 0437482dc..a52de98ab 100644
--- a/common_param.h
+++ b/common_param.h
@@ -51,6 +51,11 @@ typedef struct {
int shgemm_p, shgemm_q, shgemm_r;
int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn;
+ void (*shstobf16_k) (BLASLONG, float *, BLASLONG, bfloat16 *, BLASLONG);
+ void (*shdtobf16_k) (BLASLONG, double *, BLASLONG, bfloat16 *, BLASLONG);
+ void (*sbf16tos_k) (BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
+ void (*dbf16tod_k) (BLASLONG, bfloat16 *, BLASLONG, double *, BLASLONG);
+
float (*shamax_k) (BLASLONG, float *, BLASLONG);
float (*shamin_k) (BLASLONG, float *, BLASLONG);
float (*shmax_k) (BLASLONG, float *, BLASLONG);
@@ -64,7 +69,7 @@ BLASLONG (*ishmin_k) (BLASLONG, float *, BLASLONG);
float (*shasum_k) (BLASLONG, float *, BLASLONG);
float (*shsum_k) (BLASLONG, float *, BLASLONG);
int (*shcopy_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
- float (*shdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
+ float (*shdot_k) (BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
double (*dshdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
int (*shrot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float, float);
diff --git a/common_sh.h b/common_sh.h
index 7a0045762..5dc99b3bd 100644
--- a/common_sh.h
+++ b/common_sh.h
@@ -3,6 +3,12 @@
#ifndef DYNAMIC_ARCH
+#define SHDOT_K shdot_k
+#define SHSTOBF16_K shstobf16_k
+#define SHDTOBF16_K shdtobf16_k
+#define SBF16TOS_K sbf16tos_k
+#define DBF16TOD_K dbf16tod_k
+
#define SHGEMM_ONCOPY shgemm_oncopy
#define SHGEMM_OTCOPY shgemm_otcopy
@@ -18,6 +24,12 @@
#else
+#define SHDOT_K gotoblas -> shdot_k
+#define SHSTOBF16_K gotoblas -> shstobf16_k
+#define SHDTOBF16_K gotoblas -> shdtobf16_k
+#define SBF16TOS_K gotoblas -> sbf16tos_k
+#define DBF16TOD_K gotoblas -> dbf16tod_k
+
#define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy
#define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy
#define SHGEMM_INCOPY gotoblas -> shgemm_incopy
diff --git a/common_thread.h b/common_thread.h
index ec0c65b22..a18df0d78 100644
--- a/common_thread.h
+++ b/common_thread.h
@@ -59,12 +59,19 @@ extern int blas_omp_linked;
#define BLAS_PTHREAD 0x4000U
#define BLAS_NODE 0x2000U
-#define BLAS_PREC 0x0003U
-#define BLAS_SINGLE 0x0000U
-#define BLAS_DOUBLE 0x0001U
-#define BLAS_XDOUBLE 0x0002U
-#define BLAS_REAL 0x0000U
-#define BLAS_COMPLEX 0x0004U
+#define BLAS_PREC 0x000FU
+#define BLAS_INT8 0x0000U
+#define BLAS_BFLOAT16 0x0001U
+#define BLAS_SINGLE 0x0002U
+#define BLAS_DOUBLE 0x0003U
+#define BLAS_XDOUBLE 0x0004U
+#define BLAS_STOBF16 0x0008U
+#define BLAS_DTOBF16 0x0009U
+#define BLAS_BF16TOS 0x000AU
+#define BLAS_BF16TOD 0x000BU
+
+#define BLAS_REAL 0x0000U
+#define BLAS_COMPLEX 0x1000U
#define BLAS_TRANSA 0x0030U /* 2bit */
#define BLAS_TRANSA_N 0x0000U
diff --git a/common_x86_64.h b/common_x86_64.h
index bee7e8cdb..b813336c6 100644
--- a/common_x86_64.h
+++ b/common_x86_64.h
@@ -142,6 +142,29 @@ static __inline void cpuid(int op, int *eax, int *ebx, int *ecx, int *edx){
#endif
}
+static __inline void cpuid_count(int op, int count, int *eax, int *ebx, int *ecx, int *edx)
+{
+#ifdef C_MSVC
+ int cpuInfo[4] = {-1};
+ __cpuidex(cpuInfo, op, count);
+ *eax = cpuInfo[0];
+ *ebx = cpuInfo[1];
+ *ecx = cpuInfo[2];
+ *edx = cpuInfo[3];
+#else
+#if defined(__i386__) && defined(__PIC__)
+ __asm__ __volatile__
+ ("mov %%ebx, %%edi;"
+ "cpuid;"
+ "xchgl %%ebx, %%edi;"
+ : "=a" (*eax), "=D" (*ebx), "=c" (*ecx), "=d" (*edx) : "0" (op), "2" (count) : "cc");
+#else
+ __asm__ __volatile__
+ ("cpuid": "=a" (*eax), "=b" (*ebx), "=c" (*ecx), "=d" (*edx) : "0" (op), "2" (count) : "cc");
+#endif
+#endif
+}
+
/*
#define WHEREAMI
*/
diff --git a/driver/others/blas_l1_thread.c b/driver/others/blas_l1_thread.c
index e405c7465..04acbcc5f 100644
--- a/driver/others/blas_l1_thread.c
+++ b/driver/others/blas_l1_thread.c
@@ -49,9 +49,36 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha
blas_arg_t args [MAX_CPU_NUMBER];
BLASLONG i, width, astride, bstride;
- int num_cpu, calc_type;
-
- calc_type = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0) + 2;
+ int num_cpu, calc_type_a, calc_type_b;
+
+ switch (mode & BLAS_PREC) {
+ case BLAS_INT8 :
+ case BLAS_BFLOAT16:
+ case BLAS_SINGLE :
+ case BLAS_DOUBLE :
+ case BLAS_XDOUBLE :
+ calc_type_a = calc_type_b = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0);
+ break;
+ case BLAS_STOBF16 :
+ calc_type_a = 2 + ((mode & BLAS_COMPLEX) != 0);
+ calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
+ break;
+ case BLAS_DTOBF16 :
+ calc_type_a = 3 + ((mode & BLAS_COMPLEX) != 0);
+ calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
+ break;
+ case BLAS_BF16TOS :
+ calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
+ calc_type_b = 2 + ((mode & BLAS_COMPLEX) != 0);
+ break;
+ case BLAS_BF16TOD :
+ calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
+ calc_type_b = 3 + ((mode & BLAS_COMPLEX) != 0);
+ break;
+ default:
+ calc_type_a = calc_type_b = 0;
+ break;
+ }
mode |= BLAS_LEGACY;
@@ -77,8 +104,8 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha
bstride = width;
}
- astride <<= calc_type;
- bstride <<= calc_type;
+ astride <<= calc_type_a;
+ bstride <<= calc_type_b;
args[num_cpu].m = width;
args[num_cpu].n = n;
@@ -120,9 +147,36 @@ int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASL
blas_arg_t args [MAX_CPU_NUMBER];
BLASLONG i, width, astride, bstride;
- int num_cpu, calc_type;
-
- calc_type = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0) + 2;
+ int num_cpu, calc_type_a, calc_type_b;
+
+ switch (mode & BLAS_PREC) {
+ case BLAS_INT8 :
+ case BLAS_BFLOAT16:
+ case BLAS_SINGLE :
+ case BLAS_DOUBLE :
+ case BLAS_XDOUBLE :
+ calc_type_a = calc_type_b = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0);
+ break;
+ case BLAS_STOBF16 :
+ calc_type_a = 2 + ((mode & BLAS_COMPLEX) != 0);
+ calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
+ break;
+ case BLAS_DTOBF16 :
+ calc_type_a = 3 + ((mode & BLAS_COMPLEX) != 0);
+ calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
+ break;
+ case BLAS_BF16TOS :
+ calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
+ calc_type_b = 2 + ((mode & BLAS_COMPLEX) != 0);
+ break;
+ case BLAS_BF16TOD :
+ calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
+ calc_type_b = 3 + ((mode & BLAS_COMPLEX) != 0);
+ break;
+ default:
+ calc_type_a = calc_type_b = 0;
+ break;
+ }
mode |= BLAS_LEGACY;
@@ -148,8 +202,8 @@ int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASL
bstride = width;
}
- astride <<= calc_type;
- bstride <<= calc_type;
+ astride <<= calc_type_a;
+ bstride <<= calc_type_b;
args[num_cpu].m = width;
args[num_cpu].n = n;
diff --git a/driver/others/blas_server.c b/driver/others/blas_server.c
index 756e51b5d..8d3dda3bf 100644
--- a/driver/others/blas_server.c
+++ b/driver/others/blas_server.c
@@ -192,7 +192,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
if (!(mode & BLAS_COMPLEX)){
#ifdef EXPRECISION
- if (mode & BLAS_XDOUBLE){
+ if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
/* REAL / Extended Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble,
xdouble *, BLASLONG, xdouble *, BLASLONG,
@@ -205,7 +205,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> c, args -> ldc, sb);
} else
#endif
- if (mode & BLAS_DOUBLE){
+ if ((mode & BLAS_PREC) == BLAS_DOUBLE){
/* REAL / Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
double *, BLASLONG, double *, BLASLONG,
@@ -216,21 +216,58 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
- } else {
- /* REAL / Single */
- void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
- float *, BLASLONG, float *, BLASLONG,
- float *, BLASLONG, void *) = func;
-
- afunc(args -> m, args -> n, args -> k,
- ((float *)args -> alpha)[0],
- args -> a, args -> lda,
- args -> b, args -> ldb,
- args -> c, args -> ldc, sb);
+ } else if ((mode & BLAS_PREC) == BLAS_SINGLE){
+ /* REAL / Single */
+ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
+ float *, BLASLONG, float *, BLASLONG,
+ float *, BLASLONG, void *) = func;
+
+ afunc(args -> m, args -> n, args -> k,
+ ((float *)args -> alpha)[0],
+ args -> a, args -> lda,
+ args -> b, args -> ldb,
+ args -> c, args -> ldc, sb);
+#ifdef BUILD_HALF
+ } else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){
+ /* REAL / BFLOAT16 */
+ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16,
+ bfloat16 *, BLASLONG, bfloat16 *, BLASLONG,
+ bfloat16 *, BLASLONG, void *) = func;
+
+ afunc(args -> m, args -> n, args -> k,
+ ((bfloat16 *)args -> alpha)[0],
+ args -> a, args -> lda,
+ args -> b, args -> ldb,
+ args -> c, args -> ldc, sb);
+ } else if ((mode & BLAS_PREC) == BLAS_STOBF16){
+ /* REAL / BLAS_STOBF16 */
+ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
+ float *, BLASLONG, bfloat16 *, BLASLONG,
+ float *, BLASLONG, void *) = func;
+
+ afunc(args -> m, args -> n, args -> k,
+ ((float *)args -> alpha)[0],
+ args -> a, args -> lda,
+ args -> b, args -> ldb,
+ args -> c, args -> ldc, sb);
+ } else if ((mode & BLAS_PREC) == BLAS_DTOBF16){
+ /* REAL / BLAS_DTOBF16 */
+ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
+ double *, BLASLONG, bfloat16 *, BLASLONG,
+ double *, BLASLONG, void *) = func;
+
+ afunc(args -> m, args -> n, args -> k,
+ ((double *)args -> alpha)[0],
+ args -> a, args -> lda,
+ args -> b, args -> ldb,
+ args -> c, args -> ldc, sb);
+#endif
+ } else {
+ /* REAL / Other types in future */
}
} else {
#ifdef EXPRECISION
- if (mode & BLAS_XDOUBLE){
+ if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
/* COMPLEX / Extended Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble,
xdouble *, BLASLONG, xdouble *, BLASLONG,
@@ -244,7 +281,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> c, args -> ldc, sb);
} else
#endif
- if (mode & BLAS_DOUBLE){
+ if ((mode & BLAS_PREC) == BLAS_DOUBLE) {
/* COMPLEX / Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double,
double *, BLASLONG, double *, BLASLONG,
@@ -256,7 +293,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
- } else {
+ } else if ((mode & BLAS_PREC) == BLAS_SINGLE) {
/* COMPLEX / Single */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float,
float *, BLASLONG, float *, BLASLONG,
@@ -268,7 +305,9 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
- }
+ } else {
+ /* COMPLEX / Other types in future */
+ }
}
}
@@ -414,33 +453,37 @@ blas_queue_t *tscq;
if (sb == NULL) {
if (!(queue -> mode & BLAS_COMPLEX)){
#ifdef EXPRECISION
- if (queue -> mode & BLAS_XDOUBLE){
+ if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else
#endif
- if (queue -> mode & BLAS_DOUBLE){
+ if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE) {
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
- } else {
+ } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
- }
+ } else {
+ /* Other types in future */
+ }
} else {
#ifdef EXPRECISION
- if (queue -> mode & BLAS_XDOUBLE){
+ if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else
#endif
- if (queue -> mode & BLAS_DOUBLE){
+ if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
- } else {
+ } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
- }
+ } else {
+ /* Other types in future */
+ }
}
queue->sb=sb;
}
diff --git a/driver/others/blas_server_omp.c b/driver/others/blas_server_omp.c
index d9969b599..d126955e4 100644
--- a/driver/others/blas_server_omp.c
+++ b/driver/others/blas_server_omp.c
@@ -142,7 +142,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
if (!(mode & BLAS_COMPLEX)){
#ifdef EXPRECISION
- if (mode & BLAS_XDOUBLE){
+ if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
/* REAL / Extended Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble,
xdouble *, BLASLONG, xdouble *, BLASLONG,
@@ -155,7 +155,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> c, args -> ldc, sb);
} else
#endif
- if (mode & BLAS_DOUBLE){
+ if ((mode & BLAS_PREC) == BLAS_DOUBLE){
/* REAL / Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
double *, BLASLONG, double *, BLASLONG,
@@ -166,7 +166,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
- } else {
+ } else if ((mode & BLAS_PREC) == BLAS_SINGLE){
/* REAL / Single */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
float *, BLASLONG, float *, BLASLONG,
@@ -177,10 +177,47 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
+#ifdef BUILD_HALF
+ } else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){
+ /* REAL / BFLOAT16 */
+ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16,
+ bfloat16 *, BLASLONG, bfloat16 *, BLASLONG,
+ bfloat16 *, BLASLONG, void *) = func;
+
+ afunc(args -> m, args -> n, args -> k,
+ ((bfloat16 *)args -> alpha)[0],
+ args -> a, args -> lda,
+ args -> b, args -> ldb,
+ args -> c, args -> ldc, sb);
+ } else if ((mode & BLAS_PREC) == BLAS_STOBF16){
+ /* REAL / BLAS_STOBF16 */
+ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
+ float *, BLASLONG, bfloat16 *, BLASLONG,
+ float *, BLASLONG, void *) = func;
+
+ afunc(args -> m, args -> n, args -> k,
+ ((float *)args -> alpha)[0],
+ args -> a, args -> lda,
+ args -> b, args -> ldb,
+ args -> c, args -> ldc, sb);
+ } else if ((mode & BLAS_PREC) == BLAS_DTOBF16){
+ /* REAL / BLAS_DTOBF16 */
+ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
+ double *, BLASLONG, bfloat16 *, BLASLONG,
+ double *, BLASLONG, void *) = func;
+
+ afunc(args -> m, args -> n, args -> k,
+ ((double *)args -> alpha)[0],
+ args -> a, args -> lda,
+ args -> b, args -> ldb,
+ args -> c, args -> ldc, sb);
+#endif
+ } else {
+ /* REAL / Other types in future */
}
} else {
#ifdef EXPRECISION
- if (mode & BLAS_XDOUBLE){
+ if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
/* COMPLEX / Extended Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble,
xdouble *, BLASLONG, xdouble *, BLASLONG,
@@ -194,7 +231,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> c, args -> ldc, sb);
} else
#endif
- if (mode & BLAS_DOUBLE){
+ if ((mode & BLAS_PREC) == BLAS_DOUBLE){
/* COMPLEX / Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double,
double *, BLASLONG, double *, BLASLONG,
@@ -206,7 +243,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
- } else {
+ } else if ((mode & BLAS_PREC) == BLAS_SINGLE){
/* COMPLEX / Single */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float,
float *, BLASLONG, float *, BLASLONG,
@@ -218,8 +255,10 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
- }
- }
+ } else {
+ /* COMPLEX / Other types in future */
+ }
+ }
}
static void exec_threads(blas_queue_t *queue, int buf_index){
@@ -255,32 +294,36 @@ static void exec_threads(blas_queue_t *queue, int buf_index){
if (sb == NULL) {
if (!(queue -> mode & BLAS_COMPLEX)){
#ifdef EXPRECISION
- if (queue -> mode & BLAS_XDOUBLE){
+ if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else
#endif
- if (queue -> mode & BLAS_DOUBLE){
+ if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
- } else {
+ } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE){
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
+ } else {
+ /* Other types in future */
}
} else {
#ifdef EXPRECISION
- if (queue -> mode & BLAS_XDOUBLE){
+ if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else
#endif
- if (queue -> mode & BLAS_DOUBLE){
+ if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
- } else {
+ } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
+ } else {
+ /* Other types in future */
}
}
queue->sb=sb;
diff --git a/driver/others/blas_server_win32.c b/driver/others/blas_server_win32.c
index 5ecc4428b..d2cc91757 100644
--- a/driver/others/blas_server_win32.c
+++ b/driver/others/blas_server_win32.c
@@ -77,7 +77,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
if (!(mode & BLAS_COMPLEX)){
#ifdef EXPRECISION
- if (mode & BLAS_XDOUBLE){
+ if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
/* REAL / Extended Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble,
xdouble *, BLASLONG, xdouble *, BLASLONG,
@@ -90,7 +90,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> c, args -> ldc, sb);
} else
#endif
- if (mode & BLAS_DOUBLE){
+ if ((mode & BLAS_PREC) == BLAS_DOUBLE){
/* REAL / Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
double *, BLASLONG, double *, BLASLONG,
@@ -101,7 +101,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
- } else {
+ } else if ((mode & BLAS_PREC) == BLAS_SINGLE){
/* REAL / Single */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
float *, BLASLONG, float *, BLASLONG,
@@ -112,10 +112,47 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
+#ifdef BUILD_HALF
+ } else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){
+ /* REAL / BFLOAT16 */
+ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16,
+ bfloat16 *, BLASLONG, bfloat16 *, BLASLONG,
+ bfloat16 *, BLASLONG, void *) = func;
+
+ afunc(args -> m, args -> n, args -> k,
+ ((bfloat16 *)args -> alpha)[0],
+ args -> a, args -> lda,
+ args -> b, args -> ldb,
+ args -> c, args -> ldc, sb);
+ } else if ((mode & BLAS_PREC) == BLAS_STOBF16){
+ /* REAL / BLAS_STOBF16 */
+ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
+ float *, BLASLONG, bfloat16 *, BLASLONG,
+ float *, BLASLONG, void *) = func;
+
+ afunc(args -> m, args -> n, args -> k,
+ ((float *)args -> alpha)[0],
+ args -> a, args -> lda,
+ args -> b, args -> ldb,
+ args -> c, args -> ldc, sb);
+ } else if ((mode & BLAS_PREC) == BLAS_DTOBF16){
+ /* REAL / BLAS_DTOBF16 */
+ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
+ double *, BLASLONG, bfloat16 *, BLASLONG,
+ double *, BLASLONG, void *) = func;
+
+ afunc(args -> m, args -> n, args -> k,
+ ((double *)args -> alpha)[0],
+ args -> a, args -> lda,
+ args -> b, args -> ldb,
+ args -> c, args -> ldc, sb);
+#endif
+ } else {
+ /* REAL / Other types in future */
}
} else {
#ifdef EXPRECISION
- if (mode & BLAS_XDOUBLE){
+ if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
/* COMPLEX / Extended Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble,
xdouble *, BLASLONG, xdouble *, BLASLONG,
@@ -129,7 +166,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> c, args -> ldc, sb);
} else
#endif
- if (mode & BLAS_DOUBLE){
+ if ((mode & BLAS_PREC) == BLAS_DOUBLE){
/* COMPLEX / Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double,
double *, BLASLONG, double *, BLASLONG,
@@ -141,7 +178,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
- } else {
+ } else if ((mode & BLAS_PREC) == BLAS_SINGLE) {
/* COMPLEX / Single */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float,
float *, BLASLONG, float *, BLASLONG,
@@ -153,7 +190,9 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
- }
+ } else {
+ /* COMPLEX / Other types in future */
+ }
}
}
@@ -233,32 +272,36 @@ static DWORD WINAPI blas_thread_server(void *arg){
if (sb == NULL) {
if (!(queue -> mode & BLAS_COMPLEX)){
#ifdef EXPRECISION
- if (queue -> mode & BLAS_XDOUBLE){
+ if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * sizeof(xdouble)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else
#endif
- if (queue -> mode & BLAS_DOUBLE){
+ if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
- } else {
+ } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
+ } else {
+ /* Other types in future */
}
} else {
#ifdef EXPRECISION
- if (queue -> mode & BLAS_XDOUBLE){
+ if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else
#endif
- if (queue -> mode & BLAS_DOUBLE){
+ if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
- } else {
+ } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
+ } else {
+ /* Other types in future */
}
}
queue->sb=sb;
diff --git a/driver/others/dynamic.c b/driver/others/dynamic.c
index 5d71b1b2c..21d2c7948 100644
--- a/driver/others/dynamic.c
+++ b/driver/others/dynamic.c
@@ -207,6 +207,19 @@ extern gotoblas_t gotoblas_SKYLAKEX;
#else
#define gotoblas_SKYLAKEX gotoblas_PRESCOTT
#endif
+#ifdef DYN_COOPERLAKE
+extern gotoblas_t gotoblas_COOPERLAKE;
+#elif defined(DYN_SKYLAKEX)
+#define gotoblas_COOPERLAKE gotoblas_SKYLAKEX
+#elif defined(DYN_HASWELL)
+#define gotoblas_COOPERLAKE gotoblas_HASWELL
+#elif defined(DYN_SANDYBRIDGE)
+#define gotoblas_COOPERLAKE gotoblas_SANDYBRIDGE
+#elif defined(DYN_NEHALEM)
+#define gotoblas_COOPERLAKE gotoblas_NEHALEM
+#else
+#define gotoblas_COOPERLAKE gotoblas_PRESCOTT
+#endif
#else // not DYNAMIC_LIST
@@ -247,14 +260,17 @@ extern gotoblas_t gotoblas_EXCAVATOR;
#ifdef NO_AVX2
#define gotoblas_HASWELL gotoblas_SANDYBRIDGE
#define gotoblas_SKYLAKEX gotoblas_SANDYBRIDGE
+#define gotoblas_COOPERLAKE gotoblas_SANDYBRIDGE
#define gotoblas_ZEN gotoblas_SANDYBRIDGE
#else
extern gotoblas_t gotoblas_HASWELL;
extern gotoblas_t gotoblas_ZEN;
#ifndef NO_AVX512
extern gotoblas_t gotoblas_SKYLAKEX;
+extern gotoblas_t gotoblas_COOPERLAKE;
#else
#define gotoblas_SKYLAKEX gotoblas_HASWELL
+#define gotoblas_COOPERLAKE gotoblas_HASWELL
#endif
#endif
#else
@@ -262,6 +278,7 @@ extern gotoblas_t gotoblas_SKYLAKEX;
#define gotoblas_SANDYBRIDGE gotoblas_NEHALEM
#define gotoblas_HASWELL gotoblas_NEHALEM
#define gotoblas_SKYLAKEX gotoblas_NEHALEM
+#define gotoblas_COOPERLAKE gotoblas_NEHALEM
#define gotoblas_BULLDOZER gotoblas_BARCELONA
#define gotoblas_PILEDRIVER gotoblas_BARCELONA
#define gotoblas_STEAMROLLER gotoblas_BARCELONA
@@ -343,6 +360,23 @@ int support_avx512(){
#endif
}
+int support_avx512_bf16(){
+#if !defined(NO_AVX) && !defined(NO_AVX512)
+ int eax, ebx, ecx, edx;
+ int ret=0;
+
+ if (!support_avx512())
+ return 0;
+ cpuid_count(7, 1, &eax, &ebx, &ecx, &edx);
+ if((eax & 32) == 32){
+ ret=1; // CPUID.7.1:EAX[bit 5] indicates whether avx512_bf16 supported or not
+ }
+ return ret;
+#else
+ return 0;
+#endif
+}
+
extern void openblas_warning(int verbose, const char * msg);
#define FALLBACK_VERBOSE 1
#define NEHALEM_FALLBACK "OpenBLAS : Your OS does not support AVX instructions. OpenBLAS is using Nehalem kernels as a fallback, which may give poorer performance.\n"
@@ -524,7 +558,10 @@ static gotoblas_t *get_coretype(void){
return &gotoblas_NEHALEM; //OS doesn't support AVX. Use old kernels.
}
}
- if (model == 5) {
+ if (model == 5) {
+ // Intel Cooperlake
+ if(support_avx512_bf16())
+ return &gotoblas_COOPERLAKE;
// Intel Skylake X
if (support_avx512())
return &gotoblas_SKYLAKEX;
@@ -774,7 +811,8 @@ static char *corename[] = {
"Steamroller",
"Excavator",
"Zen",
- "SkylakeX"
+ "SkylakeX",
+ "Cooperlake"
};
char *gotoblas_corename(void) {
@@ -838,6 +876,7 @@ char *gotoblas_corename(void) {
if (gotoblas == &gotoblas_EXCAVATOR) return corename[22];
if (gotoblas == &gotoblas_ZEN) return corename[23];
if (gotoblas == &gotoblas_SKYLAKEX) return corename[24];
+ if (gotoblas == &gotoblas_COOPERLAKE) return corename[25];
return corename[0];
}
@@ -868,6 +907,7 @@ static gotoblas_t *force_coretype(char *coretype){
switch (found)
{
+ case 25: return (&gotoblas_COOPERLAKE);
case 24: return (&gotoblas_SKYLAKEX);
case 23: return (&gotoblas_ZEN);
case 22: return (&gotoblas_EXCAVATOR);
diff --git a/exports/gensymbol b/exports/gensymbol
index 73b4be248..ce4d9bb64 100644
--- a/exports/gensymbol
+++ b/exports/gensymbol
@@ -46,7 +46,7 @@
ssum, dsum, scsum, dzsum
);
-@halfblasobjs = (shgemm);
+@halfblasobjs = (shgemm, shdot, shstobf16, shdtobf16, sbf16tos, dbf16tod);
@cblasobjs = (
cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv,
cblas_cgerc, cblas_cgeru, cblas_chbmv, cblas_chemm, cblas_chemv, cblas_cher2, cblas_cher2k,
@@ -84,7 +84,7 @@
cblas_xerbla
);
-@halfcblasobjs = (cblas_shgemm);
+@halfcblasobjs = (cblas_shgemm, cblas_shdot, cblas_shstobf16, cblas_shdtobf16, cblas_sbf16tos, cblas_dbf16tod);
@exblasobjs = (
qamax,qamin,qasum,qaxpy,qcabs1,qcopy,qdot,qgbmv,qgemm,
diff --git a/interface/Makefile b/interface/Makefile
index 2dbd60073..fde6227bc 100644
--- a/interface/Makefile
+++ b/interface/Makefile
@@ -47,7 +47,9 @@ SBLAS3OBJS = \
sgeadd.$(SUFFIX)
ifeq ($(BUILD_HALF),1)
+SHBLAS1OBJS = shdot.$(SUFFIX)
SHBLAS3OBJS = shgemm.$(SUFFIX)
+SHEXTOBJS = shstobf16.$(SUFFIX) shdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX)
endif
DBLAS1OBJS = \
@@ -281,7 +283,9 @@ CSBLAS3OBJS = \
cblas_sgeadd.$(SUFFIX)
ifeq ($(BUILD_HALF),1)
+CSHBLAS1OBJS = cblas_shdot.$(SUFFIX)
CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX)
+CSHEXTOBJS = cblas_shstobf16.$(SUFFIX) cblas_shdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX)
endif
CDBLAS1OBJS = \
@@ -374,6 +378,7 @@ override CFLAGS += -I.
SBLAS1OBJS += $(CSBLAS1OBJS)
SBLAS2OBJS += $(CSBLAS2OBJS)
SBLAS3OBJS += $(CSBLAS3OBJS)
+SHBLAS1OBJS += $(CSHBLAS1OBJS)
SHBLAS3OBJS += $(CSHBLAS3OBJS)
DBLAS1OBJS += $(CDBLAS1OBJS)
DBLAS2OBJS += $(CDBLAS2OBJS)
@@ -385,10 +390,11 @@ ZBLAS1OBJS += $(CZBLAS1OBJS)
ZBLAS2OBJS += $(CZBLAS2OBJS)
ZBLAS3OBJS += $(CZBLAS3OBJS)
+SHEXTOBJS += $(CSHEXTOBJS)
endif
SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS)
-SHBLASOBJS = $(SHBLAS3OBJS)
+SHBLASOBJS = $(SHBLAS1OBJS) $(SHBLAS3OBJS)
DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS)
QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS)
CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS)
@@ -463,7 +469,7 @@ ZBLASOBJS += $(ZLAPACKOBJS)
endif
-FUNCOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
+FUNCOBJS = $(SHEXTOBJS) $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
ifdef EXPRECISION
FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS)
@@ -491,7 +497,7 @@ endif
clean ::
@rm -f functable.h
-level1 : $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS)
+level1 : $(BEXTOBJS) $(SHBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS)
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
level2 : $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS)
@@ -725,6 +731,19 @@ sdsdot.$(SUFFIX) sdsdot.$(PSUFFIX) : sdsdot.c
dsdot.$(SUFFIX) dsdot.$(PSUFFIX) : dsdot.c
$(CC) $(CFLAGS) -c $< -o $(@F)
+ifeq ($(BUILD_HALF),1)
+shdot.$(SUFFIX) shdot.$(PSUFFIX) : bf16dot.c
+ $(CC) $(CFLAGS) -c $< -o $(@F)
+shstobf16.$(SUFFIX) shstobf16.$(PSUFFIX) : tobf16.c
+ $(CC) $(CFLAGS) -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F)
+shdtobf16.$(SUFFIX) shdtobf16.$(PSUFFIX) : tobf16.c
+ $(CC) $(CFLAGS) -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F)
+sbf16tos.$(SUFFIX) sbf16tos.$(PSUFFIX) : bf16to.c
+ $(CC) $(CFLAGS) -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F)
+dbf16tod.$(SUFFIX) dbf16tod.$(PSUFFIX) : bf16to.c
+ $(CC) $(CFLAGS) -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F)
+endif
+
sdot.$(SUFFIX) sdot.$(PSUFFIX) : dot.c
$(CC) $(CFLAGS) -c $< -o $(@F)
@@ -1463,6 +1482,19 @@ cblas_sdsdot.$(SUFFIX) cblas_sdsdot.$(PSUFFIX) : sdsdot.c
cblas_dsdot.$(SUFFIX) cblas_dsdot.$(PSUFFIX) : dsdot.c
$(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F)
+ifeq ($(BUILD_HALF),1)
+cblas_shdot.$(SUFFIX) cblas_shdot.$(PSUFFIX) : bf16dot.c
+ $(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F)
+cblas_shstobf16.$(SUFFIX) cblas_shstobf16.$(PSUFFIX) : tobf16.c
+ $(CC) $(CFLAGS) -DCBLAS -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F)
+cblas_shdtobf16.$(SUFFIX) cblas_shdtobf16.$(PSUFFIX) : tobf16.c
+ $(CC) $(CFLAGS) -DCBLAS -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F)
+cblas_sbf16tos.$(SUFFIX) cblas_sbf16tos.$(PSUFFIX) : bf16to.c
+ $(CC) $(CFLAGS) -DCBLAS -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F)
+cblas_dbf16tod.$(SUFFIX) cblas_dbf16tod.$(PSUFFIX) : bf16to.c
+ $(CC) $(CFLAGS) -DCBLAS -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F)
+endif
+
cblas_sdot.$(SUFFIX) cblas_sdot.$(PSUFFIX) : dot.c
$(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F)
diff --git a/interface/bf16dot.c b/interface/bf16dot.c
new file mode 100644
index 000000000..33717e374
--- /dev/null
+++ b/interface/bf16dot.c
@@ -0,0 +1,52 @@
+#include <stdio.h>
+#include "common.h"
+#ifdef FUNCTION_PROFILE
+#include "functable.h"
+#endif
+
+#ifndef CBLAS
+float NAME(blasint *N, bfloat16 *x, blasint *INCX, bfloat16 *y, blasint *INCY){
+ BLASLONG n = *N;
+ BLASLONG incx = *INCX;
+ BLASLONG incy = *INCY;
+ float ret;
+ PRINT_DEBUG_NAME;
+
+ if (n <= 0) return 0.;
+
+ IDEBUG_START;
+ FUNCTION_PROFILE_START();
+
+ if (incx < 0) x -= (n - 1) * incx;
+ if (incy < 0) y -= (n - 1) * incy;
+ ret = BF16_DOT_K(n, x, incx, y, incy);
+
+ FUNCTION_PROFILE_END(1, 2 * n, 2 * n);
+ IDEBUG_END;
+
+ return ret;
+ }
+
+#else
+
+float CNAME(blasint n, bfloat16 *x, blasint incx, bfloat16 *y, blasint incy){
+
+ float ret;
+ PRINT_DEBUG_CNAME;
+
+ if (n <= 0) return 0.;
+
+ IDEBUG_START;
+ FUNCTION_PROFILE_START();
+
+ if (incx < 0) x -= (n - 1) * incx;
+ if (incy < 0) y -= (n - 1) * incy;
+ ret = BF16_DOT_K(n, x, incx, y, incy);
+
+ FUNCTION_PROFILE_END(1, 2 * n, 2 * n);
+ IDEBUG_END;
+
+ return ret;
+}
+
+#endif
diff --git a/interface/bf16to.c b/interface/bf16to.c
new file mode 100644
index 000000000..036c0b142
--- /dev/null
+++ b/interface/bf16to.c
@@ -0,0 +1,62 @@
+#include <stdio.h>
+#include "common.h"
+#ifdef FUNCTION_PROFILE
+#include "functable.h"
+#endif
+
+#if defined(DOUBLE_PREC)
+#define FLOAT_TYPE double
+#elif defined(SINGLE_PREC)
+#define FLOAT_TYPE float
+#else
+#endif
+
+#ifndef CBLAS
+void NAME(blasint *N, bfloat16 *in, blasint *INC_IN, FLOAT_TYPE *out, blasint *INC_OUT){
+ BLASLONG n = *N;
+ BLASLONG inc_in = *INC_IN;
+ BLASLONG inc_out = *INC_OUT;
+
+ PRINT_DEBUG_NAME;
+
+ if (n <= 0) return;
+
+ IDEBUG_START;
+ FUNCTION_PROFILE_START();
+
+ if (inc_in < 0) in -= (n - 1) * inc_in;
+ if (inc_out < 0) out -= (n - 1) * inc_out;
+
+#if defined(DOUBLE_PREC)
+ D_BF16_TO_K(n, in, inc_in, out, inc_out);
+#elif defined(SINGLE_PREC)
+ S_BF16_TO_K(n, in, inc_in, out, inc_out);
+#else
+#endif
+
+ FUNCTION_PROFILE_END(1, 2 * n, 2 * n);
+ IDEBUG_END;
+}
+#else
+void CNAME(blasint n, bfloat16 * in, blasint inc_in, FLOAT_TYPE * out, blasint inc_out){
+ PRINT_DEBUG_CNAME;
+
+ if (n <= 0) return;
+
+ IDEBUG_START;
+ FUNCTION_PROFILE_START();
+
+ if (inc_in < 0) in -= (n - 1) * inc_in;
+ if (inc_out < 0) out -= (n - 1) * inc_out;
+
+#if defined(DOUBLE_PREC)
+ D_BF16_TO_K(n, in, inc_in, out, inc_out);
+#elif defined(SINGLE_PREC)
+ S_BF16_TO_K(n, in, inc_in, out, inc_out);
+#else
+#endif
+
+ FUNCTION_PROFILE_END(1, 2 * n, 2 * n);
+ IDEBUG_END;
+}
+#endif
diff --git a/interface/tobf16.c b/interface/tobf16.c
new file mode 100644
index 000000000..787d9d689
--- /dev/null
+++ b/interface/tobf16.c
@@ -0,0 +1,61 @@
+#include <stdio.h>
+#include "common.h"
+#ifdef FUNCTION_PROFILE
+#include "functable.h"
+#endif
+
+#if defined(DOUBLE_PREC)
+#define FLOAT_TYPE double
+#elif defined(SINGLE_PREC)
+#define FLOAT_TYPE float
+#else
+#endif
+
+#ifndef CBLAS
+void NAME(blasint *N, FLOAT_TYPE *in, blasint *INC_IN, bfloat16 *out, blasint *INC_OUT){
+ BLASLONG n = *N;
+ BLASLONG inc_in = *INC_IN;
+ BLASLONG inc_out = *INC_OUT;
+
+ PRINT_DEBUG_NAME;
+
+ if (n <= 0) return;
+
+ IDEBUG_START;
+ FUNCTION_PROFILE_START();
+
+ if (inc_in < 0) in -= (n - 1) * inc_in;
+ if (inc_out < 0) out -= (n - 1) * inc_out;
+
+#if defined(DOUBLE_PREC)
+ D_TO_BF16_K(n, in, inc_in, out, inc_out);
+#elif defined(SINGLE_PREC)
+ S_TO_BF16_K(n, in, inc_in, out, inc_out);
+#else
+#endif
+
+ FUNCTION_PROFILE_END(1, 2 * n, 2 * n);
+ IDEBUG_END;
+}
+#else
+void CNAME(blasint n, FLOAT_TYPE *in, blasint inc_in, bfloat16 *out, blasint inc_out){
+ PRINT_DEBUG_CNAME;
+
+ if (n <= 0) return;
+
+ IDEBUG_START;
+ FUNCTION_PROFILE_START();
+
+ if (inc_in < 0) in -= (n - 1) * inc_in;
+ if (inc_out < 0) out -= (n - 1) * inc_out;
+
+#if defined(DOUBLE_PREC)
+ D_TO_BF16_K(n, in, inc_in, out, inc_out);
+#elif defined(SINGLE_PREC)
+ S_TO_BF16_K(n, in, inc_in, out, inc_out);
+#endif
+
+ FUNCTION_PROFILE_END(1, 2 * n, 2 * n);
+ IDEBUG_END;
+}
+#endif
diff --git a/kernel/Makefile.L1 b/kernel/Makefile.L1
index 970703230..c6576ee07 100644
--- a/kernel/Makefile.L1
+++ b/kernel/Makefile.L1
@@ -262,6 +262,20 @@ ifndef XDOTKERNEL
XDOTKERNEL = zdot.S
endif
+ifeq ($(BUILD_HALF),1)
+ifndef SHDOTKERNEL
+SHDOTKERNEL = ../x86_64/shdot.c
+endif
+
+ifndef TOBF16KERNEL
+TOBF16KERNEL = ../x86_64/tobf16.c
+endif
+
+ifndef BF16TOKERNEL
+BF16TOKERNEL = ../x86_64/bf16to.c
+endif
+endif
+
### NRM2 ###
ifndef SNRM2KERNEL
@@ -516,6 +530,15 @@ XBLASOBJS += \
xdotc_k$(TSUFFIX).$(SUFFIX) xdotu_k$(TSUFFIX).$(SUFFIX) xnrm2_k$(TSUFFIX).$(SUFFIX) xqrot_k$(TSUFFIX).$(SUFFIX) \
xscal_k$(TSUFFIX).$(SUFFIX) xswap_k$(TSUFFIX).$(SUFFIX) xsum_k$(TSUFFIX).$(SUFFIX)
+ifeq ($(BUILD_HALF),1)
+SHBLASOBJS += \
+ shdot_k$(TSUFFIX).$(SUFFIX)
+SHEXTOBJS += \
+ shstobf16_k$(TSUFFIX).$(SUFFIX) shdtobf16_k$(TSUFFIX).$(SUFFIX)
+SHEXTOBJS += \
+ sbf16tos_k$(TSUFFIX).$(SUFFIX) dbf16tod_k$(TSUFFIX).$(SUFFIX)
+endif
+
### AMAX ###
@@ -734,6 +757,19 @@ $(KDIR)ddot_k$(TSUFFIX).$(SUFFIX) $(KDIR)ddot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNEL
$(KDIR)qdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)qdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(QDOTKERNEL)
$(CC) -c $(CFLAGS) -UCOMPLEX -DXDOUBLE $< -o $@
+ifeq ($(BUILD_HALF),1)
+$(KDIR)shdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)shdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHDOTKERNEL)
+ $(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@
+$(KDIR)shstobf16_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(TOBF16KERNEL)
+ $(CC) -c $(CFLAGS) -UDOUBLE -DSINGLE $< -o $@
+$(KDIR)shdtobf16_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(TOBF16KERNEL)
+ $(CC) -c $(CFLAGS) -DDOUBLE -USINGLE $< -o $@
+$(KDIR)sbf16tos_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BF16TOKERNEL)
+ $(CC) -c $(CFLAGS) -UDOUBLE -DSINGLE $< -o $@
+$(KDIR)dbf16tod_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BF16TOKERNEL)
+ $(CC) -c $(CFLAGS) -DDOUBLE -USINGLE $< -o $@
+endif
+
$(KDIR)sdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)sdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SDOTKERNEL)
$(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE $< -o $@
diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c
index 582a1dc01..c43520310 100644
--- a/kernel/setparam-ref.c
+++ b/kernel/setparam-ref.c
@@ -62,9 +62,11 @@ gotoblas_t TABLE_NAME = {
MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N),
#endif
+ shstobf16_kTS, shdtobf16_kTS, sbf16tos_kTS, dbf16tod_kTS,
+
samax_kTS, samin_kTS, smax_kTS, smin_kTS,
isamax_kTS, isamin_kTS, ismax_kTS, ismin_kTS,
- snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, sdot_kTS,
+ snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, shdot_kTS,
dsdot_kTS,
srot_kTS, saxpy_kTS, sscal_kTS, sswap_kTS,
sgemv_nTS, sgemv_tTS, sger_kTS,
diff --git a/kernel/x86_64/KERNEL b/kernel/x86_64/KERNEL
index 4874711bb..4a2e13bed 100644
--- a/kernel/x86_64/KERNEL
+++ b/kernel/x86_64/KERNEL
@@ -146,6 +146,18 @@ ifndef XDOTKERNEL
XDOTKERNEL = zdot.S
endif
+ifndef SHDOTKERNEL
+SHDOTKERNEL = shdot.c
+endif
+
+ifndef TOBF16KERNEL
+TOBF16KERNEL = tobf16.c
+endif
+
+ifndef BF16TOKERNEL
+BF16TOKERNEL = bf16to.c
+endif
+
ifndef ISAMAXKERNEL
ISAMAXKERNEL = iamax_sse.S
endif
diff --git a/kernel/x86_64/bf16to.c b/kernel/x86_64/bf16to.c
new file mode 100644
index 000000000..fc6b5a529
--- /dev/null
+++ b/kernel/x86_64/bf16to.c
@@ -0,0 +1,114 @@
+/***************************************************************************
+Copyright (c) 2014, The OpenBLAS Project
+All rights reserved.
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+1. Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in
+the documentation and/or other materials provided with the
+distribution.
+3. Neither the name of the OpenBLAS project nor the names of
+its contributors may be used to endorse or promote products
+derived from this software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*****************************************************************************/
+
+#include <stddef.h>
+#include "common.h"
+
+#if defined(DOUBLE)
+#define FLOAT_TYPE double
+#elif defined(SINGLE)
+#define FLOAT_TYPE float
+#else
+#endif
+
+/* Notes for algorithm:
+ * - Input denormal treated as zero
+ * - Force to be QNAN
+ */
+static void bf16to_kernel_1(BLASLONG n, const bfloat16 * in, BLASLONG inc_in, FLOAT_TYPE * out, BLASLONG inc_out)
+{
+ BLASLONG register index_in = 0;
+ BLASLONG register index_out = 0;
+ BLASLONG register index = 0;
+ uint16_t * tmp = NULL;
+#if defined(DOUBLE)
+ float float_out = 0.0;
+#endif
+
+ while(index<n) {
+#if defined(DOUBLE)
+ float_out = 0.0;
+ tmp = (uint16_t*)(&float_out);
+#else
+ *(out+index_out) = 0;
+ tmp = (uint16_t*)(out+index_out);
+#endif
+
+ switch((*(in+index_in)) & 0xff80u) {
+ case (0x0000u): /* Type 1: Positive denormal */
+ tmp[1] = 0x0000u;
+ tmp[0] = 0x0000u;
+ break;
+ case (0x8000u): /* Type 2: Negative denormal */
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
+ tmp[1] = 0x8000u;
+ tmp[0] = 0x0000u;
+#else
+ tmp[1] = 0x0000u;
+ tmp[0] = 0x8000u;
+#endif
+ break;
+ case (0x7f80u): /* Type 3: Positive infinity or NAN */
+ case (0xff80u): /* Type 4: Negative infinity or NAN */
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
+ tmp[1] = *(in+index_in);
+#else
+ tmp[0] = *(in+index_in);
+#endif
+ /* Specific for NAN */
+ if (((*(in+index_in)) & 0x007fu) != 0) {
+ /* Force to be QNAN */
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
+ tmp[1] |= 0x0040u;
+#else
+ tmp[0] |= 0x0040u;
+#endif
+ }
+ break;
+ default: /* Type 5: Normal case */
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
+ tmp[1] = *(in+index_in);
+#else
+ tmp[0] = *(in+index_in);
+#endif
+ break;
+ }
+#if defined(DOUBLE)
+ *(out+index_out) = (double)float_out;
+#endif
+ index_in += inc_in;
+ index_out += inc_out;
+ index++;
+ }
+}
+
+void CNAME(BLASLONG n, bfloat16 * in, BLASLONG inc_in, FLOAT_TYPE * out, BLASLONG inc_out)
+{
+ if (n <= 0) return;
+
+ bf16to_kernel_1(n, in, inc_in, out, inc_out);
+}
diff --git a/kernel/x86_64/dtobf16_microk_cooperlake.c b/kernel/x86_64/dtobf16_microk_cooperlake.c
new file mode 100644
index 000000000..9b8ac4714
--- /dev/null
+++ b/kernel/x86_64/dtobf16_microk_cooperlake.c
@@ -0,0 +1,104 @@
+/***************************************************************************
+Copyright (c) 2014, The OpenBLAS Project
+All rights reserved.
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+1. Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in
+the documentation and/or other materials provided with the
+distribution.
+3. Neither the name of the OpenBLAS project nor the names of
+its contributors may be used to endorse or promote products
+derived from this software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*****************************************************************************/
+
+/* need a new enough GCC for avx512 support */
+#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9))
+
+#define HAVE_TOBF16_ACCL_KERNEL 1
+#include "common.h"
+#include <immintrin.h>
+
+static void tobf16_accl_kernel(BLASLONG n, const double * in, bfloat16 * out)
+{
+ /* Get the 64-bytes unaligned header number targeting for avx512
+ * processing (Assume input float array is natural aligned) */
+ int align_header = ((64 - ((uintptr_t)in & (uintptr_t)0x3f)) >> 3) & 0x7;
+
+ if (n < align_header) {align_header = n;}
+
+ if (align_header != 0) {
+ unsigned char align_mask8 = (((unsigned char)0xff) >> (8-align_header));
+ __m512d a = _mm512_maskz_loadu_pd(*((__mmask8*) &align_mask8), &in[0]);
+ _mm_mask_storeu_epi16(&out[0], *((__mmask8*) &align_mask8), (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(a)));
+ }
+
+ if (n == align_header) {
+ return;
+ } else {
+ n -= align_header;
+ in += align_header;
+ out += align_header;
+ }
+
+ int tail_index_8 = n&(~7);
+ int tail_index_32 = n&(~31);
+ int tail_index_128 = n&(~127);
+ unsigned char tail_mask8 = (((unsigned char) 0xff) >> (8 -(n&7)));
+
+ /* Processing the main chunk with 128-elements per round */
+ for (int i = 0; i < tail_index_128; i += 128) {
+ // Fold 1
+ __m512 data1_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+ 0]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+ 8])), 1);
+ __m512 data1_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+16]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+24])), 1);
+ _mm512_storeu_si512(&out[i+ 0], (__m512i) _mm512_cvtne2ps_pbh(data1_512_high, data1_512_low));
+
+ // Fold 2
+ __m512 data2_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+32]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+40])), 1);
+ __m512 data2_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+48]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+56])), 1);
+ _mm512_storeu_si512(&out[i+32], (__m512i) _mm512_cvtne2ps_pbh(data2_512_high, data2_512_low));
+
+ // Fold 3
+ __m512 data3_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+64]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+72])), 1);
+ __m512 data3_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+80]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+88])), 1);
+ _mm512_storeu_si512(&out[i+64], (__m512i) _mm512_cvtne2ps_pbh(data3_512_high, data3_512_low));
+
+ // Fold 4
+ __m512 data4_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+96]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+104])), 1);
+ __m512 data4_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+112]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+120])), 1);
+ _mm512_storeu_si512(&out[i+96], (__m512i) _mm512_cvtne2ps_pbh(data4_512_high, data4_512_low));
+ }
+
+ /* Processing the remaining <128 chunk with 32-elements per round */
+ for (int j = tail_index_128; j < tail_index_32; j += 32) {
+ __m512 data1_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[j+ 0]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[j+ 8])), 1);
+ __m512 data1_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[j+16]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[j+24])), 1);
+ _mm512_storeu_si512(&out[j], (__m512i) _mm512_cvtne2ps_pbh(data1_512_high, data1_512_low));
+ }
+
+ /* Processing the remaining <32 chunk with 8-elements per round */
+ for (int j = tail_index_32; j < tail_index_8; j += 8) {
+ _mm_storeu_si128((__m128i *)&out[j], (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(_mm512_load_pd(&in[j]))));
+ }
+
+ /* Processing the remaining <8 chunk with masked processing */
+ if ((n&7) > 0) {
+ __m512d data_512 = _mm512_maskz_load_pd(*((__mmask8*) &tail_mask8), &in[tail_index_8]);
+ _mm_mask_storeu_epi16(&out[tail_index_8], *((__mmask8*) &tail_mask8), (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(data_512)));
+ }
+}
+
+#endif
diff --git a/kernel/x86_64/shdot.c b/kernel/x86_64/shdot.c
new file mode 100644
index 000000000..5073fda2a
--- /dev/null
+++ b/kernel/x86_64/shdot.c
@@ -0,0 +1,115 @@
+/***************************************************************************
+Copyright (c) 2014, The OpenBLAS Project
+All rights reserved.
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+1. Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in
+the documentation and/or other materials provided with the
+distribution.
+3. Neither the name of the OpenBLAS project nor the names of
+its contributors may be used to endorse or promote products
+derived from this software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*****************************************************************************/
+
+#include "common.h"
+
+#if defined(COOPERLAKE)
+#include "shdot_microk_cooperlake.c"
+#endif
+
+static float shdot_compute(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y)
+{
+ float d = 0.0;
+
+#ifdef HAVE_SHDOT_ACCL_KERNEL
+ if ((inc_x == 1) && (inc_y == 1)) {
+ return shdot_accl_kernel(n, x, y);
+ }
+#endif
+
+ float * x_fp32 = malloc(sizeof(float)*n);
+ float * y_fp32 = malloc(sizeof(float)*n);
+
+ SBF16TOS_K(n, x, inc_x, x_fp32, 1);
+ SBF16TOS_K(n, y, inc_y, y_fp32, 1);
+
+ d = SDOTU_K(n, x_fp32, 1, y_fp32, 1);
+
+ free(x_fp32);
+ free(y_fp32);
+
+ return d;
+}
+
+#if defined(SMP)
+static int shdot_thread_func(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, bfloat16 dummy2,
+ bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y,
+ float *result, BLASLONG dummy3)
+{
+ *(float *)result = shdot_compute(n, x, inc_x, y, inc_y);
+ return 0;
+}
+
+extern int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha,
+ void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc,
+ int (*function)(), int nthreads);
+#endif
+
+float CNAME(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y)
+{
+ float dot_result = 0.0;
+
+ if (n <= 0) return 0.0;
+
+#if defined(SMP)
+ int nthreads;
+ int thread_thres = 40960;
+ bfloat16 dummy_alpha;
+#endif
+
+#if defined(SMP)
+ if (inc_x == 0 || inc_y == 0 || n <= thread_thres)
+ nthreads = 1;
+ else
+ nthreads = num_cpu_avail(1);
+
+ int best_threads = (int) (n/(float)thread_thres + 0.5);
+
+ if (best_threads < nthreads) {
+ nthreads = best_threads;
+ }
+
+ if (nthreads <= 1) {
+ dot_result = shdot_compute(n, x, inc_x, y, inc_y);
+ } else {
+ char thread_result[MAX_CPU_NUMBER * sizeof(double) * 2];
+ int mode = BLAS_BFLOAT16 | BLAS_REAL;
+ blas_level1_thread_with_return_value(mode, n, 0, 0, &dummy_alpha,
+ x, inc_x, y, inc_y, thread_result, 0,
+ (void *)shdot_thread_func, nthreads);
+ float * ptr = (float *)thread_result;
+ for (int i = 0; i < nthreads; i++) {
+ dot_result += (*ptr);
+ ptr = (float *)(((char *)ptr) + sizeof(double) * 2);
+ }
+ }
+#else
+ dot_result = shdot_compute(n, x, inc_x, y, inc_y);
+#endif
+
+ return dot_result;
+}
diff --git a/kernel/x86_64/shdot_microk_cooperlake.c b/kernel/x86_64/shdot_microk_cooperlake.c
new file mode 100644
index 000000000..e645296f1
--- /dev/null
+++ b/kernel/x86_64/shdot_microk_cooperlake.c
@@ -0,0 +1,159 @@
+/***************************************************************************
+Copyright (c) 2014, The OpenBLAS Project
+All rights reserved.
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+1. Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in
+the documentation and/or other materials provided with the
+distribution.
+3. Neither the name of the OpenBLAS project nor the names of
+its contributors may be used to endorse or promote products
+derived from this software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*****************************************************************************/
+
+/* need a new enough GCC for avx512 support */
+#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9))
+
+#define HAVE_SHDOT_ACCL_KERNEL 1
+#include "common.h"
+#include <immintrin.h>
+
+static float shdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y)
+{
+ __m128 accum128 = _mm_setzero_ps();
+ if (n> 127) { /* n range from 128 to inf. */
+ long tail_index_32 = n&(~31);
+ long tail_index_128 = n&(~127);
+ unsigned int tail_mask_uint = (((unsigned int)0xffffffff) >> (32-(n&31)));
+ __mmask32 tail_mask = *((__mmask32*) &tail_mask_uint);
+
+ __m512 accum512_0 = _mm512_setzero_ps();
+ __m512 accum512_1 = _mm512_setzero_ps();
+ __m512 accum512_2 = _mm512_setzero_ps();
+ __m512 accum512_3 = _mm512_setzero_ps();
+
+ /* Processing the main chunk with 128-elements per round */
+ for (long i = 0; i < tail_index_128; i += 128) {
+ accum512_0 = _mm512_dpbf16_ps(accum512_0, (__m512bh) _mm512_loadu_si512(&x[i+ 0]), (__m512bh) _mm512_loadu_si512(&y[i+ 0]));
+ accum512_1 = _mm512_dpbf16_ps(accum512_1, (__m512bh) _mm512_loadu_si512(&x[i+32]), (__m512bh) _mm512_loadu_si512(&y[i+32]));
+ accum512_2 = _mm512_dpbf16_ps(accum512_2, (__m512bh) _mm512_loadu_si512(&x[i+64]), (__m512bh) _mm512_loadu_si512(&y[i+64]));
+ accum512_3 = _mm512_dpbf16_ps(accum512_3, (__m512bh) _mm512_loadu_si512(&x[i+96]), (__m512bh) _mm512_loadu_si512(&y[i+96]));
+ }
+
+ /* Processing the remaining <128 chunk with 32-elements per round */
+ for (long j = tail_index_128; j < tail_index_32; j += 32) {
+ accum512_0 = _mm512_dpbf16_ps(accum512_0, (__m512bh) _mm512_loadu_si512(&x[j]), (__m512bh) _mm512_loadu_si512(&y[j]));
+ }
+
+ /* Processing the remaining <32 chunk with masked 32-elements processing */
+ if ((n&31) != 0) {
+ accum512_2 = _mm512_dpbf16_ps(accum512_2,
+ (__m512bh) _mm512_maskz_loadu_epi16(tail_mask, &x[tail_index_32]),
+ (__m512bh) _mm512_maskz_loadu_epi16(tail_mask, &y[tail_index_32]));
+ }
+
+ /* Accumulate the 4 registers into 1 register */
+ accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
+ accum512_2 = _mm512_add_ps(accum512_2, accum512_3);
+ accum512_0 = _mm512_add_ps(accum512_0, accum512_2);
+
+ __m256 accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
+ accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1));
+ } else if (n > 31) { /* n range from 32 to 127 */
+ /* Processing <128 chunk with 32-elements per round */
+ __m256 accum256 = _mm256_setzero_ps();
+ __m256 accum256_1 = _mm256_setzero_ps();
+ int tail_index_32 = n&(~31);
+ for (int j = 0; j < tail_index_32; j += 32) {
+ accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[j+ 0]), (__m256bh) _mm256_loadu_si256(&y[j+ 0]));
+ accum256_1 = _mm256_dpbf16_ps(accum256_1, (__m256bh) _mm256_loadu_si256(&x[j+16]), (__m256bh) _mm256_loadu_si256(&y[j+16]));
+ }
+ accum256 = _mm256_add_ps(accum256, accum256_1);
+
+ /* Processing the remaining <32 chunk with 16-elements processing */
+ if ((n&16) != 0) {
+ accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[tail_index_32]), (__m256bh) _mm256_loadu_si256(&y[tail_index_32]));
+ }
+ accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1));
+
+ /* Processing the remaining <16 chunk with 8-elements processing */
+ if ((n&8) != 0) {
+ int tail_index_16 = n&(~15);
+ accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16]));
+ }
+
+ /* Processing the remaining <8 chunk with masked 8-elements processing */
+ if ((n&7) != 0) {
+ unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7)));
+ __mmask8 tail_mask = *((__mmask8*) &tail_mask_uint);
+ int tail_index_8 = n&(~7);
+ accum128 = _mm_dpbf16_ps(accum128,
+ (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]),
+ (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8]));
+ }
+ } else if (n > 15) { /* n range from 16 to 31 */
+ /* Processing <32 chunk with 16-elements processing */
+ __m256 accum256 = _mm256_setzero_ps();
+ accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[0]), (__m256bh) _mm256_loadu_si256(&y[0]));
+ accum128 += _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1));
+
+ /* Processing the remaining <16 chunk with 8-elements processing */
+ if ((n&8) != 0) {
+ int tail_index_16 = n&(~15);
+ accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16]));
+ }
+
+ /* Processing the remaining <8 chunk with masked 8-elements processing */
+ if ((n&7) != 0) {
+ unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7)));
+ __mmask8 tail_mask = *((__mmask8*) &tail_mask_uint);
+ int tail_index_8 = n&(~7);
+ accum128 = _mm_dpbf16_ps(accum128,
+ (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]),
+ (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8]));
+ }
+ } else if (n > 7) { /* n range from 8 to 15 */
+ /* Processing <16 chunk with 8-elements processing */
+ accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[0]), (__m128bh) _mm_loadu_si128(&y[0]));
+
+ /* Processing the remaining <8 chunk with masked 8-elements processing */
+ if ((n&7) != 0) {
+ unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7)));
+ __mmask8 tail_mask = *((__mmask8*) &tail_mask_uint);
+ int tail_index_8 = n&(~7);
+ accum128 = _mm_dpbf16_ps(accum128,
+ (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]),
+ (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8]));
+ }
+ } else { /* n range from 1 to 7 */
+ unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7)));
+ __mmask8 tail_mask = *((__mmask8*) &tail_mask_uint);
+ accum128 = _mm_dpbf16_ps(accum128,
+ (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[0]),
+ (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[0]));
+ }
+
+ /* Add up the 4 elements into lowest entry */
+ __m128 accum128_1 = _mm_shuffle_ps(accum128, accum128, 14);
+ accum128 = _mm_add_ps(accum128, accum128_1);
+ accum128_1 = _mm_shuffle_ps(accum128, accum128, 1);
+ accum128 = _mm_add_ps(accum128, accum128_1);
+
+ return accum128[0];
+}
+
+#endif
diff --git a/kernel/x86_64/stobf16_microk_cooperlake.c b/kernel/x86_64/stobf16_microk_cooperlake.c
new file mode 100644
index 000000000..2756a6934
--- /dev/null
+++ b/kernel/x86_64/stobf16_microk_cooperlake.c
@@ -0,0 +1,86 @@
+/***************************************************************************
+Copyright (c) 2014, The OpenBLAS Project
+All rights reserved.
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+1. Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in
+the documentation and/or other materials provided with the
+distribution.
+3. Neither the name of the OpenBLAS project nor the names of
+its contributors may be used to endorse or promote products
+derived from this software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*****************************************************************************/
+
+/* need a new enough GCC for avx512 support */
+#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9))
+
+#define HAVE_TOBF16_ACCL_KERNEL 1
+#include "common.h"
+#include <immintrin.h>
+
+static void tobf16_accl_kernel(BLASLONG n, const float * in, bfloat16 * out)
+{
+ /* Get the 64-bytes unaligned header number targeting for avx512
+ * processing (Assume input float array is natural aligned) */
+ int align_header = ((64 - ((uintptr_t)in & (uintptr_t)0x3f)) >> 2) & 0xf;
+
+ if (n < align_header) {align_header = n;}
+
+ if (align_header != 0) {
+ uint16_t align_mask16 = (((uint16_t)0xffff) >> (16-align_header));
+ __m512 a = _mm512_maskz_loadu_ps(*((__mmask16*) &align_mask16), &in[0]);
+ _mm256_mask_storeu_epi16(&out[0], *((__mmask16*) &align_mask16), (__m256i) _mm512_cvtneps_pbh(a));
+ }
+
+ if (n == align_header) {
+ return;
+ } else {
+ n -= align_header;
+ in += align_header;
+ out += align_header;
+ }
+
+ int tail_index_32 = n&(~31);
+ int tail_index_128 = n&(~127);
+ uint32_t tail_mask32 = (((uint32_t) 0xffffffff) >> (32-(n&31)));
+ uint16_t tail_mask16 = (((uint16_t) 0xffff) >> (16-(n&15)));
+
+ /* Processing the main chunk with 128-elements per round */
+ for (int i = 0; i < tail_index_128; i += 128) {
+ _mm512_storeu_si512(&out[i+ 0], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 16]), _mm512_load_ps(&in[i+ 0])));
+ _mm512_storeu_si512(&out[i+32], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 48]), _mm512_load_ps(&in[i+32])));
+ _mm512_storeu_si512(&out[i+64], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 80]), _mm512_load_ps(&in[i+64])));
+ _mm512_storeu_si512(&out[i+96], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+112]), _mm512_load_ps(&in[i+96])));
+ }
+
+ /* Processing the remaining <128 chunk with 32-elements per round */
+ for (int j = tail_index_128; j < tail_index_32; j += 32) {
+ _mm512_storeu_si512(&out[j], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[j+ 16]), _mm512_load_ps(&in[j])));
+ }
+
+ /* Processing the remaining <32 chunk with masked processing */
+ if ((n&31) > 15) {
+ __m512 b = _mm512_load_ps(&in[tail_index_32]);
+ __m512 a = _mm512_maskz_load_ps(*((__mmask16*) &tail_mask16), &in[tail_index_32+16]);
+ _mm512_mask_storeu_epi16(&out[tail_index_32], *((__mmask32*) &tail_mask32), (__m512i) _mm512_cvtne2ps_pbh(a, b));
+ } else if ((n&31) > 0) {
+ __m512 a = _mm512_maskz_load_ps(*((__mmask16*) &tail_mask16), &in[tail_index_32]);
+ _mm256_mask_storeu_epi16(&out[tail_index_32], *((__mmask16*) &tail_mask16), (__m256i) _mm512_cvtneps_pbh(a));
+ }
+}
+
+#endif
diff --git a/kernel/x86_64/tobf16.c b/kernel/x86_64/tobf16.c
new file mode 100644
index 000000000..3d1796621
--- /dev/null
+++ b/kernel/x86_64/tobf16.c
@@ -0,0 +1,170 @@
+/***************************************************************************
+Copyright (c) 2014, The OpenBLAS Project
+All rights reserved.
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+1. Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in
+the documentation and/or other materials provided with the
+distribution.
+3. Neither the name of the OpenBLAS project nor the names of
+its contributors may be used to endorse or promote products
+derived from this software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*****************************************************************************/
+
+#include <stddef.h>
+#include "common.h"
+
+#if defined(DOUBLE)
+#define FLOAT_TYPE double
+#elif defined(SINGLE)
+#define FLOAT_TYPE float
+#else
+#endif
+
+#if defined(COOPERLAKE)
+#if defined(DOUBLE)
+#include "dtobf16_microk_cooperlake.c"
+#elif defined(SINGLE)
+#include "stobf16_microk_cooperlake.c"
+#endif
+#endif
+
+/* Notes for algorithm:
+ * - Round to Nearest Even used generally
+ * - QNAN for NAN case
+ * - Input denormals are treated as zero
+ */
+static void tobf16_generic_kernel(BLASLONG n, const FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out)
+{
+ BLASLONG register index_in = 0;
+ BLASLONG register index_out = 0;
+ BLASLONG register index = 0;
+ float float_in = 0.0;
+ uint32_t * uint32_in = (uint32_t *)(&float_in);
+ uint16_t * uint16_in = (uint16_t *)(&float_in);
+
+ while(index<n) {
+#if defined(DOUBLE)
+ float_in = (float)(*(in+index_in));
+#else
+ float_in = *(in+index_in);
+#endif
+
+ switch((*uint32_in) & 0xff800000u) {
+ case (0x00000000u): /* Type 1: Positive denormal */
+ *(out+index_out) = 0x0000u;
+ break;
+ case (0x80000000u): /* Type 2: Negative denormal */
+ *(out+index_out) = 0x8000u;
+ break;
+ case (0x7f800000u): /* Type 3: Positive infinity or NAN */
+ case (0xff800000u): /* Type 4: Negative infinity or NAN */
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
+ *(out+index_out) = uint16_in[1];
+#else
+ *(out+index_out) = uint16_in[0];
+#endif
+ /* Specific for NAN */
+ if (((*uint32_in) & 0x007fffffu) != 0) {
+ /* Force to be QNAN */
+ *(out+index_out) |= 0x0040u;
+ }
+ break;
+ default: /* Type 5: Normal case */
+ (*uint32_in) += ((((*uint32_in) >> 16) & 0x1u) + 0x7fffu);
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
+ *(out+index_out) = uint16_in[1];
+#else
+ *(out+index_out) = uint16_in[0];
+#endif
+ break;
+ }
+
+ index_in += inc_in;
+ index_out += inc_out;
+ index++;
+ }
+}
+
+#ifndef HAVE_TOBF16_ACCL_KERNEL
+static void tobf16_accl_kernel(BLASLONG n, const FLOAT_TYPE * in, bfloat16 * out)
+{
+ tobf16_generic_kernel(n, in, 1, out, 1);
+}
+#endif
+
+static void tobf16_compute(BLASLONG n, FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out)
+{
+ if ((inc_in == 1) && (inc_out == 1)) {
+ tobf16_accl_kernel(n, in, out);
+ } else {
+ tobf16_generic_kernel(n, in, inc_in, out, inc_out);
+ }
+}
+
+#if defined(SMP)
+static int tobf16_thread_func(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT_TYPE dummy2,
+ FLOAT_TYPE *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y,
+ FLOAT_TYPE *dummy3, BLASLONG dummy4)
+{
+ tobf16_compute(n, x, inc_x, y, inc_y);
+ return 0;
+}
+
+extern int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha,
+ void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc,
+ int (*function)(), int nthreads);
+#endif
+
+void CNAME(BLASLONG n, FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out)
+{
+ if (n <= 0) return;
+
+#if defined(SMP)
+ int nthreads;
+ FLOAT_TYPE dummy_alpha;
+ FLOAT_TYPE dummy_c;
+#endif
+
+#if defined(SMP)
+ if (inc_in == 0 || inc_out == 0 || n <= 100000) {
+ nthreads = 1;
+ } else {
+ if (n/100000 < 100) {
+ nthreads = 4;
+ } else {
+ nthreads = 16;
+ }
+ }
+
+ if (nthreads == 1) {
+ tobf16_compute(n, in, inc_in, out, inc_out);
+ } else {
+#if defined(DOUBLE)
+ int mode = BLAS_REAL | BLAS_DTOBF16;
+#elif defined(SINGLE)
+ int mode = BLAS_REAL | BLAS_STOBF16;
+#endif
+ blas_level1_thread(mode, n, 0, 0, &dummy_alpha,
+ in, inc_in, out, inc_out, &dummy_c, 0,
+ (void *)tobf16_thread_func, nthreads);
+ }
+#else
+ tobf16_compute(n, in, inc_in, out, inc_out);
+#endif
+
+}
diff --git a/openblas_config_template.h b/openblas_config_template.h
index 9955e5c73..858b8c5cb 100644
--- a/openblas_config_template.h
+++ b/openblas_config_template.h
@@ -35,7 +35,8 @@ typedef unsigned long BLASULONG;
#endif
#ifndef BFLOAT16
-typedef unsigned short bfloat16;
+#include <stdint.h>
+typedef uint16_t bfloat16;
#endif
#ifdef OPENBLAS_USE64BITINT