diff options
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp')
-rw-r--r-- | inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp | 314 |
1 files changed, 161 insertions, 153 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp index 551fef401..ae171de98 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp @@ -67,12 +67,7 @@ void _jit_avx512_common_1x1_convolution_fwd_t auto bias = reinterpret_cast<const dst_data_t *>(this->input_memory(2)); auto dst = reinterpret_cast<dst_data_t *>(this->memory()); - const memory_desc_wrapper src_d(conf_.src_pd()); - const memory_desc_wrapper dst_d(conf_.dst_pd()); - const memory_desc_wrapper weights_d(conf_.weights_pd(0)); - - const auto &jcp = kernel_->jcp; - const int MB = conf_.MB(); + auto &jcp = kernel_->jcp; if (conf_.want_padded_bias()) { assert(jcp.ngroups == 1); for (int oc = 0; oc < jcp.oc_without_padding; ++oc) @@ -80,126 +75,128 @@ void _jit_avx512_common_1x1_convolution_fwd_t bias = padded_bias_; } - const int work_amount = MB * jcp.ngroups * jcp.nb_bcast; + parallel(0, [&](const int ithr, const int nthr) { + execute_forward_thr(ithr, nthr, src, weights, bias, dst); + }); +} + +template <bool with_relu, data_type_t src_type, data_type_t wei_type, + data_type_t dst_type> +void _jit_avx512_common_1x1_convolution_fwd_t + <with_relu, src_type, wei_type, dst_type>::execute_forward_thr( + const int ithr, const int nthr, + const src_data_t *src, const wei_data_t *weights, + const dst_data_t *bias, dst_data_t *dst) +{ + const memory_desc_wrapper src_d(conf_.src_pd()); + const memory_desc_wrapper dst_d(conf_.dst_pd()); + const memory_desc_wrapper weights_d(conf_.weights_pd(0)); const int stride_h = conf_.cdesc()->strides[0]; const int stride_w = conf_.cdesc()->strides[1]; const int pad_t = conf_.cdesc()->padding[0][0]; const int pad_l = conf_.cdesc()->padding[0][1]; + auto &jcp = kernel_->jcp; + const int MB = conf_.MB(); + const int work_amount = MB * jcp.ngroups * jcp.nb_bcast; + auto step = [](int default_step, int remaining, int tail_step) { assert(default_step <= tail_step); return remaining < tail_step ? remaining : default_step; }; -# pragma omp parallel - { - int ithr = omp_get_thread_num(), nthr = omp_get_num_threads(); - - auto p = jit_1x1_conv_call_s(); + auto p = jit_1x1_conv_call_s(); - auto rp = rtus_driver_t<avx512_common>::call_params_t(); + auto rp = rtus_driver_t<avx512_common>::call_params_t(); - const int nb_oc = jcp.nb_load; - const int nb_ic = jcp.nb_reduce; - const int nb_ic_blocking = jcp.nb_reduce_blocking; - const int os_block = jcp.bcast_block; + const int nb_oc = jcp.nb_load; + const int nb_ic = jcp.nb_reduce; + const int nb_ic_blocking = jcp.nb_reduce_blocking; + const int os_block = jcp.bcast_block; - int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0}; - balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, - jcp.nb_load, ocb_start, ocb_end, jcp.load_grp_count); + int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0}; + balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, + jcp.nb_load, ocb_start, ocb_end, jcp.load_grp_count); - auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step, + auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step, int &oh, int &ow, int &ih, int &iw) - { - int osb{0}; - nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb, - jcp.nb_bcast); - bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, - jcp.nb_bcast_blocking_max); - bcast_step = nstl::min(bcast_step, bcast_end - iwork); - - const int os = osb * os_block; - oh = os / jcp.ow; - ow = os % jcp.ow; - - ih = nstl::max(oh * stride_h - pad_t, 0); - iw = nstl::max(ow * stride_w - pad_l, 0); - rp.iw_start = iw; - - p.bcast_dim = this_block_size(os, jcp.os, - bcast_step * os_block); - rp.os = p.bcast_dim; - }; - - auto init_load = [&](int ocb, int &load_step) - { - load_step = step(jcp.nb_load_blocking, ocb_end - ocb, - jcp.nb_load_blocking_max); - p.load_dim = this_block_size(ocb * jcp.oc_block, - ocb_end * jcp.oc_block, load_step * jcp.oc_block); - }; - - auto init_reduce = [&](int icb) - { - const int nb_ic_blocking_step = - nstl::min(icb + nb_ic_blocking, nb_ic) - icb; - p.reduce_pos_flag = 0 - | (icb == 0 ? FLAG_REDUCE_FIRST : 0) - | (icb + nb_ic_blocking_step >= nb_ic - ? FLAG_REDUCE_LAST : 0); - - p.reduce_dim = this_block_size(icb * jcp.ic_block, - jcp.ic, nb_ic_blocking_step * jcp.ic_block); - rp.icb = p.reduce_dim / jcp.reduce_block; - }; - - auto inner_ker = [&](int ocb, int icb, int n, int g, int oh, int ow, - int ih, int iw) - { - - const int _ocb = g * nb_oc + ocb; - const size_t dst_off = dst_d.blk_off(n, _ocb, oh, ow); - - p.output_data = &dst[dst_off]; - p.bias_data = &bias[_ocb * jcp.oc_block]; - p.load_data = &weights[conf_.with_groups() - ? weights_d.blk_off(g, ocb, icb) - : weights_d.blk_off(ocb, icb)]; - - const int _icb = g * nb_ic + icb; - if (conf_.rtus_.reduce_src_) { - rp.ws = scratch_ + ithr * ws_per_thread_ - + _icb * jcp.is * jcp.ic_block; - if (ocb == ocb_start) { - rp.src = src + src_d.blk_off(n, _icb, ih, iw); - rtus_driver_->ker_(&rp); - } - p.bcast_data = rp.ws; - } else - p.bcast_data = src + src_d.blk_off(n, _icb, ih, iw); + { + int osb{0}; + nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb, + jcp.nb_bcast); + bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, bcast_end - iwork); + + const int os = osb * os_block; + oh = os / jcp.ow; + ow = os % jcp.ow; + + ih = nstl::max(oh * stride_h - pad_t, 0); + iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + p.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + rp.os = p.bcast_dim; + }; - kernel_->jit_ker(&p); - }; + auto init_load = [&](int ocb, int &load_step) + { + load_step = step(jcp.nb_load_blocking, ocb_end - ocb, + jcp.nb_load_blocking_max); + p.load_dim = this_block_size(ocb * jcp.oc_block, + ocb_end * jcp.oc_block, load_step * jcp.oc_block); + }; - if (jcp.loop_order == loop_rlb) { - for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { - init_reduce(icb); - int ocb = ocb_start; - while (ocb < ocb_end) { - int load_step; - init_load(ocb, load_step); - int iwork = bcast_start; - while (iwork < bcast_end) { - int n, g, bcast_step, oh, ow, ih, iw; - init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); - inner_ker(ocb, icb, n, g, oh, ow, ih, iw); - iwork += bcast_step; - } - ocb += load_step; - } + auto init_reduce = [&](int icb) + { + const int nb_ic_blocking_step = + nstl::min(icb + nb_ic_blocking, nb_ic) - icb; + p.first_last_flag = 0 + | (icb == 0 ? FLAG_REDUCE_FIRST : 0) + | (icb + nb_ic_blocking_step >= nb_ic + ? FLAG_REDUCE_LAST : 0); + + p.reduce_dim = this_block_size(icb * jcp.ic_block, + jcp.ic, nb_ic_blocking_step * jcp.ic_block); + rp.icb = p.reduce_dim / jcp.reduce_block; + }; + + auto inner_ker = [&](int ocb, int icb, int n, int g, int oh, int ow, + int ih, int iw) + { + + const int _ocb = g * nb_oc + ocb; + const size_t dst_off = dst_d.blk_off(n, _ocb, oh, ow); + + p.output_data = &dst[dst_off]; + p.bias_data = &bias[_ocb * jcp.oc_block]; + p.load_data = &weights[conf_.with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + const int _icb = g * nb_ic + icb; + if (conf_.rtus_.reduce_src_) { + rp.ws = scratch_ + ithr * ws_per_thread_ + + _icb * jcp.is * jcp.ic_block; + if (ocb == ocb_start) { + rp.src = src + src_d.blk_off(n, _icb, ih, iw); + rtus_driver_->ker_(&rp); } - } else if (jcp.loop_order == loop_lbr) { + p.bcast_data = rp.ws; + } else + p.bcast_data = src + src_d.blk_off(n, _icb, ih, iw); + + p.oc_off = _ocb * jcp.oc_block * sizeof(dst_data_t); + + kernel_->jit_ker(&p); + }; + + if (jcp.loop_order == loop_rlb) { + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); int ocb = ocb_start; while (ocb < ocb_end) { int load_step; @@ -208,32 +205,32 @@ void _jit_avx512_common_1x1_convolution_fwd_t while (iwork < bcast_end) { int n, g, bcast_step, oh, ow, ih, iw; init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); - for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { - init_reduce(icb); - inner_ker(ocb, icb, n, g, oh, ow, ih, iw); - } + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); iwork += bcast_step; } ocb += load_step; } - } else if (jcp.loop_order == loop_rbl) { - for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { - init_reduce(icb); - int iwork = bcast_start; - while (iwork < bcast_end) { - int n, g, bcast_step, oh, ow, ih, iw; - init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); - int ocb = ocb_start; - while (ocb < ocb_end) { - int load_step; - init_load(ocb, load_step); - inner_ker(ocb, icb, n, g, oh, ow, ih, iw); - ocb += load_step; - } - iwork += bcast_step; + } + } else if (jcp.loop_order == loop_lbr) { + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); } + iwork += bcast_step; } - } else if (jcp.loop_order == loop_blr) { + ocb += load_step; + } + } else if (jcp.loop_order == loop_rbl) { + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); int iwork = bcast_start; while (iwork < bcast_end) { int n, g, bcast_step, oh, ow, ih, iw; @@ -242,20 +239,35 @@ void _jit_avx512_common_1x1_convolution_fwd_t while (ocb < ocb_end) { int load_step; init_load(ocb, load_step); - for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { - init_reduce(icb); - inner_ker(ocb, icb, n, g, oh, ow, ih, iw); - } + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); ocb += load_step; } iwork += bcast_step; } - } else { - assert(!"unsupported loop order"); } + } else if (jcp.loop_order == loop_blr) { + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); + } + ocb += load_step; + } + iwork += bcast_step; + } + } else { + assert(!"unsupported loop order"); } } + template struct _jit_avx512_common_1x1_convolution_fwd_t<true, data_type::f32>; template struct _jit_avx512_common_1x1_convolution_fwd_t<false, data_type::f32>; template struct _jit_avx512_common_1x1_convolution_fwd_t<false, data_type::s16, @@ -302,10 +314,7 @@ void _jit_avx512_common_1x1_convolution_bwd_data_t return remaining < tail_step ? remaining : default_step; }; -# pragma omp parallel - { - int ithr = omp_get_thread_num(), nthr = omp_get_num_threads(); - + parallel(0, [&](const int ithr, const int nthr) { auto p = jit_1x1_conv_call_s(); auto rp = rtus_driver_t<avx512_common>::call_params_t(); @@ -385,7 +394,7 @@ void _jit_avx512_common_1x1_convolution_bwd_data_t ? weights_d.blk_off(g, ocb, icb) : weights_d.blk_off(ocb, icb)]; - p.reduce_pos_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0; + p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0; p.reduce_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, nb_oc_blocking_step * jcp.oc_block); @@ -397,7 +406,7 @@ void _jit_avx512_common_1x1_convolution_bwd_data_t } } } - } + }); } template struct _jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>; @@ -424,11 +433,6 @@ jit_avx512_common_1x1_convolution_bwd_weights_t :: const auto &jcp = kernel_->jcp; - bctx_ = (simple_barrier::ctx_t *)malloc( - jcp.nthr * sizeof(simple_barrier::ctx_t), 64); - for (int i = 0; i < jcp.nthr; ++i) - simple_barrier::ctx_init(&bctx_[i]); - const int wei_size = jcp.ngroups * jcp.oc * jcp.ic; ws_reduction_ = (data_t *)malloc((jcp.nthr_mb - 1) * wei_size * sizeof(data_t), 64); @@ -450,15 +454,18 @@ jit_avx512_common_1x1_convolution_bwd_weights_t :: const ptrdiff_t tr_src_size = (ptrdiff_t)jcp.nthr_mb * (ptrdiff_t)jcp.ngroups * (ptrdiff_t)jcp.ic * jcp.tr_is; tr_src_ = (data_t *)malloc(tr_src_size * sizeof(data_t), 64); -# pragma omp parallel for - for (ptrdiff_t i = 0; i < tr_src_size; i++) - tr_src_[i] = 0; + parallel_nd(tr_src_size, [&](ptrdiff_t i) { tr_src_[i] = 0; }); auto tp = jit_transpose4x16_src_t(); tp.src_pf0_distance = 4; tp.tr_src_pf0_distance = 0; tp.src_pf1 = true; tp.tr_src_pf1 = false; trans_kernel_ = new jit_transpose4x16_src(&jcp, &tp); + + bctx_ = (simple_barrier::ctx_t *)malloc( + jcp.nthr * sizeof(simple_barrier::ctx_t), 64); + for (int i = 0; i < jcp.nthr; ++i) + simple_barrier::ctx_init(&bctx_[i]); } init_rtus_driver<avx512_common>(this); @@ -557,6 +564,9 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights() }; auto ker = [&](const int ithr, const int nthr) { + assert(nthr == jcp.nthr); + assert(utils::implication(!mkldnn_thr_syncable(), jcp.nthr_mb == 1)); + const int ithr_ic_b = ithr % jcp.nthr_ic_b; const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b; const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g; @@ -666,7 +676,7 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights() p.reduce_dim = sp_b_step * jcp.reduce_block; rp.os = p.reduce_dim; - p.reduce_pos_flag = 0 + p.first_last_flag = 0 | (mb_sp_b == mb_sp_b_start ? FLAG_REDUCE_FIRST : 0) | (sp_b_end == sp_nb ? FLAG_SP_LAST : 0); @@ -737,6 +747,7 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights() } } }; + auto ker_bias = [&](int ithr, int nthr) { auto rb = this->reducer_bias_; assert(nthr == rb->balancer_.nthr_); @@ -784,14 +795,11 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights() rb->reduce(ithr, diff_bias); }; -#pragma omp parallel num_threads(jcp.nthr) - { - int ithr = omp_get_thread_num(); - assert(jcp.nthr == omp_get_num_threads()); + parallel(jcp.nthr, [&](const int ithr, const int nthr) { ker(ithr, jcp.nthr); if (conf_.with_bias()) ker_bias(ithr, jcp.nthr); - } + }); /* TODO: put this in ker_bias */ if (conf_.want_padded_bias()) { |