summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp')
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp195
1 files changed, 97 insertions, 98 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp
index 08ba7b47d..8aee85fbd 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp
@@ -22,7 +22,7 @@
#include "gemm_utils.hpp"
#include "jit_avx512_common_gemm_f32.hpp"
-#define CACHE_LINE_SIZE 16
+#define CACHE_LINE_SIZE 64
namespace mkldnn {
namespace impl {
@@ -128,15 +128,16 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
// Function for packing if needed
auto do_pack = [&](int unroll_m) {
- inLocalLabel();
+ Label pack2, pack3, pack4, pack10;
+
mov(BO1, A);
lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
mov(LL, K);
sar(LL, 2);
- jle(".pack3", T_NEAR);
+ jle(pack3, T_NEAR);
align(16);
- L(".pack2");
+ L(pack2);
if (!isTransA) {
for (int i = 0; i < 4; i++) {
vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]);
@@ -216,16 +217,16 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
add(AO1, unroll_m * 4 * SIZE);
sub(LL, 1);
- jg(".pack2", T_NEAR);
+ jg(pack2, T_NEAR);
align(16);
- L(".pack3");
+ L(pack3);
mov(LL, K);
and_(LL, 3);
- jle(".pack10", T_NEAR);
+ jle(pack10, T_NEAR);
align(16);
- L(".pack4");
+ L(pack4);
if (!isTransA) {
vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]);
if (unroll_m > 16)
@@ -279,11 +280,10 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
add(AO1, unroll_m * SIZE);
sub(LL, 1);
- jg(".pack4", T_NEAR);
+ jg(pack4, T_NEAR);
align(16);
- L(".pack10");
- outLocalLabel();
+ L(pack10);
};
// Function to update C, covering masking and other considerations
@@ -617,8 +617,6 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
// Innerkernel; called by kernel
auto innerkernel = [&](int unroll_m, int unroll_n, bool isDirect,
bool isCopy, bool doCPrefetch, bool isUnmasked = true) {
- inLocalLabel();
-
for (int i = 0; i < 8; i++) {
if (!isDirect) {
prefetcht0(ptr[AO1
@@ -960,7 +958,6 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
}
sub(LL, 1);
- outLocalLabel();
};
// Main kernel; does prefetching and calls innerkernel
@@ -968,7 +965,6 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
// calling update
auto kernel = [&](int unroll_m, int unroll_n, bool isDirect,
bool isCopy, bool isUnmasked = true) {
- inLocalLabel();
if (!isDirect) {
lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
} else {
@@ -1020,36 +1016,38 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
}
}
+ Label kernel12, kernel13, kernel14, kernel15, kernel16, kernel18;
+
mov(LL, K);
sar(LL, 3);
sub(LL, SECOND_FETCH);
- jle(".kernel13", T_NEAR);
+ jle(kernel13, T_NEAR);
align(16);
- L(".kernel12");
+ L(kernel12);
innerkernel(
unroll_m, unroll_n, isDirect, isCopy, false, isUnmasked);
- jg(".kernel12", T_NEAR);
+ jg(kernel12, T_NEAR);
align(16);
- L(".kernel13");
+ L(kernel13);
lea(CO2, ptr[CO1 + (16 - 1) * SIZE]);
add(LL, unroll_n);
- jle(".kernel15", T_NEAR);
+ jle(kernel15, T_NEAR);
align(16);
- L(".kernel14");
+ L(kernel14);
innerkernel(unroll_m, unroll_n, isDirect, isCopy, true, isUnmasked);
- jg(".kernel14", T_NEAR);
+ jg(kernel14, T_NEAR);
align(16);
- L(".kernel15");
+ L(kernel15);
mov(LL, K);
and_(LL, 7);
- jle(".kernel18", T_NEAR);
+ jle(kernel18, T_NEAR);
align(16);
- L(".kernel16");
+ L(kernel16);
if (isDirect) {
if (isUnmasked || unroll_m > 16) {
vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
@@ -1204,10 +1202,10 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
}
sub(LL, 1);
- jg(".kernel16", T_NEAR);
+ jg(kernel16, T_NEAR);
align(16);
- L(".kernel18");
+ L(kernel18);
vbroadcastss(VALPHA, ALPHA);
if (isBetaN) {
@@ -1329,8 +1327,6 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
sub(BO1, rax);
add(BO1, unroll_n * SIZE);
}
-
- outLocalLabel();
};
// High-level subroutine; does packing if needed, then splits C matrix.
@@ -1338,11 +1334,16 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
// cases appropriately by doing 32 or 16 rows, and/or with masking,
// and/or fewer columns).
auto subloop = [&](int unroll_m) {
- inLocalLabel();
-
Label l_subloop_20x[8], l_subloop_mask_20x[8];
Label l_subloop_30x[8], l_subloop_mask_30x[8];
+ Label subloop11, subloop11mask;
+ Label subloop30, subloop30mask;
+ Label subloop31, subloop31mask;
+ Label subloop96;
+ Label subloop98, subloop98mask;
+ Label subloop99;
+
// Create mask
mov(BO1, rcx);
mov(rcx, M);
@@ -1370,7 +1371,7 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
and_(rax, 0xffff);
cmp(rax, 0xffff);
- jne(".subloop96", T_NEAR);
+ jne(subloop96, T_NEAR);
if (isTransA) {
do_pack(unroll_m);
@@ -1387,11 +1388,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
if (!isTransA) {
lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]);
cmp(M, UNROLL_M);
- jg(".subloop98", T_NEAR);
+ jg(subloop98, T_NEAR);
mov(AA, ORIG_A);
lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]);
- L(".subloop98");
+ L(subloop98);
}
mov(LL, N);
@@ -1399,11 +1400,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
if (!isTransA) {
// If N is too small, skip copy operation
cmp(LL, UNROLL_N * 3);
- jle(".subloop30", T_NEAR);
+ jle(subloop30, T_NEAR);
// If A is not aligned to cache line
cmp(FLAG, 0);
- je(".subloop30", T_NEAR);
+ je(subloop30, T_NEAR);
} else {
cmp(LL, UNROLL_N);
jl(l_subloop_20x[1], T_NEAR);
@@ -1421,11 +1422,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
jl(l_subloop_20x[1], T_NEAR);
align(16);
- L(".subloop11");
+ L(subloop11);
kernel(unroll_m, UNROLL_N, false, false);
sub(I, UNROLL_N);
cmp(I, UNROLL_N);
- jge(".subloop11", T_NEAR);
+ jge(subloop11, T_NEAR);
align(16);
for (int i = 1; i <= 7; i++) {
@@ -1434,24 +1435,24 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
if (i < 7) {
jne(l_subloop_20x[i + 1], T_NEAR);
} else {
- jne(".subloop99", T_NEAR);
+ jne(subloop99, T_NEAR);
}
kernel(unroll_m, i, false, false);
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
}
if (!isTransA) {
- L(".subloop30");
+ L(subloop30);
cmp(I, UNROLL_N);
jl(l_subloop_30x[1], T_NEAR);
align(16);
- L(".subloop31");
+ L(subloop31);
kernel(unroll_m, UNROLL_N, true, false);
sub(I, UNROLL_N);
cmp(I, UNROLL_N);
- jge(".subloop31", T_NEAR);
+ jge(subloop31, T_NEAR);
align(16);
for (int i = 1; i <= 7; i++) {
@@ -1460,18 +1461,18 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
if (i < 7) {
jne(l_subloop_30x[i + 1], T_NEAR);
} else {
- jne(".subloop99", T_NEAR);
+ jne(subloop99, T_NEAR);
}
kernel(unroll_m, i, true, false);
if (i < 7)
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
}
}
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
- L(".subloop96");
+ L(subloop96);
if (isTransA) {
do_pack(unroll_m);
}
@@ -1486,10 +1487,10 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
if (!isTransA) {
lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]);
cmp(M, UNROLL_M);
- jg(".subloop98mask", T_NEAR);
+ jg(subloop98mask, T_NEAR);
mov(AA, ORIG_A);
lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]);
- L(".subloop98mask");
+ L(subloop98mask);
}
mov(LL, N);
@@ -1497,11 +1498,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
if (!isTransA) {
// If N is too small, skip copy operation
cmp(LL, UNROLL_N * 3);
- jle(".subloop30mask", T_NEAR);
+ jle(subloop30mask, T_NEAR);
// If A is not aligned to cache line
cmp(FLAG, 0);
- je(".subloop30mask", T_NEAR);
+ je(subloop30mask, T_NEAR);
} else {
cmp(LL, UNROLL_N);
jl(l_subloop_mask_20x[1], T_NEAR);
@@ -1519,11 +1520,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
jl(l_subloop_mask_20x[1], T_NEAR);
align(16);
- L(".subloop11mask");
+ L(subloop11mask);
kernel(unroll_m, UNROLL_N, false, false, false);
sub(I, UNROLL_N);
cmp(I, UNROLL_N);
- jge(".subloop11mask", T_NEAR);
+ jge(subloop11mask, T_NEAR);
align(16);
for (int i = 1; i <= 7; i++) {
@@ -1532,24 +1533,24 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
if (i < 7) {
jne(l_subloop_mask_20x[i + 1], T_NEAR);
} else {
- jne(".subloop99", T_NEAR);
+ jne(subloop99, T_NEAR);
}
kernel(unroll_m, i, false, false, false);
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
}
if (!isTransA) {
- L(".subloop30mask");
+ L(subloop30mask);
cmp(I, UNROLL_N);
jl(l_subloop_mask_30x[1], T_NEAR);
align(16);
- L(".subloop31mask");
+ L(subloop31mask);
kernel(unroll_m, UNROLL_N, true, false, false);
sub(I, UNROLL_N);
cmp(I, UNROLL_N);
- jge(".subloop31mask", T_NEAR);
+ jge(subloop31mask, T_NEAR);
align(16);
for (int i = 1; i <= 7; i++) {
@@ -1558,16 +1559,16 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
if (i < 7) {
jne(l_subloop_mask_30x[i + 1], T_NEAR);
} else {
- jne(".subloop99", T_NEAR);
+ jne(subloop99, T_NEAR);
}
kernel(unroll_m, i, true, false, false);
if (i < 7)
- jmp(".subloop99", T_NEAR);
+ jmp(subloop99, T_NEAR);
align(16);
}
}
- L(".subloop99");
+ L(subloop99);
// Compute address for A
if (!isTransA) {
add(A, unroll_m * SIZE);
@@ -1581,14 +1582,12 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
if (hasBias) {
add(BIAS, unroll_m * SIZE);
}
-
- outLocalLabel();
};
- inLocalLabel();
-
preamble();
+ Label buffer_in_ws, buffer_allocated;
+
// Get the registers
mov(B, ARG_B);
mov(LDB, ARG_LDB);
@@ -1608,7 +1607,7 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
#endif
cmp(K, STACK_K_CAPACITY);
- jg(".buffer_in_ws", T_NEAR);
+ jg(buffer_in_ws, T_NEAR);
// Create buffer and align to 4kB page
lea(rax, ptr[K * SIZE]);
@@ -1616,12 +1615,12 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
add(rax, 256);
sub(rsp, rax);
and_(rsp, -PAGE_4K);
- jmp(".buffer_allocated", T_NEAR);
+ jmp(buffer_allocated, T_NEAR);
- L(".buffer_in_ws");
+ L(buffer_in_ws);
mov(rsp, ARG_WS);
- L(".buffer_allocated");
+ L(buffer_allocated);
mov(ORIG_SP, rbp);
mov(M, ARG_M);
@@ -1665,40 +1664,40 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
}
}
+ Label main0, main1, main2, main999;
+
cmp(M, 32);
- jle(".main0", T_NEAR);
+ jle(main0, T_NEAR);
align(16);
- L(".main1");
+ L(main1);
subloop(48);
sub(M, UNROLL_M);
cmp(M, 32);
- jg(".main1", T_NEAR);
+ jg(main1, T_NEAR);
align(16);
- L(".main0");
+ L(main0);
cmp(M, 16);
- jle(".main2", T_NEAR);
+ jle(main2, T_NEAR);
subloop(32);
- jmp(".main999", T_NEAR);
+ jmp(main999, T_NEAR);
align(16);
- L(".main2");
+ L(main2);
cmp(M, 0);
- jle(".main999", T_NEAR);
+ jle(main999, T_NEAR);
subloop(16);
align(16);
- L(".main999");
+ L(main999);
// Restore original stack
mov(rsp, ORIG_SP);
vzeroupper();
postamble();
- outLocalLabel();
-
ker_ = reinterpret_cast<decltype(ker_)>(
const_cast<uint8_t *>(this->getCode()));
}
@@ -1763,7 +1762,7 @@ void jit_avx512_common_gemm_f32::sgemm_nocopy_driver(const char *transa,
if (!isTransA && !isTransB)
BK = 128;
}
- const float *curA, *curB, *curBias = NULL;
+ const float *curA, *curB, *curBias = nullptr;
float *curC;
for (Bk = 0; Bk < k; Bk += sizeK) {
@@ -1804,15 +1803,15 @@ void jit_avx512_common_gemm_f32::sgemm_nocopy_driver(const char *transa,
curB = b + Bn + (size_t)Bk * ldb;
}
curC = c + Bm + (size_t)Bn * ldc;
- if (bias != NULL) {
+ if (bias != nullptr) {
if (Bk == 0) {
curBias = bias + Bm;
} else {
- curBias = NULL;
+ curBias = nullptr;
}
}
if (Bk == 0) {
- if (*beta == 0.0 && bias == NULL)
+ if (*beta == 0.0 && bias == nullptr)
(*ker_b0_)((long long int)sizeM, (long long int)sizeN,
(long long int)sizeK, alpha, curA,
(long long int)lda, curB, (long long int)ldb,
@@ -1860,7 +1859,7 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb,
// Determine threading partitioning
gemm_utils::calc_nthr_nocopy_avx512_common(
m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
- assert(utils::implication(!mkldnn_thr_syncable(), nthr_k == 1));
+ assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
// May not happen, but just in case
if (nthr < nthr_m * nthr_n * nthr_k)
@@ -1868,13 +1867,18 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb,
nthr_mn = nthr_m * nthr_n;
- unsigned int volatile *ompstatus = (unsigned int volatile *)ompstatus_;
- if (!ompstatus) return;
+ unsigned char * ompstatus_ = nullptr;
+ unsigned char volatile *ompstatus = nullptr;
- float *c_buffers = NULL;
- float *ws_buffers = NULL;
+ float *c_buffers = nullptr;
+ float *ws_buffers = nullptr;
if (nthr_k > 1) {
+ ompstatus_ = (unsigned char *) malloc(
+ nthr * CACHE_LINE_SIZE,
+ CACHE_LINE_SIZE);
+ ompstatus = (unsigned char volatile *) ompstatus_;
+ assert(ompstatus);
for (int i = 0; i < nthr; i++)
ompstatus[i * CACHE_LINE_SIZE] = 0;
@@ -1895,7 +1899,7 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb,
int n_from, n_to, myN;
int k_from, k_to, myK;
int cbase, ibase;
- const float *myA, *myB, *myBias = NULL;
+ const float *myA, *myB, *myBias = nullptr;
float *myC = C, myBeta;
float *ws = ws_buffers ?
ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
@@ -1957,7 +1961,7 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb,
myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
myBeta = 0.0;
ld = MB;
- myBias = NULL;
+ myBias = nullptr;
}
sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
@@ -2004,8 +2008,8 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb,
}
});
- if (nthr_k > 1)
- free(c_buffers);
+ free(c_buffers);
+ free(ompstatus_);
free(ws_buffers);
}
@@ -2032,10 +2036,6 @@ jit_avx512_common_gemm_f32::jit_avx512_common_gemm_f32(
}
nthrs_ = mkldnn_get_max_threads();
- ompstatus_ = (unsigned int *)malloc(
- sizeof(unsigned int *) * nthrs_ * CACHE_LINE_SIZE, 64);
- assert(ompstatus_);
-
}
jit_avx512_common_gemm_f32::~jit_avx512_common_gemm_f32()
@@ -2045,7 +2045,6 @@ jit_avx512_common_gemm_f32::~jit_avx512_common_gemm_f32()
delete ker_b1_;
if (beta_ != 0.0 || (beta_ == 0.0 && hasBias_))
delete ker_b0_;
- free(ompstatus_);
}
}
}