summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp')
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp82
1 files changed, 44 insertions, 38 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp
index 33f293013..fc7a4e363 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp
@@ -41,7 +41,7 @@ using jit_conv_ker_t = void (*)(jit_conv_call_s *);
inline void jit_conv_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p,
const void *src, const void *dst, const void *filt, const void *bias,
- int channel, int kh_padding)
+ int channel, int kh_padding, int oc_off)
{
PIPELINE(src);
PIPELINE(dst);
@@ -49,6 +49,7 @@ inline void jit_conv_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p,
PIPELINE(bias);
PIPELINE(channel);
PIPELINE(kh_padding);
+ PIPELINE(oc_off);
if (p.src)
ker(&p);
@@ -100,7 +101,7 @@ void _jit_avx512_common_convolution_fwd_t
if (jcp.aligned_threads)
nthr = jcp.aligned_threads;
else
- nthr = omp_get_max_threads();
+ nthr = mkldnn_get_max_threads();
if (conf_.want_padded_bias()) {
for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
@@ -108,10 +109,7 @@ void _jit_avx512_common_convolution_fwd_t
bias = padded_bias_;
}
-#pragma omp parallel num_threads(nthr)
- {
- int ithr = omp_get_thread_num();
-
+ parallel(nthr, [&](const int ithr, const int nthr) {
int start, end, start_copy;
balance211(work_amount, nthr, ithr, start, end);
start_copy = start;
@@ -151,6 +149,8 @@ void _jit_avx512_common_convolution_fwd_t
auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, ih_s);
auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2);
+ int oc_off = g_oc * sizeof(dst_data_t);
+
for (int icb = icb_l2;
icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) {
auto src_c = src_w;
@@ -170,7 +170,7 @@ void _jit_avx512_common_convolution_fwd_t
jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
src_c + i_t_overflow * dilate_h * src_h_stride,
dst_c, wht_w + i_t_overflow * wht_h_stride,
- bias_w, icb, kh_padding);
+ bias_w, icb, kh_padding, oc_off);
src_c += src_h_stride * jcp.stride_h;
dst_c += dst_h_stride;
@@ -191,8 +191,8 @@ void _jit_avx512_common_convolution_fwd_t
}
jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
- src, dst, weights, bias, 0, 0);
- }
+ src, dst, weights, bias, 0, 0, 0);
+ });
}
template <bool with_relu, data_type_t src_type, data_type_t wei_type,
@@ -214,10 +214,13 @@ void _jit_avx512_common_convolution_fwd_t
const int MB = conf_.MB();
assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
-# pragma omp parallel
- {
- int ithr = omp_get_thread_num(), nthr = omp_get_num_threads();
+ if (conf_.want_padded_bias()) {
+ for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
+ padded_bias_[oc] = bias[oc];
+ bias = padded_bias_;
+ }
+ parallel(0, [&](const int ithr, const int nthr) {
int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
int start, end, start_copy;
int work_amount = MB * jcp.ngroups * oc_chunks * jcp.od * jcp.oh;
@@ -315,7 +318,7 @@ void _jit_avx512_common_convolution_fwd_t
}
jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv,
src, dst, weights, bias, 0, 0, 0);
- }
+ });
}
template struct _jit_avx512_common_convolution_fwd_t<false, data_type::f32>;
@@ -341,10 +344,7 @@ void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
const auto &jcp = kernel_->jcp;
const int MB = conf_.MB();
-# pragma omp parallel
- {
- int ithr = omp_get_thread_num(), nthr = omp_get_num_threads();
-
+ parallel(0, [&](const int ithr, const int nthr) {
int start, end, start_copy;
int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
int work_amount = jcp.ngroups * MB * ic_chunks * jcp.ih;
@@ -431,7 +431,7 @@ void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
diff_src_w + ij * diff_src_h_stride,
diff_dst_w + oj * diff_dst_h_stride,
wht_w + k_lo * wht_h_stride,
- 0, ocb, k_len);
+ 0, ocb, k_len, 0);
}
diff_dst_w += diff_dst_c_stride;
wht_w += wht_oc_stride;
@@ -449,8 +449,8 @@ void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
}
jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
- diff_src, diff_dst, weights, 0, 0, 1);
- }
+ diff_src, diff_dst, weights, 0, 0, 1, 0);
+ });
}
template <data_type_t diff_dst_type, data_type_t wei_type,
@@ -469,10 +469,7 @@ void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
const auto &jcp = kernel_->jcp;
const int MB = conf_.MB();
-# pragma omp parallel
- {
- int ithr = omp_get_thread_num(), nthr = omp_get_num_threads();
-
+ parallel(0, [&](const int ithr, const int nthr) {
int start, end, start_copy;
int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
int work_amount = jcp.ngroups * MB * ic_chunks * jcp.id * jcp.ih;
@@ -625,7 +622,7 @@ void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv,
diff_src, diff_dst, weights, 0, 0, 1, 1);
- }
+ });
}
template struct jit_avx512_common_convolution_bwd_data_t<data_type::f32>;
@@ -964,6 +961,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
tr_ctx.tr_src = tr_src_
+ ti->ithr_but_oc * jcp.ih * jcp.stride_w * jcp.tr_ld;
+ assert(utils::implication(!mkldnn_thr_syncable(), nthr_oc_b_ == 1));
tr_ctx.nthr_oc_b = nthr_oc_b_;
int ih_start{0}, ih_end{0};
balance211(jcp.ih, nthr_oc_b_, ti->ithr_oc_b, ih_start, ih_end);
@@ -1051,7 +1049,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
? &tr_diff_dst_[tr_diff_dst_off(ti->ithr_mb, _oc, 0)]
: &ti->diff_dst[diff_dst_d.blk_off(img, _oc)],
diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b),
- 0, (img == ti->img_start), 0);
+ 0, (img == ti->img_start), 0, 0);
}
}
@@ -1069,7 +1067,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
diff_wei + wht_blk_off(
diff_weights_d, ti->g_start,
ti->oc_b_start, ti->ic_b_start),
- 0, 0, 0);
+ 0, 0, 0, 0);
}
}
}
@@ -1093,7 +1091,6 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
: (diff_weights_data_t*)ws_reduction_ + (nthr_mb_ - 1) * wei_size
+ (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc;
-
const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
const int input_step = jcp.stride_d * jcp.ih * jcp.iw * inp_mult;
const int output_step = jcp.ow * jcp.oh * jcp.oc_block;
@@ -1110,7 +1107,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
const int id_s = od_s * jcp.stride_d;
const int idp = jcp.id + jcp.f_pad + jcp.back_pad;
- if (id_s < idp - jcp.back_pad - jcp.kd + 1) {
+ if (id_s < idp - jcp.back_pad - (jcp.kd - 1) * (jcp.dilate_d)) {
for (int g = ti->g_start; g < ti->g_end; ++g) {
for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) {
for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) {
@@ -1320,7 +1317,8 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
const diff_weights_data_t *diff_bias_ws
= ws_reduction_ + (size_t)(nthr_mb_ - 1) * wei_size;
- #pragma omp barrier
+ if (nthr_mb_ > 1) mkldnn_thr_barrier();
+
if (ti->ithr == 0)
{
for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
@@ -1330,15 +1328,12 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
}
}
-
template <data_type_t src_type, data_type_t diff_dst_type,
data_type_t diff_weights_type>
void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
diff_weights_type>::execute_backward_weights() {
-# pragma omp parallel num_threads(nthr_)
- {
- int ithr = omp_get_thread_num();
- assert(nthr_ == omp_get_num_threads());
+ parallel(nthr_, [&](const int ithr, const int nthr) {
+ assert(nthr_ == nthr);
thread_info_t thread_info(this, ithr);
@@ -1351,7 +1346,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
if (nthr_mb_ > 1) reduce_diff_weights_3d(&thread_info);
if (conf_.with_bias()) compute_diff_bias_3d(&thread_info);
}
- }
+ });
/* TODO: put that into compute_diff_bias() */
if (conf_.want_padded_bias()) {
@@ -1366,7 +1361,7 @@ template <data_type_t src_type, data_type_t diff_dst_type,
data_type_t diff_weights_type>
void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
diff_weights_type>::balance() {
- const int max_threads = omp_get_max_threads();
+ const int max_threads = mkldnn_get_max_threads();
const auto &j = conf_.jcp_;
nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
@@ -1376,6 +1371,13 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
return;
}
+ if (!mkldnn_thr_syncable()
+ && utils::one_of(j.ver, ver_4fma, ver_4vnni, ver_vnni)) {
+ // should not happen -- the driver is not ready
+ // for TBB-like non-synchronous threading yet
+ return;
+ }
+
if (j.ver == ver_4fma && j.is_1stconv) {
nthr_g_ = 1;
nthr_oc_b_ = 1;
@@ -1424,7 +1426,6 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
const int nthr_par = nthr / nthr_mb;
const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
-
int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
@@ -1435,6 +1436,8 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
nthr_ic_b_ = nthr_ic_b;
}
}
+
+ if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
}
if (j.ver != ver_vnni && !mayiuse(avx512_mic)) {
@@ -1471,6 +1474,8 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
nthr_ic_b_ = nthr_ic_b;
}
}
+
+ if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
}
}
@@ -1478,6 +1483,7 @@ void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
nthr_mb_ = min(j.mb * j.od, max_threads);
nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
assert(nthr_ <= max_threads);
+ assert(utils::implication(!mkldnn_thr_syncable(), nthr_mb_ == 1));
}
template struct jit_avx512_common_convolution_bwd_weights_t<data_type::f32>;