diff options
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp')
-rw-r--r-- | inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp | 82 |
1 files changed, 44 insertions, 38 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp index 33f293013..fc7a4e363 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp @@ -41,7 +41,7 @@ using jit_conv_ker_t = void (*)(jit_conv_call_s *); inline void jit_conv_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p, const void *src, const void *dst, const void *filt, const void *bias, - int channel, int kh_padding) + int channel, int kh_padding, int oc_off) { PIPELINE(src); PIPELINE(dst); @@ -49,6 +49,7 @@ inline void jit_conv_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p, PIPELINE(bias); PIPELINE(channel); PIPELINE(kh_padding); + PIPELINE(oc_off); if (p.src) ker(&p); @@ -100,7 +101,7 @@ void _jit_avx512_common_convolution_fwd_t if (jcp.aligned_threads) nthr = jcp.aligned_threads; else - nthr = omp_get_max_threads(); + nthr = mkldnn_get_max_threads(); if (conf_.want_padded_bias()) { for (int oc = 0; oc < jcp.oc_without_padding; ++oc) @@ -108,10 +109,7 @@ void _jit_avx512_common_convolution_fwd_t bias = padded_bias_; } -#pragma omp parallel num_threads(nthr) - { - int ithr = omp_get_thread_num(); - + parallel(nthr, [&](const int ithr, const int nthr) { int start, end, start_copy; balance211(work_amount, nthr, ithr, start, end); start_copy = start; @@ -151,6 +149,8 @@ void _jit_avx512_common_convolution_fwd_t auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, ih_s); auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2); + int oc_off = g_oc * sizeof(dst_data_t); + for (int icb = icb_l2; icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) { auto src_c = src_w; @@ -170,7 +170,7 @@ void _jit_avx512_common_convolution_fwd_t jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, src_c + i_t_overflow * dilate_h * src_h_stride, dst_c, wht_w + i_t_overflow * wht_h_stride, - bias_w, icb, kh_padding); + bias_w, icb, kh_padding, oc_off); src_c += src_h_stride * jcp.stride_h; dst_c += dst_h_stride; @@ -191,8 +191,8 @@ void _jit_avx512_common_convolution_fwd_t } jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, - src, dst, weights, bias, 0, 0); - } + src, dst, weights, bias, 0, 0, 0); + }); } template <bool with_relu, data_type_t src_type, data_type_t wei_type, @@ -214,10 +214,13 @@ void _jit_avx512_common_convolution_fwd_t const int MB = conf_.MB(); assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); -# pragma omp parallel - { - int ithr = omp_get_thread_num(), nthr = omp_get_num_threads(); + if (conf_.want_padded_bias()) { + for (int oc = 0; oc < jcp.oc_without_padding; ++oc) + padded_bias_[oc] = bias[oc]; + bias = padded_bias_; + } + parallel(0, [&](const int ithr, const int nthr) { int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; int start, end, start_copy; int work_amount = MB * jcp.ngroups * oc_chunks * jcp.od * jcp.oh; @@ -315,7 +318,7 @@ void _jit_avx512_common_convolution_fwd_t } jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv, src, dst, weights, bias, 0, 0, 0); - } + }); } template struct _jit_avx512_common_convolution_fwd_t<false, data_type::f32>; @@ -341,10 +344,7 @@ void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type, const auto &jcp = kernel_->jcp; const int MB = conf_.MB(); -# pragma omp parallel - { - int ithr = omp_get_thread_num(), nthr = omp_get_num_threads(); - + parallel(0, [&](const int ithr, const int nthr) { int start, end, start_copy; int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; int work_amount = jcp.ngroups * MB * ic_chunks * jcp.ih; @@ -431,7 +431,7 @@ void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type, diff_src_w + ij * diff_src_h_stride, diff_dst_w + oj * diff_dst_h_stride, wht_w + k_lo * wht_h_stride, - 0, ocb, k_len); + 0, ocb, k_len, 0); } diff_dst_w += diff_dst_c_stride; wht_w += wht_oc_stride; @@ -449,8 +449,8 @@ void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type, } jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, - diff_src, diff_dst, weights, 0, 0, 1); - } + diff_src, diff_dst, weights, 0, 0, 1, 0); + }); } template <data_type_t diff_dst_type, data_type_t wei_type, @@ -469,10 +469,7 @@ void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type, const auto &jcp = kernel_->jcp; const int MB = conf_.MB(); -# pragma omp parallel - { - int ithr = omp_get_thread_num(), nthr = omp_get_num_threads(); - + parallel(0, [&](const int ithr, const int nthr) { int start, end, start_copy; int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; int work_amount = jcp.ngroups * MB * ic_chunks * jcp.id * jcp.ih; @@ -625,7 +622,7 @@ void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type, jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv, diff_src, diff_dst, weights, 0, 0, 1, 1); - } + }); } template struct jit_avx512_common_convolution_bwd_data_t<data_type::f32>; @@ -964,6 +961,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, tr_ctx.tr_src = tr_src_ + ti->ithr_but_oc * jcp.ih * jcp.stride_w * jcp.tr_ld; + assert(utils::implication(!mkldnn_thr_syncable(), nthr_oc_b_ == 1)); tr_ctx.nthr_oc_b = nthr_oc_b_; int ih_start{0}, ih_end{0}; balance211(jcp.ih, nthr_oc_b_, ti->ithr_oc_b, ih_start, ih_end); @@ -1051,7 +1049,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, ? &tr_diff_dst_[tr_diff_dst_off(ti->ithr_mb, _oc, 0)] : &ti->diff_dst[diff_dst_d.blk_off(img, _oc)], diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b), - 0, (img == ti->img_start), 0); + 0, (img == ti->img_start), 0, 0); } } @@ -1069,7 +1067,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, diff_wei + wht_blk_off( diff_weights_d, ti->g_start, ti->oc_b_start, ti->ic_b_start), - 0, 0, 0); + 0, 0, 0, 0); } } } @@ -1093,7 +1091,6 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, : (diff_weights_data_t*)ws_reduction_ + (nthr_mb_ - 1) * wei_size + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc; - const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; const int input_step = jcp.stride_d * jcp.ih * jcp.iw * inp_mult; const int output_step = jcp.ow * jcp.oh * jcp.oc_block; @@ -1110,7 +1107,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, const int id_s = od_s * jcp.stride_d; const int idp = jcp.id + jcp.f_pad + jcp.back_pad; - if (id_s < idp - jcp.back_pad - jcp.kd + 1) { + if (id_s < idp - jcp.back_pad - (jcp.kd - 1) * (jcp.dilate_d)) { for (int g = ti->g_start; g < ti->g_end; ++g) { for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) { for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) { @@ -1320,7 +1317,8 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, const diff_weights_data_t *diff_bias_ws = ws_reduction_ + (size_t)(nthr_mb_ - 1) * wei_size; - #pragma omp barrier + if (nthr_mb_ > 1) mkldnn_thr_barrier(); + if (ti->ithr == 0) { for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { @@ -1330,15 +1328,12 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, } } - template <data_type_t src_type, data_type_t diff_dst_type, data_type_t diff_weights_type> void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, diff_weights_type>::execute_backward_weights() { -# pragma omp parallel num_threads(nthr_) - { - int ithr = omp_get_thread_num(); - assert(nthr_ == omp_get_num_threads()); + parallel(nthr_, [&](const int ithr, const int nthr) { + assert(nthr_ == nthr); thread_info_t thread_info(this, ithr); @@ -1351,7 +1346,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, if (nthr_mb_ > 1) reduce_diff_weights_3d(&thread_info); if (conf_.with_bias()) compute_diff_bias_3d(&thread_info); } - } + }); /* TODO: put that into compute_diff_bias() */ if (conf_.want_padded_bias()) { @@ -1366,7 +1361,7 @@ template <data_type_t src_type, data_type_t diff_dst_type, data_type_t diff_weights_type> void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, diff_weights_type>::balance() { - const int max_threads = omp_get_max_threads(); + const int max_threads = mkldnn_get_max_threads(); const auto &j = conf_.jcp_; nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1; @@ -1376,6 +1371,13 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, return; } + if (!mkldnn_thr_syncable() + && utils::one_of(j.ver, ver_4fma, ver_4vnni, ver_vnni)) { + // should not happen -- the driver is not ready + // for TBB-like non-synchronous threading yet + return; + } + if (j.ver == ver_4fma && j.is_1stconv) { nthr_g_ = 1; nthr_oc_b_ = 1; @@ -1424,7 +1426,6 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, const int nthr_par = nthr / nthr_mb; const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc); for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { - int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic); int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); @@ -1435,6 +1436,8 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, nthr_ic_b_ = nthr_ic_b; } } + + if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } } if (j.ver != ver_vnni && !mayiuse(avx512_mic)) { @@ -1471,6 +1474,8 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, nthr_ic_b_ = nthr_ic_b; } } + + if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } } } @@ -1478,6 +1483,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, nthr_mb_ = min(j.mb * j.od, max_threads); nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_; assert(nthr_ <= max_threads); + assert(utils::implication(!mkldnn_thr_syncable(), nthr_mb_ == 1)); } template struct jit_avx512_common_convolution_bwd_weights_t<data_type::f32>; |