diff options
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.cpp')
-rw-r--r-- | inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.cpp | 347 |
1 files changed, 196 insertions, 151 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.cpp index 1113b277c..7d88935a4 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.cpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.cpp @@ -69,41 +69,31 @@ bool jit_avx512_core_u8s8s32x_fwd_kernel::maybe_relu(int position) void jit_avx512_core_u8s8s32x_fwd_kernel::prepare_output(int ur_w) { - Label l_first_load, l_ret; - - mov(reg_channel, ptr[param1 + GET_OFF(channel)]); - cmp(reg_channel, 0); // FISRT load - je(l_first_load, T_NEAR); - - for (int k = 0; k < jcp.nb_oc_blocking; k++) - for (int j = 0; j < ur_w; j++) { - Zmm zmm = zmm_out(j, k); - int offset = jcp.typesize_acc * (k*ur_w + j) * jcp.oc_block; - vmovups(zmm, EVEX_compress_addr(reg_acc_s32, offset)); - } - jmp(l_ret, T_NEAR); - - L(l_first_load); for (int k = 0; k < jcp.nb_oc_blocking; k++) for (int j = 0; j < ur_w; j++) { Zmm zmm = zmm_out(j, k); vpxord(zmm, zmm, zmm); } +} - L(l_ret); +void jit_avx512_core_u8s8s32x_fwd_kernel::cvt2ps(data_type_t type_in, + zmm_t zmm_in, const Xbyak::Operand &op, bool mask_flag) { + zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in; + switch (type_in) { + case data_type::f32: + case data_type::s32: vmovups(zmm, op); break; + case data_type::s8: vpmovsxbd(zmm, op); break; + case data_type::u8: vpmovzxbd(zmm, op); break; + default: assert(!"unsupported data type"); + } + if (type_in != data_type::f32) + vcvtdq2ps(zmm_in, zmm_in); } -void jit_avx512_core_u8s8s32x_fwd_kernel::store_output(int ur_w) +void jit_avx512_core_u8s8s32x_fwd_kernel::store_output(int ur_w, + int last_oc_block_flag) { - Label l_update_acc, l_ret; - - mov(reg_channel, ptr[param1 + GET_OFF(channel)]); - - int adjusment = jcp.nb_ic - ((jcp.nb_ic_blocking <= 1) - ? 0 - : jcp.nb_ic_blocking) - 1; - cmp(reg_channel, adjusment); // LAST channel - jl(l_update_acc, T_NEAR); + int nb_oc_block = jcp.nb_oc_blocking; mov(reg_bias, ptr[param1 + GET_OFF(bias)]); mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); @@ -117,47 +107,37 @@ void jit_avx512_core_u8s8s32x_fwd_kernel::store_output(int ur_w) mov(reg_ptr_sum_scale, (size_t)p_sum_scale); vpxord(zmm_zero, zmm_zero, zmm_zero); - for (int k = 0; k < jcp.nb_oc_blocking; k++) { + for (int k = 0; k < nb_oc_block; k++) { + const bool mask_flag = last_oc_block_flag == 1 && k == nb_oc_block - 1; int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * jcp.oc_block); auto zmm_bias = zmm_tmp; if (jcp.with_bias) { int bias_offset = jcp.typesize_bia * k * jcp.oc_block; auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); - switch (jcp.bia_dt) { - case data_type::f32: - case data_type::s32: vmovups(zmm_bias, bias_addr); break; - case data_type::s8: vpmovsxbd(zmm_bias, bias_addr); break; - case data_type::u8: vpmovzxbd(zmm_bias, bias_addr); break; - default: assert(!"unsupported dst data type"); - } - if (jcp.bia_dt != data_type::f32) - vcvtdq2ps(zmm_bias, zmm_bias); + + cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag); } for (int j = 0; j < ur_w; j++) { int aux_output_offset = jcp.typesize_out * (k * jcp.oc_block - + j * jcp.oc * jcp.ngroups); + + j * jcp.oc_without_padding * jcp.ngroups); auto addr = EVEX_compress_addr(reg_out, aux_output_offset); - Xmm xmm = xmm_out(j, k); Zmm zmm = zmm_out(j, k); vcvtdq2ps(zmm, zmm); if (jcp.with_bias) vaddps(zmm, zmm, zmm_bias); - vmulps(zmm, zmm, EVEX_compress_addr(reg_ptr_scales, scale_offset)); + + zmm_t mask_zmm = mask_flag ? zmm | ktail_mask | T_z : zmm; + vmulps(mask_zmm, zmm, + EVEX_compress_addr(reg_ptr_scales, scale_offset)); if (maybe_relu(0)) vmaxps(zmm, zmm_zero, zmm); if (p_sum_scale) { // post_op: sum auto zmm_prev_dst = zmm_bcast; - switch (jcp.dst_dt) { - case data_type::f32: - case data_type::s32: vmovups(zmm_prev_dst, addr); break; - case data_type::s8: vpmovsxbd(zmm_prev_dst, addr); break; - case data_type::u8: vpmovzxbd(zmm_prev_dst, addr); break; - default: assert(!"unknown dst_dt"); - } - if (jcp.dst_dt != data_type::f32) - vcvtdq2ps(zmm_prev_dst, zmm_prev_dst); + + cvt2ps(jcp.dst_dt, zmm_prev_dst, addr, mask_flag); + if (*p_sum_scale == 1.f) vaddps(zmm, zmm_prev_dst); else @@ -174,29 +154,28 @@ void jit_avx512_core_u8s8s32x_fwd_kernel::store_output(int ur_w) else assert(!"unimplemented"); } + } + + for (int j = 0; j < ur_w; j++) { + int aux_output_offset = jcp.typesize_out * (k * jcp.oc_block + + j * jcp.oc_without_padding * jcp.ngroups); + auto addr = EVEX_compress_addr(reg_out, aux_output_offset); + + Zmm zmm = zmm_out(j, k); + zmm_t r_zmm = mask_flag ? zmm | ktail_mask : zmm; switch (jcp.dst_dt) { case data_type::f32: - case data_type::s32: vmovups(addr, zmm); break; - case data_type::s8: vpmovsdb(xmm, zmm); vmovups(addr, xmm); break; - case data_type::u8: vpmovusdb(xmm, zmm); vmovups(addr, xmm); break; + case data_type::s32: vmovups(addr, r_zmm); break; + case data_type::s8: vpmovsdb(addr, r_zmm); break; + case data_type::u8: vpmovusdb(addr, r_zmm); break; default: assert(!"unknown dst_dt"); } } } - jmp(l_ret, T_NEAR); - - L(l_update_acc); - for (int k = 0; k < jcp.nb_oc_blocking; k++) - for (int j = 0; j < ur_w; j++) { - Zmm zmm = zmm_out(j, k); - int offset = jcp.typesize_acc * (k*ur_w + j) * jcp.oc_block; - vmovups(EVEX_compress_addr(reg_acc_s32, offset), zmm); - } - L(l_ret); } -void jit_avx512_core_u8s8s32x_fwd_kernel::compute_loop(int ur_w, - int pad_l, int pad_r) +void jit_avx512_core_u8s8s32x_fwd_kernel::compute_ker(int ur_w, + int pad_l, int pad_r, int last_ic_block_flag) { int kw = jcp.kw; int stride_w = jcp.stride_w; @@ -205,48 +184,42 @@ void jit_avx512_core_u8s8s32x_fwd_kernel::compute_loop(int ur_w, int ch_block_all = jcp.ch_block * ic_block * oc_block; int nb_oc_block = jcp.nb_oc_blocking; - int nb_ic_block = jcp.nb_ic_blocking; Label kh_label, skip_kh_loop; int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all; - int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw * jcp.ic - * jcp.ngroups; + int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw + * jcp.ic_without_padding * jcp.ngroups; - auto input_offset = [=](int oi, int nb_ic, int ic, int ki) { + auto input_offset = [=](int oi, int ic, int ki) { return jcp.typesize_in - * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) * jcp.ic - * jcp.ngroups - + 4 * ic + nb_ic * jcp.ic_block); + * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) + * jcp.ic_without_padding * jcp.ngroups + 4 * ic); }; - auto kernel_offset = [=](int ii, int nb_ic, int ic, int ki) { + auto kernel_offset = [=](int ii, int ic, int ki) { return jcp.typesize_in * ((ii * jcp.nb_ic * jcp.kh * jcp.kw + ki) * ch_block_all - + 4 * ic * oc_block - + nb_ic * jcp.kh * jcp.kw * ch_block_all); + + 4 * ic * oc_block); }; auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { - if (jcp.is_depthwise) { + if (jcp.ver == ver_vnni) { + // also okay for depthwise since src is zero-extended + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else if (jcp.is_depthwise) { vpmulld(zmm_tmp, vreg_src, vreg_wei); vpaddd(vreg_acc, vreg_acc, zmm_tmp); } else { - if (jcp.ver == ver_vnni) { - vpdpbusd(vreg_acc, vreg_src, vreg_wei); - } else { - vpmaddubsw(zmm_tmp, vreg_src, vreg_wei); - vpmaddwd(zmm_tmp, zmm_tmp, zmm_one); - vpaddd(vreg_acc, vreg_acc, zmm_tmp); - } + vpmaddubsw(zmm_tmp, vreg_src, vreg_wei); + vpmaddwd(zmm_tmp, zmm_tmp, zmm_one); + vpaddd(vreg_acc, vreg_acc, zmm_tmp); } }; - prepare_output(ur_w); - mov(aux_reg_inp, reg_inp); mov(aux_reg_ker, reg_ker); mov(reg_kj, reg_kh); - if ((jcp.kh - 1) * (jcp.dilate_h + 1) < jcp.t_pad) { + if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { cmp(reg_kj, 0); je(skip_kh_loop, T_NEAR); } @@ -254,37 +227,48 @@ void jit_avx512_core_u8s8s32x_fwd_kernel::compute_loop(int ur_w, for (int ki = 0; ki < kw; ki++) { int jj_start = get_ow_start(ki, pad_l); int jj_end = get_ow_end(ur_w, ki, pad_r); + int tail_size = jcp.ic_without_padding % 4; + /* Skip the last loads of input if (ic%16)/4 < ic_block/4 */ + int icb = jcp.is_depthwise + ? 1 + : (last_ic_block_flag != no_last_block) + ? div_up((jcp.ic_without_padding % ic_block), 4) + : ic_block / 4; + for (int ic = 0; ic < icb; ic++) { + for (int jj = jj_start; jj < jj_end; jj++) { + int aux_input_offset = input_offset(jj, ic, ki); + if (jcp.is_depthwise) { + vpmovzxbd(zmm_inp(jj, nb_oc_block), + EVEX_compress_addr( + aux_reg_inp, aux_input_offset)); + } else if (last_ic_block_flag == last_sp_block + && tail_size != 0 && ic == icb - 1) { + Xmm xmm_tmp = Xmm(zmm_inp(jj, nb_oc_block).getIdx()); + for (int r = 0; r < tail_size; ++r) + vpinsrb(xmm_tmp, xmm_tmp, + ptr[aux_reg_inp + aux_input_offset + r], r); + vpbroadcastd(zmm_inp(jj, nb_oc_block), xmm_tmp); + } else { + vpbroadcastd(zmm_inp(jj, nb_oc_block), + EVEX_compress_addr( + aux_reg_inp, aux_input_offset)); + } + } - for (int cc = 0; cc < nb_ic_block; cc++) { - for (int ic = 0; ic < (jcp.is_depthwise ? 1 : ic_block / 4); - ic++) { - for (int jj = jj_start; jj < jj_end; jj++) { - int aux_input_offset = input_offset(jj, cc, ic, ki); + for (int ii = 0; ii < nb_oc_block; ii++) { + int aux_kernel_offset = kernel_offset(ii, ic, ki); + if (jj_end - jj_start > 0) { if (jcp.is_depthwise) - vpmovzxbd(zmm_inp(jj, nb_oc_block), - EVEX_compress_addr( - aux_reg_inp, aux_input_offset)); + vpmovsxbd( + zmm_wei, EVEX_compress_addr(aux_reg_ker, + aux_kernel_offset)); else - vpbroadcastd(zmm_inp(jj, nb_oc_block), - EVEX_compress_addr(aux_reg_inp, - aux_input_offset)); + vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker, + aux_kernel_offset)); } - - for (int ii = 0; ii < nb_oc_block; ii++) { - int aux_kernel_offset = kernel_offset(ii, cc, ic, ki); - if (jj_end - jj_start > 0) { - if (jcp.is_depthwise) - vpmovsxbd( - zmm_wei, EVEX_compress_addr(aux_reg_ker, - aux_kernel_offset)); - else - vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker, - aux_kernel_offset)); - } - for (int jj = jj_start; jj < jj_end; jj++) { - compute(zmm_out(jj, ii), zmm_wei, - zmm_inp(jj, nb_oc_block)); - } + for (int jj = jj_start; jj < jj_end; jj++) { + compute(zmm_out(jj, ii), zmm_wei, + zmm_inp(jj, nb_oc_block)); } } } @@ -296,22 +280,78 @@ void jit_avx512_core_u8s8s32x_fwd_kernel::compute_loop(int ur_w, jg(kh_label, T_NEAR); } L(skip_kh_loop); +} + +void jit_avx512_core_u8s8s32x_fwd_kernel::compute_loop( + int ur_w, int pad_l, int pad_r, bool is_last_sp_block) +{ + prepare_output(ur_w); + + // IC loop + Label icb_label; + mov(reg_icb, jcp.nb_ic); + L(icb_label); + if (jcp.ic_without_padding != jcp.ic) { + Label common_ker, end_ker; + + cmp(reg_icb, 1); // The last IC block + jne(common_ker, T_NEAR); + + compute_ker(ur_w, pad_l, pad_r, + is_last_sp_block ? last_sp_block : last_ic_block); + jmp(end_ker, T_NEAR); + + L(common_ker); + compute_ker(ur_w, pad_l, pad_r, no_last_block); + + L(end_ker); + } else { + compute_ker(ur_w, pad_l, pad_r, no_last_block); + } + // End of IC Loop + int inp_step = jcp.ic_block; + int ker_step = jcp.kh * jcp.kw * jcp.oc_block * jcp.ic_block; + add(reg_inp, jcp.typesize_in * inp_step); + add(reg_ker, jcp.typesize_in * ker_step); + + dec(reg_icb); + cmp(reg_icb, 0); + jg(icb_label, T_NEAR); - store_output(ur_w); + sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic); + sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic); + if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { + Label common_store, end_store; + + if (jcp.is_depthwise) + cmp(reg_oc_blocks, jcp.nb_ch - 1); + else + cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking); + + jne(common_store, T_NEAR); + + store_output(ur_w, 1); + jmp(end_store, T_NEAR); + + L(common_store); + store_output(ur_w, 0); + + L(end_store); + } else { + store_output(ur_w, 0); + } } void jit_avx512_core_u8s8s32x_fwd_kernel::generate() { int inp_shift_pad = jcp.typesize_in * (jcp.ur_w * jcp.stride_w - jcp.l_pad) - * jcp.ic * jcp.ngroups; + * jcp.ic_without_padding * jcp.ngroups; int inp_shift = jcp.typesize_in * - (jcp.ur_w * jcp.stride_w * jcp.ic * jcp.ngroups); + (jcp.ur_w * jcp.stride_w * jcp.ic_without_padding + * jcp.ngroups); int out_shift = jcp.typesize_out * - (jcp.ur_w * jcp.oc * jcp.ngroups); - int acc_shift = jcp.typesize_acc * - (jcp.ur_w * jcp.oc_block * jcp.nb_oc_blocking); - + (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups); preamble(); xor_(reg_scratch, reg_scratch); @@ -323,7 +363,17 @@ void jit_avx512_core_u8s8s32x_fwd_kernel::generate() mov(reg_out, ptr[param1 + GET_OFF(dst)]); mov(reg_ker, ptr[param1 + GET_OFF(filt)]); mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]); - mov(reg_acc_s32, ptr[param1 + GET_OFF(acc_s32)]); + + if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { + int tail_size = jcp.is_depthwise + ? jcp.ngroups % jcp.ch_block + : jcp.oc_without_padding % jcp.oc_block; + int mask = (1 << tail_size) - 1; + mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]); + Reg32 regw_tmp = reg_oi.cvt32(); + mov(regw_tmp, mask); + kmovw(ktail_mask, regw_tmp); + } int r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) @@ -331,52 +381,47 @@ void jit_avx512_core_u8s8s32x_fwd_kernel::generate() int n_oi = jcp.ow / jcp.ur_w; int r_pad1 = (jcp.ur_w * n_oi - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1); - if (r_pad1 > 0) n_oi--; + if (r_pad1 > 0 || jcp.ur_w_tail == 0) + n_oi--; xor_(reg_oi, reg_oi); if (jcp.ow == jcp.ur_w) { - compute_loop(jcp.ur_w, jcp.l_pad, r_pad); + compute_loop(jcp.ur_w, jcp.l_pad, r_pad, true); } else { if (n_oi == 0) { - compute_loop(jcp.ur_w, jcp.l_pad, r_pad1); + compute_loop(jcp.ur_w, jcp.l_pad, r_pad1, jcp.ur_w_tail == 0); add(reg_inp, inp_shift_pad); add(reg_out, out_shift); - add(reg_acc_s32, acc_shift); if (jcp.ur_w_tail != 0) { - compute_loop(jcp.ur_w_tail, 0, r_pad); + compute_loop(jcp.ur_w_tail, 0, r_pad, true); } } else { if (jcp.l_pad > 0) { - compute_loop(jcp.ur_w, jcp.l_pad, 0); + compute_loop(jcp.ur_w, jcp.l_pad, 0, false); add(reg_inp, inp_shift_pad); add(reg_out, out_shift); - add(reg_acc_s32, acc_shift); inc(reg_oi); } if ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1)) { - if (jcp.l_pad <= 0 && r_pad1 > 0) - n_oi--; Label ow_loop_label; L(ow_loop_label); { - compute_loop(jcp.ur_w, 0, 0); + compute_loop(jcp.ur_w, 0, 0, false); add(reg_inp, inp_shift); add(reg_out, out_shift); - add(reg_acc_s32, acc_shift); inc(reg_oi); cmp(reg_oi, n_oi); jl(ow_loop_label, T_NEAR); } } - if (r_pad1 > 0) { - compute_loop(jcp.ur_w, 0, r_pad1); + if (r_pad1 > 0 || jcp.ur_w_tail == 0) { + compute_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0); add(reg_inp, inp_shift); add(reg_out, out_shift); - add(reg_acc_s32, acc_shift); } if (jcp.ur_w_tail != 0) { - compute_loop(jcp.ur_w_tail, 0, r_pad); + compute_loop(jcp.ur_w_tail, 0, r_pad, true); } } } @@ -397,7 +442,7 @@ bool jit_avx512_core_u8s8s32x_fwd_kernel::post_ops_ok( && p.entry_[idx].eltwise.alpha == 0.; }; - switch (p.len_) { + switch (p.len_) { case 0: return true; case 1: return true && implication(jcp.with_eltwise, p.contain(sum, 0)) @@ -429,7 +474,6 @@ status_t jit_avx512_core_u8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, const memory_desc_wrapper dst_d(&dst_pd); const memory_desc_wrapper bias_d(&bias_pd); - const int regs = 28; const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; if (!(mayiuse(avx512_core) && @@ -444,7 +488,9 @@ status_t jit_avx512_core_u8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; jcp.mb = src_d.dims()[0]; jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ic_without_padding = jcp.ic; jcp.ih = src_d.dims()[2]; jcp.iw = src_d.dims()[3]; jcp.oh = dst_d.dims()[2]; @@ -469,12 +515,16 @@ status_t jit_avx512_core_u8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, jcp.ch_block = 16; jcp.ic_block = 1; jcp.oc_block = 1; - if (jcp.ngroups % jcp.ch_block != 0) - return status::unimplemented; } else { jcp.ch_block = 1; jcp.ic_block = 16; jcp.oc_block = 16; + + if (jcp.ngroups == 1) { + jcp.oc = rnd_up(jcp.oc, jcp.oc_block); + jcp.ic = rnd_up(jcp.ic, jcp.ic_block); + } + if (jcp.ic % jcp.ic_block != 0) return status::unimplemented; } @@ -482,6 +532,9 @@ status_t jit_avx512_core_u8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, jcp.dilate_h = cd.dilates[0]; jcp.dilate_w = cd.dilates[1]; + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + if (!post_ops_ok(jcp, attr)) return status::unimplemented; @@ -489,6 +542,8 @@ status_t jit_avx512_core_u8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, if (mayiuse(avx512_core_vnni)) jcp.ver = ver_vnni; + const int regs = (jcp.ver == ver_vnni && !jcp.is_depthwise) ? 31 : 28; + const auto w_format = with_groups ? (jcp.is_depthwise ? Goihw16g : gOIhw4i16o4i) : OIhw4i16o4i; if (weights_d.format() == any) @@ -516,25 +571,14 @@ status_t jit_avx512_core_u8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, jcp.typesize_in = types::data_type_size(src_d.data_type()); jcp.typesize_out = types::data_type_size(dst_d.data_type()); - jcp.typesize_acc = sizeof(int32_t); jcp.typesize_bia = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; - jcp.nb_ch = jcp.ngroups / jcp.ch_block; + jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block); jcp.nb_ic = jcp.ic / jcp.ic_block; jcp.nb_oc = jcp.oc / jcp.oc_block; - jcp.nb_ic_blocking = (!(jcp.nb_ic % 8)) - ? 8 - : (!(jcp.nb_ic % 4)) - ? 4 - : (!(jcp.nb_ic % 2)) ? 2 : 1; - if (jcp.kh >= 7 || jcp.kw >= 7) // Note: Large code issue on SKX - jcp.nb_ic_blocking = (!(jcp.nb_ic % 4)) - ? 4 - : (!(jcp.nb_ic % 2)) ? 2 : 1; - // If OC blocking is incommensurate with the number of OC blocks (general // requirement for all convolutions), or if it results in an unrolling // factor smaller than the left padding (special requirement for SSD:fc6), @@ -546,7 +590,8 @@ status_t jit_avx512_core_u8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, break; jcp.ur_w = regs / (jcp.nb_oc_blocking + 1); - if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow; + if (jcp.ow < jcp.ur_w) + jcp.ur_w = jcp.ow; jcp.ur_w_tail = jcp.ow % jcp.ur_w; bool args_ok = true |