summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm
diff options
context:
space:
mode:
authorAlexey Suhov <asuhov@users.noreply.github.com>2019-01-21 21:31:31 +0300
committeropenvino-pushbot <44090433+openvino-pushbot@users.noreply.github.com>2019-01-21 21:31:31 +0300
commit9de27f16bc8b712a5b8c99d1d4b4a66c9144942d (patch)
tree01a383efe94d92b9870d513c2c5ea5d15b07010a /inference-engine/thirdparty/mkl-dnn/src/cpu/gemm
parentfbc7a4a710c24def8ab199926a7da90a0394b87d (diff)
downloaddldt-9de27f16bc8b712a5b8c99d1d4b4a66c9144942d.tar.gz
dldt-9de27f16bc8b712a5b8c99d1d4b4a66c9144942d.tar.bz2
dldt-9de27f16bc8b712a5b8c99d1d4b4a66c9144942d.zip
Publishing R5 content (#72)
* Publishing R5 content * Updated ade revision * updated readme * add possibility to build CPU plugin with Intel MKL package
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/gemm')
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm.cpp157
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm.hpp17
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm_utils.cpp9
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm_utils.hpp30
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp195
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.hpp1
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx_gemm_f32.cpp231
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx_gemm_f32.hpp1
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/ref_gemm.cpp194
9 files changed, 516 insertions, 319 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm.cpp
index fe8412194..146e68887 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm.cpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm.cpp
@@ -25,6 +25,8 @@
#include "../jit_generator.hpp"
#include "nstl.hpp"
#include "os_blas.hpp"
+#include "math_utils.hpp"
+#include "mkldnn_traits.hpp"
/* USE_MKL USE_CBLAS effect
* ------- --------- ------
@@ -52,6 +54,7 @@ mkldnn_status_t check_gemm_input(const char *transa, const char *transb,
&& *M >= 0
&& *N >= 0
&& *K >= 0;
+
if (!consistency) return invalid_arguments;
bool isTransA = utils::one_of(*transa, 'T', 't');
bool isTransB = utils::one_of(*transb, 'T', 't');
@@ -66,6 +69,19 @@ mkldnn_status_t check_gemm_input(const char *transa, const char *transb,
return success;
}
+mkldnn_status_t check_gemm_x8x8x32_input(const char *offsetc,
+ const char *transa, const char *transb, const int *M, const int *N,
+ const int *K, const int *lda, const int *ldb, const int *ldc,
+ const float *alpha, const float *beta, const bool with_bias) {
+
+ if (offsetc == nullptr) return invalid_arguments;
+ if (!utils::one_of(*offsetc, 'F', 'f', 'C', 'c', 'R', 'r'))
+ return invalid_arguments;
+
+ return check_gemm_input(transa, transb, M, N, K, lda, ldb, ldc, alpha,
+ beta, with_bias);
+}
+
struct gemm_impl_t {
gemm_impl_t(char transa, char transb, bool zero_beta, bool with_bias) {
//jit kernel has three codepaths: beta is 0, 1 or arbitrary
@@ -132,7 +148,7 @@ mkldnn_status_t extended_sgemm(const char *transa, const char *transb,
const int *M, const int *N, const int *K, const float *alpha,
const float *A, const int *lda, const float *B, const int *ldb,
const float *beta, float *C, const int *ldc,
- const float *bias) {
+ const float *bias, const bool force_jit_gemm) {
//Check input
mkldnn_status_t status = check_gemm_input(transa, transb, M, N, K,
lda, ldb, ldc, alpha, beta, bias != nullptr);
@@ -143,20 +159,22 @@ mkldnn_status_t extended_sgemm(const char *transa, const char *transb,
int trA = *transa == 't' || *transa == 'T';
int trB = *transb == 't' || *transb == 'T';
#ifdef USE_CBLAS
- //Call cblas
- CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans;
- CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans;
- cblas_sgemm(CblasColMajor, Cblas_trA, Cblas_trB,
- *M, *N, *K, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
- //Add bias if necessary (bias is applied to columns of C)
- if (bias) {
- cblas_int incx = 1, incy = 1;
- parallel_nd(*N, [&](int n) {
- cblas_saxpy(*M, 1.0, bias, incx, C + n*(*ldc), incy);
- });
+ if (!force_jit_gemm) {
+ //Call cblas
+ CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans;
+ CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans;
+ cblas_sgemm(CblasColMajor, Cblas_trA, Cblas_trB,
+ *M, *N, *K, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
+ //Add bias if necessary (bias is applied to columns of C)
+ if (bias) {
+ cblas_int incx = 1, incy = 1;
+ parallel_nd(*N, [&](int n) {
+ cblas_saxpy(*M, 1.0, bias, incx, C + n*(*ldc), incy);
+ });
+ }
+ return mkldnn_success;
}
- return mkldnn_success;
-#else
+#endif
//Generate jit kernel and call sgemm with bias
volatile static int initialized = 0;
if (!initialized) {
@@ -176,9 +194,98 @@ mkldnn_status_t extended_sgemm(const char *transa, const char *transb,
transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
return mkldnn_success;
-#endif
}
+template <typename b_dt>
+mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb,
+ const char *offsetc, const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
+ const b_dt *B, const int *LDB, const int8_t *bo, const float *beta,
+ int32_t *C, const int *LDC, const int32_t *co) {
+
+ mkldnn_status_t status = check_gemm_x8x8x32_input(offsetc, transa, transb,
+ M, N, K, LDA, LDB, LDC, alpha, beta, false);
+
+ if (status != mkldnn_success)
+ return status;
+
+ if (*M == 0 || *N == 0 || *K == 0)
+ return mkldnn_success;
+
+ bool OCisR = (*offsetc == 'R' || *offsetc == 'r');
+ bool OCisC = (*offsetc == 'C' || *offsetc == 'c');
+ bool AisN = (*transa == 'N' || *transa == 'n');
+ bool BisN = (*transb == 'N' || *transb == 'n');
+
+#if defined(USE_MKL) && defined(USE_CBLAS)
+ if (data_traits<b_dt>::data_type == data_type::u8) {
+ CBLAS_TRANSPOSE Cblas_trA = AisN ? CblasNoTrans : CblasTrans;
+ CBLAS_TRANSPOSE Cblas_trB = BisN ? CblasNoTrans : CblasTrans;
+ CBLAS_OFFSET Cblas_offsetc =
+ OCisR
+ ? CblasRowOffset
+ : OCisC
+ ? CblasColOffset
+ : CblasFixOffset;
+ cblas_gemm_s8u8s32(CblasColMajor, Cblas_trA, Cblas_trB, Cblas_offsetc,
+ *M, *N, *K, *alpha, A, *LDA, *ao, (b_dt*)B, *LDB, *bo, *beta, C, *LDC, co);
+ return mkldnn_success;
+ }
+#endif
+ int m = *M, n = *N, k = *K, lda = *LDA, ldb = *LDB, ldc = *LDC;
+ size_t sizeA = AisN ? lda * k : lda * m;
+ size_t sizeB = BisN ? ldb * n : ldb * k;
+ size_t sizeC = ldc * n;
+
+ double *dA = (double *)malloc(sizeA * sizeof(double), PAGE_4K);
+ double *dB = (double *)malloc(sizeB * sizeof(double), PAGE_4K);
+ double *dC = (double *)malloc(sizeC * sizeof(double), PAGE_4K);
+
+ if (utils::any_null(dA, dB, dC)) {
+ free(dA);
+ free(dB);
+ free(dC);
+ return mkldnn_out_of_memory;
+ }
+
+ auto da_setter = [=] (int i, int j, double v) { dA[j * lda + i] = v; };
+ auto db_setter = [=] (int i, int j, double v) { dB[j * ldb + i] = v; };
+
+ auto ia_accessor = [=] (int i, int j) { return A[j * lda + i]; };
+ auto ib_accessor = [=] (int i, int j) { return B[j * ldb + i]; };
+
+ const int a_rows = AisN ? m : k;
+ const int a_cols = AisN ? k : m;
+ mkldnn::impl::parallel_nd(a_cols, a_rows, [&](int j, int i) {
+ da_setter(i, j,
+ static_cast<double>(ia_accessor(i, j)) + static_cast<double>(ao[0]));
+ });
+
+ const int b_rows = BisN ? k : n;
+ const int b_cols = BisN ? n : k;
+ mkldnn::impl::parallel_nd(b_cols, b_rows, [&](int j, int i) {
+ db_setter(i, j,
+ static_cast<double>(ib_accessor(i, j)) + static_cast<double>(bo[0]));
+ });
+ double one = 1.0, zero = 0.0;
+ ref_gemm<double>(transa, transb, M, N, K, &one, dA, LDA, dB, LDB, &zero,
+ dC, LDC, nullptr);
+
+ auto i2d = [=] (int32_t v) { return static_cast<double>(v); };
+ auto f2d = [=] (float v) { return static_cast<double>(v); };
+
+ mkldnn::impl::parallel_nd(n, m, [&] (int j, int i) {
+ double coffset = OCisR ? i2d(co[j]) : OCisC ? i2d(co[i]) : i2d(co[0]);
+ double val = ((*beta == 0.0f) ? 0.0 : f2d(*beta) * i2d(C[i + j * ldc]))
+ + f2d(*alpha) * dC[i + j * ldc] + coffset;
+ C[i + j * ldc] = math::out_round<int32_t>(math::saturate<int32_t>(val));
+ });
+
+ free(dA);
+ free(dB);
+ free(dC);
+ return mkldnn_success;
+}
}
}
}
@@ -193,3 +300,23 @@ mkldnn_status_t mkldnn_sgemm(const char *transa, const char *transb,
return extended_sgemm(
transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}
+
+mkldnn_status_t mkldnn_gemm_s8u8s32(const char *transa, const char *transb,
+ const char *offsetc, const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *lda, const int8_t *ao,
+ const uint8_t *B, const int *ldb, const int8_t *bo, const float *beta,
+ int32_t *c, const int *ldc, const int32_t *co) {
+ return gemm_s8x8s32(
+ transa, transb, offsetc, M, N, K, alpha, A, lda, ao, B, ldb, bo,
+ beta, c, ldc, co);
+}
+
+mkldnn_status_t mkldnn_gemm_s8s8s32(const char *transa, const char *transb,
+ const char *offsetc, const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *lda, const int8_t *ao,
+ const int8_t *B, const int *ldb, const int8_t *bo, const float *beta,
+ int32_t *c, const int *ldc, const int32_t *co) {
+ return gemm_s8x8s32(
+ transa, transb, offsetc, M, N, K, alpha, A, lda, ao, B, ldb, bo,
+ beta, c, ldc, co);
+}
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm.hpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm.hpp
index 8917de157..3f33a3713 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm.hpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm.hpp
@@ -22,11 +22,20 @@ mkldnn_status_t extended_sgemm(const char *transa, const char *transb,
const int *M, const int *N, const int *K, const float *alpha,
const float *A, const int *lda, const float *B, const int *ldb,
const float *beta, float *C, const int *ldc,
- const float *bias = nullptr);
+ const float *bias = nullptr, bool force_jit_gemm = false);
+
+template <typename b_dt>
+mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb,
+ const char *offsetc, const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *lda, const int8_t *ao,
+ const b_dt *B, const int *ldb, const int8_t *bo, const float *beta,
+ int32_t *c, const int *ldc, const int32_t *co);
+
+template <typename data_t>
void ref_gemm(const char *transa, const char *transb, const int *M,
- const int *N, const int *K, const float *alpha, const float *A,
- const int *lda, const float *B, const int *ldb, const float *beta,
- float *C, const int *ldc, const float *bias);
+ const int *N, const int *K, const data_t *alpha, const data_t *A,
+ const int *lda, const data_t *B, const int *ldb, const data_t *beta,
+ data_t *C, const int *ldc, const data_t *bias);
#ifdef USE_CBLAS
#define GEMM_IMPL_STR "gemm:blas"
#else
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm_utils.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm_utils.cpp
index 934ba81ce..e3b6cff8a 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm_utils.cpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm_utils.cpp
@@ -343,8 +343,9 @@ void partition_unit_diff(
// Sum the m*n values from p_src into p_dst, assuming the two-dimensional
// arrays have leading dimensions ld_src and ld_dst, respectively
+template<typename data_t>
void sum_two_matrices(
- int m, int n, float *p_src, int ld_src, float *p_dst, int ld_dst)
+ int m, int n, data_t *p_src, int ld_src, data_t *p_dst, int ld_dst)
{
int i, j;
for (j = 0; j < n; j++) {
@@ -353,6 +354,12 @@ void sum_two_matrices(
}
}
}
+
+template void sum_two_matrices<float>(
+ int m, int n, float *p_src, int ld_src, float *p_dst, int ld_dst);
+
+template void sum_two_matrices<double>(
+ int m, int n, double *p_src, int ld_src, double *p_dst, int ld_dst);
}
}
}
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm_utils.hpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm_utils.hpp
index 7a8f7fcbe..0888787b9 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm_utils.hpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm_utils.hpp
@@ -22,8 +22,34 @@ namespace impl {
namespace cpu {
namespace gemm_utils {
+
+template <typename T, bool isTransA, bool isTransB>
+struct gemm_traits {};
+
+template <bool isTransA, bool isTransB>
+struct gemm_traits<double, isTransA, isTransB> {
+ static constexpr int m = 8;
+ static constexpr int n = 6;
+ static constexpr int BM = 4032;
+ static constexpr int BN = isTransA ? 96 : 192;
+ static constexpr int BK = isTransB ? 96 : 512;
+};
+
+template <bool isTransA, bool isTransB>
+struct gemm_traits<float, isTransA, isTransB> {
+ static constexpr int m = 16;
+ static constexpr int n = 6;
+ static constexpr int BM = 4032;
+ static constexpr int BN = isTransA ? 96 : 48;
+ static constexpr int BK = isTransB ? 96 : 256;
+};
+
+template <typename T>
+using unroll_factor = gemm_traits<T, false, false>;
+
+template <typename data_type>
void sum_two_matrices(
- int m, int n, float *p_src, int ld_src, float *p_dst, int ld_dst);
+ int m, int n, data_type *p_src, int ld_src, data_type *p_dst, int ld_dst);
void calc_nthr_nocopy_avx512_common(int m,
int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k,
@@ -35,6 +61,8 @@ void calc_nthr_nocopy_avx(int m, int n, int k,
void partition_unit_diff(
int ithr, int nthr, int n, int *t_offset, int *t_block);
+
+inline double saturate(double value, double min, double max);
};
}
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp
index 08ba7b47d..8aee85fbd 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp
@@ -22,7 +22,7 @@
#include "gemm_utils.hpp"
#include "jit_avx512_common_gemm_f32.hpp"
-#define CACHE_LINE_SIZE 16
+#define CACHE_LINE_SIZE 64
namespace mkldnn {
namespace impl {
@@ -128,15 +128,16 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
// Function for packing if needed
auto do_pack = [&](int unroll_m) {
- inLocalLabel();
+ Label pack2, pack3, pack4, pack10;
+
mov(BO1, A);
lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
mov(LL, K);
sar(LL, 2);
- jle(".pack3", T_NEAR);
+ jle(pack3, T_NEAR);
align(16);
- L(".pack2");
+ L(pack2);
if (!isTransA) {
for (int i = 0; i < 4; i++) {
vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]);
@@ -216,16 +217,16 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
add(AO1, unroll_m * 4 * SIZE);
sub(LL, 1);
- jg(".pack2", T_NEAR);
+ jg(pack2, T_NEAR);
align(16);
- L(".pack3");
+ L(pack3);
mov(LL, K);
and_(LL, 3);
- jle(".pack10", T_NEAR);
+ jle(pack10, T_NEAR);
align(16);
- L(".pack4");
+ L(pack4);
if (!isTransA) {
vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]);
if (unroll_m > 16)
@@ -279,11 +280,10 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
add(AO1, unroll_m * SIZE);
sub(LL, 1);
- jg(".pack4", T_NEAR);
+ jg(pack4, T_NEAR);
align(16);
- L(".pack10");
- outLocalLabel();
+ L(pack10);
};
// Function to update C, covering masking and other considerations
@@ -617,8 +617,6 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
// Innerkernel; called by kernel
auto innerkernel = [&](int unroll_m, int unroll_n, bool isDirect,
bool isCopy, bool doCPrefetch, bool isUnmasked = true) {
- inLocalLabel();
-
for (int i = 0; i < 8; i++) {
if (!isDirect) {
prefetcht0(ptr[AO1
@@ -960,7 +958,6 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
}
sub(LL, 1);
- outLocalLabel();
};
// Main kernel; does prefetching and calls innerkernel
@@ -968,7 +965,6 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
// calling update
auto kernel = [&](int unroll_m, int unroll_n, bool isDirect,
bool isCopy, bool isUnmasked = true) {
- inLocalLabel();
if (!isDirect) {
lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
} else {
@@ -1020,36 +1016,38 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
}
}
+ Label kernel12, kernel13, kernel14, kernel15, kernel16, kernel18;
+
mov(LL, K);
sar(LL, 3);
sub(LL, SECOND_FETCH);
- jle(".kernel13", T_NEAR);
+ jle(kernel13, T_NEAR);
align(16);
- L(".kernel12");
+ L(kernel12);
innerkernel(
unroll_m, unroll_n, isDirect, isCopy, false, isUnmasked);
- jg(".kernel12", T_NEAR);
+ jg(kernel12, T_NEAR);
align(16);
- L(".kernel13");
+ L(kernel13);
lea(CO2, ptr[CO1 + (16 - 1) * SIZE]);
add(LL, unroll_n);
- jle(".kernel15", T_NEAR);
+ jle(kernel15, T_NEAR);
align(16);
- L(".kernel14");
+ L(kernel14);
innerkernel(unroll_m, unroll_n, isDirect, isCopy, true, isUnmasked);
- jg(".kernel14", T_NEAR);
+ jg(kernel14, T_NEAR);
align(16);
- L(".kernel15");
+ L(kernel15);
mov(LL, K);
and_(LL, 7);
- jle(".kernel18", T_NEAR);
+ jle(kernel18, T_NEAR);
align(16);
- L(".kernel16");
+ L(kernel16);
if (isDirect) {
if (isUnmasked || unroll_m > 16) {
vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
@@ -1204,10 +1202,10 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
}
sub(LL, 1);
- jg(".kernel16", T_NEAR);
+ jg(kernel16, T_NEAR);
align(16);
- L(".kernel18");
+ L(kernel18);
vbroadcastss(VALPHA, ALPHA);
if (isBetaN) {
@@ -1329,8 +1327,6 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
sub(BO1, rax);
add(BO1, unroll_n * SIZE);
}
-
- outLocalLabel();
};
// High-level subroutine; does packing if needed, then splits C matrix.
@@ -1338,11 +1334,16 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
// cases appropriately by doing 32 or 16 rows, and/or with masking,
// and/or fewer columns).
auto subloop = [&](int unroll_m) {
- inLocalLabel();
-
Label l_subloop_20x[8], l_subloop_mask_20x[8];
Label l_subloop_30x[8], l_subloop_mask_30x[8];
+ Label subloop11, subloop11mask;
+ Label subloop30, subloop30mask;
+ Label subloop31, subloop31mask;
+ Label subloop96;
+ Label subloop98, subloop98mask;
+ Label subloop99;
+
// Create mask
mov(BO1, rcx);
mov(rcx, M);
@@ -1370,7 +1371,7 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
and_(rax, 0xffff);
cmp(rax, 0xffff);
- jne(".subloop96", T_NEAR);
+ jne(subloop96, T_NEAR);
if (isTransA) {
do_pack(unroll_m);
@@ -1387,11 +1388,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
if (!isTransA) {
lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]);
cmp(M, UNROLL_M);
- jg(".subloop98", T_NEAR);
+ jg(subloop98, T_NEAR);
mov(AA, ORIG_A);
lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]);
- L(".subloop98");
+ L(subloop98);
}
mov(LL, N);
@@ -1399,11 +1400,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
if (!isTransA) {
// If N is too small, skip copy operation
cmp(LL, UNROLL_N * 3);
- jle(".subloop30", T_NEAR);
+ jle(subloop30, T_NEAR);
// If A is not aligned to cache line
cmp(FLAG, 0);
- je(".subloop30", T_NEAR);
+ je(subloop30, T_NEAR);
} else {
cmp(LL, UNROLL_N);
jl(l_subloop_20x[1], T_NEAR);
@@ -1421,11 +1422,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
jl(l_subloop_20x[1], T_NEAR);
align(16);
- L(".subloop11");
+ L(subloop11);
kernel(unroll_m, UNROLL_N, false, false);
sub(I, UNROLL_N);
cmp(I, UNROLL_N);
- jge(".subloop11", T_NEAR);
+ jge(subloop11, T_NEAR);
align(16);
for (int i = 1; i <= 7; i++) {
@@ -1434,24 +1435,24 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
if (i < 7) {
jne(l_subloop_20x[i + 1], T_NEAR);
} else {
- jne(".subloop99", T_NEAR);
+ jne(subloop99, T_NEAR);
}
kernel(unroll_m, i, false, false);
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
}
if (!isTransA) {
- L(".subloop30");
+ L(subloop30);
cmp(I, UNROLL_N);
jl(l_subloop_30x[1], T_NEAR);
align(16);
- L(".subloop31");
+ L(subloop31);
kernel(unroll_m, UNROLL_N, true, false);
sub(I, UNROLL_N);
cmp(I, UNROLL_N);
- jge(".subloop31", T_NEAR);
+ jge(subloop31, T_NEAR);
align(16);
for (int i = 1; i <= 7; i++) {
@@ -1460,18 +1461,18 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
if (i < 7) {
jne(l_subloop_30x[i + 1], T_NEAR);
} else {
- jne(".subloop99", T_NEAR);
+ jne(subloop99, T_NEAR);
}
kernel(unroll_m, i, true, false);
if (i < 7)
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
}
}
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
- L(".subloop96");
+ L(subloop96);
if (isTransA) {
do_pack(unroll_m);
}
@@ -1486,10 +1487,10 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
if (!isTransA) {
lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]);
cmp(M, UNROLL_M);
- jg(".subloop98mask", T_NEAR);
+ jg(subloop98mask, T_NEAR);
mov(AA, ORIG_A);
lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]);
- L(".subloop98mask");
+ L(subloop98mask);
}
mov(LL, N);
@@ -1497,11 +1498,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
if (!isTransA) {
// If N is too small, skip copy operation
cmp(LL, UNROLL_N * 3);
- jle(".subloop30mask", T_NEAR);
+ jle(subloop30mask, T_NEAR);
// If A is not aligned to cache line
cmp(FLAG, 0);
- je(".subloop30mask", T_NEAR);
+ je(subloop30mask, T_NEAR);
} else {
cmp(LL, UNROLL_N);
jl(l_subloop_mask_20x[1], T_NEAR);
@@ -1519,11 +1520,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
jl(l_subloop_mask_20x[1], T_NEAR);
align(16);
- L(".subloop11mask");
+ L(subloop11mask);
kernel(unroll_m, UNROLL_N, false, false, false);
sub(I, UNROLL_N);
cmp(I, UNROLL_N);
- jge(".subloop11mask", T_NEAR);
+ jge(subloop11mask, T_NEAR);
align(16);
for (int i = 1; i <= 7; i++) {
@@ -1532,24 +1533,24 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
if (i < 7) {
jne(l_subloop_mask_20x[i + 1], T_NEAR);
} else {
- jne(".subloop99", T_NEAR);
+ jne(subloop99, T_NEAR);
}
kernel(unroll_m, i, false, false, false);
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
}
if (!isTransA) {
- L(".subloop30mask");
+ L(subloop30mask);
cmp(I, UNROLL_N);
jl(l_subloop_mask_30x[1], T_NEAR);
align(16);
- L(".subloop31mask");
+ L(subloop31mask);
kernel(unroll_m, UNROLL_N, true, false, false);
sub(I, UNROLL_N);
cmp(I, UNROLL_N);
- jge(".subloop31mask", T_NEAR);
+ jge(subloop31mask, T_NEAR);
align(16);
for (int i = 1; i <= 7; i++) {
@@ -1558,16 +1559,16 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
if (i < 7) {
jne(l_subloop_mask_30x[i + 1], T_NEAR);
} else {
- jne(".subloop99", T_NEAR);
+ jne(subloop99, T_NEAR);
}
kernel(unroll_m, i, true, false, false);
if (i < 7)
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
}
}
- L(".subloop99");
+ L(subloop99);
// Compute address for A
if (!isTransA) {
add(A, unroll_m * SIZE);
@@ -1581,14 +1582,12 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
if (hasBias) {
add(BIAS, unroll_m * SIZE);
}
-
- outLocalLabel();
};
- inLocalLabel();
-
preamble();
+ Label buffer_in_ws, buffer_allocated;
+
// Get the registers
mov(B, ARG_B);
mov(LDB, ARG_LDB);
@@ -1608,7 +1607,7 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
#endif
cmp(K, STACK_K_CAPACITY);
- jg(".buffer_in_ws", T_NEAR);
+ jg(buffer_in_ws, T_NEAR);
// Create buffer and align to 4kB page
lea(rax, ptr[K * SIZE]);
@@ -1616,12 +1615,12 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
add(rax, 256);
sub(rsp, rax);
and_(rsp, -PAGE_4K);
- jmp(".buffer_allocated", T_NEAR);
+ jmp(buffer_allocated, T_NEAR);
- L(".buffer_in_ws");
+ L(buffer_in_ws);
mov(rsp, ARG_WS);
- L(".buffer_allocated");
+ L(buffer_allocated);
mov(ORIG_SP, rbp);
mov(M, ARG_M);
@@ -1665,40 +1664,40 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
}
}
+ Label main0, main1, main2, main999;
+
cmp(M, 32);
- jle(".main0", T_NEAR);
+ jle(main0, T_NEAR);
align(16);
- L(".main1");
+ L(main1);
subloop(48);
sub(M, UNROLL_M);
cmp(M, 32);
- jg(".main1", T_NEAR);
+ jg(main1, T_NEAR);
align(16);
- L(".main0");
+ L(main0);
cmp(M, 16);
- jle(".main2", T_NEAR);
+ jle(main2, T_NEAR);
subloop(32);
- jmp(".main999", T_NEAR);
+ jmp(main999, T_NEAR);
align(16);
- L(".main2");
+ L(main2);
cmp(M, 0);
- jle(".main999", T_NEAR);
+ jle(main999, T_NEAR);
subloop(16);
align(16);
- L(".main999");
+ L(main999);
// Restore original stack
mov(rsp, ORIG_SP);
vzeroupper();
postamble();
- outLocalLabel();
-
ker_ = reinterpret_cast<decltype(ker_)>(
const_cast<uint8_t *>(this->getCode()));
}
@@ -1763,7 +1762,7 @@ void jit_avx512_common_gemm_f32::sgemm_nocopy_driver(const char *transa,
if (!isTransA && !isTransB)
BK = 128;
}
- const float *curA, *curB, *curBias = NULL;
+ const float *curA, *curB, *curBias = nullptr;
float *curC;
for (Bk = 0; Bk < k; Bk += sizeK) {
@@ -1804,15 +1803,15 @@ void jit_avx512_common_gemm_f32::sgemm_nocopy_driver(const char *transa,
curB = b + Bn + (size_t)Bk * ldb;
}
curC = c + Bm + (size_t)Bn * ldc;
- if (bias != NULL) {
+ if (bias != nullptr) {
if (Bk == 0) {
curBias = bias + Bm;
} else {
- curBias = NULL;
+ curBias = nullptr;
}
}
if (Bk == 0) {
- if (*beta == 0.0 && bias == NULL)
+ if (*beta == 0.0 && bias == nullptr)
(*ker_b0_)((long long int)sizeM, (long long int)sizeN,
(long long int)sizeK, alpha, curA,
(long long int)lda, curB, (long long int)ldb,
@@ -1860,7 +1859,7 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb,
// Determine threading partitioning
gemm_utils::calc_nthr_nocopy_avx512_common(
m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
- assert(utils::implication(!mkldnn_thr_syncable(), nthr_k == 1));
+ assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
// May not happen, but just in case
if (nthr < nthr_m * nthr_n * nthr_k)
@@ -1868,13 +1867,18 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb,
nthr_mn = nthr_m * nthr_n;
- unsigned int volatile *ompstatus = (unsigned int volatile *)ompstatus_;
- if (!ompstatus) return;
+ unsigned char * ompstatus_ = nullptr;
+ unsigned char volatile *ompstatus = nullptr;
- float *c_buffers = NULL;
- float *ws_buffers = NULL;
+ float *c_buffers = nullptr;
+ float *ws_buffers = nullptr;
if (nthr_k > 1) {
+ ompstatus_ = (unsigned char *) malloc(
+ nthr * CACHE_LINE_SIZE,
+ CACHE_LINE_SIZE);
+ ompstatus = (unsigned char volatile *) ompstatus_;
+ assert(ompstatus);
for (int i = 0; i < nthr; i++)
ompstatus[i * CACHE_LINE_SIZE] = 0;
@@ -1895,7 +1899,7 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb,
int n_from, n_to, myN;
int k_from, k_to, myK;
int cbase, ibase;
- const float *myA, *myB, *myBias = NULL;
+ const float *myA, *myB, *myBias = nullptr;
float *myC = C, myBeta;
float *ws = ws_buffers ?
ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
@@ -1957,7 +1961,7 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb,
myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
myBeta = 0.0;
ld = MB;
- myBias = NULL;
+ myBias = nullptr;
}
sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
@@ -2004,8 +2008,8 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb,
}
});
- if (nthr_k > 1)
- free(c_buffers);
+ free(c_buffers);
+ free(ompstatus_);
free(ws_buffers);
}
@@ -2032,10 +2036,6 @@ jit_avx512_common_gemm_f32::jit_avx512_common_gemm_f32(
}
nthrs_ = mkldnn_get_max_threads();
- ompstatus_ = (unsigned int *)malloc(
- sizeof(unsigned int *) * nthrs_ * CACHE_LINE_SIZE, 64);
- assert(ompstatus_);
-
}
jit_avx512_common_gemm_f32::~jit_avx512_common_gemm_f32()
@@ -2045,7 +2045,6 @@ jit_avx512_common_gemm_f32::~jit_avx512_common_gemm_f32()
delete ker_b1_;
if (beta_ != 0.0 || (beta_ == 0.0 && hasBias_))
delete ker_b0_;
- free(ompstatus_);
}
}
}
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.hpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.hpp
index ede1cf9c1..c05733581 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.hpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.hpp
@@ -49,7 +49,6 @@ private:
bool hasBias_;
struct xbyak_gemm;
xbyak_gemm *ker_bn_, *ker_b1_, *ker_b0_;
- unsigned int *ompstatus_;
int nthrs_;
};
}
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx_gemm_f32.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx_gemm_f32.cpp
index 9766a46d7..354fa0bc7 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx_gemm_f32.cpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx_gemm_f32.cpp
@@ -21,7 +21,7 @@
#include "gemm_utils.hpp"
#include "jit_avx_gemm_f32.hpp"
-#define CACHE_LINE_SIZE 16
+#define CACHE_LINE_SIZE 64
namespace mkldnn {
namespace impl {
@@ -51,7 +51,7 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
: jit_generator(code_ptr, code_size)
{
const bool is_avx2 = mayiuse(avx2);
- assert(implication(!is_avx2, mayiuse(avx)));
+ assert(IMPLICATION(!is_avx2, mayiuse(avx)));
const int UNROLL_M = is_avx2 ? 16 : 8;
const int UNROLL_N = 6;
@@ -128,10 +128,10 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
// Function for packing if needed
auto do_pack = [&](
int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
+ Label pack2, pack3, pack4, pack10;
int regIdx;
Reg64 reg;
- inLocalLabel();
mov(BO1, A);
lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
@@ -144,10 +144,10 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
mov(LL, K);
sar(LL, 2);
- jle(".pack3", T_NEAR);
+ jle(pack3, T_NEAR);
align(16);
- L(".pack2");
+ L(pack2);
if (!isTransA) {
for (int i = 0; i < 4; i++) {
regIdx = (i % 2 == 0) ? 4 : 6;
@@ -396,16 +396,16 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
add(AO1, unroll_m * 4 * SIZE);
sub(LL, 1);
- jg(".pack2", T_NEAR);
+ jg(pack2, T_NEAR);
align(16);
- L(".pack3");
+ L(pack3);
mov(LL, K);
and_(LL, 3);
- jle(".pack10", T_NEAR);
+ jle(pack10, T_NEAR);
align(16);
- L(".pack4");
+ L(pack4);
if (!isTransA) {
if (isLoad1Unmasked) {
vmovups(ymm4, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
@@ -542,12 +542,10 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
add(AO1, unroll_m * SIZE);
sub(LL, 1);
- jg(".pack4", T_NEAR);
+ jg(pack4, T_NEAR);
align(16);
- L(".pack10");
-
- outLocalLabel();
+ L(pack10);
};
// Fused multiply add; may become one or two instructions
@@ -1382,8 +1380,6 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
Ymm reg15 = Ymm(7), Ymm reg16 = Ymm(8), Ymm reg17 = Ymm(9),
Ymm reg18 = Ymm(10), Ymm reg19 = Ymm(11), Ymm reg20 = Ymm(12),
Ymm reg21 = Ymm(13), Ymm reg22 = Ymm(14), Ymm reg23 = Ymm(15)) {
- inLocalLabel();
-
if (!isDirect) {
lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
} else {
@@ -1431,20 +1427,23 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
mov(LL, K);
sar(LL, 3);
+ Label kernel12, kernel13, kernel14, kernel15;
+ Label kernel16, kernel17, kernel18;
+
sub(LL, SECOND_FETCH);
- jle(".kernel13", T_NEAR);
+ jle(kernel13, T_NEAR);
align(16);
- L(".kernel12");
+ L(kernel12);
innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
reg21, reg22, reg23);
- jg(".kernel12", T_NEAR);
+ jg(kernel12, T_NEAR);
align(16);
- L(".kernel13");
+ L(kernel13);
prefetcht0(ptr[CO1 + (unroll_m - 1) * SIZE]);
if (unroll_n >= 2)
prefetcht0(ptr[CO1 + LDC + (unroll_m - 1) * SIZE]);
@@ -1458,30 +1457,30 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
prefetcht0(ptr[CO2 + LDC * 2 + (unroll_m - 1) * SIZE]);
add(LL, SECOND_FETCH);
- jle(".kernel15", T_NEAR);
+ jle(kernel15, T_NEAR);
align(16);
- L(".kernel14");
+ L(kernel14);
innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
reg21, reg22, reg23);
- jg(".kernel14", T_NEAR);
+ jg(kernel14, T_NEAR);
align(16);
- L(".kernel15");
+ L(kernel15);
test(K, 4);
- jle(".kernel16", T_NEAR);
+ jle(kernel16, T_NEAR);
innerkernel4(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
reg21, reg22, reg23);
- L(".kernel16");
+ L(kernel16);
test(K, 2);
- jle(".kernel17", T_NEAR);
+ jle(kernel17, T_NEAR);
innerkernel2(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
@@ -1489,7 +1488,7 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
reg21, reg22, reg23);
align(16);
- L(".kernel17");
+ L(kernel17);
if (unroll_m == 16) {
if (unroll_n <= 3) {
vaddps(reg00, reg00, reg12);
@@ -1511,13 +1510,13 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
}
test(K, 1);
- jle(".kernel18", T_NEAR);
+ jle(kernel18, T_NEAR);
innerkernel1(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
reg05, reg06, reg07, reg08, reg09, reg10, reg11);
align(16);
- L(".kernel18");
+ L(kernel18);
vbroadcastss(VALPHA, ALPHA);
if (isBetaN) {
@@ -1804,8 +1803,6 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
sub(BO1, rax);
add(BO1, unroll_n * SIZE);
}
-
- outLocalLabel();
};
auto kernel_16x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
@@ -1898,12 +1895,18 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
// Masking is used for tail cases where M is not divisible by 8.
auto subloop = [&](
int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
- inLocalLabel();
-
if (isTransA) {
do_pack(unroll_m, isLoad1Unmasked, isLoad2Unmasked);
}
+ Label subloop11, subloop11mask;
+ Label subloop20, subloop21, subloop22, subloop23;
+ Label subloop24, subloop25;
+ Label subloop30, subloop31, subloop32, subloop33;
+ Label subloop34, subloop35;
+ Label subloop98, subloop98mask;
+ Label subloop99, subloop99mask;
+
mov(CO1, C);
lea(CO2, ptr[CO1 + LDC * 2]);
add(CO2, LDC);
@@ -1916,11 +1919,11 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
if (!isTransA) {
lea(AA, ptr[A + (unroll_m * 2 - 1 - OFFSET) * SIZE]);
cmp(M, UNROLL_M);
- jg(".subloop98", T_NEAR);
+ jg(subloop98, T_NEAR);
mov(AA, ORIG_A);
lea(AA, ptr[AA + (unroll_m - 1 - OFFSET) * SIZE]);
- L(".subloop98");
+ L(subloop98);
}
mov(LL, N);
@@ -1928,14 +1931,14 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
if (!isTransA) {
// If N is too small, skip copy operation
cmp(LL, UNROLL_N * 3);
- jle(".subloop30", T_NEAR);
+ jle(subloop30, T_NEAR);
// If A is not aligned to cache line
cmp(FLAG, 0);
- je(".subloop30", T_NEAR);
+ je(subloop30, T_NEAR);
} else {
cmp(LL, UNROLL_N);
- jl(".subloop20", T_NEAR);
+ jl(subloop20, T_NEAR);
}
align(16);
@@ -1959,10 +1962,10 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
sub(I, UNROLL_N);
cmp(I, UNROLL_N);
- jl(".subloop20", T_NEAR);
+ jl(subloop20, T_NEAR);
align(16);
- L(".subloop11");
+ L(subloop11);
if (unroll_m == 16) {
kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
isLoad2Unmasked, false, false);
@@ -1972,12 +1975,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
}
sub(I, UNROLL_N);
cmp(I, UNROLL_N);
- jge(".subloop11", T_NEAR);
+ jge(subloop11, T_NEAR);
align(16);
- L(".subloop20");
+ L(subloop20);
cmp(I, 1);
- jne(".subloop21", T_NEAR);
+ jne(subloop21, T_NEAR);
if (unroll_m == 16) {
kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
false, false);
@@ -1985,12 +1988,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, false,
false);
}
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
- L(".subloop21");
+ L(subloop21);
cmp(I, 2);
- jne(".subloop22", T_NEAR);
+ jne(subloop22, T_NEAR);
if (unroll_m == 16) {
kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
false, false);
@@ -1998,12 +2001,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, false,
false);
}
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
- L(".subloop22");
+ L(subloop22);
cmp(I, 3);
- jne(".subloop23", T_NEAR);
+ jne(subloop23, T_NEAR);
if (unroll_m == 16) {
kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
false, false);
@@ -2011,12 +2014,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, false,
false);
}
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
- L(".subloop23");
+ L(subloop23);
cmp(I, 4);
- jne(".subloop24", T_NEAR);
+ jne(subloop24, T_NEAR);
if (unroll_m == 16) {
kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
false, false);
@@ -2024,12 +2027,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, false,
false);
}
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
- L(".subloop24");
+ L(subloop24);
cmp(I, 5);
- jne(".subloop99", T_NEAR);
+ jne(subloop99, T_NEAR);
if (unroll_m == 16) {
kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
false, false);
@@ -2037,16 +2040,16 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, false,
false);
}
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
if (!isTransA) {
- L(".subloop30");
+ L(subloop30);
cmp(I, UNROLL_N);
- jl(".subloop25", T_NEAR);
+ jl(subloop25, T_NEAR);
align(16);
- L(".subloop31");
+ L(subloop31);
if (unroll_m == 16) {
kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
isLoad2Unmasked, true, false);
@@ -2056,12 +2059,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
}
sub(I, UNROLL_N);
cmp(I, UNROLL_N);
- jge(".subloop31", T_NEAR);
+ jge(subloop31, T_NEAR);
align(16);
- L(".subloop25");
+ L(subloop25);
cmp(I, 1);
- jne(".subloop32", T_NEAR);
+ jne(subloop32, T_NEAR);
if (unroll_m == 16) {
kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
true, false);
@@ -2069,12 +2072,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
true, false);
}
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
- L(".subloop32");
+ L(subloop32);
cmp(I, 2);
- jne(".subloop33", T_NEAR);
+ jne(subloop33, T_NEAR);
if (unroll_m == 16) {
kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
true, false);
@@ -2082,12 +2085,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
true, false);
}
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
- L(".subloop33");
+ L(subloop33);
cmp(I, 3);
- jne(".subloop34", T_NEAR);
+ jne(subloop34, T_NEAR);
if (unroll_m == 16) {
kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
true, false);
@@ -2095,12 +2098,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
true, false);
}
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
- L(".subloop34");
+ L(subloop34);
cmp(I, 4);
- jne(".subloop35", T_NEAR);
+ jne(subloop35, T_NEAR);
if (unroll_m == 16) {
kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
true, false);
@@ -2108,12 +2111,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
true, false);
}
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
- L(".subloop35");
+ L(subloop35);
cmp(I, 5);
- jne(".subloop99", T_NEAR);
+ jne(subloop99, T_NEAR);
if (unroll_m == 16) {
kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
true, false);
@@ -2124,7 +2127,7 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
align(16);
}
- L(".subloop99");
+ L(subloop99);
// Compute address for A
if (!isTransA) {
add(A, unroll_m * SIZE);
@@ -2138,14 +2141,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
if (hasBias) {
add(BIAS, unroll_m * SIZE);
}
-
- outLocalLabel();
};
- inLocalLabel();
-
preamble();
+ Label buffer_in_ws, buffer_allocated;
+
// Get the registers
mov(B, ARG_B);
mov(LDB, ARG_LDB);
@@ -2165,7 +2166,7 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
#endif
cmp(K, STACK_K_CAPACITY);
- jg(".buffer_in_ws", T_NEAR);
+ jg(buffer_in_ws, T_NEAR);
// Create buffer and align to 4kB page
lea(rax, ptr[K * SIZE]);
@@ -2173,12 +2174,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
add(rax, 256);
sub(rsp, rax);
and_(rsp, -PAGE_4K);
- jmp(".buffer_allocated", T_NEAR);
+ jmp(buffer_allocated, T_NEAR);
- L(".buffer_in_ws");
+ L(buffer_in_ws);
mov(rsp, ARG_WS);
- L(".buffer_allocated");
+ L(buffer_allocated);
mov(ORIG_SP, rbp);
mov(M, ARG_M);
@@ -2218,43 +2219,45 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
and_(rax, 0x1f);
mov(FLAG, rax);
+ Label main0, main1, main2, main3, main999;
+
cmp(M, UNROLL_M);
- jl(".main0", T_NEAR);
+ jl(main0, T_NEAR);
align(16);
- L(".main1");
+ L(main1);
subloop(UNROLL_M, true, true);
sub(M, UNROLL_M);
cmp(M, UNROLL_M);
- jge(".main1", T_NEAR);
+ jge(main1, T_NEAR);
align(16);
- L(".main0");
+ L(main0);
cmp(M, 0);
- jle(".main999", T_NEAR);
+ jle(main999, T_NEAR);
if (UNROLL_M > 8) {
cmp(M, 8);
- jle(".main2", T_NEAR);
+ jle(main2, T_NEAR);
sub(M, 8);
vbroadcastss(VMASK, M);
vpcmpgtd(VMASK, VMASK, MASK);
subloop(16, true, false);
- jmp(".main999", T_NEAR);
+ jmp(main999, T_NEAR);
align(16);
- L(".main2");
+ L(main2);
cmp(M, 8);
- jne(".main3", T_NEAR);
+ jne(main3, T_NEAR);
subloop(8, true, true);
- jmp(".main999", T_NEAR);
+ jmp(main999, T_NEAR);
}
align(16);
- L(".main3");
+ L(main3);
vbroadcastss(VMASK, M);
if (is_avx2) {
vpcmpgtd(VMASK, VMASK, MASK);
@@ -2270,7 +2273,7 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
subloop(8, false, false);
align(16);
- L(".main999");
+ L(main999);
// Restore original stack
mov(rax, ORIG_SP);
mov(rsp, rax);
@@ -2278,8 +2281,6 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
vzeroupper();
postamble();
- outLocalLabel();
-
ker_ = reinterpret_cast<decltype(ker_)>(
const_cast<uint8_t *>(this->getCode()));
}
@@ -2335,7 +2336,7 @@ void jit_avx_gemm_f32::sgemm_nocopy_driver(const char *transa,
int BM = 4032;
int BN = isTransA ? 96 : 48;
int BK = isTransB ? 96 : 256;
- const float *curA, *curB, *curBias = NULL;
+ const float *curA, *curB, *curBias = nullptr;
float *curC;
for (Bk = 0; Bk < k; Bk += sizeK) {
@@ -2376,15 +2377,15 @@ void jit_avx_gemm_f32::sgemm_nocopy_driver(const char *transa,
curB = b + Bn + (size_t)Bk * ldb;
}
curC = c + Bm + (size_t)Bn * ldc;
- if (bias != NULL) {
+ if (bias != nullptr) {
if (Bk == 0) {
curBias = bias + Bm;
} else {
- curBias = NULL;
+ curBias = nullptr;
}
}
if (Bk == 0) {
- if (*beta == 0.0 && bias == NULL)
+ if (*beta == 0.0 && bias == nullptr)
(*ker_b0_)((long long int)sizeM, (long long int)sizeN,
(long long int)sizeK, alpha, curA,
(long long int)lda, curB, (long long int)ldb,
@@ -2431,7 +2432,7 @@ void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb,
// Determine threading partitioning
gemm_utils::calc_nthr_nocopy_avx(
m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
- assert(utils::implication(!mkldnn_thr_syncable(), nthr_k == 1));
+ assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
// May not happen, but just in case
if (nthr < nthr_m * nthr_n * nthr_k)
@@ -2439,13 +2440,19 @@ void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb,
nthr_mn = nthr_m * nthr_n;
- unsigned int volatile *ompstatus = (unsigned int volatile *)ompstatus_;
- if (!ompstatus) return;
+ unsigned char * ompstatus_ = nullptr;
+ unsigned char volatile *ompstatus = nullptr;
- float *c_buffers = NULL;
- float *ws_buffers = NULL;
+ float *c_buffers = nullptr;
+ float *ws_buffers = nullptr;
if (nthr_k > 1) {
+ ompstatus_ = (unsigned char *) malloc(
+ nthr * CACHE_LINE_SIZE,
+ CACHE_LINE_SIZE);
+ ompstatus = (unsigned char volatile *) ompstatus_;
+ assert(ompstatus);
+
for (int i = 0; i < nthr; i++)
ompstatus[i * CACHE_LINE_SIZE] = 0;
@@ -2466,7 +2473,7 @@ void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb,
int n_from, n_to, myN;
int k_from, k_to, myK;
int cbase, ibase;
- const float *myA, *myB, *myBias = NULL;
+ const float *myA, *myB, *myBias = nullptr;
float *myC = C, myBeta;
float *ws = ws_buffers ?
ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
@@ -2528,7 +2535,7 @@ void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb,
myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
myBeta = 0.0;
ld = MB;
- myBias = NULL;
+ myBias = nullptr;
}
sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
@@ -2575,8 +2582,8 @@ void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb,
}
});
- if (nthr_k > 1)
- free(c_buffers);
+ free(c_buffers);
+ free(ompstatus_);
free(ws_buffers);
}
@@ -2602,9 +2609,6 @@ jit_avx_gemm_f32::jit_avx_gemm_f32(
ker_b0_ = ker_bn_;
}
nthrs_ = mkldnn_get_max_threads();
- ompstatus_ = (unsigned int *)malloc(
- sizeof(unsigned int *) * nthrs_ * CACHE_LINE_SIZE, 64);
- assert(ompstatus_);
}
jit_avx_gemm_f32::~jit_avx_gemm_f32()
@@ -2614,7 +2618,6 @@ jit_avx_gemm_f32::~jit_avx_gemm_f32()
delete ker_b1_;
if (beta_ != 0.0 || (beta_ == 0.0 && hasBias_))
delete ker_b0_;
- free(ompstatus_);
}
}
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx_gemm_f32.hpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx_gemm_f32.hpp
index 0f0cc46ea..dd34e09f0 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx_gemm_f32.hpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx_gemm_f32.hpp
@@ -49,7 +49,6 @@ private:
bool hasBias_;
struct xbyak_gemm;
xbyak_gemm *ker_bn_, *ker_b1_, *ker_b0_;
- unsigned int *ompstatus_;
int nthrs_;
};
}
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/ref_gemm.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/ref_gemm.cpp
index 3310bf5e9..e0331e0ef 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/ref_gemm.cpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/ref_gemm.cpp
@@ -27,63 +27,68 @@ namespace impl {
namespace cpu {
using namespace mkldnn::impl::utils;
+using namespace gemm_utils;
-constexpr int unroll_m = 16;
-constexpr int unroll_n = 6;
+
+template <typename data_t>
static void copy_A(
- bool isTransA, int K, const float *A, const int lda, float *ws) {
+ bool isTransA, int K, const data_t *A, const int lda, data_t *ws) {
for (int k = 0; k < K; k++) {
PRAGMA_OMP_SIMD()
- for (int i = 0; i < unroll_m; i++) {
+ for (int i = 0; i < gemm_utils::unroll_factor<data_t>::m; i++) {
ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda];
}
- ws += unroll_m;
+ ws += unroll_factor<data_t>::m;
}
}
-template <bool isTransA, bool isTransB>
-static void kernel_mxn(int K, const float *A, const int lda,
- const float *B, const int ldb, float *C, const int ldc,
- const float alpha, const float beta) {
- float c[unroll_m * unroll_n] = { 0. };
+template <typename data_t, bool isTransA, bool isTransB>
+static void kernel_mxn(int K, const data_t *A, const int lda,
+ const data_t *B, const int ldb, data_t *C, const int ldc,
+ const data_t alpha, const data_t beta) {
+ data_t c[unroll_factor<data_t>::m * unroll_factor<data_t>::n] =
+ { static_cast<data_t>(0.) };
for (int k = 0; k < K; k++) {
- for (int j = 0; j < unroll_n; j++) {
- float b = isTransB ? B[j + k * ldb] : B[k + j * ldb];
+ for (int j = 0; j < unroll_factor<data_t>::n; j++) {
+ data_t b = isTransB ? B[j + k * ldb] : B[k + j * ldb];
PRAGMA_OMP_SIMD()
- for (int i = 0; i < unroll_m; i++) {
- float a = isTransA ? A[i * lda + k] : A[i + lda * k];
- c[i + unroll_m * j] += a * b;
+ for (int i = 0; i < unroll_factor<data_t>::m; i++) {
+ data_t a = isTransA ? A[i * lda + k] : A[i + lda * k];
+ c[i + unroll_factor<data_t>::m * j] += a * b;
}
}
}
- for (int j = 0; j < unroll_n; j++) {
+ for (int j = 0; j < unroll_factor<data_t>::n; j++) {
PRAGMA_OMP_SIMD()
- for (int i = 0; i < unroll_m; i++) {
- C[i + j * ldc] = (beta == 0.0f)
- ? alpha * c[i + unroll_m * j]
- : alpha * c[i + unroll_m * j] + beta * C[i + j * ldc];
+ for (int i = 0; i < unroll_factor<data_t>::m; i++) {
+ C[i + j * ldc] = (beta == static_cast<data_t>(0.))
+ ? alpha * c[i + unroll_factor<data_t>::m * j]
+ : alpha * c[i + unroll_factor<data_t>::m * j]
+ + beta * C[i + j * ldc];
}
}
}
-template <bool isTransA, bool isTransB>
+template <typename data_t, bool isTransA, bool isTransB>
static void block_ker(const int M, const int N, const int K,
- const float *A, const int lda, const float *B, const int ldb, float *C,
- const int ldc, const float alpha, const float beta, float *ws,
- bool do_copy) {
- int Nu = rnd_dn(N, unroll_n), Mu = rnd_dn(M, unroll_m);
- for (int i = 0; i < Mu; i += unroll_m) {
- for (int j = 0; j < Nu; j += unroll_n) {
- const float *b = isTransB ? &B[j] : &B[j * ldb];
- const float *a = isTransA ? &A[i * lda] : &A[i];
+ const data_t *A, const int lda, const data_t *B, const int ldb,
+ data_t *C, const int ldc, const data_t alpha, const data_t beta,
+ data_t *ws, bool do_copy) {
+ int Nu = rnd_dn(N, unroll_factor<data_t>::n);
+ int Mu = rnd_dn(M, unroll_factor<data_t>::m);
+ for (int i = 0; i < Mu; i += unroll_factor<data_t>::m) {
+ for (int j = 0; j < Nu; j += unroll_factor<data_t>::n) {
+ const data_t *b = isTransB ? &B[j] : &B[j * ldb];
+ const data_t *a = isTransA ? &A[i * lda] : &A[i];
if (do_copy) {
if (j == 0) {
- copy_A(isTransA, K, a, lda, ws);
+ copy_A<data_t>(isTransA, K, a, lda, ws);
}
- kernel_mxn<false, isTransB>(
- K, ws, unroll_m, b, ldb, &C[i + j * ldc], ldc, alpha, beta);
+ kernel_mxn<data_t, false, isTransB>(
+ K, ws, unroll_factor<data_t>::m, b, ldb,
+ &C[i + j * ldc], ldc, alpha, beta);
} else {
- kernel_mxn<isTransA, isTransB>(
+ kernel_mxn<data_t, isTransA, isTransB>(
K, a, lda, b, ldb, &C[i + j * ldc], ldc, alpha, beta);
}
}
@@ -91,10 +96,12 @@ static void block_ker(const int M, const int N, const int K,
// tail processing
for (int i = 0; i < M; i++) {
for (int j = Nu; j < N; j++) {
- float c = beta == 0.0f ? 0.0f : beta * C[i + j * ldc];
+ data_t c = beta == static_cast<data_t>(0.)
+ ? static_cast<data_t>(0.)
+ : beta * C[i + j * ldc];
for (int p = 0; p < K; p++) {
- float b = isTransB ? B[j + p * ldb] : B[p + j * ldb];
- float a = isTransA ? A[p + i * lda] : A[i + p * lda];
+ data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb];
+ data_t a = isTransA ? A[p + i * lda] : A[i + p * lda];
c += alpha * a * b;
}
C[i + j * ldc] = c;
@@ -102,10 +109,12 @@ static void block_ker(const int M, const int N, const int K,
}
for (int i = Mu; i < M; i++) {
for (int j = 0; j < Nu; j++) {
- float c = beta == 0.0f ? 0.0f : beta * C[i + j * ldc];
+ data_t c = beta == static_cast<data_t>(0.)
+ ? static_cast<data_t>(0.)
+ : beta * C[i + j * ldc];
for (int p = 0; p < K; p++) {
- float b = isTransB ? B[j + p * ldb] : B[p + j * ldb];
- float a = isTransA ? A[p + i * lda] : A[i + p * lda];
+ data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb];
+ data_t a = isTransA ? A[p + i * lda] : A[i + p * lda];
c += alpha * a * b;
}
C[i + j * ldc] = c;
@@ -113,25 +122,28 @@ static void block_ker(const int M, const int N, const int K,
}
}
-template <bool isTransA, bool isTransB>
-void gemm_ithr(const int M, const int N, const int K, const float alpha,
- const float *A, const int lda, const float *B, const int ldb,
- const float beta, float *C, const int ldc, bool do_copy, float *ws) {
- int BM = 4032;
- int BN = isTransA ? 96 : 48;
- int BK = isTransB ? 96 : 256;
- const float *curA, *curB;
- float *curC;
+template <typename data_t, bool isTransA, bool isTransB>
+void gemm_ithr(const int M, const int N, const int K, const data_t alpha,
+ const data_t *A, const int lda, const data_t *B, const int ldb,
+ const data_t beta, data_t *C, const int ldc, bool do_copy, data_t *ws) {
+ constexpr int BM = gemm_traits<data_t, isTransA, isTransB>::BM;
+ constexpr int BN = gemm_traits<data_t, isTransA, isTransB>::BN;
+ constexpr int BK = gemm_traits<data_t, isTransA, isTransB>::BK;
+
+ const data_t *curA;
+ const data_t *curB;
+ data_t *curC;
if ((M <= 0) || (N <= 0))
return;
- if ((K <= 0) || (alpha == 0.0f)) {
- if (beta == 0.0f) {
- for (int j = 0; j < N * M; j++)
- C[j] = 0.0f;
- } else if (beta != 1.0f) {
- for (int j = 0; j < N * M; j++)
+ if ((K <= 0) || (alpha == static_cast<data_t>(0))) {
+ ptrdiff_t MN = (ptrdiff_t)N * M;
+ if (beta == static_cast<data_t>(0.)) {
+ for (ptrdiff_t j = 0; j < MN; j++)
+ C[j] = static_cast<data_t>(0.);
+ } else if (beta != static_cast<data_t>(1.)) {
+ for (ptrdiff_t j = 0; j < MN; j++)
C[j] *= beta;
}
return;
@@ -147,25 +159,27 @@ void gemm_ithr(const int M, const int N, const int K, const float alpha,
curB = isTransB ? B + Bn + Bk * ldb : B + Bk + Bn * ldb;
curC = C + Bm + Bn * ldc;
if (Bk == 0) {
- block_ker<isTransA, isTransB>(mb, nb, kb, curA, lda, curB,
- ldb, curC, ldc, alpha, beta, ws, do_copy);
+ block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda,
+ curB, ldb, curC, ldc, alpha, beta, ws, do_copy);
} else {
- block_ker<isTransA, isTransB>(mb, nb, kb, curA, lda, curB,
- ldb, curC, ldc, alpha, 1.0f, ws, do_copy);
+ block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda,
+ curB, ldb, curC, ldc, alpha, static_cast<data_t>(1.0),
+ ws, do_copy);
}
}
}
}
}
+template <typename data_t>
void ref_gemm(const char *transa_, const char *transb_, const int *M_,
- const int *N_, const int *K_, const float *alpha_, const float *A,
- const int *lda_, const float *B, const int *ldb_, const float *beta_,
- float *C, const int *ldc_, const float *bias) {
+ const int *N_, const int *K_, const data_t *alpha_, const data_t *A,
+ const int *lda_, const data_t *B, const int *ldb_, const data_t *beta_,
+ data_t *C, const int *ldc_, const data_t *bias) {
bool isTransA = (*transa_ == 'T' || *transa_ == 't');
bool isTransB = (*transb_ == 'T' || *transb_ == 't');
const int M = *M_, N = *N_, K = *K_, lda = *lda_, ldb = *ldb_, ldc = *ldc_;
- const float alpha = *alpha_, beta = *beta_;
+ const data_t alpha = *alpha_, beta = *beta_;
int max_nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads();
int nthr_m, nthr_n, nthr_k;
@@ -173,26 +187,27 @@ void ref_gemm(const char *transa_, const char *transb_, const int *M_,
// thread balancing over M, N, K & size of blocking dimensions
gemm_utils::calc_nthr_nocopy_avx(
M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
- assert(utils::implication(!mkldnn_thr_syncable(), nthr_k == 1));
+ assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
- float *c_buffers = nullptr, *ws_buffers = nullptr;
+ data_t *c_buffers = nullptr;
+ data_t *ws_buffers = nullptr;
if (nthr_k > 1) {
- c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
- * sizeof(float), PAGE_4K);
+ c_buffers = (data_t *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
+ * sizeof(data_t), PAGE_4K);
if (!c_buffers) {
nthr_k = 1;
KB = K;
}
}
- bool do_copy = (NB / unroll_n > 3);
+ bool do_copy = (NB / unroll_factor<data_t>::n > 3);
const int nthr_mn = nthr_m * nthr_n;
const int nthr = nthr_mn * nthr_k;
- const size_t ws_elems_per_thr = K * unroll_m;
+ const size_t ws_elems_per_thr = K * unroll_factor<data_t>::m;
const size_t ws_size_per_thr
- = utils::rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
+ = utils::rnd_up(ws_elems_per_thr * sizeof(data_t), PAGE_4K);
if (do_copy) {
- ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
+ ws_buffers = (data_t*)malloc(nthr * ws_size_per_thr, PAGE_4K);
if (!ws_buffers)
do_copy = false;
}
@@ -205,8 +220,8 @@ void ref_gemm(const char *transa_, const char *transb_, const int *M_,
int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
- float *ws = do_copy
- ? ws_buffers + ithr * ws_size_per_thr / sizeof(float)
+ data_t *ws = do_copy
+ ? ws_buffers + ithr * ws_size_per_thr / sizeof(data_t)
: nullptr;
int m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0,
@@ -224,7 +239,7 @@ void ref_gemm(const char *transa_, const char *transb_, const int *M_,
get_thr_block(k_from, k_to, myK, KB, K, ithr_k);
if (myM > 0 && myN > 0) {
- float myBeta, *myC;
+ data_t myBeta, *myC;
int ld;
if (ithr_k == 0) {
myC = &(C[m_from + n_from * ldc]);
@@ -235,28 +250,28 @@ void ref_gemm(const char *transa_, const char *transb_, const int *M_,
myBeta = 0.0f;
ld = MB;
}
- const float *myA = isTransA
+ const data_t *myA = isTransA
? &(A[k_from + m_from * lda])
: &(A[m_from + k_from * lda]);
- const float *myB = isTransB
+ const data_t *myB = isTransB
? &(B[n_from + k_from * ldb])
: &(B[k_from + n_from * ldb]);
if (!isTransA) {
if (!isTransB) {
- gemm_ithr<false, false>(myM, myN, myK, alpha, myA, lda, myB,
- ldb, myBeta, myC, ld, do_copy, ws);
+ gemm_ithr<data_t, false, false>(myM, myN, myK, alpha, myA,
+ lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
} else {
- gemm_ithr<false, true>(myM, myN, myK, alpha, myA, lda, myB,
- ldb, myBeta, myC, ld, do_copy, ws);
+ gemm_ithr<data_t, false, true>(myM, myN, myK, alpha, myA,
+ lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
}
} else {
if (!isTransB) {
- gemm_ithr<true, false>(myM, myN, myK, alpha, myA, lda, myB,
- ldb, myBeta, myC, ld, do_copy, ws);
+ gemm_ithr<data_t, true, false>(myM, myN, myK, alpha, myA,
+ lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
} else {
- gemm_ithr<true, true>(myM, myN, myK, alpha, myA, lda, myB,
- ldb, myBeta, myC, ld, do_copy, ws);
+ gemm_ithr<data_t, true, true>(myM, myN, myK, alpha, myA,
+ lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
}
}
}
@@ -270,7 +285,8 @@ void ref_gemm(const char *transa_, const char *transb_, const int *M_,
gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &offset,
&block);
for (int ik = 1; ik < nthr_k; ++ik) {
- float *myC = c_buffers + MB * (NB * (cbase + ik - 1) + offset);
+ data_t *myC = c_buffers + MB * (NB * (cbase + ik - 1) + offset);
+
gemm_utils::sum_two_matrices(myM, block, myC, MB,
&C[m_from + (n_from + offset) * ldc], ldc);
}
@@ -286,6 +302,16 @@ void ref_gemm(const char *transa_, const char *transb_, const int *M_,
free(ws_buffers);
free(c_buffers);
}
+
+template void ref_gemm<float>(const char *transa_, const char *transb_,
+ const int *M_, const int *N_, const int *K_, const float *alpha_,
+ const float *A, const int *lda_, const float *B, const int *ldb_,
+ const float *beta_, float *C, const int *ldc_, const float *bias);
+
+template void ref_gemm<double>(const char *transa_, const char *transb_,
+ const int *M_, const int *N_, const int *K_, const double *alpha_,
+ const double *A, const int *lda_, const double *B, const int *ldb_,
+ const double *beta_, double *C, const int *ldc_, const double *bias);
}
}
}