summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp')
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp314
1 files changed, 161 insertions, 153 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp
index 551fef401..ae171de98 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp
@@ -67,12 +67,7 @@ void _jit_avx512_common_1x1_convolution_fwd_t
auto bias = reinterpret_cast<const dst_data_t *>(this->input_memory(2));
auto dst = reinterpret_cast<dst_data_t *>(this->memory());
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper dst_d(conf_.dst_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
-
- const auto &jcp = kernel_->jcp;
- const int MB = conf_.MB();
+ auto &jcp = kernel_->jcp;
if (conf_.want_padded_bias()) {
assert(jcp.ngroups == 1);
for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
@@ -80,126 +75,128 @@ void _jit_avx512_common_1x1_convolution_fwd_t
bias = padded_bias_;
}
- const int work_amount = MB * jcp.ngroups * jcp.nb_bcast;
+ parallel(0, [&](const int ithr, const int nthr) {
+ execute_forward_thr(ithr, nthr, src, weights, bias, dst);
+ });
+}
+
+template <bool with_relu, data_type_t src_type, data_type_t wei_type,
+ data_type_t dst_type>
+void _jit_avx512_common_1x1_convolution_fwd_t
+ <with_relu, src_type, wei_type, dst_type>::execute_forward_thr(
+ const int ithr, const int nthr,
+ const src_data_t *src, const wei_data_t *weights,
+ const dst_data_t *bias, dst_data_t *dst)
+{
+ const memory_desc_wrapper src_d(conf_.src_pd());
+ const memory_desc_wrapper dst_d(conf_.dst_pd());
+ const memory_desc_wrapper weights_d(conf_.weights_pd(0));
const int stride_h = conf_.cdesc()->strides[0];
const int stride_w = conf_.cdesc()->strides[1];
const int pad_t = conf_.cdesc()->padding[0][0];
const int pad_l = conf_.cdesc()->padding[0][1];
+ auto &jcp = kernel_->jcp;
+ const int MB = conf_.MB();
+ const int work_amount = MB * jcp.ngroups * jcp.nb_bcast;
+
auto step = [](int default_step, int remaining, int tail_step) {
assert(default_step <= tail_step);
return remaining < tail_step ? remaining : default_step;
};
-# pragma omp parallel
- {
- int ithr = omp_get_thread_num(), nthr = omp_get_num_threads();
-
- auto p = jit_1x1_conv_call_s();
+ auto p = jit_1x1_conv_call_s();
- auto rp = rtus_driver_t<avx512_common>::call_params_t();
+ auto rp = rtus_driver_t<avx512_common>::call_params_t();
- const int nb_oc = jcp.nb_load;
- const int nb_ic = jcp.nb_reduce;
- const int nb_ic_blocking = jcp.nb_reduce_blocking;
- const int os_block = jcp.bcast_block;
+ const int nb_oc = jcp.nb_load;
+ const int nb_ic = jcp.nb_reduce;
+ const int nb_ic_blocking = jcp.nb_reduce_blocking;
+ const int os_block = jcp.bcast_block;
- int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0};
- balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
- jcp.nb_load, ocb_start, ocb_end, jcp.load_grp_count);
+ int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0};
+ balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
+ jcp.nb_load, ocb_start, ocb_end, jcp.load_grp_count);
- auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step,
+ auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step,
int &oh, int &ow, int &ih, int &iw)
- {
- int osb{0};
- nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb,
- jcp.nb_bcast);
- bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
- jcp.nb_bcast_blocking_max);
- bcast_step = nstl::min(bcast_step, bcast_end - iwork);
-
- const int os = osb * os_block;
- oh = os / jcp.ow;
- ow = os % jcp.ow;
-
- ih = nstl::max(oh * stride_h - pad_t, 0);
- iw = nstl::max(ow * stride_w - pad_l, 0);
- rp.iw_start = iw;
-
- p.bcast_dim = this_block_size(os, jcp.os,
- bcast_step * os_block);
- rp.os = p.bcast_dim;
- };
-
- auto init_load = [&](int ocb, int &load_step)
- {
- load_step = step(jcp.nb_load_blocking, ocb_end - ocb,
- jcp.nb_load_blocking_max);
- p.load_dim = this_block_size(ocb * jcp.oc_block,
- ocb_end * jcp.oc_block, load_step * jcp.oc_block);
- };
-
- auto init_reduce = [&](int icb)
- {
- const int nb_ic_blocking_step =
- nstl::min(icb + nb_ic_blocking, nb_ic) - icb;
- p.reduce_pos_flag = 0
- | (icb == 0 ? FLAG_REDUCE_FIRST : 0)
- | (icb + nb_ic_blocking_step >= nb_ic
- ? FLAG_REDUCE_LAST : 0);
-
- p.reduce_dim = this_block_size(icb * jcp.ic_block,
- jcp.ic, nb_ic_blocking_step * jcp.ic_block);
- rp.icb = p.reduce_dim / jcp.reduce_block;
- };
-
- auto inner_ker = [&](int ocb, int icb, int n, int g, int oh, int ow,
- int ih, int iw)
- {
-
- const int _ocb = g * nb_oc + ocb;
- const size_t dst_off = dst_d.blk_off(n, _ocb, oh, ow);
-
- p.output_data = &dst[dst_off];
- p.bias_data = &bias[_ocb * jcp.oc_block];
- p.load_data = &weights[conf_.with_groups()
- ? weights_d.blk_off(g, ocb, icb)
- : weights_d.blk_off(ocb, icb)];
-
- const int _icb = g * nb_ic + icb;
- if (conf_.rtus_.reduce_src_) {
- rp.ws = scratch_ + ithr * ws_per_thread_
- + _icb * jcp.is * jcp.ic_block;
- if (ocb == ocb_start) {
- rp.src = src + src_d.blk_off(n, _icb, ih, iw);
- rtus_driver_->ker_(&rp);
- }
- p.bcast_data = rp.ws;
- } else
- p.bcast_data = src + src_d.blk_off(n, _icb, ih, iw);
+ {
+ int osb{0};
+ nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb,
+ jcp.nb_bcast);
+ bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
+ jcp.nb_bcast_blocking_max);
+ bcast_step = nstl::min(bcast_step, bcast_end - iwork);
+
+ const int os = osb * os_block;
+ oh = os / jcp.ow;
+ ow = os % jcp.ow;
+
+ ih = nstl::max(oh * stride_h - pad_t, 0);
+ iw = nstl::max(ow * stride_w - pad_l, 0);
+ rp.iw_start = iw;
+
+ p.bcast_dim = this_block_size(os, jcp.os,
+ bcast_step * os_block);
+ rp.os = p.bcast_dim;
+ };
- kernel_->jit_ker(&p);
- };
+ auto init_load = [&](int ocb, int &load_step)
+ {
+ load_step = step(jcp.nb_load_blocking, ocb_end - ocb,
+ jcp.nb_load_blocking_max);
+ p.load_dim = this_block_size(ocb * jcp.oc_block,
+ ocb_end * jcp.oc_block, load_step * jcp.oc_block);
+ };
- if (jcp.loop_order == loop_rlb) {
- for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
- init_reduce(icb);
- int ocb = ocb_start;
- while (ocb < ocb_end) {
- int load_step;
- init_load(ocb, load_step);
- int iwork = bcast_start;
- while (iwork < bcast_end) {
- int n, g, bcast_step, oh, ow, ih, iw;
- init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
- inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
- iwork += bcast_step;
- }
- ocb += load_step;
- }
+ auto init_reduce = [&](int icb)
+ {
+ const int nb_ic_blocking_step =
+ nstl::min(icb + nb_ic_blocking, nb_ic) - icb;
+ p.first_last_flag = 0
+ | (icb == 0 ? FLAG_REDUCE_FIRST : 0)
+ | (icb + nb_ic_blocking_step >= nb_ic
+ ? FLAG_REDUCE_LAST : 0);
+
+ p.reduce_dim = this_block_size(icb * jcp.ic_block,
+ jcp.ic, nb_ic_blocking_step * jcp.ic_block);
+ rp.icb = p.reduce_dim / jcp.reduce_block;
+ };
+
+ auto inner_ker = [&](int ocb, int icb, int n, int g, int oh, int ow,
+ int ih, int iw)
+ {
+
+ const int _ocb = g * nb_oc + ocb;
+ const size_t dst_off = dst_d.blk_off(n, _ocb, oh, ow);
+
+ p.output_data = &dst[dst_off];
+ p.bias_data = &bias[_ocb * jcp.oc_block];
+ p.load_data = &weights[conf_.with_groups()
+ ? weights_d.blk_off(g, ocb, icb)
+ : weights_d.blk_off(ocb, icb)];
+
+ const int _icb = g * nb_ic + icb;
+ if (conf_.rtus_.reduce_src_) {
+ rp.ws = scratch_ + ithr * ws_per_thread_
+ + _icb * jcp.is * jcp.ic_block;
+ if (ocb == ocb_start) {
+ rp.src = src + src_d.blk_off(n, _icb, ih, iw);
+ rtus_driver_->ker_(&rp);
}
- } else if (jcp.loop_order == loop_lbr) {
+ p.bcast_data = rp.ws;
+ } else
+ p.bcast_data = src + src_d.blk_off(n, _icb, ih, iw);
+
+ p.oc_off = _ocb * jcp.oc_block * sizeof(dst_data_t);
+
+ kernel_->jit_ker(&p);
+ };
+
+ if (jcp.loop_order == loop_rlb) {
+ for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
+ init_reduce(icb);
int ocb = ocb_start;
while (ocb < ocb_end) {
int load_step;
@@ -208,32 +205,32 @@ void _jit_avx512_common_1x1_convolution_fwd_t
while (iwork < bcast_end) {
int n, g, bcast_step, oh, ow, ih, iw;
init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
- for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
- init_reduce(icb);
- inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
- }
+ inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
iwork += bcast_step;
}
ocb += load_step;
}
- } else if (jcp.loop_order == loop_rbl) {
- for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
- init_reduce(icb);
- int iwork = bcast_start;
- while (iwork < bcast_end) {
- int n, g, bcast_step, oh, ow, ih, iw;
- init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
- int ocb = ocb_start;
- while (ocb < ocb_end) {
- int load_step;
- init_load(ocb, load_step);
- inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
- ocb += load_step;
- }
- iwork += bcast_step;
+ }
+ } else if (jcp.loop_order == loop_lbr) {
+ int ocb = ocb_start;
+ while (ocb < ocb_end) {
+ int load_step;
+ init_load(ocb, load_step);
+ int iwork = bcast_start;
+ while (iwork < bcast_end) {
+ int n, g, bcast_step, oh, ow, ih, iw;
+ init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
+ for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
+ init_reduce(icb);
+ inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
}
+ iwork += bcast_step;
}
- } else if (jcp.loop_order == loop_blr) {
+ ocb += load_step;
+ }
+ } else if (jcp.loop_order == loop_rbl) {
+ for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
+ init_reduce(icb);
int iwork = bcast_start;
while (iwork < bcast_end) {
int n, g, bcast_step, oh, ow, ih, iw;
@@ -242,20 +239,35 @@ void _jit_avx512_common_1x1_convolution_fwd_t
while (ocb < ocb_end) {
int load_step;
init_load(ocb, load_step);
- for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
- init_reduce(icb);
- inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
- }
+ inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
ocb += load_step;
}
iwork += bcast_step;
}
- } else {
- assert(!"unsupported loop order");
}
+ } else if (jcp.loop_order == loop_blr) {
+ int iwork = bcast_start;
+ while (iwork < bcast_end) {
+ int n, g, bcast_step, oh, ow, ih, iw;
+ init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
+ int ocb = ocb_start;
+ while (ocb < ocb_end) {
+ int load_step;
+ init_load(ocb, load_step);
+ for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
+ init_reduce(icb);
+ inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
+ }
+ ocb += load_step;
+ }
+ iwork += bcast_step;
+ }
+ } else {
+ assert(!"unsupported loop order");
}
}
+
template struct _jit_avx512_common_1x1_convolution_fwd_t<true, data_type::f32>;
template struct _jit_avx512_common_1x1_convolution_fwd_t<false, data_type::f32>;
template struct _jit_avx512_common_1x1_convolution_fwd_t<false, data_type::s16,
@@ -302,10 +314,7 @@ void _jit_avx512_common_1x1_convolution_bwd_data_t
return remaining < tail_step ? remaining : default_step;
};
-# pragma omp parallel
- {
- int ithr = omp_get_thread_num(), nthr = omp_get_num_threads();
-
+ parallel(0, [&](const int ithr, const int nthr) {
auto p = jit_1x1_conv_call_s();
auto rp = rtus_driver_t<avx512_common>::call_params_t();
@@ -385,7 +394,7 @@ void _jit_avx512_common_1x1_convolution_bwd_data_t
? weights_d.blk_off(g, ocb, icb)
: weights_d.blk_off(ocb, icb)];
- p.reduce_pos_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0;
+ p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0;
p.reduce_dim = this_block_size(ocb * jcp.oc_block,
jcp.oc, nb_oc_blocking_step * jcp.oc_block);
@@ -397,7 +406,7 @@ void _jit_avx512_common_1x1_convolution_bwd_data_t
}
}
}
- }
+ });
}
template struct _jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>;
@@ -424,11 +433,6 @@ jit_avx512_common_1x1_convolution_bwd_weights_t ::
const auto &jcp = kernel_->jcp;
- bctx_ = (simple_barrier::ctx_t *)malloc(
- jcp.nthr * sizeof(simple_barrier::ctx_t), 64);
- for (int i = 0; i < jcp.nthr; ++i)
- simple_barrier::ctx_init(&bctx_[i]);
-
const int wei_size = jcp.ngroups * jcp.oc * jcp.ic;
ws_reduction_ =
(data_t *)malloc((jcp.nthr_mb - 1) * wei_size * sizeof(data_t), 64);
@@ -450,15 +454,18 @@ jit_avx512_common_1x1_convolution_bwd_weights_t ::
const ptrdiff_t tr_src_size = (ptrdiff_t)jcp.nthr_mb
* (ptrdiff_t)jcp.ngroups * (ptrdiff_t)jcp.ic * jcp.tr_is;
tr_src_ = (data_t *)malloc(tr_src_size * sizeof(data_t), 64);
-# pragma omp parallel for
- for (ptrdiff_t i = 0; i < tr_src_size; i++)
- tr_src_[i] = 0;
+ parallel_nd(tr_src_size, [&](ptrdiff_t i) { tr_src_[i] = 0; });
auto tp = jit_transpose4x16_src_t();
tp.src_pf0_distance = 4;
tp.tr_src_pf0_distance = 0;
tp.src_pf1 = true;
tp.tr_src_pf1 = false;
trans_kernel_ = new jit_transpose4x16_src(&jcp, &tp);
+
+ bctx_ = (simple_barrier::ctx_t *)malloc(
+ jcp.nthr * sizeof(simple_barrier::ctx_t), 64);
+ for (int i = 0; i < jcp.nthr; ++i)
+ simple_barrier::ctx_init(&bctx_[i]);
}
init_rtus_driver<avx512_common>(this);
@@ -557,6 +564,9 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
};
auto ker = [&](const int ithr, const int nthr) {
+ assert(nthr == jcp.nthr);
+ assert(utils::implication(!mkldnn_thr_syncable(), jcp.nthr_mb == 1));
+
const int ithr_ic_b = ithr % jcp.nthr_ic_b;
const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b;
const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g;
@@ -666,7 +676,7 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
p.reduce_dim = sp_b_step * jcp.reduce_block;
rp.os = p.reduce_dim;
- p.reduce_pos_flag = 0
+ p.first_last_flag = 0
| (mb_sp_b == mb_sp_b_start ? FLAG_REDUCE_FIRST : 0)
| (sp_b_end == sp_nb ? FLAG_SP_LAST : 0);
@@ -737,6 +747,7 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
}
}
};
+
auto ker_bias = [&](int ithr, int nthr) {
auto rb = this->reducer_bias_;
assert(nthr == rb->balancer_.nthr_);
@@ -784,14 +795,11 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
rb->reduce(ithr, diff_bias);
};
-#pragma omp parallel num_threads(jcp.nthr)
- {
- int ithr = omp_get_thread_num();
- assert(jcp.nthr == omp_get_num_threads());
+ parallel(jcp.nthr, [&](const int ithr, const int nthr) {
ker(ithr, jcp.nthr);
if (conf_.with_bias())
ker_bias(ithr, jcp.nthr);
- }
+ });
/* TODO: put this in ker_bias */
if (conf_.want_padded_bias()) {