diff options
author | Alexey Suhov <asuhov@users.noreply.github.com> | 2019-01-21 21:31:31 +0300 |
---|---|---|
committer | openvino-pushbot <44090433+openvino-pushbot@users.noreply.github.com> | 2019-01-21 21:31:31 +0300 |
commit | 9de27f16bc8b712a5b8c99d1d4b4a66c9144942d (patch) | |
tree | 01a383efe94d92b9870d513c2c5ea5d15b07010a /inference-engine/thirdparty/mkl-dnn/src/cpu/gemm | |
parent | fbc7a4a710c24def8ab199926a7da90a0394b87d (diff) | |
download | dldt-9de27f16bc8b712a5b8c99d1d4b4a66c9144942d.tar.gz dldt-9de27f16bc8b712a5b8c99d1d4b4a66c9144942d.tar.bz2 dldt-9de27f16bc8b712a5b8c99d1d4b4a66c9144942d.zip |
Publishing R5 content (#72)
* Publishing R5 content
* Updated ade revision
* updated readme
* add possibility to build CPU plugin with Intel MKL package
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/gemm')
9 files changed, 516 insertions, 319 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm.cpp index fe8412194..146e68887 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm.cpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm.cpp @@ -25,6 +25,8 @@ #include "../jit_generator.hpp" #include "nstl.hpp" #include "os_blas.hpp" +#include "math_utils.hpp" +#include "mkldnn_traits.hpp" /* USE_MKL USE_CBLAS effect * ------- --------- ------ @@ -52,6 +54,7 @@ mkldnn_status_t check_gemm_input(const char *transa, const char *transb, && *M >= 0 && *N >= 0 && *K >= 0; + if (!consistency) return invalid_arguments; bool isTransA = utils::one_of(*transa, 'T', 't'); bool isTransB = utils::one_of(*transb, 'T', 't'); @@ -66,6 +69,19 @@ mkldnn_status_t check_gemm_input(const char *transa, const char *transb, return success; } +mkldnn_status_t check_gemm_x8x8x32_input(const char *offsetc, + const char *transa, const char *transb, const int *M, const int *N, + const int *K, const int *lda, const int *ldb, const int *ldc, + const float *alpha, const float *beta, const bool with_bias) { + + if (offsetc == nullptr) return invalid_arguments; + if (!utils::one_of(*offsetc, 'F', 'f', 'C', 'c', 'R', 'r')) + return invalid_arguments; + + return check_gemm_input(transa, transb, M, N, K, lda, ldb, ldc, alpha, + beta, with_bias); +} + struct gemm_impl_t { gemm_impl_t(char transa, char transb, bool zero_beta, bool with_bias) { //jit kernel has three codepaths: beta is 0, 1 or arbitrary @@ -132,7 +148,7 @@ mkldnn_status_t extended_sgemm(const char *transa, const char *transb, const int *M, const int *N, const int *K, const float *alpha, const float *A, const int *lda, const float *B, const int *ldb, const float *beta, float *C, const int *ldc, - const float *bias) { + const float *bias, const bool force_jit_gemm) { //Check input mkldnn_status_t status = check_gemm_input(transa, transb, M, N, K, lda, ldb, ldc, alpha, beta, bias != nullptr); @@ -143,20 +159,22 @@ mkldnn_status_t extended_sgemm(const char *transa, const char *transb, int trA = *transa == 't' || *transa == 'T'; int trB = *transb == 't' || *transb == 'T'; #ifdef USE_CBLAS - //Call cblas - CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans; - CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans; - cblas_sgemm(CblasColMajor, Cblas_trA, Cblas_trB, - *M, *N, *K, *alpha, A, *lda, B, *ldb, *beta, C, *ldc); - //Add bias if necessary (bias is applied to columns of C) - if (bias) { - cblas_int incx = 1, incy = 1; - parallel_nd(*N, [&](int n) { - cblas_saxpy(*M, 1.0, bias, incx, C + n*(*ldc), incy); - }); + if (!force_jit_gemm) { + //Call cblas + CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans; + cblas_sgemm(CblasColMajor, Cblas_trA, Cblas_trB, + *M, *N, *K, *alpha, A, *lda, B, *ldb, *beta, C, *ldc); + //Add bias if necessary (bias is applied to columns of C) + if (bias) { + cblas_int incx = 1, incy = 1; + parallel_nd(*N, [&](int n) { + cblas_saxpy(*M, 1.0, bias, incx, C + n*(*ldc), incy); + }); + } + return mkldnn_success; } - return mkldnn_success; -#else +#endif //Generate jit kernel and call sgemm with bias volatile static int initialized = 0; if (!initialized) { @@ -176,9 +194,98 @@ mkldnn_status_t extended_sgemm(const char *transa, const char *transb, transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); return mkldnn_success; -#endif } +template <typename b_dt> +mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const b_dt *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co) { + + mkldnn_status_t status = check_gemm_x8x8x32_input(offsetc, transa, transb, + M, N, K, LDA, LDB, LDC, alpha, beta, false); + + if (status != mkldnn_success) + return status; + + if (*M == 0 || *N == 0 || *K == 0) + return mkldnn_success; + + bool OCisR = (*offsetc == 'R' || *offsetc == 'r'); + bool OCisC = (*offsetc == 'C' || *offsetc == 'c'); + bool AisN = (*transa == 'N' || *transa == 'n'); + bool BisN = (*transb == 'N' || *transb == 'n'); + +#if defined(USE_MKL) && defined(USE_CBLAS) + if (data_traits<b_dt>::data_type == data_type::u8) { + CBLAS_TRANSPOSE Cblas_trA = AisN ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE Cblas_trB = BisN ? CblasNoTrans : CblasTrans; + CBLAS_OFFSET Cblas_offsetc = + OCisR + ? CblasRowOffset + : OCisC + ? CblasColOffset + : CblasFixOffset; + cblas_gemm_s8u8s32(CblasColMajor, Cblas_trA, Cblas_trB, Cblas_offsetc, + *M, *N, *K, *alpha, A, *LDA, *ao, (b_dt*)B, *LDB, *bo, *beta, C, *LDC, co); + return mkldnn_success; + } +#endif + int m = *M, n = *N, k = *K, lda = *LDA, ldb = *LDB, ldc = *LDC; + size_t sizeA = AisN ? lda * k : lda * m; + size_t sizeB = BisN ? ldb * n : ldb * k; + size_t sizeC = ldc * n; + + double *dA = (double *)malloc(sizeA * sizeof(double), PAGE_4K); + double *dB = (double *)malloc(sizeB * sizeof(double), PAGE_4K); + double *dC = (double *)malloc(sizeC * sizeof(double), PAGE_4K); + + if (utils::any_null(dA, dB, dC)) { + free(dA); + free(dB); + free(dC); + return mkldnn_out_of_memory; + } + + auto da_setter = [=] (int i, int j, double v) { dA[j * lda + i] = v; }; + auto db_setter = [=] (int i, int j, double v) { dB[j * ldb + i] = v; }; + + auto ia_accessor = [=] (int i, int j) { return A[j * lda + i]; }; + auto ib_accessor = [=] (int i, int j) { return B[j * ldb + i]; }; + + const int a_rows = AisN ? m : k; + const int a_cols = AisN ? k : m; + mkldnn::impl::parallel_nd(a_cols, a_rows, [&](int j, int i) { + da_setter(i, j, + static_cast<double>(ia_accessor(i, j)) + static_cast<double>(ao[0])); + }); + + const int b_rows = BisN ? k : n; + const int b_cols = BisN ? n : k; + mkldnn::impl::parallel_nd(b_cols, b_rows, [&](int j, int i) { + db_setter(i, j, + static_cast<double>(ib_accessor(i, j)) + static_cast<double>(bo[0])); + }); + double one = 1.0, zero = 0.0; + ref_gemm<double>(transa, transb, M, N, K, &one, dA, LDA, dB, LDB, &zero, + dC, LDC, nullptr); + + auto i2d = [=] (int32_t v) { return static_cast<double>(v); }; + auto f2d = [=] (float v) { return static_cast<double>(v); }; + + mkldnn::impl::parallel_nd(n, m, [&] (int j, int i) { + double coffset = OCisR ? i2d(co[j]) : OCisC ? i2d(co[i]) : i2d(co[0]); + double val = ((*beta == 0.0f) ? 0.0 : f2d(*beta) * i2d(C[i + j * ldc])) + + f2d(*alpha) * dC[i + j * ldc] + coffset; + C[i + j * ldc] = math::out_round<int32_t>(math::saturate<int32_t>(val)); + }); + + free(dA); + free(dB); + free(dC); + return mkldnn_success; +} } } } @@ -193,3 +300,23 @@ mkldnn_status_t mkldnn_sgemm(const char *transa, const char *transb, return extended_sgemm( transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } + +mkldnn_status_t mkldnn_gemm_s8u8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *lda, const int8_t *ao, + const uint8_t *B, const int *ldb, const int8_t *bo, const float *beta, + int32_t *c, const int *ldc, const int32_t *co) { + return gemm_s8x8s32( + transa, transb, offsetc, M, N, K, alpha, A, lda, ao, B, ldb, bo, + beta, c, ldc, co); +} + +mkldnn_status_t mkldnn_gemm_s8s8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *lda, const int8_t *ao, + const int8_t *B, const int *ldb, const int8_t *bo, const float *beta, + int32_t *c, const int *ldc, const int32_t *co) { + return gemm_s8x8s32( + transa, transb, offsetc, M, N, K, alpha, A, lda, ao, B, ldb, bo, + beta, c, ldc, co); +} diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm.hpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm.hpp index 8917de157..3f33a3713 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm.hpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm.hpp @@ -22,11 +22,20 @@ mkldnn_status_t extended_sgemm(const char *transa, const char *transb, const int *M, const int *N, const int *K, const float *alpha, const float *A, const int *lda, const float *B, const int *ldb, const float *beta, float *C, const int *ldc, - const float *bias = nullptr); + const float *bias = nullptr, bool force_jit_gemm = false); + +template <typename b_dt> +mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *lda, const int8_t *ao, + const b_dt *B, const int *ldb, const int8_t *bo, const float *beta, + int32_t *c, const int *ldc, const int32_t *co); + +template <typename data_t> void ref_gemm(const char *transa, const char *transb, const int *M, - const int *N, const int *K, const float *alpha, const float *A, - const int *lda, const float *B, const int *ldb, const float *beta, - float *C, const int *ldc, const float *bias); + const int *N, const int *K, const data_t *alpha, const data_t *A, + const int *lda, const data_t *B, const int *ldb, const data_t *beta, + data_t *C, const int *ldc, const data_t *bias); #ifdef USE_CBLAS #define GEMM_IMPL_STR "gemm:blas" #else diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm_utils.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm_utils.cpp index 934ba81ce..e3b6cff8a 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm_utils.cpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm_utils.cpp @@ -343,8 +343,9 @@ void partition_unit_diff( // Sum the m*n values from p_src into p_dst, assuming the two-dimensional // arrays have leading dimensions ld_src and ld_dst, respectively +template<typename data_t> void sum_two_matrices( - int m, int n, float *p_src, int ld_src, float *p_dst, int ld_dst) + int m, int n, data_t *p_src, int ld_src, data_t *p_dst, int ld_dst) { int i, j; for (j = 0; j < n; j++) { @@ -353,6 +354,12 @@ void sum_two_matrices( } } } + +template void sum_two_matrices<float>( + int m, int n, float *p_src, int ld_src, float *p_dst, int ld_dst); + +template void sum_two_matrices<double>( + int m, int n, double *p_src, int ld_src, double *p_dst, int ld_dst); } } } diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm_utils.hpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm_utils.hpp index 7a8f7fcbe..0888787b9 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm_utils.hpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/gemm_utils.hpp @@ -22,8 +22,34 @@ namespace impl { namespace cpu { namespace gemm_utils { + +template <typename T, bool isTransA, bool isTransB> +struct gemm_traits {}; + +template <bool isTransA, bool isTransB> +struct gemm_traits<double, isTransA, isTransB> { + static constexpr int m = 8; + static constexpr int n = 6; + static constexpr int BM = 4032; + static constexpr int BN = isTransA ? 96 : 192; + static constexpr int BK = isTransB ? 96 : 512; +}; + +template <bool isTransA, bool isTransB> +struct gemm_traits<float, isTransA, isTransB> { + static constexpr int m = 16; + static constexpr int n = 6; + static constexpr int BM = 4032; + static constexpr int BN = isTransA ? 96 : 48; + static constexpr int BK = isTransB ? 96 : 256; +}; + +template <typename T> +using unroll_factor = gemm_traits<T, false, false>; + +template <typename data_type> void sum_two_matrices( - int m, int n, float *p_src, int ld_src, float *p_dst, int ld_dst); + int m, int n, data_type *p_src, int ld_src, data_type *p_dst, int ld_dst); void calc_nthr_nocopy_avx512_common(int m, int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, @@ -35,6 +61,8 @@ void calc_nthr_nocopy_avx(int m, int n, int k, void partition_unit_diff( int ithr, int nthr, int n, int *t_offset, int *t_block); + +inline double saturate(double value, double min, double max); }; } diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp index 08ba7b47d..8aee85fbd 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp @@ -22,7 +22,7 @@ #include "gemm_utils.hpp" #include "jit_avx512_common_gemm_f32.hpp" -#define CACHE_LINE_SIZE 16 +#define CACHE_LINE_SIZE 64 namespace mkldnn { namespace impl { @@ -128,15 +128,16 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { // Function for packing if needed auto do_pack = [&](int unroll_m) { - inLocalLabel(); + Label pack2, pack3, pack4, pack10; + mov(BO1, A); lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]); mov(LL, K); sar(LL, 2); - jle(".pack3", T_NEAR); + jle(pack3, T_NEAR); align(16); - L(".pack2"); + L(pack2); if (!isTransA) { for (int i = 0; i < 4; i++) { vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]); @@ -216,16 +217,16 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { add(AO1, unroll_m * 4 * SIZE); sub(LL, 1); - jg(".pack2", T_NEAR); + jg(pack2, T_NEAR); align(16); - L(".pack3"); + L(pack3); mov(LL, K); and_(LL, 3); - jle(".pack10", T_NEAR); + jle(pack10, T_NEAR); align(16); - L(".pack4"); + L(pack4); if (!isTransA) { vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]); if (unroll_m > 16) @@ -279,11 +280,10 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { add(AO1, unroll_m * SIZE); sub(LL, 1); - jg(".pack4", T_NEAR); + jg(pack4, T_NEAR); align(16); - L(".pack10"); - outLocalLabel(); + L(pack10); }; // Function to update C, covering masking and other considerations @@ -617,8 +617,6 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { // Innerkernel; called by kernel auto innerkernel = [&](int unroll_m, int unroll_n, bool isDirect, bool isCopy, bool doCPrefetch, bool isUnmasked = true) { - inLocalLabel(); - for (int i = 0; i < 8; i++) { if (!isDirect) { prefetcht0(ptr[AO1 @@ -960,7 +958,6 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { } sub(LL, 1); - outLocalLabel(); }; // Main kernel; does prefetching and calls innerkernel @@ -968,7 +965,6 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { // calling update auto kernel = [&](int unroll_m, int unroll_n, bool isDirect, bool isCopy, bool isUnmasked = true) { - inLocalLabel(); if (!isDirect) { lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]); } else { @@ -1020,36 +1016,38 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { } } + Label kernel12, kernel13, kernel14, kernel15, kernel16, kernel18; + mov(LL, K); sar(LL, 3); sub(LL, SECOND_FETCH); - jle(".kernel13", T_NEAR); + jle(kernel13, T_NEAR); align(16); - L(".kernel12"); + L(kernel12); innerkernel( unroll_m, unroll_n, isDirect, isCopy, false, isUnmasked); - jg(".kernel12", T_NEAR); + jg(kernel12, T_NEAR); align(16); - L(".kernel13"); + L(kernel13); lea(CO2, ptr[CO1 + (16 - 1) * SIZE]); add(LL, unroll_n); - jle(".kernel15", T_NEAR); + jle(kernel15, T_NEAR); align(16); - L(".kernel14"); + L(kernel14); innerkernel(unroll_m, unroll_n, isDirect, isCopy, true, isUnmasked); - jg(".kernel14", T_NEAR); + jg(kernel14, T_NEAR); align(16); - L(".kernel15"); + L(kernel15); mov(LL, K); and_(LL, 7); - jle(".kernel18", T_NEAR); + jle(kernel18, T_NEAR); align(16); - L(".kernel16"); + L(kernel16); if (isDirect) { if (isUnmasked || unroll_m > 16) { vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); @@ -1204,10 +1202,10 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { } sub(LL, 1); - jg(".kernel16", T_NEAR); + jg(kernel16, T_NEAR); align(16); - L(".kernel18"); + L(kernel18); vbroadcastss(VALPHA, ALPHA); if (isBetaN) { @@ -1329,8 +1327,6 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { sub(BO1, rax); add(BO1, unroll_n * SIZE); } - - outLocalLabel(); }; // High-level subroutine; does packing if needed, then splits C matrix. @@ -1338,11 +1334,16 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { // cases appropriately by doing 32 or 16 rows, and/or with masking, // and/or fewer columns). auto subloop = [&](int unroll_m) { - inLocalLabel(); - Label l_subloop_20x[8], l_subloop_mask_20x[8]; Label l_subloop_30x[8], l_subloop_mask_30x[8]; + Label subloop11, subloop11mask; + Label subloop30, subloop30mask; + Label subloop31, subloop31mask; + Label subloop96; + Label subloop98, subloop98mask; + Label subloop99; + // Create mask mov(BO1, rcx); mov(rcx, M); @@ -1370,7 +1371,7 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { and_(rax, 0xffff); cmp(rax, 0xffff); - jne(".subloop96", T_NEAR); + jne(subloop96, T_NEAR); if (isTransA) { do_pack(unroll_m); @@ -1387,11 +1388,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { if (!isTransA) { lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]); cmp(M, UNROLL_M); - jg(".subloop98", T_NEAR); + jg(subloop98, T_NEAR); mov(AA, ORIG_A); lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]); - L(".subloop98"); + L(subloop98); } mov(LL, N); @@ -1399,11 +1400,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { if (!isTransA) { // If N is too small, skip copy operation cmp(LL, UNROLL_N * 3); - jle(".subloop30", T_NEAR); + jle(subloop30, T_NEAR); // If A is not aligned to cache line cmp(FLAG, 0); - je(".subloop30", T_NEAR); + je(subloop30, T_NEAR); } else { cmp(LL, UNROLL_N); jl(l_subloop_20x[1], T_NEAR); @@ -1421,11 +1422,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { jl(l_subloop_20x[1], T_NEAR); align(16); - L(".subloop11"); + L(subloop11); kernel(unroll_m, UNROLL_N, false, false); sub(I, UNROLL_N); cmp(I, UNROLL_N); - jge(".subloop11", T_NEAR); + jge(subloop11, T_NEAR); align(16); for (int i = 1; i <= 7; i++) { @@ -1434,24 +1435,24 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { if (i < 7) { jne(l_subloop_20x[i + 1], T_NEAR); } else { - jne(".subloop99", T_NEAR); + jne(subloop99, T_NEAR); } kernel(unroll_m, i, false, false); - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); } if (!isTransA) { - L(".subloop30"); + L(subloop30); cmp(I, UNROLL_N); jl(l_subloop_30x[1], T_NEAR); align(16); - L(".subloop31"); + L(subloop31); kernel(unroll_m, UNROLL_N, true, false); sub(I, UNROLL_N); cmp(I, UNROLL_N); - jge(".subloop31", T_NEAR); + jge(subloop31, T_NEAR); align(16); for (int i = 1; i <= 7; i++) { @@ -1460,18 +1461,18 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { if (i < 7) { jne(l_subloop_30x[i + 1], T_NEAR); } else { - jne(".subloop99", T_NEAR); + jne(subloop99, T_NEAR); } kernel(unroll_m, i, true, false); if (i < 7) - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); } } - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); - L(".subloop96"); + L(subloop96); if (isTransA) { do_pack(unroll_m); } @@ -1486,10 +1487,10 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { if (!isTransA) { lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]); cmp(M, UNROLL_M); - jg(".subloop98mask", T_NEAR); + jg(subloop98mask, T_NEAR); mov(AA, ORIG_A); lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]); - L(".subloop98mask"); + L(subloop98mask); } mov(LL, N); @@ -1497,11 +1498,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { if (!isTransA) { // If N is too small, skip copy operation cmp(LL, UNROLL_N * 3); - jle(".subloop30mask", T_NEAR); + jle(subloop30mask, T_NEAR); // If A is not aligned to cache line cmp(FLAG, 0); - je(".subloop30mask", T_NEAR); + je(subloop30mask, T_NEAR); } else { cmp(LL, UNROLL_N); jl(l_subloop_mask_20x[1], T_NEAR); @@ -1519,11 +1520,11 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { jl(l_subloop_mask_20x[1], T_NEAR); align(16); - L(".subloop11mask"); + L(subloop11mask); kernel(unroll_m, UNROLL_N, false, false, false); sub(I, UNROLL_N); cmp(I, UNROLL_N); - jge(".subloop11mask", T_NEAR); + jge(subloop11mask, T_NEAR); align(16); for (int i = 1; i <= 7; i++) { @@ -1532,24 +1533,24 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { if (i < 7) { jne(l_subloop_mask_20x[i + 1], T_NEAR); } else { - jne(".subloop99", T_NEAR); + jne(subloop99, T_NEAR); } kernel(unroll_m, i, false, false, false); - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); } if (!isTransA) { - L(".subloop30mask"); + L(subloop30mask); cmp(I, UNROLL_N); jl(l_subloop_mask_30x[1], T_NEAR); align(16); - L(".subloop31mask"); + L(subloop31mask); kernel(unroll_m, UNROLL_N, true, false, false); sub(I, UNROLL_N); cmp(I, UNROLL_N); - jge(".subloop31mask", T_NEAR); + jge(subloop31mask, T_NEAR); align(16); for (int i = 1; i <= 7; i++) { @@ -1558,16 +1559,16 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { if (i < 7) { jne(l_subloop_mask_30x[i + 1], T_NEAR); } else { - jne(".subloop99", T_NEAR); + jne(subloop99, T_NEAR); } kernel(unroll_m, i, true, false, false); if (i < 7) - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); } } - L(".subloop99"); + L(subloop99); // Compute address for A if (!isTransA) { add(A, unroll_m * SIZE); @@ -1581,14 +1582,12 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { if (hasBias) { add(BIAS, unroll_m * SIZE); } - - outLocalLabel(); }; - inLocalLabel(); - preamble(); + Label buffer_in_ws, buffer_allocated; + // Get the registers mov(B, ARG_B); mov(LDB, ARG_LDB); @@ -1608,7 +1607,7 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { #endif cmp(K, STACK_K_CAPACITY); - jg(".buffer_in_ws", T_NEAR); + jg(buffer_in_ws, T_NEAR); // Create buffer and align to 4kB page lea(rax, ptr[K * SIZE]); @@ -1616,12 +1615,12 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { add(rax, 256); sub(rsp, rax); and_(rsp, -PAGE_4K); - jmp(".buffer_allocated", T_NEAR); + jmp(buffer_allocated, T_NEAR); - L(".buffer_in_ws"); + L(buffer_in_ws); mov(rsp, ARG_WS); - L(".buffer_allocated"); + L(buffer_allocated); mov(ORIG_SP, rbp); mov(M, ARG_M); @@ -1665,40 +1664,40 @@ struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator { } } + Label main0, main1, main2, main999; + cmp(M, 32); - jle(".main0", T_NEAR); + jle(main0, T_NEAR); align(16); - L(".main1"); + L(main1); subloop(48); sub(M, UNROLL_M); cmp(M, 32); - jg(".main1", T_NEAR); + jg(main1, T_NEAR); align(16); - L(".main0"); + L(main0); cmp(M, 16); - jle(".main2", T_NEAR); + jle(main2, T_NEAR); subloop(32); - jmp(".main999", T_NEAR); + jmp(main999, T_NEAR); align(16); - L(".main2"); + L(main2); cmp(M, 0); - jle(".main999", T_NEAR); + jle(main999, T_NEAR); subloop(16); align(16); - L(".main999"); + L(main999); // Restore original stack mov(rsp, ORIG_SP); vzeroupper(); postamble(); - outLocalLabel(); - ker_ = reinterpret_cast<decltype(ker_)>( const_cast<uint8_t *>(this->getCode())); } @@ -1763,7 +1762,7 @@ void jit_avx512_common_gemm_f32::sgemm_nocopy_driver(const char *transa, if (!isTransA && !isTransB) BK = 128; } - const float *curA, *curB, *curBias = NULL; + const float *curA, *curB, *curBias = nullptr; float *curC; for (Bk = 0; Bk < k; Bk += sizeK) { @@ -1804,15 +1803,15 @@ void jit_avx512_common_gemm_f32::sgemm_nocopy_driver(const char *transa, curB = b + Bn + (size_t)Bk * ldb; } curC = c + Bm + (size_t)Bn * ldc; - if (bias != NULL) { + if (bias != nullptr) { if (Bk == 0) { curBias = bias + Bm; } else { - curBias = NULL; + curBias = nullptr; } } if (Bk == 0) { - if (*beta == 0.0 && bias == NULL) + if (*beta == 0.0 && bias == nullptr) (*ker_b0_)((long long int)sizeM, (long long int)sizeN, (long long int)sizeK, alpha, curA, (long long int)lda, curB, (long long int)ldb, @@ -1860,7 +1859,7 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb, // Determine threading partitioning gemm_utils::calc_nthr_nocopy_avx512_common( m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); - assert(utils::implication(!mkldnn_thr_syncable(), nthr_k == 1)); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); // May not happen, but just in case if (nthr < nthr_m * nthr_n * nthr_k) @@ -1868,13 +1867,18 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb, nthr_mn = nthr_m * nthr_n; - unsigned int volatile *ompstatus = (unsigned int volatile *)ompstatus_; - if (!ompstatus) return; + unsigned char * ompstatus_ = nullptr; + unsigned char volatile *ompstatus = nullptr; - float *c_buffers = NULL; - float *ws_buffers = NULL; + float *c_buffers = nullptr; + float *ws_buffers = nullptr; if (nthr_k > 1) { + ompstatus_ = (unsigned char *) malloc( + nthr * CACHE_LINE_SIZE, + CACHE_LINE_SIZE); + ompstatus = (unsigned char volatile *) ompstatus_; + assert(ompstatus); for (int i = 0; i < nthr; i++) ompstatus[i * CACHE_LINE_SIZE] = 0; @@ -1895,7 +1899,7 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb, int n_from, n_to, myN; int k_from, k_to, myK; int cbase, ibase; - const float *myA, *myB, *myBias = NULL; + const float *myA, *myB, *myBias = nullptr; float *myC = C, myBeta; float *ws = ws_buffers ? ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0; @@ -1957,7 +1961,7 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb, myC = c_buffers + MB * NB * (cbase + ithr_k - 1); myBeta = 0.0; ld = MB; - myBias = NULL; + myBias = nullptr; } sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA, @@ -2004,8 +2008,8 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb, } }); - if (nthr_k > 1) - free(c_buffers); + free(c_buffers); + free(ompstatus_); free(ws_buffers); } @@ -2032,10 +2036,6 @@ jit_avx512_common_gemm_f32::jit_avx512_common_gemm_f32( } nthrs_ = mkldnn_get_max_threads(); - ompstatus_ = (unsigned int *)malloc( - sizeof(unsigned int *) * nthrs_ * CACHE_LINE_SIZE, 64); - assert(ompstatus_); - } jit_avx512_common_gemm_f32::~jit_avx512_common_gemm_f32() @@ -2045,7 +2045,6 @@ jit_avx512_common_gemm_f32::~jit_avx512_common_gemm_f32() delete ker_b1_; if (beta_ != 0.0 || (beta_ == 0.0 && hasBias_)) delete ker_b0_; - free(ompstatus_); } } } diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.hpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.hpp index ede1cf9c1..c05733581 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.hpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx512_common_gemm_f32.hpp @@ -49,7 +49,6 @@ private: bool hasBias_; struct xbyak_gemm; xbyak_gemm *ker_bn_, *ker_b1_, *ker_b0_; - unsigned int *ompstatus_; int nthrs_; }; } diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx_gemm_f32.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx_gemm_f32.cpp index 9766a46d7..354fa0bc7 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx_gemm_f32.cpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx_gemm_f32.cpp @@ -21,7 +21,7 @@ #include "gemm_utils.hpp" #include "jit_avx_gemm_f32.hpp" -#define CACHE_LINE_SIZE 16 +#define CACHE_LINE_SIZE 64 namespace mkldnn { namespace impl { @@ -51,7 +51,7 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { : jit_generator(code_ptr, code_size) { const bool is_avx2 = mayiuse(avx2); - assert(implication(!is_avx2, mayiuse(avx))); + assert(IMPLICATION(!is_avx2, mayiuse(avx))); const int UNROLL_M = is_avx2 ? 16 : 8; const int UNROLL_N = 6; @@ -128,10 +128,10 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { // Function for packing if needed auto do_pack = [&]( int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) { + Label pack2, pack3, pack4, pack10; int regIdx; Reg64 reg; - inLocalLabel(); mov(BO1, A); lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]); @@ -144,10 +144,10 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { mov(LL, K); sar(LL, 2); - jle(".pack3", T_NEAR); + jle(pack3, T_NEAR); align(16); - L(".pack2"); + L(pack2); if (!isTransA) { for (int i = 0; i < 4; i++) { regIdx = (i % 2 == 0) ? 4 : 6; @@ -396,16 +396,16 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { add(AO1, unroll_m * 4 * SIZE); sub(LL, 1); - jg(".pack2", T_NEAR); + jg(pack2, T_NEAR); align(16); - L(".pack3"); + L(pack3); mov(LL, K); and_(LL, 3); - jle(".pack10", T_NEAR); + jle(pack10, T_NEAR); align(16); - L(".pack4"); + L(pack4); if (!isTransA) { if (isLoad1Unmasked) { vmovups(ymm4, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); @@ -542,12 +542,10 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { add(AO1, unroll_m * SIZE); sub(LL, 1); - jg(".pack4", T_NEAR); + jg(pack4, T_NEAR); align(16); - L(".pack10"); - - outLocalLabel(); + L(pack10); }; // Fused multiply add; may become one or two instructions @@ -1382,8 +1380,6 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { Ymm reg15 = Ymm(7), Ymm reg16 = Ymm(8), Ymm reg17 = Ymm(9), Ymm reg18 = Ymm(10), Ymm reg19 = Ymm(11), Ymm reg20 = Ymm(12), Ymm reg21 = Ymm(13), Ymm reg22 = Ymm(14), Ymm reg23 = Ymm(15)) { - inLocalLabel(); - if (!isDirect) { lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]); } else { @@ -1431,20 +1427,23 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { mov(LL, K); sar(LL, 3); + Label kernel12, kernel13, kernel14, kernel15; + Label kernel16, kernel17, kernel18; + sub(LL, SECOND_FETCH); - jle(".kernel13", T_NEAR); + jle(kernel13, T_NEAR); align(16); - L(".kernel12"); + L(kernel12); innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, reg21, reg22, reg23); - jg(".kernel12", T_NEAR); + jg(kernel12, T_NEAR); align(16); - L(".kernel13"); + L(kernel13); prefetcht0(ptr[CO1 + (unroll_m - 1) * SIZE]); if (unroll_n >= 2) prefetcht0(ptr[CO1 + LDC + (unroll_m - 1) * SIZE]); @@ -1458,30 +1457,30 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { prefetcht0(ptr[CO2 + LDC * 2 + (unroll_m - 1) * SIZE]); add(LL, SECOND_FETCH); - jle(".kernel15", T_NEAR); + jle(kernel15, T_NEAR); align(16); - L(".kernel14"); + L(kernel14); innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, reg21, reg22, reg23); - jg(".kernel14", T_NEAR); + jg(kernel14, T_NEAR); align(16); - L(".kernel15"); + L(kernel15); test(K, 4); - jle(".kernel16", T_NEAR); + jle(kernel16, T_NEAR); innerkernel4(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, reg21, reg22, reg23); - L(".kernel16"); + L(kernel16); test(K, 2); - jle(".kernel17", T_NEAR); + jle(kernel17, T_NEAR); innerkernel2(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, @@ -1489,7 +1488,7 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { reg21, reg22, reg23); align(16); - L(".kernel17"); + L(kernel17); if (unroll_m == 16) { if (unroll_n <= 3) { vaddps(reg00, reg00, reg12); @@ -1511,13 +1510,13 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { } test(K, 1); - jle(".kernel18", T_NEAR); + jle(kernel18, T_NEAR); innerkernel1(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, reg05, reg06, reg07, reg08, reg09, reg10, reg11); align(16); - L(".kernel18"); + L(kernel18); vbroadcastss(VALPHA, ALPHA); if (isBetaN) { @@ -1804,8 +1803,6 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { sub(BO1, rax); add(BO1, unroll_n * SIZE); } - - outLocalLabel(); }; auto kernel_16x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, @@ -1898,12 +1895,18 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { // Masking is used for tail cases where M is not divisible by 8. auto subloop = [&]( int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) { - inLocalLabel(); - if (isTransA) { do_pack(unroll_m, isLoad1Unmasked, isLoad2Unmasked); } + Label subloop11, subloop11mask; + Label subloop20, subloop21, subloop22, subloop23; + Label subloop24, subloop25; + Label subloop30, subloop31, subloop32, subloop33; + Label subloop34, subloop35; + Label subloop98, subloop98mask; + Label subloop99, subloop99mask; + mov(CO1, C); lea(CO2, ptr[CO1 + LDC * 2]); add(CO2, LDC); @@ -1916,11 +1919,11 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { if (!isTransA) { lea(AA, ptr[A + (unroll_m * 2 - 1 - OFFSET) * SIZE]); cmp(M, UNROLL_M); - jg(".subloop98", T_NEAR); + jg(subloop98, T_NEAR); mov(AA, ORIG_A); lea(AA, ptr[AA + (unroll_m - 1 - OFFSET) * SIZE]); - L(".subloop98"); + L(subloop98); } mov(LL, N); @@ -1928,14 +1931,14 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { if (!isTransA) { // If N is too small, skip copy operation cmp(LL, UNROLL_N * 3); - jle(".subloop30", T_NEAR); + jle(subloop30, T_NEAR); // If A is not aligned to cache line cmp(FLAG, 0); - je(".subloop30", T_NEAR); + je(subloop30, T_NEAR); } else { cmp(LL, UNROLL_N); - jl(".subloop20", T_NEAR); + jl(subloop20, T_NEAR); } align(16); @@ -1959,10 +1962,10 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { sub(I, UNROLL_N); cmp(I, UNROLL_N); - jl(".subloop20", T_NEAR); + jl(subloop20, T_NEAR); align(16); - L(".subloop11"); + L(subloop11); if (unroll_m == 16) { kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, isLoad2Unmasked, false, false); @@ -1972,12 +1975,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { } sub(I, UNROLL_N); cmp(I, UNROLL_N); - jge(".subloop11", T_NEAR); + jge(subloop11, T_NEAR); align(16); - L(".subloop20"); + L(subloop20); cmp(I, 1); - jne(".subloop21", T_NEAR); + jne(subloop21, T_NEAR); if (unroll_m == 16) { kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, false, false); @@ -1985,12 +1988,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, false, false); } - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); - L(".subloop21"); + L(subloop21); cmp(I, 2); - jne(".subloop22", T_NEAR); + jne(subloop22, T_NEAR); if (unroll_m == 16) { kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, false, false); @@ -1998,12 +2001,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, false, false); } - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); - L(".subloop22"); + L(subloop22); cmp(I, 3); - jne(".subloop23", T_NEAR); + jne(subloop23, T_NEAR); if (unroll_m == 16) { kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, false, false); @@ -2011,12 +2014,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, false, false); } - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); - L(".subloop23"); + L(subloop23); cmp(I, 4); - jne(".subloop24", T_NEAR); + jne(subloop24, T_NEAR); if (unroll_m == 16) { kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, false, false); @@ -2024,12 +2027,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, false, false); } - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); - L(".subloop24"); + L(subloop24); cmp(I, 5); - jne(".subloop99", T_NEAR); + jne(subloop99, T_NEAR); if (unroll_m == 16) { kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, false, false); @@ -2037,16 +2040,16 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, false, false); } - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); if (!isTransA) { - L(".subloop30"); + L(subloop30); cmp(I, UNROLL_N); - jl(".subloop25", T_NEAR); + jl(subloop25, T_NEAR); align(16); - L(".subloop31"); + L(subloop31); if (unroll_m == 16) { kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, isLoad2Unmasked, true, false); @@ -2056,12 +2059,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { } sub(I, UNROLL_N); cmp(I, UNROLL_N); - jge(".subloop31", T_NEAR); + jge(subloop31, T_NEAR); align(16); - L(".subloop25"); + L(subloop25); cmp(I, 1); - jne(".subloop32", T_NEAR); + jne(subloop32, T_NEAR); if (unroll_m == 16) { kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, true, false); @@ -2069,12 +2072,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, true, false); } - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); - L(".subloop32"); + L(subloop32); cmp(I, 2); - jne(".subloop33", T_NEAR); + jne(subloop33, T_NEAR); if (unroll_m == 16) { kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, true, false); @@ -2082,12 +2085,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, true, false); } - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); - L(".subloop33"); + L(subloop33); cmp(I, 3); - jne(".subloop34", T_NEAR); + jne(subloop34, T_NEAR); if (unroll_m == 16) { kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, true, false); @@ -2095,12 +2098,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, true, false); } - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); - L(".subloop34"); + L(subloop34); cmp(I, 4); - jne(".subloop35", T_NEAR); + jne(subloop35, T_NEAR); if (unroll_m == 16) { kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, true, false); @@ -2108,12 +2111,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, true, false); } - jmp(".subloop99", T_NEAR); + jmp(subloop99, T_NEAR); align(16); - L(".subloop35"); + L(subloop35); cmp(I, 5); - jne(".subloop99", T_NEAR); + jne(subloop99, T_NEAR); if (unroll_m == 16) { kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, true, false); @@ -2124,7 +2127,7 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { align(16); } - L(".subloop99"); + L(subloop99); // Compute address for A if (!isTransA) { add(A, unroll_m * SIZE); @@ -2138,14 +2141,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { if (hasBias) { add(BIAS, unroll_m * SIZE); } - - outLocalLabel(); }; - inLocalLabel(); - preamble(); + Label buffer_in_ws, buffer_allocated; + // Get the registers mov(B, ARG_B); mov(LDB, ARG_LDB); @@ -2165,7 +2166,7 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { #endif cmp(K, STACK_K_CAPACITY); - jg(".buffer_in_ws", T_NEAR); + jg(buffer_in_ws, T_NEAR); // Create buffer and align to 4kB page lea(rax, ptr[K * SIZE]); @@ -2173,12 +2174,12 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { add(rax, 256); sub(rsp, rax); and_(rsp, -PAGE_4K); - jmp(".buffer_allocated", T_NEAR); + jmp(buffer_allocated, T_NEAR); - L(".buffer_in_ws"); + L(buffer_in_ws); mov(rsp, ARG_WS); - L(".buffer_allocated"); + L(buffer_allocated); mov(ORIG_SP, rbp); mov(M, ARG_M); @@ -2218,43 +2219,45 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { and_(rax, 0x1f); mov(FLAG, rax); + Label main0, main1, main2, main3, main999; + cmp(M, UNROLL_M); - jl(".main0", T_NEAR); + jl(main0, T_NEAR); align(16); - L(".main1"); + L(main1); subloop(UNROLL_M, true, true); sub(M, UNROLL_M); cmp(M, UNROLL_M); - jge(".main1", T_NEAR); + jge(main1, T_NEAR); align(16); - L(".main0"); + L(main0); cmp(M, 0); - jle(".main999", T_NEAR); + jle(main999, T_NEAR); if (UNROLL_M > 8) { cmp(M, 8); - jle(".main2", T_NEAR); + jle(main2, T_NEAR); sub(M, 8); vbroadcastss(VMASK, M); vpcmpgtd(VMASK, VMASK, MASK); subloop(16, true, false); - jmp(".main999", T_NEAR); + jmp(main999, T_NEAR); align(16); - L(".main2"); + L(main2); cmp(M, 8); - jne(".main3", T_NEAR); + jne(main3, T_NEAR); subloop(8, true, true); - jmp(".main999", T_NEAR); + jmp(main999, T_NEAR); } align(16); - L(".main3"); + L(main3); vbroadcastss(VMASK, M); if (is_avx2) { vpcmpgtd(VMASK, VMASK, MASK); @@ -2270,7 +2273,7 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { subloop(8, false, false); align(16); - L(".main999"); + L(main999); // Restore original stack mov(rax, ORIG_SP); mov(rsp, rax); @@ -2278,8 +2281,6 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator { vzeroupper(); postamble(); - outLocalLabel(); - ker_ = reinterpret_cast<decltype(ker_)>( const_cast<uint8_t *>(this->getCode())); } @@ -2335,7 +2336,7 @@ void jit_avx_gemm_f32::sgemm_nocopy_driver(const char *transa, int BM = 4032; int BN = isTransA ? 96 : 48; int BK = isTransB ? 96 : 256; - const float *curA, *curB, *curBias = NULL; + const float *curA, *curB, *curBias = nullptr; float *curC; for (Bk = 0; Bk < k; Bk += sizeK) { @@ -2376,15 +2377,15 @@ void jit_avx_gemm_f32::sgemm_nocopy_driver(const char *transa, curB = b + Bn + (size_t)Bk * ldb; } curC = c + Bm + (size_t)Bn * ldc; - if (bias != NULL) { + if (bias != nullptr) { if (Bk == 0) { curBias = bias + Bm; } else { - curBias = NULL; + curBias = nullptr; } } if (Bk == 0) { - if (*beta == 0.0 && bias == NULL) + if (*beta == 0.0 && bias == nullptr) (*ker_b0_)((long long int)sizeM, (long long int)sizeN, (long long int)sizeK, alpha, curA, (long long int)lda, curB, (long long int)ldb, @@ -2431,7 +2432,7 @@ void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb, // Determine threading partitioning gemm_utils::calc_nthr_nocopy_avx( m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); - assert(utils::implication(!mkldnn_thr_syncable(), nthr_k == 1)); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); // May not happen, but just in case if (nthr < nthr_m * nthr_n * nthr_k) @@ -2439,13 +2440,19 @@ void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb, nthr_mn = nthr_m * nthr_n; - unsigned int volatile *ompstatus = (unsigned int volatile *)ompstatus_; - if (!ompstatus) return; + unsigned char * ompstatus_ = nullptr; + unsigned char volatile *ompstatus = nullptr; - float *c_buffers = NULL; - float *ws_buffers = NULL; + float *c_buffers = nullptr; + float *ws_buffers = nullptr; if (nthr_k > 1) { + ompstatus_ = (unsigned char *) malloc( + nthr * CACHE_LINE_SIZE, + CACHE_LINE_SIZE); + ompstatus = (unsigned char volatile *) ompstatus_; + assert(ompstatus); + for (int i = 0; i < nthr; i++) ompstatus[i * CACHE_LINE_SIZE] = 0; @@ -2466,7 +2473,7 @@ void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb, int n_from, n_to, myN; int k_from, k_to, myK; int cbase, ibase; - const float *myA, *myB, *myBias = NULL; + const float *myA, *myB, *myBias = nullptr; float *myC = C, myBeta; float *ws = ws_buffers ? ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0; @@ -2528,7 +2535,7 @@ void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb, myC = c_buffers + MB * NB * (cbase + ithr_k - 1); myBeta = 0.0; ld = MB; - myBias = NULL; + myBias = nullptr; } sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA, @@ -2575,8 +2582,8 @@ void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb, } }); - if (nthr_k > 1) - free(c_buffers); + free(c_buffers); + free(ompstatus_); free(ws_buffers); } @@ -2602,9 +2609,6 @@ jit_avx_gemm_f32::jit_avx_gemm_f32( ker_b0_ = ker_bn_; } nthrs_ = mkldnn_get_max_threads(); - ompstatus_ = (unsigned int *)malloc( - sizeof(unsigned int *) * nthrs_ * CACHE_LINE_SIZE, 64); - assert(ompstatus_); } jit_avx_gemm_f32::~jit_avx_gemm_f32() @@ -2614,7 +2618,6 @@ jit_avx_gemm_f32::~jit_avx_gemm_f32() delete ker_b1_; if (beta_ != 0.0 || (beta_ == 0.0 && hasBias_)) delete ker_b0_; - free(ompstatus_); } } diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx_gemm_f32.hpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx_gemm_f32.hpp index 0f0cc46ea..dd34e09f0 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx_gemm_f32.hpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/jit_avx_gemm_f32.hpp @@ -49,7 +49,6 @@ private: bool hasBias_; struct xbyak_gemm; xbyak_gemm *ker_bn_, *ker_b1_, *ker_b0_; - unsigned int *ompstatus_; int nthrs_; }; } diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/ref_gemm.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/ref_gemm.cpp index 3310bf5e9..e0331e0ef 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/ref_gemm.cpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm/ref_gemm.cpp @@ -27,63 +27,68 @@ namespace impl { namespace cpu { using namespace mkldnn::impl::utils; +using namespace gemm_utils; -constexpr int unroll_m = 16; -constexpr int unroll_n = 6; + +template <typename data_t> static void copy_A( - bool isTransA, int K, const float *A, const int lda, float *ws) { + bool isTransA, int K, const data_t *A, const int lda, data_t *ws) { for (int k = 0; k < K; k++) { PRAGMA_OMP_SIMD() - for (int i = 0; i < unroll_m; i++) { + for (int i = 0; i < gemm_utils::unroll_factor<data_t>::m; i++) { ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda]; } - ws += unroll_m; + ws += unroll_factor<data_t>::m; } } -template <bool isTransA, bool isTransB> -static void kernel_mxn(int K, const float *A, const int lda, - const float *B, const int ldb, float *C, const int ldc, - const float alpha, const float beta) { - float c[unroll_m * unroll_n] = { 0. }; +template <typename data_t, bool isTransA, bool isTransB> +static void kernel_mxn(int K, const data_t *A, const int lda, + const data_t *B, const int ldb, data_t *C, const int ldc, + const data_t alpha, const data_t beta) { + data_t c[unroll_factor<data_t>::m * unroll_factor<data_t>::n] = + { static_cast<data_t>(0.) }; for (int k = 0; k < K; k++) { - for (int j = 0; j < unroll_n; j++) { - float b = isTransB ? B[j + k * ldb] : B[k + j * ldb]; + for (int j = 0; j < unroll_factor<data_t>::n; j++) { + data_t b = isTransB ? B[j + k * ldb] : B[k + j * ldb]; PRAGMA_OMP_SIMD() - for (int i = 0; i < unroll_m; i++) { - float a = isTransA ? A[i * lda + k] : A[i + lda * k]; - c[i + unroll_m * j] += a * b; + for (int i = 0; i < unroll_factor<data_t>::m; i++) { + data_t a = isTransA ? A[i * lda + k] : A[i + lda * k]; + c[i + unroll_factor<data_t>::m * j] += a * b; } } } - for (int j = 0; j < unroll_n; j++) { + for (int j = 0; j < unroll_factor<data_t>::n; j++) { PRAGMA_OMP_SIMD() - for (int i = 0; i < unroll_m; i++) { - C[i + j * ldc] = (beta == 0.0f) - ? alpha * c[i + unroll_m * j] - : alpha * c[i + unroll_m * j] + beta * C[i + j * ldc]; + for (int i = 0; i < unroll_factor<data_t>::m; i++) { + C[i + j * ldc] = (beta == static_cast<data_t>(0.)) + ? alpha * c[i + unroll_factor<data_t>::m * j] + : alpha * c[i + unroll_factor<data_t>::m * j] + + beta * C[i + j * ldc]; } } } -template <bool isTransA, bool isTransB> +template <typename data_t, bool isTransA, bool isTransB> static void block_ker(const int M, const int N, const int K, - const float *A, const int lda, const float *B, const int ldb, float *C, - const int ldc, const float alpha, const float beta, float *ws, - bool do_copy) { - int Nu = rnd_dn(N, unroll_n), Mu = rnd_dn(M, unroll_m); - for (int i = 0; i < Mu; i += unroll_m) { - for (int j = 0; j < Nu; j += unroll_n) { - const float *b = isTransB ? &B[j] : &B[j * ldb]; - const float *a = isTransA ? &A[i * lda] : &A[i]; + const data_t *A, const int lda, const data_t *B, const int ldb, + data_t *C, const int ldc, const data_t alpha, const data_t beta, + data_t *ws, bool do_copy) { + int Nu = rnd_dn(N, unroll_factor<data_t>::n); + int Mu = rnd_dn(M, unroll_factor<data_t>::m); + for (int i = 0; i < Mu; i += unroll_factor<data_t>::m) { + for (int j = 0; j < Nu; j += unroll_factor<data_t>::n) { + const data_t *b = isTransB ? &B[j] : &B[j * ldb]; + const data_t *a = isTransA ? &A[i * lda] : &A[i]; if (do_copy) { if (j == 0) { - copy_A(isTransA, K, a, lda, ws); + copy_A<data_t>(isTransA, K, a, lda, ws); } - kernel_mxn<false, isTransB>( - K, ws, unroll_m, b, ldb, &C[i + j * ldc], ldc, alpha, beta); + kernel_mxn<data_t, false, isTransB>( + K, ws, unroll_factor<data_t>::m, b, ldb, + &C[i + j * ldc], ldc, alpha, beta); } else { - kernel_mxn<isTransA, isTransB>( + kernel_mxn<data_t, isTransA, isTransB>( K, a, lda, b, ldb, &C[i + j * ldc], ldc, alpha, beta); } } @@ -91,10 +96,12 @@ static void block_ker(const int M, const int N, const int K, // tail processing for (int i = 0; i < M; i++) { for (int j = Nu; j < N; j++) { - float c = beta == 0.0f ? 0.0f : beta * C[i + j * ldc]; + data_t c = beta == static_cast<data_t>(0.) + ? static_cast<data_t>(0.) + : beta * C[i + j * ldc]; for (int p = 0; p < K; p++) { - float b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; - float a = isTransA ? A[p + i * lda] : A[i + p * lda]; + data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; + data_t a = isTransA ? A[p + i * lda] : A[i + p * lda]; c += alpha * a * b; } C[i + j * ldc] = c; @@ -102,10 +109,12 @@ static void block_ker(const int M, const int N, const int K, } for (int i = Mu; i < M; i++) { for (int j = 0; j < Nu; j++) { - float c = beta == 0.0f ? 0.0f : beta * C[i + j * ldc]; + data_t c = beta == static_cast<data_t>(0.) + ? static_cast<data_t>(0.) + : beta * C[i + j * ldc]; for (int p = 0; p < K; p++) { - float b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; - float a = isTransA ? A[p + i * lda] : A[i + p * lda]; + data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; + data_t a = isTransA ? A[p + i * lda] : A[i + p * lda]; c += alpha * a * b; } C[i + j * ldc] = c; @@ -113,25 +122,28 @@ static void block_ker(const int M, const int N, const int K, } } -template <bool isTransA, bool isTransB> -void gemm_ithr(const int M, const int N, const int K, const float alpha, - const float *A, const int lda, const float *B, const int ldb, - const float beta, float *C, const int ldc, bool do_copy, float *ws) { - int BM = 4032; - int BN = isTransA ? 96 : 48; - int BK = isTransB ? 96 : 256; - const float *curA, *curB; - float *curC; +template <typename data_t, bool isTransA, bool isTransB> +void gemm_ithr(const int M, const int N, const int K, const data_t alpha, + const data_t *A, const int lda, const data_t *B, const int ldb, + const data_t beta, data_t *C, const int ldc, bool do_copy, data_t *ws) { + constexpr int BM = gemm_traits<data_t, isTransA, isTransB>::BM; + constexpr int BN = gemm_traits<data_t, isTransA, isTransB>::BN; + constexpr int BK = gemm_traits<data_t, isTransA, isTransB>::BK; + + const data_t *curA; + const data_t *curB; + data_t *curC; if ((M <= 0) || (N <= 0)) return; - if ((K <= 0) || (alpha == 0.0f)) { - if (beta == 0.0f) { - for (int j = 0; j < N * M; j++) - C[j] = 0.0f; - } else if (beta != 1.0f) { - for (int j = 0; j < N * M; j++) + if ((K <= 0) || (alpha == static_cast<data_t>(0))) { + ptrdiff_t MN = (ptrdiff_t)N * M; + if (beta == static_cast<data_t>(0.)) { + for (ptrdiff_t j = 0; j < MN; j++) + C[j] = static_cast<data_t>(0.); + } else if (beta != static_cast<data_t>(1.)) { + for (ptrdiff_t j = 0; j < MN; j++) C[j] *= beta; } return; @@ -147,25 +159,27 @@ void gemm_ithr(const int M, const int N, const int K, const float alpha, curB = isTransB ? B + Bn + Bk * ldb : B + Bk + Bn * ldb; curC = C + Bm + Bn * ldc; if (Bk == 0) { - block_ker<isTransA, isTransB>(mb, nb, kb, curA, lda, curB, - ldb, curC, ldc, alpha, beta, ws, do_copy); + block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda, + curB, ldb, curC, ldc, alpha, beta, ws, do_copy); } else { - block_ker<isTransA, isTransB>(mb, nb, kb, curA, lda, curB, - ldb, curC, ldc, alpha, 1.0f, ws, do_copy); + block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda, + curB, ldb, curC, ldc, alpha, static_cast<data_t>(1.0), + ws, do_copy); } } } } } +template <typename data_t> void ref_gemm(const char *transa_, const char *transb_, const int *M_, - const int *N_, const int *K_, const float *alpha_, const float *A, - const int *lda_, const float *B, const int *ldb_, const float *beta_, - float *C, const int *ldc_, const float *bias) { + const int *N_, const int *K_, const data_t *alpha_, const data_t *A, + const int *lda_, const data_t *B, const int *ldb_, const data_t *beta_, + data_t *C, const int *ldc_, const data_t *bias) { bool isTransA = (*transa_ == 'T' || *transa_ == 't'); bool isTransB = (*transb_ == 'T' || *transb_ == 't'); const int M = *M_, N = *N_, K = *K_, lda = *lda_, ldb = *ldb_, ldc = *ldc_; - const float alpha = *alpha_, beta = *beta_; + const data_t alpha = *alpha_, beta = *beta_; int max_nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads(); int nthr_m, nthr_n, nthr_k; @@ -173,26 +187,27 @@ void ref_gemm(const char *transa_, const char *transb_, const int *M_, // thread balancing over M, N, K & size of blocking dimensions gemm_utils::calc_nthr_nocopy_avx( M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); - assert(utils::implication(!mkldnn_thr_syncable(), nthr_k == 1)); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); - float *c_buffers = nullptr, *ws_buffers = nullptr; + data_t *c_buffers = nullptr; + data_t *ws_buffers = nullptr; if (nthr_k > 1) { - c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB - * sizeof(float), PAGE_4K); + c_buffers = (data_t *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB + * sizeof(data_t), PAGE_4K); if (!c_buffers) { nthr_k = 1; KB = K; } } - bool do_copy = (NB / unroll_n > 3); + bool do_copy = (NB / unroll_factor<data_t>::n > 3); const int nthr_mn = nthr_m * nthr_n; const int nthr = nthr_mn * nthr_k; - const size_t ws_elems_per_thr = K * unroll_m; + const size_t ws_elems_per_thr = K * unroll_factor<data_t>::m; const size_t ws_size_per_thr - = utils::rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K); + = utils::rnd_up(ws_elems_per_thr * sizeof(data_t), PAGE_4K); if (do_copy) { - ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K); + ws_buffers = (data_t*)malloc(nthr * ws_size_per_thr, PAGE_4K); if (!ws_buffers) do_copy = false; } @@ -205,8 +220,8 @@ void ref_gemm(const char *transa_, const char *transb_, const int *M_, int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); - float *ws = do_copy - ? ws_buffers + ithr * ws_size_per_thr / sizeof(float) + data_t *ws = do_copy + ? ws_buffers + ithr * ws_size_per_thr / sizeof(data_t) : nullptr; int m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0, @@ -224,7 +239,7 @@ void ref_gemm(const char *transa_, const char *transb_, const int *M_, get_thr_block(k_from, k_to, myK, KB, K, ithr_k); if (myM > 0 && myN > 0) { - float myBeta, *myC; + data_t myBeta, *myC; int ld; if (ithr_k == 0) { myC = &(C[m_from + n_from * ldc]); @@ -235,28 +250,28 @@ void ref_gemm(const char *transa_, const char *transb_, const int *M_, myBeta = 0.0f; ld = MB; } - const float *myA = isTransA + const data_t *myA = isTransA ? &(A[k_from + m_from * lda]) : &(A[m_from + k_from * lda]); - const float *myB = isTransB + const data_t *myB = isTransB ? &(B[n_from + k_from * ldb]) : &(B[k_from + n_from * ldb]); if (!isTransA) { if (!isTransB) { - gemm_ithr<false, false>(myM, myN, myK, alpha, myA, lda, myB, - ldb, myBeta, myC, ld, do_copy, ws); + gemm_ithr<data_t, false, false>(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); } else { - gemm_ithr<false, true>(myM, myN, myK, alpha, myA, lda, myB, - ldb, myBeta, myC, ld, do_copy, ws); + gemm_ithr<data_t, false, true>(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); } } else { if (!isTransB) { - gemm_ithr<true, false>(myM, myN, myK, alpha, myA, lda, myB, - ldb, myBeta, myC, ld, do_copy, ws); + gemm_ithr<data_t, true, false>(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); } else { - gemm_ithr<true, true>(myM, myN, myK, alpha, myA, lda, myB, - ldb, myBeta, myC, ld, do_copy, ws); + gemm_ithr<data_t, true, true>(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); } } } @@ -270,7 +285,8 @@ void ref_gemm(const char *transa_, const char *transb_, const int *M_, gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &offset, &block); for (int ik = 1; ik < nthr_k; ++ik) { - float *myC = c_buffers + MB * (NB * (cbase + ik - 1) + offset); + data_t *myC = c_buffers + MB * (NB * (cbase + ik - 1) + offset); + gemm_utils::sum_two_matrices(myM, block, myC, MB, &C[m_from + (n_from + offset) * ldc], ldc); } @@ -286,6 +302,16 @@ void ref_gemm(const char *transa_, const char *transb_, const int *M_, free(ws_buffers); free(c_buffers); } + +template void ref_gemm<float>(const char *transa_, const char *transb_, + const int *M_, const int *N_, const int *K_, const float *alpha_, + const float *A, const int *lda_, const float *B, const int *ldb_, + const float *beta_, float *C, const int *ldc_, const float *bias); + +template void ref_gemm<double>(const char *transa_, const char *transb_, + const int *M_, const int *N_, const int *K_, const double *alpha_, + const double *A, const int *lda_, const double *B, const int *ldb_, + const double *beta_, double *C, const int *ldc_, const double *bias); } } } |