diff options
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp')
-rw-r--r-- | inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp | 195 |
1 files changed, 97 insertions, 98 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp index 08ba7b47d..8aee85fbd 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp @@ -22,7 +22,7 @@ #include "gemm_utils.hpp" #include "jit_avx512_common_gemm_f32.hpp" -#define CACHE_LINE_SIZE 16 +#define CACHE_LINE_SIZE 64 namespace mkldnn { namespace impl { @@ -128,15 +128,16 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { // Function for packing if needed auto do_pack = [&](int unroll_m) { - inLocalLabel(); + Label pack2, pack3, pack4, pack10; + mov(BO1, A); lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]); mov(LL, K); sar(LL, 2); - jle(".pack3", T_NEAR); + jle(pack3, T_NEAR); align(16); - L(".pack2"); + L(pack2); if (!isTransA) { for (int i = 0; i < 4; i++) { vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]); @@ -216,16 +217,16 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { add(AO1, unroll_m * 4 * SIZE); sub(LL, 1); - jg(".pack2", T_NEAR); + jg(pack2, T_NEAR); align(16); - L(".pack3"); + L(pack3); mov(LL, K); and_(LL, 3); - jle(".pack10", T_NEAR); + jle(pack10, T_NEAR); align(16); - L(".pack4"); + L(pack4); if (!isTransA) { vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]); if (unroll_m > 16) @@ -279,11 +280,10 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { add(AO1, unroll_m * SIZE); sub(LL, 1); - jg(".pack4", T_NEAR); + jg(pack4, T_NEAR); align(16); - L(".pack10"); - outLocalLabel(); + L(pack10); }; // Function to update C, covering masking and other considerations @@ -617,8 +617,6 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { // Innerkernel; called by kernel auto innerkernel = [&](int unroll_m, int unroll_n, bool isDirect, bool isCopy, bool doCPrefetch, bool isUnmasked = true) { - inLocalLabel(); - for (int i = 0; i < 8; i++) { if (!isDirect) { prefetcht0(ptr[AO1 @@ -960,7 +958,6 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { } sub(LL, 1); - outLocalLabel(); }; // Main kernel; does prefetching and calls innerkernel @@ -968,7 +965,6 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { // calling update auto kernel = [&](int unroll_m, int unroll_n, bool isDirect, bool isCopy, bool isUnmasked = true) { - inLocalLabel(); if (!isDirect) { lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]); } else { @@ -1020,36 +1016,38 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { } } + Label kernel12, kernel13, kernel14, kernel15, kernel16, kernel18; + mov(LL, K); sar(LL, 3); sub(LL, SECOND_FETCH); - jle(".kernel13", T_NEAR); + jle(kernel13, T_NEAR); align(16); - L(".kernel12"); + L(kernel12); innerkernel( unroll_m, unroll_n, isDirect, isCopy, false, isUnmasked); - jg(".kernel12", T_NEAR); + jg(kernel12, T_NEAR); align(16); - L(".kernel13"); + L(kernel13); lea(CO2, ptr[CO1 + (16 - 1) * SIZE]); add(LL, unroll_n); - jle(".kernel15", T_NEAR); + jle(kernel15, T_NEAR); align(16); - L(".kernel14"); + L(kernel14); innerkernel(unroll_m, unroll_n, isDirect, isCopy, true, isUnmasked); - jg(".kernel14", T_NEAR); + jg(kernel14, T_NEAR); align(16); - L(".kernel15"); + L(kernel15); mov(LL, K); and_(LL, 7); - jle(".kernel18", T_NEAR); + jle(kernel18, T_NEAR); align(16); - L(".kernel16"); + L(kernel16); if (isDirect) { if (isUnmasked || unroll_m > 16) { vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); @@ -1204,10 +1202,10 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { } sub(LL, 1); - jg(".kernel16", T_NEAR); + jg(kernel16, T_NEAR); align(16); - L(".kernel18"); + L(kernel18); vbroadcastss(VALPHA, ALPHA); if (isBetaN) { @@ -1329,8 +1327,6 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { sub(BO1, rax); add(BO1, unroll_n * SIZE); } - - outLocalLabel(); }; // High-level subroutine; does packing if needed, then splits C matrix. @@ -1338,11 +1334,16 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { // cases appropriately by doing 32 or 16 rows, and/or with masking, // and/or fewer columns). auto subloop = [&](int unroll_m) { - inLocalLabel(); - Label l_subloop_20x[8], l_subloop_mask_20x[8]; Label l_subloop_30x[8], l_subloop_mask_30x[8]; + Label subloop11, subloop11mask; + Label subloop30, subloop30mask; + Label subloop31, subloop31mask; + Label subloop96; + Label subloop98, subloop98mask; + Label subloop99; + // Create mask mov(BO1, rcx); mov(rcx, M); @@ -1370,7 +1371,7 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { and_(rax, 0xffff); cmp(rax, 0xffff); - jne(".subloop96", T_NEAR); + jne(subloop96, T_NEAR); if (isTransA) { do_pack(unroll_m); @@ -1387,11 +1388,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { if (!isTransA) { lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]); cmp(M, UNROLL_M); - jg(".subloop98", T_NEAR); + jg(subloop98, T_NEAR); mov(AA, ORIG_A); lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]); - L(".subloop98"); + L(subloop98); } mov(LL, N); @@ -1399,11 +1400,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { if (!isTransA) { // If N is too small, skip copy operation cmp(LL, UNROLL_N * 3); - jle(".subloop30", T_NEAR); + jle(subloop30, T_NEAR); // If A is not aligned to cache line cmp(FLAG, 0); - je(".subloop30", T_NEAR); + je(subloop30, T_NEAR); } else { cmp(LL, UNROLL_N); jl(l_subloop_20x[1], T_NEAR); @@ -1421,11 +1422,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { jl(l_subloop_20x[1], T_NEAR); align(16); - L(".subloop11"); + L(subloop11); kernel(unroll_m, UNROLL_N, false, false); sub(I, UNROLL_N); cmp(I, UNROLL_N); - jge(".subloop11", T_NEAR); + jge(subloop11, T_NEAR); align(16); for (int i = 1; i <= 7; i++) { @@ -1434,24 +1435,24 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { if (i < 7) { jne(l_subloop_20x[i + 1], T_NEAR); } else { - jne(".subloop99", T_NEAR); + jne(subloop99, T_NEAR); } kernel(unroll_m, i, false, false); - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); } if (!isTransA) { - L(".subloop30"); + L(subloop30); cmp(I, UNROLL_N); jl(l_subloop_30x[1], T_NEAR); align(16); - L(".subloop31"); + L(subloop31); kernel(unroll_m, UNROLL_N, true, false); sub(I, UNROLL_N); cmp(I, UNROLL_N); - jge(".subloop31", T_NEAR); + jge(subloop31, T_NEAR); align(16); for (int i = 1; i <= 7; i++) { @@ -1460,18 +1461,18 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { if (i < 7) { jne(l_subloop_30x[i + 1], T_NEAR); } else { - jne(".subloop99", T_NEAR); + jne(subloop99, T_NEAR); } kernel(unroll_m, i, true, false); if (i < 7) - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); } } - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); - L(".subloop96"); + L(subloop96); if (isTransA) { do_pack(unroll_m); } @@ -1486,10 +1487,10 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { if (!isTransA) { lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]); cmp(M, UNROLL_M); - jg(".subloop98mask", T_NEAR); + jg(subloop98mask, T_NEAR); mov(AA, ORIG_A); lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]); - L(".subloop98mask"); + L(subloop98mask); } mov(LL, N); @@ -1497,11 +1498,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { if (!isTransA) { // If N is too small, skip copy operation cmp(LL, UNROLL_N * 3); - jle(".subloop30mask", T_NEAR); + jle(subloop30mask, T_NEAR); // If A is not aligned to cache line cmp(FLAG, 0); - je(".subloop30mask", T_NEAR); + je(subloop30mask, T_NEAR); } else { cmp(LL, UNROLL_N); jl(l_subloop_mask_20x[1], T_NEAR); @@ -1519,11 +1520,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { jl(l_subloop_mask_20x[1], T_NEAR); align(16); - L(".subloop11mask"); + L(subloop11mask); kernel(unroll_m, UNROLL_N, false, false, false); sub(I, UNROLL_N); cmp(I, UNROLL_N); - jge(".subloop11mask", T_NEAR); + jge(subloop11mask, T_NEAR); align(16); for (int i = 1; i <= 7; i++) { @@ -1532,24 +1533,24 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { if (i < 7) { jne(l_subloop_mask_20x[i + 1], T_NEAR); } else { - jne(".subloop99", T_NEAR); + jne(subloop99, T_NEAR); } kernel(unroll_m, i, false, false, false); - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); } if (!isTransA) { - L(".subloop30mask"); + L(subloop30mask); cmp(I, UNROLL_N); jl(l_subloop_mask_30x[1], T_NEAR); align(16); - L(".subloop31mask"); + L(subloop31mask); kernel(unroll_m, UNROLL_N, true, false, false); sub(I, UNROLL_N); cmp(I, UNROLL_N); - jge(".subloop31mask", T_NEAR); + jge(subloop31mask, T_NEAR); align(16); for (int i = 1; i <= 7; i++) { @@ -1558,16 +1559,16 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { if (i < 7) { jne(l_subloop_mask_30x[i + 1], T_NEAR); } else { - jne(".subloop99", T_NEAR); + jne(subloop99, T_NEAR); } kernel(unroll_m, i, true, false, false); if (i < 7) - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); } } - L(".subloop99"); + L(subloop99); // Compute address for A if (!isTransA) { add(A, unroll_m * SIZE); @@ -1581,14 +1582,12 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { if (hasBias) { add(BIAS, unroll_m * SIZE); } - - outLocalLabel(); }; - inLocalLabel(); - preamble(); + Label buffer_in_ws, buffer_allocated; + // Get the registers mov(B, ARG_B); mov(LDB, ARG_LDB); @@ -1608,7 +1607,7 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { #endif cmp(K, STACK_K_CAPACITY); - jg(".buffer_in_ws", T_NEAR); + jg(buffer_in_ws, T_NEAR); // Create buffer and align to 4kB page lea(rax, ptr[K * SIZE]); @@ -1616,12 +1615,12 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { add(rax, 256); sub(rsp, rax); and_(rsp, -PAGE_4K); - jmp(".buffer_allocated", T_NEAR); + jmp(buffer_allocated, T_NEAR); - L(".buffer_in_ws"); + L(buffer_in_ws); mov(rsp, ARG_WS); - L(".buffer_allocated"); + L(buffer_allocated); mov(ORIG_SP, rbp); mov(M, ARG_M); @@ -1665,40 +1664,40 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { } } + Label main0, main1, main2, main999; + cmp(M, 32); - jle(".main0", T_NEAR); + jle(main0, T_NEAR); align(16); - L(".main1"); + L(main1); subloop(48); sub(M, UNROLL_M); cmp(M, 32); - jg(".main1", T_NEAR); + jg(main1, T_NEAR); align(16); - L(".main0"); + L(main0); cmp(M, 16); - jle(".main2", T_NEAR); + jle(main2, T_NEAR); subloop(32); - jmp(".main999", T_NEAR); + jmp(main999, T_NEAR); align(16); - L(".main2"); + L(main2); cmp(M, 0); - jle(".main999", T_NEAR); + jle(main999, T_NEAR); subloop(16); align(16); - L(".main999"); + L(main999); // Restore original stack mov(rsp, ORIG_SP); vzeroupper(); postamble(); - outLocalLabel(); - ker_ = reinterpret_cast<decltype(ker_)>( const_cast<uint8_t *>(this->getCode())); } @@ -1763,7 +1762,7 @@ void jit_avx512_common_gemm_f32::sgemm_nocopy_driver(const char *transa, if (!isTransA && !isTransB) BK = 128; } - const float *curA, *curB, *curBias = NULL; + const float *curA, *curB, *curBias = nullptr; float *curC; for (Bk = 0; Bk < k; Bk += sizeK) { @@ -1804,15 +1803,15 @@ void jit_avx512_common_gemm_f32::sgemm_nocopy_driver(const char *transa, curB = b + Bn + (size_t)Bk * ldb; } curC = c + Bm + (size_t)Bn * ldc; - if (bias != NULL) { + if (bias != nullptr) { if (Bk == 0) { curBias = bias + Bm; } else { - curBias = NULL; + curBias = nullptr; } } if (Bk == 0) { - if (*beta == 0.0 && bias == NULL) + if (*beta == 0.0 && bias == nullptr) (*ker_b0_)((long long int)sizeM, (long long int)sizeN, (long long int)sizeK, alpha, curA, (long long int)lda, curB, (long long int)ldb, @@ -1860,7 +1859,7 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb, // Determine threading partitioning gemm_utils::calc_nthr_nocopy_avx512_common( m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); - assert(utils::implication(!mkldnn_thr_syncable(), nthr_k == 1)); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); // May not happen, but just in case if (nthr < nthr_m * nthr_n * nthr_k) @@ -1868,13 +1867,18 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb, nthr_mn = nthr_m * nthr_n; - unsigned int volatile *ompstatus = (unsigned int volatile *)ompstatus_; - if (!ompstatus) return; + unsigned char * ompstatus_ = nullptr; + unsigned char volatile *ompstatus = nullptr; - float *c_buffers = NULL; - float *ws_buffers = NULL; + float *c_buffers = nullptr; + float *ws_buffers = nullptr; if (nthr_k > 1) { + ompstatus_ = (unsigned char *) malloc( + nthr * CACHE_LINE_SIZE, + CACHE_LINE_SIZE); + ompstatus = (unsigned char volatile *) ompstatus_; + assert(ompstatus); for (int i = 0; i < nthr; i++) ompstatus[i * CACHE_LINE_SIZE] = 0; @@ -1895,7 +1899,7 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb, int n_from, n_to, myN; int k_from, k_to, myK; int cbase, ibase; - const float *myA, *myB, *myBias = NULL; + const float *myA, *myB, *myBias = nullptr; float *myC = C, myBeta; float *ws = ws_buffers ? ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0; @@ -1957,7 +1961,7 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb, myC = c_buffers + MB * NB * (cbase + ithr_k - 1); myBeta = 0.0; ld = MB; - myBias = NULL; + myBias = nullptr; } sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA, @@ -2004,8 +2008,8 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb, } }); - if (nthr_k > 1) - free(c_buffers); + free(c_buffers); + free(ompstatus_); free(ws_buffers); } @@ -2032,10 +2036,6 @@ jit_avx512_common_gemm_f32::jit_avx512_common_gemm_f32( } nthrs_ = mkldnn_get_max_threads(); - ompstatus_ = (unsigned int *)malloc( - sizeof(unsigned int *) * nthrs_ * CACHE_LINE_SIZE, 64); - assert(ompstatus_); - } jit_avx512_common_gemm_f32::~jit_avx512_common_gemm_f32() @@ -2045,7 +2045,6 @@ jit_avx512_common_gemm_f32::~jit_avx512_common_gemm_f32() delete ker_b1_; if (beta_ != 0.0 || (beta_ == 0.0 && hasBias_)) delete ker_b0_; - free(ompstatus_); } } } |