summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.cpp')
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.cpp347
1 files changed, 196 insertions, 151 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.cpp
index 1113b277c..7d88935a4 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.cpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.cpp
@@ -69,41 +69,31 @@ bool jit_avx512_core_u8s8s32x_fwd_kernel::maybe_relu(int position)
void jit_avx512_core_u8s8s32x_fwd_kernel::prepare_output(int ur_w)
{
- Label l_first_load, l_ret;
-
- mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
- cmp(reg_channel, 0); // FISRT load
- je(l_first_load, T_NEAR);
-
- for (int k = 0; k < jcp.nb_oc_blocking; k++)
- for (int j = 0; j < ur_w; j++) {
- Zmm zmm = zmm_out(j, k);
- int offset = jcp.typesize_acc * (k*ur_w + j) * jcp.oc_block;
- vmovups(zmm, EVEX_compress_addr(reg_acc_s32, offset));
- }
- jmp(l_ret, T_NEAR);
-
- L(l_first_load);
for (int k = 0; k < jcp.nb_oc_blocking; k++)
for (int j = 0; j < ur_w; j++) {
Zmm zmm = zmm_out(j, k);
vpxord(zmm, zmm, zmm);
}
+}
- L(l_ret);
+void jit_avx512_core_u8s8s32x_fwd_kernel::cvt2ps(data_type_t type_in,
+ zmm_t zmm_in, const Xbyak::Operand &op, bool mask_flag) {
+ zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in;
+ switch (type_in) {
+ case data_type::f32:
+ case data_type::s32: vmovups(zmm, op); break;
+ case data_type::s8: vpmovsxbd(zmm, op); break;
+ case data_type::u8: vpmovzxbd(zmm, op); break;
+ default: assert(!"unsupported data type");
+ }
+ if (type_in != data_type::f32)
+ vcvtdq2ps(zmm_in, zmm_in);
}
-void jit_avx512_core_u8s8s32x_fwd_kernel::store_output(int ur_w)
+void jit_avx512_core_u8s8s32x_fwd_kernel::store_output(int ur_w,
+ int last_oc_block_flag)
{
- Label l_update_acc, l_ret;
-
- mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
-
- int adjusment = jcp.nb_ic - ((jcp.nb_ic_blocking <= 1)
- ? 0
- : jcp.nb_ic_blocking) - 1;
- cmp(reg_channel, adjusment); // LAST channel
- jl(l_update_acc, T_NEAR);
+ int nb_oc_block = jcp.nb_oc_blocking;
mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
@@ -117,47 +107,37 @@ void jit_avx512_core_u8s8s32x_fwd_kernel::store_output(int ur_w)
mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
vpxord(zmm_zero, zmm_zero, zmm_zero);
- for (int k = 0; k < jcp.nb_oc_blocking; k++) {
+ for (int k = 0; k < nb_oc_block; k++) {
+ const bool mask_flag = last_oc_block_flag == 1 && k == nb_oc_block - 1;
int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * jcp.oc_block);
auto zmm_bias = zmm_tmp;
if (jcp.with_bias) {
int bias_offset = jcp.typesize_bia * k * jcp.oc_block;
auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
- switch (jcp.bia_dt) {
- case data_type::f32:
- case data_type::s32: vmovups(zmm_bias, bias_addr); break;
- case data_type::s8: vpmovsxbd(zmm_bias, bias_addr); break;
- case data_type::u8: vpmovzxbd(zmm_bias, bias_addr); break;
- default: assert(!"unsupported dst data type");
- }
- if (jcp.bia_dt != data_type::f32)
- vcvtdq2ps(zmm_bias, zmm_bias);
+
+ cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
}
for (int j = 0; j < ur_w; j++) {
int aux_output_offset
= jcp.typesize_out * (k * jcp.oc_block
- + j * jcp.oc * jcp.ngroups);
+ + j * jcp.oc_without_padding * jcp.ngroups);
auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
- Xmm xmm = xmm_out(j, k);
Zmm zmm = zmm_out(j, k);
vcvtdq2ps(zmm, zmm);
if (jcp.with_bias)
vaddps(zmm, zmm, zmm_bias);
- vmulps(zmm, zmm, EVEX_compress_addr(reg_ptr_scales, scale_offset));
+
+ zmm_t mask_zmm = mask_flag ? zmm | ktail_mask | T_z : zmm;
+ vmulps(mask_zmm, zmm,
+ EVEX_compress_addr(reg_ptr_scales, scale_offset));
if (maybe_relu(0))
vmaxps(zmm, zmm_zero, zmm);
if (p_sum_scale) { // post_op: sum
auto zmm_prev_dst = zmm_bcast;
- switch (jcp.dst_dt) {
- case data_type::f32:
- case data_type::s32: vmovups(zmm_prev_dst, addr); break;
- case data_type::s8: vpmovsxbd(zmm_prev_dst, addr); break;
- case data_type::u8: vpmovzxbd(zmm_prev_dst, addr); break;
- default: assert(!"unknown dst_dt");
- }
- if (jcp.dst_dt != data_type::f32)
- vcvtdq2ps(zmm_prev_dst, zmm_prev_dst);
+
+ cvt2ps(jcp.dst_dt, zmm_prev_dst, addr, mask_flag);
+
if (*p_sum_scale == 1.f)
vaddps(zmm, zmm_prev_dst);
else
@@ -174,29 +154,28 @@ void jit_avx512_core_u8s8s32x_fwd_kernel::store_output(int ur_w)
else
assert(!"unimplemented");
}
+ }
+
+ for (int j = 0; j < ur_w; j++) {
+ int aux_output_offset = jcp.typesize_out * (k * jcp.oc_block
+ + j * jcp.oc_without_padding * jcp.ngroups);
+ auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
+
+ Zmm zmm = zmm_out(j, k);
+ zmm_t r_zmm = mask_flag ? zmm | ktail_mask : zmm;
switch (jcp.dst_dt) {
case data_type::f32:
- case data_type::s32: vmovups(addr, zmm); break;
- case data_type::s8: vpmovsdb(xmm, zmm); vmovups(addr, xmm); break;
- case data_type::u8: vpmovusdb(xmm, zmm); vmovups(addr, xmm); break;
+ case data_type::s32: vmovups(addr, r_zmm); break;
+ case data_type::s8: vpmovsdb(addr, r_zmm); break;
+ case data_type::u8: vpmovusdb(addr, r_zmm); break;
default: assert(!"unknown dst_dt");
}
}
}
- jmp(l_ret, T_NEAR);
-
- L(l_update_acc);
- for (int k = 0; k < jcp.nb_oc_blocking; k++)
- for (int j = 0; j < ur_w; j++) {
- Zmm zmm = zmm_out(j, k);
- int offset = jcp.typesize_acc * (k*ur_w + j) * jcp.oc_block;
- vmovups(EVEX_compress_addr(reg_acc_s32, offset), zmm);
- }
- L(l_ret);
}
-void jit_avx512_core_u8s8s32x_fwd_kernel::compute_loop(int ur_w,
- int pad_l, int pad_r)
+void jit_avx512_core_u8s8s32x_fwd_kernel::compute_ker(int ur_w,
+ int pad_l, int pad_r, int last_ic_block_flag)
{
int kw = jcp.kw;
int stride_w = jcp.stride_w;
@@ -205,48 +184,42 @@ void jit_avx512_core_u8s8s32x_fwd_kernel::compute_loop(int ur_w,
int ch_block_all = jcp.ch_block * ic_block * oc_block;
int nb_oc_block = jcp.nb_oc_blocking;
- int nb_ic_block = jcp.nb_ic_blocking;
Label kh_label, skip_kh_loop;
int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all;
- int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw * jcp.ic
- * jcp.ngroups;
+ int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
+ * jcp.ic_without_padding * jcp.ngroups;
- auto input_offset = [=](int oi, int nb_ic, int ic, int ki) {
+ auto input_offset = [=](int oi, int ic, int ki) {
return jcp.typesize_in
- * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) * jcp.ic
- * jcp.ngroups
- + 4 * ic + nb_ic * jcp.ic_block);
+ * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l)
+ * jcp.ic_without_padding * jcp.ngroups + 4 * ic);
};
- auto kernel_offset = [=](int ii, int nb_ic, int ic, int ki) {
+ auto kernel_offset = [=](int ii, int ic, int ki) {
return jcp.typesize_in
* ((ii * jcp.nb_ic * jcp.kh * jcp.kw + ki) * ch_block_all
- + 4 * ic * oc_block
- + nb_ic * jcp.kh * jcp.kw * ch_block_all);
+ + 4 * ic * oc_block);
};
auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
- if (jcp.is_depthwise) {
+ if (jcp.ver == ver_vnni) {
+ // also okay for depthwise since src is zero-extended
+ vpdpbusd(vreg_acc, vreg_src, vreg_wei);
+ } else if (jcp.is_depthwise) {
vpmulld(zmm_tmp, vreg_src, vreg_wei);
vpaddd(vreg_acc, vreg_acc, zmm_tmp);
} else {
- if (jcp.ver == ver_vnni) {
- vpdpbusd(vreg_acc, vreg_src, vreg_wei);
- } else {
- vpmaddubsw(zmm_tmp, vreg_src, vreg_wei);
- vpmaddwd(zmm_tmp, zmm_tmp, zmm_one);
- vpaddd(vreg_acc, vreg_acc, zmm_tmp);
- }
+ vpmaddubsw(zmm_tmp, vreg_src, vreg_wei);
+ vpmaddwd(zmm_tmp, zmm_tmp, zmm_one);
+ vpaddd(vreg_acc, vreg_acc, zmm_tmp);
}
};
- prepare_output(ur_w);
-
mov(aux_reg_inp, reg_inp);
mov(aux_reg_ker, reg_ker);
mov(reg_kj, reg_kh);
- if ((jcp.kh - 1) * (jcp.dilate_h + 1) < jcp.t_pad) {
+ if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
cmp(reg_kj, 0);
je(skip_kh_loop, T_NEAR);
}
@@ -254,37 +227,48 @@ void jit_avx512_core_u8s8s32x_fwd_kernel::compute_loop(int ur_w,
for (int ki = 0; ki < kw; ki++) {
int jj_start = get_ow_start(ki, pad_l);
int jj_end = get_ow_end(ur_w, ki, pad_r);
+ int tail_size = jcp.ic_without_padding % 4;
+ /* Skip the last loads of input if (ic%16)/4 < ic_block/4 */
+ int icb = jcp.is_depthwise
+ ? 1
+ : (last_ic_block_flag != no_last_block)
+ ? div_up((jcp.ic_without_padding % ic_block), 4)
+ : ic_block / 4;
+ for (int ic = 0; ic < icb; ic++) {
+ for (int jj = jj_start; jj < jj_end; jj++) {
+ int aux_input_offset = input_offset(jj, ic, ki);
+ if (jcp.is_depthwise) {
+ vpmovzxbd(zmm_inp(jj, nb_oc_block),
+ EVEX_compress_addr(
+ aux_reg_inp, aux_input_offset));
+ } else if (last_ic_block_flag == last_sp_block
+ && tail_size != 0 && ic == icb - 1) {
+ Xmm xmm_tmp = Xmm(zmm_inp(jj, nb_oc_block).getIdx());
+ for (int r = 0; r < tail_size; ++r)
+ vpinsrb(xmm_tmp, xmm_tmp,
+ ptr[aux_reg_inp + aux_input_offset + r], r);
+ vpbroadcastd(zmm_inp(jj, nb_oc_block), xmm_tmp);
+ } else {
+ vpbroadcastd(zmm_inp(jj, nb_oc_block),
+ EVEX_compress_addr(
+ aux_reg_inp, aux_input_offset));
+ }
+ }
- for (int cc = 0; cc < nb_ic_block; cc++) {
- for (int ic = 0; ic < (jcp.is_depthwise ? 1 : ic_block / 4);
- ic++) {
- for (int jj = jj_start; jj < jj_end; jj++) {
- int aux_input_offset = input_offset(jj, cc, ic, ki);
+ for (int ii = 0; ii < nb_oc_block; ii++) {
+ int aux_kernel_offset = kernel_offset(ii, ic, ki);
+ if (jj_end - jj_start > 0) {
if (jcp.is_depthwise)
- vpmovzxbd(zmm_inp(jj, nb_oc_block),
- EVEX_compress_addr(
- aux_reg_inp, aux_input_offset));
+ vpmovsxbd(
+ zmm_wei, EVEX_compress_addr(aux_reg_ker,
+ aux_kernel_offset));
else
- vpbroadcastd(zmm_inp(jj, nb_oc_block),
- EVEX_compress_addr(aux_reg_inp,
- aux_input_offset));
+ vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker,
+ aux_kernel_offset));
}
-
- for (int ii = 0; ii < nb_oc_block; ii++) {
- int aux_kernel_offset = kernel_offset(ii, cc, ic, ki);
- if (jj_end - jj_start > 0) {
- if (jcp.is_depthwise)
- vpmovsxbd(
- zmm_wei, EVEX_compress_addr(aux_reg_ker,
- aux_kernel_offset));
- else
- vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker,
- aux_kernel_offset));
- }
- for (int jj = jj_start; jj < jj_end; jj++) {
- compute(zmm_out(jj, ii), zmm_wei,
- zmm_inp(jj, nb_oc_block));
- }
+ for (int jj = jj_start; jj < jj_end; jj++) {
+ compute(zmm_out(jj, ii), zmm_wei,
+ zmm_inp(jj, nb_oc_block));
}
}
}
@@ -296,22 +280,78 @@ void jit_avx512_core_u8s8s32x_fwd_kernel::compute_loop(int ur_w,
jg(kh_label, T_NEAR);
}
L(skip_kh_loop);
+}
+
+void jit_avx512_core_u8s8s32x_fwd_kernel::compute_loop(
+ int ur_w, int pad_l, int pad_r, bool is_last_sp_block)
+{
+ prepare_output(ur_w);
+
+ // IC loop
+ Label icb_label;
+ mov(reg_icb, jcp.nb_ic);
+ L(icb_label);
+ if (jcp.ic_without_padding != jcp.ic) {
+ Label common_ker, end_ker;
+
+ cmp(reg_icb, 1); // The last IC block
+ jne(common_ker, T_NEAR);
+
+ compute_ker(ur_w, pad_l, pad_r,
+ is_last_sp_block ? last_sp_block : last_ic_block);
+ jmp(end_ker, T_NEAR);
+
+ L(common_ker);
+ compute_ker(ur_w, pad_l, pad_r, no_last_block);
+
+ L(end_ker);
+ } else {
+ compute_ker(ur_w, pad_l, pad_r, no_last_block);
+ }
+ // End of IC Loop
+ int inp_step = jcp.ic_block;
+ int ker_step = jcp.kh * jcp.kw * jcp.oc_block * jcp.ic_block;
+ add(reg_inp, jcp.typesize_in * inp_step);
+ add(reg_ker, jcp.typesize_in * ker_step);
+
+ dec(reg_icb);
+ cmp(reg_icb, 0);
+ jg(icb_label, T_NEAR);
- store_output(ur_w);
+ sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic);
+ sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic);
+ if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
+ Label common_store, end_store;
+
+ if (jcp.is_depthwise)
+ cmp(reg_oc_blocks, jcp.nb_ch - 1);
+ else
+ cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
+
+ jne(common_store, T_NEAR);
+
+ store_output(ur_w, 1);
+ jmp(end_store, T_NEAR);
+
+ L(common_store);
+ store_output(ur_w, 0);
+
+ L(end_store);
+ } else {
+ store_output(ur_w, 0);
+ }
}
void jit_avx512_core_u8s8s32x_fwd_kernel::generate()
{
int inp_shift_pad = jcp.typesize_in * (jcp.ur_w * jcp.stride_w - jcp.l_pad)
- * jcp.ic * jcp.ngroups;
+ * jcp.ic_without_padding * jcp.ngroups;
int inp_shift = jcp.typesize_in *
- (jcp.ur_w * jcp.stride_w * jcp.ic * jcp.ngroups);
+ (jcp.ur_w * jcp.stride_w * jcp.ic_without_padding
+ * jcp.ngroups);
int out_shift = jcp.typesize_out *
- (jcp.ur_w * jcp.oc * jcp.ngroups);
- int acc_shift = jcp.typesize_acc *
- (jcp.ur_w * jcp.oc_block * jcp.nb_oc_blocking);
-
+ (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups);
preamble();
xor_(reg_scratch, reg_scratch);
@@ -323,7 +363,17 @@ void jit_avx512_core_u8s8s32x_fwd_kernel::generate()
mov(reg_out, ptr[param1 + GET_OFF(dst)]);
mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
- mov(reg_acc_s32, ptr[param1 + GET_OFF(acc_s32)]);
+
+ if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
+ int tail_size = jcp.is_depthwise
+ ? jcp.ngroups % jcp.ch_block
+ : jcp.oc_without_padding % jcp.oc_block;
+ int mask = (1 << tail_size) - 1;
+ mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
+ Reg32 regw_tmp = reg_oi.cvt32();
+ mov(regw_tmp, mask);
+ kmovw(ktail_mask, regw_tmp);
+ }
int r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w
+ (jcp.kw - 1) * (jcp.dilate_w + 1)
@@ -331,52 +381,47 @@ void jit_avx512_core_u8s8s32x_fwd_kernel::generate()
int n_oi = jcp.ow / jcp.ur_w;
int r_pad1 = (jcp.ur_w * n_oi - 1) * jcp.stride_w
+ (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1);
- if (r_pad1 > 0) n_oi--;
+ if (r_pad1 > 0 || jcp.ur_w_tail == 0)
+ n_oi--;
xor_(reg_oi, reg_oi);
if (jcp.ow == jcp.ur_w) {
- compute_loop(jcp.ur_w, jcp.l_pad, r_pad);
+ compute_loop(jcp.ur_w, jcp.l_pad, r_pad, true);
} else {
if (n_oi == 0) {
- compute_loop(jcp.ur_w, jcp.l_pad, r_pad1);
+ compute_loop(jcp.ur_w, jcp.l_pad, r_pad1, jcp.ur_w_tail == 0);
add(reg_inp, inp_shift_pad);
add(reg_out, out_shift);
- add(reg_acc_s32, acc_shift);
if (jcp.ur_w_tail != 0) {
- compute_loop(jcp.ur_w_tail, 0, r_pad);
+ compute_loop(jcp.ur_w_tail, 0, r_pad, true);
}
} else {
if (jcp.l_pad > 0) {
- compute_loop(jcp.ur_w, jcp.l_pad, 0);
+ compute_loop(jcp.ur_w, jcp.l_pad, 0, false);
add(reg_inp, inp_shift_pad);
add(reg_out, out_shift);
- add(reg_acc_s32, acc_shift);
inc(reg_oi);
}
if ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1)) {
- if (jcp.l_pad <= 0 && r_pad1 > 0)
- n_oi--;
Label ow_loop_label;
L(ow_loop_label); {
- compute_loop(jcp.ur_w, 0, 0);
+ compute_loop(jcp.ur_w, 0, 0, false);
add(reg_inp, inp_shift);
add(reg_out, out_shift);
- add(reg_acc_s32, acc_shift);
inc(reg_oi);
cmp(reg_oi, n_oi);
jl(ow_loop_label, T_NEAR);
}
}
- if (r_pad1 > 0) {
- compute_loop(jcp.ur_w, 0, r_pad1);
+ if (r_pad1 > 0 || jcp.ur_w_tail == 0) {
+ compute_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0);
add(reg_inp, inp_shift);
add(reg_out, out_shift);
- add(reg_acc_s32, acc_shift);
}
if (jcp.ur_w_tail != 0) {
- compute_loop(jcp.ur_w_tail, 0, r_pad);
+ compute_loop(jcp.ur_w_tail, 0, r_pad, true);
}
}
}
@@ -397,7 +442,7 @@ bool jit_avx512_core_u8s8s32x_fwd_kernel::post_ops_ok(
&& p.entry_[idx].eltwise.alpha == 0.;
};
- switch (p.len_) {
+ switch (p.len_) {
case 0: return true;
case 1: return true
&& implication(jcp.with_eltwise, p.contain(sum, 0))
@@ -429,7 +474,6 @@ status_t jit_avx512_core_u8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
const memory_desc_wrapper dst_d(&dst_pd);
const memory_desc_wrapper bias_d(&bias_pd);
- const int regs = 28;
const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
if (!(mayiuse(avx512_core) &&
@@ -444,7 +488,9 @@ status_t jit_avx512_core_u8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
jcp.mb = src_d.dims()[0];
jcp.oc = dst_d.dims()[1] / jcp.ngroups;
+ jcp.oc_without_padding = jcp.oc;
jcp.ic = src_d.dims()[1] / jcp.ngroups;
+ jcp.ic_without_padding = jcp.ic;
jcp.ih = src_d.dims()[2];
jcp.iw = src_d.dims()[3];
jcp.oh = dst_d.dims()[2];
@@ -469,12 +515,16 @@ status_t jit_avx512_core_u8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
jcp.ch_block = 16;
jcp.ic_block = 1;
jcp.oc_block = 1;
- if (jcp.ngroups % jcp.ch_block != 0)
- return status::unimplemented;
} else {
jcp.ch_block = 1;
jcp.ic_block = 16;
jcp.oc_block = 16;
+
+ if (jcp.ngroups == 1) {
+ jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
+ jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
+ }
+
if (jcp.ic % jcp.ic_block != 0)
return status::unimplemented;
}
@@ -482,6 +532,9 @@ status_t jit_avx512_core_u8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
jcp.dilate_h = cd.dilates[0];
jcp.dilate_w = cd.dilates[1];
+ jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
+ - (jcp.ih + jcp.t_pad - 1);
+
if (!post_ops_ok(jcp, attr))
return status::unimplemented;
@@ -489,6 +542,8 @@ status_t jit_avx512_core_u8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
if (mayiuse(avx512_core_vnni))
jcp.ver = ver_vnni;
+ const int regs = (jcp.ver == ver_vnni && !jcp.is_depthwise) ? 31 : 28;
+
const auto w_format = with_groups
? (jcp.is_depthwise ? Goihw16g : gOIhw4i16o4i) : OIhw4i16o4i;
if (weights_d.format() == any)
@@ -516,25 +571,14 @@ status_t jit_avx512_core_u8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
jcp.typesize_in = types::data_type_size(src_d.data_type());
jcp.typesize_out = types::data_type_size(dst_d.data_type());
- jcp.typesize_acc = sizeof(int32_t);
jcp.typesize_bia = jcp.with_bias
? types::data_type_size(bias_d.data_type())
: 0;
- jcp.nb_ch = jcp.ngroups / jcp.ch_block;
+ jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
jcp.nb_ic = jcp.ic / jcp.ic_block;
jcp.nb_oc = jcp.oc / jcp.oc_block;
- jcp.nb_ic_blocking = (!(jcp.nb_ic % 8))
- ? 8
- : (!(jcp.nb_ic % 4))
- ? 4
- : (!(jcp.nb_ic % 2)) ? 2 : 1;
- if (jcp.kh >= 7 || jcp.kw >= 7) // Note: Large code issue on SKX
- jcp.nb_ic_blocking = (!(jcp.nb_ic % 4))
- ? 4
- : (!(jcp.nb_ic % 2)) ? 2 : 1;
-
// If OC blocking is incommensurate with the number of OC blocks (general
// requirement for all convolutions), or if it results in an unrolling
// factor smaller than the left padding (special requirement for SSD:fc6),
@@ -546,7 +590,8 @@ status_t jit_avx512_core_u8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
break;
jcp.ur_w = regs / (jcp.nb_oc_blocking + 1);
- if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
+ if (jcp.ow < jcp.ur_w)
+ jcp.ur_w = jcp.ow;
jcp.ur_w_tail = jcp.ow % jcp.ur_w;
bool args_ok = true