summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx2_convolution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx2_convolution.cpp')
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx2_convolution.cpp223
1 files changed, 148 insertions, 75 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx2_convolution.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx2_convolution.cpp
index 6aba297a0..2874bf929 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx2_convolution.cpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx2_convolution.cpp
@@ -31,6 +31,20 @@ using namespace mkldnn::impl::status;
using namespace mkldnn::impl::memory_format;
using namespace mkldnn::impl::utils;
+#define src_blk_off(f, n, c, d, h, w) \
+ conf_.ndims() == 5 \
+ ? (f).blk_off(n, c, d, h, w) \
+ : (f).blk_off(n, c, h, w)
+
+#define wht_blk_off(f, g, oc, ic, kd, kh, kw) \
+ conf_.ndims() == 5 \
+ ? conf_.with_groups() \
+ ? (f).blk_off(g, oc, ic, kd, kh, kw) \
+ : (f).blk_off(oc, ic, kd, kh, kw) \
+ : conf_.with_groups() \
+ ? (f).blk_off(g, oc, ic, kh, kw) \
+ : (f).blk_off(oc, ic, kh, kw)
+
template <bool with_relu>
void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward() {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
@@ -44,10 +58,11 @@ void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward() {
const memory_desc_wrapper bias_d(conf_.weights_pd(1));
const auto &jcp = kernel_->jcp;
- int MB = conf_.MB();
+ const int MB = conf_.MB();
int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
- const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.oh;
+ const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.od
+ * jcp.oh;
auto ker = [&](const int ithr, const int nthr) {
size_t start{0}, end{0};
@@ -60,9 +75,9 @@ void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward() {
if (icb_step_rem < jcp.nb_ic_blocking_max)
icb_step = icb_step_rem;
- size_t n{0}, g{0}, ocbb{0}, oh{0};
+ size_t n{0}, g{0}, ocbb{0}, oh{0}, od{0};
nd_iterator_init(start, n, MB, g, jcp.ngroups, ocbb, ocb_work,
- oh, jcp.oh);
+ od, jcp.od, oh, jcp.oh);
for (size_t iwork = start; iwork < end; ++iwork) {
int ocb = ocbb * jcp.nb_oc_blocking;
int ocb_num = jcp.nb_oc_blocking;
@@ -75,23 +90,31 @@ void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward() {
const int i_b_overflow = nstl::max(jcp.ih, ij
+ (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih;
+ const int dj = od * jcp.stride_d;
+ const int d_t_overflow = nstl::max(0, jcp.f_pad - dj);
+ const int d_b_overflow = nstl::max(jcp.id, dj
+ + (jcp.kd-1) * (jcp.dilate_d+1) - jcp.f_pad+1) - jcp.id;
+
const size_t _oc = g * jcp.nb_oc + ocb;
const size_t _ic = g * jcp.nb_ic + icb;
const int ih = nstl::max(ij - jcp.t_pad
+ div_up(i_t_overflow,
(jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0);
- par_conv.src = &src[src_d.blk_off(n,
- jcp.ic == 3 ? 0 : _ic, ih, 0)];
- par_conv.dst = &dst[dst_d.blk_off(n, _oc, oh, 0)];
+ const int id = nstl::max(dj - jcp.f_pad
+ + div_up(d_t_overflow,
+ (jcp.dilate_d+1)) * (jcp.dilate_d + 1), 0);
+
+ par_conv.src = &src[src_blk_off(src_d, n,
+ jcp.ic == 3 ? 0 : _ic, id, ih, 0)];
+
+ par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, od, oh, 0)];
const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
- par_conv.filt = &weights[conf_.with_groups()
- ? weights_d.blk_off(g, ocb,
- jcp.ic == 3 ? 0 : icb, wh, 0)
- : weights_d.blk_off(ocb,
- jcp.ic == 3 ? 0 : icb, wh, 0)];
+ const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1));
+ par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb,
+ jcp.ic == 3 ? 0 : icb, wd, wh, 0)];
if (icb == 0) {
if (bias)
@@ -100,10 +123,12 @@ void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward() {
par_conv.flags |= FLAG_IC_FIRST;
}
- if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) {
+ if (icb + 1 == jcp.nb_ic) {
par_conv.flags |= FLAG_IC_LAST;
}
+ par_conv.oc_off = _oc * jcp.oc_block * sizeof(float);
+
par_conv.oc_blocks =
nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb;
@@ -112,19 +137,28 @@ void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward() {
- div_up(i_t_overflow, (jcp.dilate_h + 1))
- div_up(i_b_overflow, (jcp.dilate_h + 1));
par_conv.kh_padding = nstl::max(0, kh_padding);
+
+ const int kd_padding = jcp.kd
+ - div_up(d_t_overflow, (jcp.dilate_d + 1))
+ - div_up(d_b_overflow, (jcp.dilate_d + 1));
+ par_conv.kd_padding = nstl::max(0, kd_padding);
+
kernel_->jit_ker(&par_conv);
}
nd_iterator_step(n, MB, g, jcp.ngroups, ocbb, ocb_work,
- oh, jcp.oh);
+ od, jcp.od, oh, jcp.oh);
}
icbb += icb_step;
}
};
- #pragma omp parallel
- {
- ker(omp_get_thread_num(), omp_get_num_threads());
+ if (conf_.want_padded_bias()) {
+ for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
+ padded_bias_[oc] = bias[oc];
+ bias = padded_bias_;
}
+
+ parallel(0, ker);
}
template <bool with_relu>
@@ -142,6 +176,8 @@ void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward_fusing() {
const auto &jcp_dw = kernel_dw_->jcp;
const int MB = conf_.MB();
+ auto dw_bias = jcp.dw_conv_biases;
+
int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.oh;
@@ -189,10 +225,12 @@ void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward_fusing() {
par_conv.flags |= FLAG_IC_FIRST;
}
- if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) {
+ if (icb + 1 == jcp.nb_ic) {
par_conv.flags |= FLAG_IC_LAST;
}
+ par_conv.oc_off = _oc * jcp.oc_block * sizeof(float);
+
par_conv.oc_blocks =
nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb;
@@ -224,7 +262,7 @@ void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward_fusing() {
par_conv_dw.kh_padding = jcp_dw.kh;
par_conv_dw.filt = &jcp.dw_conv_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
- par_conv_dw.bias = &jcp.dw_conv_biases[chb * jcp_dw.ch_block];
+ par_conv_dw.bias = &dw_bias[chb * jcp_dw.ch_block];
par_conv_dw.ur_w = (size_t)(jcp_dw.ow);
kernel_dw_->jit_ker(&par_conv_dw);
@@ -263,10 +301,17 @@ void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward_fusing() {
}
};
- #pragma omp parallel
- {
- ker(omp_get_thread_num(), omp_get_num_threads());
+ if (conf_.want_padded_bias()) {
+ for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
+ padded_bias_[oc] = bias[oc];
+ bias = padded_bias_;
+
+ for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
+ dw_padded_bias_[oc] = dw_bias[oc];
+ dw_bias = dw_padded_bias_;
}
+
+ parallel(0, ker);
}
template void _jit_avx2_convolution_fwd_t<true>::execute_forward();
@@ -296,74 +341,78 @@ void jit_avx2_convolution_bwd_data_t::execute_backward_data() {
size_t n{0}, g{0}, icbb{0}, ih{0};
nd_iterator_init(start, n, MB, g, jcp.ngroups, icbb, icb_work, ih, jcp.ih);
for (size_t iwork = start; iwork < end; ++iwork) {
- for (int oc = 0; oc < jcp.nb_oc; ++oc) {
- jit_conv_call_s par_conv = {};
+ for (int oc = 0; oc < jcp.nb_oc; ++oc)
+ for (int id = 0; id < jcp.id; ++id) {
+ auto par_conv = jit_conv_call_s();
+
+ const int idp = jcp.id + 2 * jcp.f_pad;
+ const int d_t_overflow = nstl::max(0,
+ jcp.kd - 1 - id - jcp.f_pad);
+ const int back_pad = idp - jcp.id - jcp.f_pad;
+ const int d_b_overflow = nstl::max(0,
+ jcp.kd - 1 - (jcp.id - 1 - id) - back_pad);
+ const int od = id + jcp.f_pad - d_b_overflow;
+
+ const int simd_w = 8;
const int i_t_overflow = nstl::max(0,
- jcp.kh - 1 - (int)ih - jcp.t_pad);
+ jcp.kh - 1 - (int)ih - jcp.t_pad);
const int b_pad = jcp.ihp - jcp.ih - jcp.t_pad;
const int i_b_overflow = nstl::max(0,
- jcp.kh - 1 - (jcp.ih - 1 - (int)ih) - b_pad);
+ jcp.kh - 1 - (jcp.ih - 1 - (int)ih) - b_pad);
int oh = ih + jcp.t_pad - i_b_overflow;
+
int stride_off_h = oh % jcp.stride_h;
oh /= jcp.stride_h;
- const int simd_w = 8;
+ par_conv.src = &diff_src[src_blk_off(diff_src_d, n,
+ /*jcp.ic == 3 ? 0 :*/
+ g * jcp.nb_ic + jcp.nb_ic_blocking * icbb, id, ih, 0)];
+ par_conv.dst = &diff_dst[src_blk_off(diff_dst_d,
+ n, g * jcp.nb_oc + oc, od, oh, 0)];
+ par_conv.filt = &weights[wht_blk_off(weights_d, g, oc,
+ jcp.ic == 3 ? 0 : jcp.nb_ic_blocking * icbb,
+ d_b_overflow, i_b_overflow + stride_off_h, 0)];
- par_conv.src = &diff_src[diff_src_d.blk_off(n,
- /*jcp.ic == 3 ? 0 :*/
- g * jcp.nb_ic + jcp.nb_ic_blocking * icbb, ih, 0)];
- par_conv.dst = &diff_dst[diff_dst_d.blk_off(
- n, g * jcp.nb_oc + oc, oh, 0)];
- par_conv.filt = &weights[
- conf_.with_groups() ? weights_d.blk_off(g, oc,
- jcp.ic == 3 ? 0 : jcp.nb_ic_blocking * icbb,
- i_b_overflow + stride_off_h, 0)
- : weights_d.blk_off(oc,
- jcp.ic == 3 ? 0 : jcp.nb_ic_blocking * icbb,
- i_b_overflow + stride_off_h, 0)];
par_conv.src_prf = nullptr;
par_conv.dst_prf = nullptr;
par_conv.filt_prf = nullptr;
// TODO: move initialization into the kernel
- if (oc == 0)
- {
- for (int iw = 0; iw < jcp.iw; iw++)
- {
- for (int b = 0; b < jcp.nb_ic_blocking; b++)
- {
- int current_ic =
+ if (oc == 0) {
+ for (int iw = 0; iw < jcp.iw; iw++) {
+ for (int b = 0; b < jcp.nb_ic_blocking; b++) {
+ int current_ic =
(jcp.ic == 3 ? 0 : g * jcp.nb_ic)
+ jcp.nb_ic_blocking * icbb + b;
- int current_idx =
- diff_src_d.blk_off(n, current_ic, ih, iw);
- for (int v = 0; v < simd_w; v++)
- diff_src[current_idx + v] = 0.0;
- }
+ int current_idx =
+ src_blk_off(diff_src_d, n, current_ic,
+ id, ih, iw);
+ for (int v = 0; v < simd_w; v++)
+ diff_src[current_idx + v] = 0.0;
}
}
+ }
+ par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow;
par_conv.kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow - stride_off_h);
par_conv.kw_padding = 0;
if (par_conv.kh_padding > 0)
kernel_->jit_ker(&par_conv);
- }
+ }
nd_iterator_step(n, MB, g, jcp.ngroups, icbb, icb_work, ih, jcp.ih);
}
};
-#pragma omp parallel
- {
- ker(omp_get_thread_num(), omp_get_num_threads());
- }
+ parallel(0, ker);
}
void jit_avx2_convolution_bwd_weights_t::execute_backward_weights() {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
- auto diff_bias = reinterpret_cast<data_t *>(this->memory(1));
+ auto diff_bias_in = reinterpret_cast<data_t *>(this->memory(1));
+ data_t *diff_bias = conf_.want_padded_bias() ? padded_bias_ : diff_bias_in;
const memory_desc_wrapper src_d(conf_.src_pd(0));
const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
@@ -381,36 +430,55 @@ void jit_avx2_convolution_bwd_weights_t::execute_backward_weights() {
if (w_njobs == 0) return;
/* reduction dimension */
- int img_start{0}, img_end{0};
- balance211(jcp.mb, rw->balancer_.nthr_per_group_,
- rw->balancer_.id_in_group(ithr), img_start, img_end);
+ int img_od_start{0}, img_od_end{0}, img{0}, od_s{0};
+ balance211(jcp.mb * jcp.od, rw->balancer_.nthr_per_group_,
+ rw->balancer_.id_in_group(ithr), img_od_start, img_od_end);
+
+ int img_start = img_od_start, img_end = img_od_end;
+ nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od);
+ const int img_first = img;
/* jobs */
int g_start{0}, ocb_start{0}, icb_start{0};
nd_iterator_init(w_job_start, g_start, jcp.ngroups, ocb_start,
jcp.nb_oc, icb_start, jcp.nb_ic);
- for (int img = img_start; img < img_end; ++img) {
+ while (img_start < img_end) {
int g = g_start, ocb = ocb_start, icb = icb_start;
+
+ const int work_rem = img_end - img_start;
+ const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem;
+ const int id_s = od_s * jcp.stride_d;
+ const int idp = jcp.id + jcp.f_pad + jcp.back_pad;
+
+ if (id_s < idp - jcp.back_pad - jcp.kd + 1)
for (int w_job_loc = 0; w_job_loc < w_njobs; ++w_job_loc) {
const size_t _oc = g * jcp.nb_oc + ocb;
const size_t _ic = g * jcp.nb_ic + icb;
- auto par_conv = jit_conv_call_s();
- par_conv.src = &src[src_d.blk_off(img, _ic)];
- par_conv.dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
- par_conv.filt = &rw->get_local_ptr(ithr, diff_weights)[
- w_job_loc * rw->balancer_.job_size_];
-
/* TODO: put dw <-- 0 in kernel */
- if (img == img_start)
- array_set((data_t *)par_conv.filt, 0,
+ if (img == img_first)
+ array_set((data_t *)&rw->get_local_ptr(ithr, diff_weights)[
+ w_job_loc * rw->balancer_.job_size_], 0,
rw->balancer_.job_size_);
- kernel_->jit_ker(&par_conv);
+ for (int od = od_s; od < od_e; ++od) {
+ const int id = od * jcp.stride_d;
+ if (id >= jcp.id - jcp.back_pad - jcp.kd + 1) break;
+
+ auto par_conv = jit_conv_call_s();
+ par_conv.src = &src[src_blk_off(src_d, img, _ic, id, 0, 0)];
+ par_conv.dst =
+ &diff_dst[src_blk_off(diff_dst_d, img, _oc, od, 0, 0)];
+ par_conv.filt = &rw->get_local_ptr(ithr, diff_weights)[
+ w_job_loc * rw->balancer_.job_size_];
+
+ kernel_->jit_ker(&par_conv);
+ }
nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc, icb,
jcp.nb_ic);
}
+ nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od);
}
rw->reduce(ithr, diff_weights);
};
@@ -447,7 +515,7 @@ void jit_avx2_convolution_bwd_weights_t::execute_backward_weights() {
for (int o = 0; o < 8; ++o)
d_bias[o] = 0.;
- for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) {
+ for (int dhw = 0; dhw < jcp.od * jcp.oh * jcp.ow; ++dhw) {
PRAGMA_OMP_SIMD()
for (int o = 0; o < 8; ++o)
d_bias[o] += d_dst[o];
@@ -460,13 +528,18 @@ void jit_avx2_convolution_bwd_weights_t::execute_backward_weights() {
rb->reduce(ithr, diff_bias);
};
-# pragma omp parallel
- {
- int ithr = omp_get_thread_num();
- int nthr = omp_get_num_threads();
+
+ parallel(0, [&](const int ithr, const int nthr) {
ker(ithr, nthr);
if (conf_.with_bias())
ker_bias(ithr, nthr);
+ });
+
+ /* TODO: put this in ker_bias */
+ if (conf_.want_padded_bias()) {
+ assert(jcp.ngroups == 1);
+ for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
+ diff_bias_in[oc] = diff_bias[oc];
}
}