diff options
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp')
-rw-r--r-- | inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp | 615 |
1 files changed, 374 insertions, 241 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp index d16cd1ac9..80206ceb7 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp @@ -74,14 +74,15 @@ void jit_avx512_common_conv_fwd_kernel::prepare_output(int ur_w) for (int j = 0; j < ur_w; j++) { Zmm zmm = zmm_out(j, k); vpxord(zmm, zmm, zmm); - int aux_output_offset = get_output_offset(j, k); - mic_prefetcht1(EVEX_compress_addr(reg_out_prf, aux_output_offset)); + size_t aux_output_offset = get_output_offset(j, k); + mic_prefetcht1(EVEX_compress_addr_safe(reg_out_prf, + aux_output_offset, reg_out_long_offt)); } } void jit_avx512_common_conv_fwd_kernel::store_output(int ur_w) { - Label no_update_label, store_label, relu_label; + Label no_update_label, store_label, postproc_label; mov(reg_channel, ptr[param1 + GET_OFF(channel)]); if (jcp.with_bias) { @@ -96,15 +97,16 @@ void jit_avx512_common_conv_fwd_kernel::store_output(int ur_w) for (int k = 0; k < jcp.nb_oc_blocking; k++) for (int j = 0; j < ur_w; j++) { Zmm zmm = zmm_out(j, k); - int aux_output_offset = get_output_offset(j, k); - vadd(zmm, reg_out, aux_output_offset); + size_t aux_output_offset = get_output_offset(j, k); + vadd(zmm, + make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt)); } if (!jcp.with_sum) { - jmp(relu_label, T_NEAR); + jmp(postproc_label, T_NEAR); } else { cmp(reg_channel, 0); - jne(relu_label, T_NEAR); + jne(postproc_label, T_NEAR); } L(no_update_label); @@ -113,26 +115,51 @@ void jit_avx512_common_conv_fwd_kernel::store_output(int ur_w) int bias_offset = jcp.typesize_out * k * jcp.oc_block; for (int j = 0; j < ur_w; j++) { Zmm zmm = zmm_out(j, k); - vadd(zmm, reg_bias, bias_offset); + vadd(zmm, EVEX_compress_addr(reg_bias, bias_offset)); } mic_prefetcht1(EVEX_compress_addr(reg_bias, bias_offset + 64)); } } - L(relu_label); - if (this->jcp.with_eltwise) { - cmp(reg_channel, jcp.nb_ic - 1); - jl(store_label, T_NEAR); + L(postproc_label); - inject(eltwise_generator.prepareConstants(jcp.eltwise_alpha, jcp.eltwise_beta)); + cmp(reg_channel, jcp.nb_ic - 1); + jl(store_label, T_NEAR); - // TODO (dmitrygo): need to find appropriate way to share labels. - mov(imm_addr64, l_table); - for (int k = 0; k < jcp.nb_oc_blocking; k++) { - for (int j = 0; j < ur_w; j++) { - Zmm zmm_reg_out = zmm_out(j, k); - inject(eltwise_generator.computeVector(zmm_reg_out, zmm_reg_out)); + int eltwise_inj_idx = 0; + int depthwise_inj_idx = 0; + const auto &p = attr_.post_ops_; + + if (p.len_ == 0 && eltwise_injectors.size() == 1) { + for (int k = 0; k < jcp.nb_oc_blocking; k++) + eltwise_injectors[0]->compute_vector_range( + k*jcp.ur_w, k*jcp.ur_w + ur_w); + } + + for (int i = 0; i < p.len_; i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_eltwise()) { + for (int k = 0; k < jcp.nb_oc_blocking; k++) + eltwise_injectors[eltwise_inj_idx]->compute_vector_range( + k*jcp.ur_w, k*jcp.ur_w + ur_w); + + eltwise_inj_idx++; + } else if (post_op.is_depthwise()) { + mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data)); + mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data)); + + add(reg_d_weights, ptr[this->param1 + GET_OFF(oc_off)]); + add(reg_d_bias, ptr[this->param1 + GET_OFF(oc_off)]); + + for (int k = 0; k < jcp.nb_oc_blocking; k++) { + depthwise_injectors[depthwise_inj_idx]->compute_vector_range( + k*jcp.ur_w, k*jcp.ur_w + ur_w, reg_d_weights, reg_d_bias); + + add(reg_d_weights, jcp.oc_block * sizeof(float)); + add(reg_d_bias, jcp.oc_block * sizeof(float)); } + + depthwise_inj_idx++; } } @@ -140,10 +167,12 @@ void jit_avx512_common_conv_fwd_kernel::store_output(int ur_w) for (int k = 0; k < jcp.nb_oc_blocking; k++) for (int j = 0; j < ur_w; j++) { Zmm zmm = zmm_out(j, k); - int aux_output_offset - = typesize * (k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block; - vmovups(EVEX_compress_addr(reg_out, aux_output_offset), zmm); - mic_prefetcht0(EVEX_compress_addr(reg_out_prf, aux_output_offset)); + size_t aux_output_offset = (size_t)typesize * + ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block; + vmovups(EVEX_compress_addr_safe(reg_out, aux_output_offset, + reg_out_long_offt), zmm); + mic_prefetcht0(EVEX_compress_addr_safe(reg_out_prf, + aux_output_offset, reg_out_long_offt)); } } @@ -160,13 +189,19 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w, Label kh_label, kd_label, skip_kd_loop; + prepare_output(ur_w); + if (jcp.ndims == 4) { mov(aux_reg_inp, reg_inp); mov(aux_reg_ker, reg_ker); mov(aux_reg_inp_prf, reg_inp_prf); } - prepare_output(ur_w); + size_t max_input_offset = (size_t)jcp.typesize_in + * ((size_t)(kw * (jcp.dilate_w + 1) + ur_w * stride_w - pad_l) + + (size_t)ic_block * iw * ih * jcp.id); + assert(reg_inp_prf == reg_long_offt); + if (max_input_offset > INT_MAX) push(reg_inp_prf); if (jcp.ndims == 5) { push(reg_out_prf); @@ -177,7 +212,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w, mov(aux_reg_inp_d, reg_inp); mov(aux_reg_inp_d_prf, reg_inp_prf); - if ((jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) { + if ((jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) { cmp(reg_ki, 0); je(skip_kd_loop, T_NEAR); } @@ -185,7 +220,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w, } mov(reg_kj, reg_kh); Label skip_kh_loop; - if ((jcp.kh - 1) * (jcp.dilate_h + 1) < jcp.t_pad) { + if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { cmp(reg_kj, 0); je(skip_kh_loop, T_NEAR); } @@ -214,11 +249,12 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w, int j_end = get_ow_end(ur_w, ki, pad_r); for (int j = j_start, prf_count=0; j < j_end; j++) { - int aux_input_offset = jcp.typesize_in - * ((ki * (jcp.dilate_w + 1) + j * stride_w - pad_l) - + ic * iw * ih * jcp.id); + size_t aux_input_offset = (size_t)jcp.typesize_in + * ((size_t)(ki * (jcp.dilate_w + 1) + j * stride_w + - pad_l) + (size_t)ic * iw * ih * jcp.id); v4fmaddps(zmm_out(j, 0), zmm_ker(0), - EVEX_compress_addr(aux_reg_inp, aux_input_offset)); + EVEX_compress_addr_safe(aux_reg_inp, aux_input_offset, + reg_long_offt)); if (ki + prf_count < kw && prf_count < 4 && ((ki < 2 && j % 4) || j % 2)) { int aux_ker_offset = jcp.typesize_in @@ -230,13 +266,13 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w, } if (ki == 0 && j % (64 / (stride_w * jcp.typesize_in)) == 0) { - mic_prefetcht0(EVEX_compress_addr(aux_reg_inp_prf, - aux_input_offset)); + mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp_prf, + aux_input_offset, reg_long_offt)); } if (ki == 1 && j % (64 / (stride_w * jcp.typesize_in)) == 0) { - mic_prefetcht0(EVEX_compress_addr(aux_reg_inp, - aux_input_offset+jcp.typesize_in * iw)); + mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset+jcp.typesize_in * iw, reg_long_offt)); } } } @@ -266,6 +302,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w, } store_output(ur_w); + if (max_input_offset > INT_MAX) pop(reg_inp_prf); } void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w, @@ -287,13 +324,6 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w, assert(jcp.oc % jcp.nb_oc_blocking == 0); - if (jcp.ndims == 4) { - mov(aux_reg_inp, reg_inp); - mov(aux_reg_ker, reg_ker); - mov(aux_reg_ker_prf, reg_ker_prf); - mov(aux_reg_inp_prf, reg_inp_prf); - } - auto kernel_offset = [=](int ocb, int ic, int ki) { int blk_idx = ocb * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd + ki; int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; @@ -322,6 +352,13 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w, prepare_output(ur_w); + if (jcp.ndims == 4) { + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_ker_prf, reg_ker_prf); + mov(aux_reg_inp_prf, reg_inp_prf); + } + if (jcp.ndims == 5) { push(reg_out_prf); push(reg_out); @@ -332,7 +369,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w, mov(aux_reg_inp_d_prf, reg_inp_prf); mov(aux_reg_ker_d_prf, reg_ker_prf); - if ((jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) { + if ((jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) { cmp(reg_ki, 0); je(skip_kd_loop, T_NEAR); } @@ -342,7 +379,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w, mov(reg_kj, reg_kh); } Label skip_kh_loop; - if ((jcp.kh - 1) * (jcp.dilate_h + 1) < jcp.t_pad) { + if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { cmp(reg_kj, 0); je(skip_kh_loop, T_NEAR); } @@ -548,6 +585,11 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w, mov(aux_reg_ker_prf, reg_ker_prf); } + size_t max_input_offset = (size_t)jcp.typesize_in * ic_block * iw * ih * id; + assert(reg_inp_prf == reg_long_offt); + if (max_input_offset > INT_MAX) push(reg_inp_prf); + + if (jcp.ndims == 5) { push(reg_out_prf); push(reg_out); @@ -558,7 +600,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w, mov(aux_reg_inp_d_prf, reg_inp_prf); mov(aux_reg_ker_d_prf, reg_ker_prf); - if ((jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) { + if ((jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) { cmp(reg_ki, 0); je(skip_kd_loop, T_NEAR); } @@ -568,7 +610,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w, mov(reg_kj, reg_kh); } Label skip_kh_loop; - if ((jcp.kh - 1) * (jcp.dilate_h + 1) < jcp.t_pad) { + if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { cmp(reg_kj, 0); je(skip_kh_loop, T_NEAR); } @@ -609,10 +651,10 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w, int j_start = get_ow_start(ki, pad_l); int j_end = get_ow_end(ur_w, ki, pad_r); for (int j = j_start; j < j_end; j++) { - int aux_input_offset = get_input_offset(ki, ic, j, pad_l); - vfmadd231ps(zmm_out(j, 0), zmm_kernel, - EVEX_compress_addr( - aux_reg_inp, aux_input_offset, true)); + size_t aux_input_offset = get_input_offset(ki, ic, j, pad_l); + auto addr = EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset, reg_long_offt, true); + vfmadd231ps(zmm_out(j, 0), zmm_kernel, addr); int fma_idx = step * ur_w + j; int prf_slot_idx = fma_idx / prf_inst_spacing; if (fma_idx % prf_inst_spacing == prf_inst_trigger) { @@ -627,8 +669,8 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w, } else if (prf_inp) { int inp_prf_idx = prf_slot_idx - ker_prfs; if (inp_prf_idx < num_inp_prfs) { - int inp_prf_stride = nstl::max(kw, stride_w); - int inp_prf_offset; + size_t inp_prf_stride = nstl::max(kw, stride_w); + size_t inp_prf_offset; if (!jcp.is_1stconv) { inp_prf_offset = ic_block * jcp.typesize_in @@ -636,17 +678,18 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w, * inp_prf_stride + (inp_prf_idx % kw)); } else { - int ic_prf_stride - = jcp.typesize_in * iw * ih * id; - int iw_prf_stride + size_t ic_prf_stride = + (size_t)jcp.typesize_in * iw * ih * id; + size_t iw_prf_stride = jcp.typesize_in * simd_w; inp_prf_offset = ((inp_prf_idx / ic_block) * iw_prf_stride + (inp_prf_idx % ic_block) * ic_prf_stride); } - mic_prefetcht0(EVEX_compress_addr( - aux_reg_inp_prf, inp_prf_offset)); + mic_prefetcht0(EVEX_compress_addr_safe( + aux_reg_inp_prf, inp_prf_offset, + reg_long_offt)); } } } @@ -686,7 +729,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w, pop(reg_out); pop(reg_out_prf); } - + if (max_input_offset > INT_MAX) pop(reg_inp_prf); store_output(ur_w); } @@ -707,18 +750,19 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w, auto input_offset = [=](int oi, int ic, int ki) { - return jcp.typesize_in - * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) * inp_mul - + ic * (!jcp.is_1stconv ? 1 : jcp.iw * jcp.ih * jcp.id)); + return (size_t)jcp.typesize_in + * ((size_t)(ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) + * inp_mul + (size_t)ic + * (!jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id)); }; + prepare_output(ur_w); + if (jcp.ndims == 4) { mov(aux_reg_inp, reg_inp); mov(aux_reg_ker, reg_ker); } - prepare_output(ur_w); - if (jcp.ndims == 5) { push(reg_out); @@ -726,7 +770,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w, mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); mov(aux_reg_inp_d, reg_inp); - if ((jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) { + if ((jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) { cmp(reg_ki, 0); je(skip_kd_loop, T_NEAR); } @@ -735,7 +779,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w, } else { mov(reg_kj, reg_kh); } - if ((jcp.kh - 1) * (jcp.dilate_h + 1) < jcp.t_pad) { + if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { cmp(reg_kj, 0); je(skip_kh_loop, T_NEAR); } @@ -753,9 +797,10 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w, for (int ic = 0; ic < ic_block; ic++) { if (jcp.kernel_kind == expl_bcast) { for (int jj = jj_start; jj < jj_end; jj++) { - int aux_input_offset = input_offset(jj, ic, ki); + size_t aux_input_offset = input_offset(jj, ic, ki); vbroadcastss(zmm_inp(jj, nb_oc_block), - ptr[aux_reg_inp + aux_input_offset]); + EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset, reg_long_offt)); } } for (int ii = 0; ii < nb_oc_block; ii++) { @@ -769,10 +814,12 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w, if (jcp.kernel_kind == expl_bcast) vfmadd231ps(zmm_out(jj, ii), zmm_inp(jj, nb_oc_block), zmm_wei); - else + else { + size_t aux_input_offset = input_offset(jj, ic, ki); vfmadd231ps(zmm_out(jj, ii), zmm_wei, - EVEX_compress_addr(aux_reg_inp, - input_offset(jj, ic, ki), true)); + EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset, reg_long_offt, true)); + } } } } @@ -813,6 +860,13 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni( const int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block; + size_t max_input_offset = (size_t)jcp.typesize_in + * jcp.ic_block * jcp.iw * jcp.ih * jcp.id; + assert(reg_inp_prf == reg_long_offt); + if (max_input_offset > INT_MAX) push(reg_inp_prf); + + prepare_output(ur_w); + if (jcp.ndims == 4) { mov(aux_reg_inp, reg_inp); mov(aux_reg_ker, reg_ker); @@ -820,8 +874,6 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni( mov(aux_reg_inp_prf, reg_inp_prf); } - prepare_output(ur_w); - Label skip_kh_loop, skip_kd_loop; if (jcp.ndims == 5) { @@ -834,7 +886,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni( mov(aux_reg_inp_d_prf, reg_inp_prf); mov(aux_reg_ker_d_prf, reg_ker_prf); - if (jcp.kd <= jcp.f_pad) { + if ((jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) { cmp(reg_ki, 0); je(skip_kd_loop, T_NEAR); } @@ -843,7 +895,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni( } else { mov(reg_kj, reg_kh); } - if ((jcp.kh - 1) * (jcp.dilate_h + 1) < jcp.t_pad) { + if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { cmp(reg_kj, 0); je(skip_kh_loop, T_NEAR); } @@ -861,9 +913,10 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni( for (int ic = 0; ic < jcp.ic_block / 2; ic += channel_inc) { if (jcp.kernel_kind == expl_bcast) { for (int oi = ow_start; oi < ow_end; oi++) { - int input_offset = get_input_offset(ki, ic, oi, pad_l); + size_t input_offset = get_input_offset(ki, ic, oi, pad_l); vpbroadcastd(zmm_inp(oi, jcp.nb_oc_blocking), - ptr[aux_reg_inp + input_offset]); + EVEX_compress_addr_safe(aux_reg_inp, input_offset, + reg_long_offt)); } } for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) { @@ -881,14 +934,14 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni( } } for (int oi = ow_start, prf_count = 0; oi < ow_end; oi++) { - int input_offset = get_input_offset(ki, ic, oi, pad_l); - + size_t input_offset = get_input_offset(ki, ic, oi, pad_l); if (jcp.kernel_kind == expl_bcast) { vpdpwssd(zmm_out(oi, kk), zmm_wei, zmm_inp(oi, jcp.nb_oc_blocking)); } else { vpXdpwssd(zmm_out(oi, kk), Zmm(ker_reg_base_idx), - aux_reg_inp, input_offset); + EVEX_compress_addr_safe(aux_reg_inp, input_offset, + reg_long_offt, jcp.ver != ver_4vnni)); } if ((oi % 2) && (prf_count < ker_load_number)) { int kernel_offset = get_kernel_offset( @@ -897,12 +950,12 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni( kernel_offset)); } if (!(oi % 2) && ki == 0 && ic == 0 && kk == 0) { - mic_prefetcht1(EVEX_compress_addr(aux_reg_inp_prf, - input_offset)); + mic_prefetcht1(EVEX_compress_addr_safe( + aux_reg_inp_prf, input_offset, reg_long_offt)); } if (!(oi % 2) && ki == 1 && ic == 0 && kk == 0) { - mic_prefetcht0(EVEX_compress_addr(aux_reg_inp, - input_offset + shift_input_ptr)); + mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp, + input_offset + shift_input_ptr, reg_long_offt)); } } } @@ -936,7 +989,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni( pop(reg_out); pop(reg_out_prf); } - + if (max_input_offset > INT_MAX) pop(reg_inp_prf); store_output(ur_w); } @@ -967,6 +1020,30 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop(int ur_w, void jit_avx512_common_conv_fwd_kernel::generate() { + if (jcp.with_eltwise) { + eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>( + this, jcp.eltwise_alg, jcp.eltwise_alpha, 0 + )); + } + + const auto &p = attr_.post_ops_; + for (int i = 0; i < p.len_; i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_eltwise()) { + eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>( + this, + post_op.eltwise.alg, + post_op.eltwise.alpha, + post_op.eltwise.beta + )); + } else if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx512_common>( + this, + post_op.depthwise.alg + )); + } + } + int iw = jcp.iw; int ow = jcp.ow; int kw = jcp.kw; @@ -975,25 +1052,11 @@ void jit_avx512_common_conv_fwd_kernel::generate() int ur_w_tail = jcp.ur_w_tail; int dilate_w = jcp.dilate_w + 1; int stride_w = jcp.stride_w; - int ic_block = jcp.ic_block; - int oc_block = jcp.oc_block; - int inp_mult = !jcp.is_1stconv ? ic_block : 1; + int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; int inp_shift_pad = jcp.typesize_in * (ur_w * stride_w - l_pad) * inp_mult; - int inp_shift = jcp.typesize_in * (ur_w * stride_w * inp_mult); - int out_shift = jcp.typesize_out * (ur_w * oc_block); - - nstl::vector<int> shared_vecs; - shared_vecs.push_back(27); - shared_vecs.push_back(28); - shared_vecs.push_back(29); - shared_vecs.push_back(30); - shared_vecs.push_back(31); - - nstl::vector<Reg64> shared_regs; - shared_regs.push_back(imm_addr64); - - eltwise_generator.init(jcp.eltwise_alg, shared_vecs, shared_regs); + int inp_shift = jcp.typesize_in * ur_w * stride_w * inp_mult; + int out_shift = jcp.typesize_out * ur_w * jcp.oc_block; preamble(); mov(reg_inp, ptr[param1 + GET_OFF(src)]); @@ -1069,11 +1132,8 @@ void jit_avx512_common_conv_fwd_kernel::generate() postamble(); - // TODO (dmitrygo): need to find appropriate way to share labels. - align(64); - L(l_table); - inject(eltwise_generator.prepareTable()); - eltwise_generator.release(); + for (auto& inj : eltwise_injectors) + inj->prepare_table(); } bool jit_avx512_common_conv_fwd_kernel::post_ops_ok( @@ -1081,16 +1141,22 @@ bool jit_avx512_common_conv_fwd_kernel::post_ops_ok( const auto &p = attr.post_ops_; auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); }; auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); }; switch (p.len_) { case 0: return true; // no post_ops case 1: - return true // sum OR relu - && !jcp.with_eltwise && (is_eltwise(0) || is_sum(0)); + return true // sum OR eltwise OR depthwise + && !jcp.with_eltwise && (is_simple(0) || is_sum(0)); case 2: return true // sum->relu - && !jcp.with_eltwise && (is_sum(0) && is_eltwise(1)); + && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) || + (is_simple(0) && is_simple(1))); + case 3: + return true // sum->relu + && !jcp.with_eltwise && (is_sum(0) && is_simple(1) && is_simple(2)); default: return false; } @@ -1145,12 +1211,18 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf( jcp.stride_w = cd.strides[ndims-3]; jcp.src_fmt = src_d.format(); jcp.with_eltwise = with_relu; + jcp.eltwise_alg = mkldnn_eltwise_relu; jcp.eltwise_alpha = relu_negative_slope; jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; jcp.dilate_h = cd.dilates[ndims-4]; jcp.dilate_w = cd.dilates[ndims-3]; + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + jcp.back_pad = (jcp.od - 1) * jcp.stride_d + + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1); + jcp.is_1stconv = is_1stconv(jcp); jcp.oc_block = simd_w; @@ -1165,34 +1237,17 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf( jcp.oc = rnd_up(jcp.oc, jcp.oc_block); jcp.ic = rnd_up(jcp.ic, jcp.ic_block); } - - - jcp.with_eltwise = with_relu; - jcp.eltwise_alg = mkldnn_eltwise_relu; - jcp.eltwise_alpha = relu_negative_slope; + bool args_ok = true + && jcp.oc % jcp.oc_block == 0 + && jcp.ic % jcp.ic_block == 0; + if (!args_ok) + return status::unimplemented; if (!post_ops_ok(jcp, attr)) return status::unimplemented; const auto &p = attr.post_ops_; jcp.with_sum = p.find(primitive_kind::sum) != -1; - if (!jcp.with_eltwise) { - int eltwise_ind = p.find(primitive_kind::eltwise); - if (eltwise_ind != -1) { - jcp.with_eltwise = true; - jcp.eltwise_alg = p.entry_[eltwise_ind].eltwise.alg; - jcp.eltwise_alpha = p.entry_[eltwise_ind].eltwise.alpha; - jcp.eltwise_beta = p.entry_[eltwise_ind].eltwise.beta; - jcp.eltwise_scale = p.entry_[eltwise_ind].eltwise.scale; - } - } - - if (jcp.with_eltwise) { - int nvecs_elt = jit_uni_eltwise_vector_f32<avx512_common>::sharedVecsCount(jcp.eltwise_alg); - int elt_regs = 32 - nvecs_elt; - - regs = nstl::min(regs, elt_regs); - } jcp.is_1stconv = is_1stconv(jcp); if (jcp.ic % simd_w != 0 && !jcp.is_1stconv) @@ -1368,7 +1423,8 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf( if (one_of(jcp.ver, ver_4vnni, ver_4fma) && !jcp.is_1stconv) { if (jcp.kw == 3 && jcp.kh == 3 && jcp.ow == 7 && jcp.oh == 7) { - jcp.nb_oc_blocking = 2; + if (jcp.nb_oc % 2 == 0) + jcp.nb_oc_blocking = 2; } else { for (int i = jcp.nb_oc; i > 0; i--) if (i * jcp.ur_w <= regs && jcp.nb_oc % i == 0) { @@ -1398,6 +1454,7 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf( jcp.kernel_kind = embd_bcast; unsigned int inp_size = jcp.mb * (jcp.ih / jcp.stride_h) * (jcp.iw / jcp.stride_w) * jcp.ic; + if (inp_size == 0) inp_size = 1; unsigned int wei_size = jcp.ic * jcp.oc * jcp.kh * jcp.kw; // Estimate whether we need to limit the number of threads @@ -1466,9 +1523,7 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf( jcp.ur_w_tail = jcp.ow % jcp.ur_w; - bool args_ok = true - && jcp.oc % jcp.oc_block == 0 - && jcp.ic % jcp.ic_block == 0 + args_ok = true && jcp.l_pad <= jcp.ur_w && jcp.ic <= src_d.blocking_desc().padding_dims[1] && jcp.oc <= dst_d.blocking_desc().padding_dims[1] @@ -1491,11 +1546,12 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf( if (jcp.ver == ver_4fma) { for (int divf = 2, temp_nb = jcp.nb_ic_L2; divf <= jcp.nb_ic; divf++) { - int l2_src = jcp.iw * jcp.ic_block * jcp.ih * temp_nb * jcp.id; - int l2_dst = jcp.ow * jcp.oc_block * jcp.nb_oc_blocking + size_t l2_src + = (size_t)jcp.iw * jcp.ic_block * jcp.ih * temp_nb * jcp.id; + size_t l2_dst = (size_t)jcp.ow * jcp.oc_block * jcp.nb_oc_blocking * jcp.oh * jcp.od; - int l2_filt = jcp.kw * jcp.oc_block * jcp.ic_block * jcp.kh - * jcp.nb_oc_blocking * temp_nb * jcp.kd; + size_t l2_filt = (size_t)jcp.kw * jcp.oc_block * jcp.ic_block + * jcp.kh * jcp.nb_oc_blocking * temp_nb * jcp.kd; if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) { if (jcp.kh == 3 && jcp.oh == 7) { jcp.nb_ic_L2 = 1; @@ -1510,10 +1566,6 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf( } } - // TODO (dmitrygo): we need at least 5 vector registers to fuse eltwise. Need to adapt unrolling scheme - if (jcp.with_eltwise && (jcp.nb_oc_blocking * jcp.ur_w) > 27) - return status::unimplemented; - return status::success; } @@ -1523,9 +1575,11 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::prepare_output(int ur_w) for (int j = 0; j < ur_w; j++) { Zmm zmm = zmm_out(j, k); vpxord(zmm, zmm, zmm); - int aux_src_offset - = typesize * (k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block; - mic_prefetcht1(EVEX_compress_addr(reg_src_prf, aux_src_offset)); + size_t aux_src_offset + = (size_t)typesize * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) + * jcp.ic_block; + mic_prefetcht1(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset, + reg_long_offt)); } } } @@ -1540,9 +1594,10 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::store_output(int ur_w) for (int k = 0; k < jcp.nb_ic_blocking; k++) { for (int j = 0; j < ur_w; j++) { Zmm zmm = zmm_out(j, k); - int aux_src_offset - = typesize * (k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block; - vadd(zmm, reg_src, aux_src_offset); + size_t aux_src_offset = (size_t)typesize + * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block; + vadd(zmm, EVEX_compress_addr_safe(reg_src, aux_src_offset, + reg_long_offt)); } } @@ -1550,10 +1605,12 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::store_output(int ur_w) for (int k = 0; k < jcp.nb_ic_blocking; k++) { for (int j = 0; j < ur_w; j++) { Zmm zmm = zmm_out(j, k); - int aux_src_offset - = typesize * (k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block; - vmovups(EVEX_compress_addr(reg_src, aux_src_offset), zmm); - mic_prefetcht0(EVEX_compress_addr(reg_src_prf, aux_src_offset)); + size_t aux_src_offset = (size_t)typesize + * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block; + vmovups(EVEX_compress_addr_safe(reg_src, aux_src_offset, + reg_long_offt), zmm); + mic_prefetcht0(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset, + reg_long_offt)); } } } @@ -1574,13 +1631,6 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma( bool check_last_kh = (jcp.kh > 3); - if (jcp.ndims == 4) { - mov(aux_reg_dst, reg_dst); - mov(aux_reg_ker, reg_ker); - mov(aux_reg_dst_prf, reg_dst_prf); - mov(aux_reg_ker_prf, reg_ker_prf); - } - auto kernel_offset = [=](int icb, int oc, int ki) { int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki; int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; @@ -1606,6 +1656,13 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma( prepare_output(ur_w); + if (jcp.ndims == 4) { + mov(aux_reg_dst, reg_dst); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_dst_prf, reg_dst_prf); + mov(aux_reg_ker_prf, reg_ker_prf); + } + if (jcp.ndims == 5) { push(reg_src_prf); push(reg_src); @@ -1764,11 +1821,6 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_vnni( const int ker_load_number = jcp.ver == ver_4vnni ? 4 : 1; Label kh_label; - mov(aux_reg_dst, reg_dst); - mov(aux_reg_ker, reg_ker); - mov(aux_reg_dst_prf, reg_dst_prf); - mov(aux_reg_ker_prf, reg_ker_prf); - auto kernel_offset = [=](int icb, int oc, int ki) { int blk_idx = icb * jcp.kh * jcp.kw + ki; int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; @@ -1778,6 +1830,11 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_vnni( prepare_output(ur_w); + mov(aux_reg_dst, reg_dst); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_dst_prf, reg_dst_prf); + mov(aux_reg_ker_prf, reg_ker_prf); + mov(reg_kj, reg_kh); L(kh_label); { for (int ki = 0; ki < kw; ki++) { @@ -2047,13 +2104,13 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core( return typesize * (blk_offset + oc_offset); }; + prepare_output(ur_w); + if (jcp.ndims == 4) { mov(aux_reg_dst, reg_dst); mov(aux_reg_ker, reg_ker); } - prepare_output(ur_w); - if (jcp.ndims == 5) { push(reg_src_prf); push(reg_src); @@ -2067,10 +2124,8 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core( } else { mov(reg_kj, reg_kh); } - if ((jcp.kh - 1) * (jcp.dilate_h + 1) < jcp.t_pad) { - cmp(reg_kj, 0); - je(skip_kh_loop, T_NEAR); - } + cmp(reg_kj, 0); + je(skip_kh_loop, T_NEAR); if (jcp.ndims == 5) { mov(aux_reg_dst, aux_reg_dst_d); @@ -2095,7 +2150,7 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core( if (jj_end - jj_start > 0) vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); - for (int jj = jj_start; jj < jj_end; jj++) + for (int jj = jj_start; jj < jj_end; jj += stride_w) if (jcp.kernel_kind == expl_bcast) vfmadd231ps(zmm_out(jj, ii), zmm_inp(jj, nb_ic_block), zmm_wei); @@ -2496,11 +2551,11 @@ status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf( if (jcp.ver == ver_4fma && (jcp.kh < 5 && jcp.kw < 5)) { for (int divf = 2, temp_nb = jcp.nb_oc_L2; divf <= jcp.nb_oc; divf++) { - int l2_src = jcp.iw * jcp.ic_block * jcp.nb_ic_blocking * jcp.ih + size_t l2_src = jcp.iw * jcp.ic_block * jcp.nb_ic_blocking * jcp.ih * jcp.id; - int l2_dst = jcp.ow * jcp.oc_block * temp_nb * jcp.oh * jcp.od; - int l2_filt = jcp.kw * jcp.oc_block * jcp.ic_block * jcp.kh * jcp.kd - * jcp.nb_ic_blocking * temp_nb; + size_t l2_dst = jcp.ow * jcp.oc_block * temp_nb * jcp.oh * jcp.od; + size_t l2_filt = jcp.kw * jcp.oc_block * jcp.ic_block * jcp.kh + * jcp.kd * jcp.nb_ic_blocking * temp_nb; if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) { if (jcp.kh == 3 && jcp.ih == 7) { jcp.nb_oc_L2 = 1; @@ -2535,7 +2590,8 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32::od_step_comeback_pointers() int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; int iw = (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) ? jcp.tr_iw : jcp.iw; - sub(aux_reg_input, jcp.typesize_in * jcp.ih * iw * inp_mult); + sub(aux_reg_input, + jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mult); sub(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block); dec(kj); @@ -2552,7 +2608,7 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers() int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; int iw = (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) ? jcp.tr_iw : jcp.iw; - sub(reg_input, jcp.typesize_in * iw * inp_mult); + sub(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mult); sub(reg_kernel, jcp.typesize_out * jcp.kw * jcp.ic_block * jcp.oc_block); dec(kj); @@ -2596,19 +2652,21 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_fma( + output_offset)); for (int i_kw = 0; i_kw < kw; i_kw++) { - int i_iw = i_ur * jcp.stride_w + i_kw; - if (i_iw - pad_l < 0 || i_iw > (ur_w - 1) * jcp.stride_w + kw - 1 - - pad_r) continue; + int i_iw = i_ur * jcp.stride_w + i_kw * (jcp.dilate_w + 1); + if (i_iw - pad_l < 0 || i_iw > (ur_w - 1) * jcp.stride_w + + (kw - 1) * (jcp.dilate_w + 1) - pad_r) continue; for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { - const int i_offset = input_offset - + typesize * (jcp.ver == ver_4fma + const size_t i_offset = (size_t)input_offset + + (size_t)typesize * (jcp.ver == ver_4fma ? (i_iw - pad_l + i_ic * jcp.tr_iw) : (jcp.is_1stconv - ? (i_iw - pad_l) + i_ic * (jcp.ih*jcp.iw*jcp.id) + ? (i_iw - pad_l) + (size_t)i_ic + * ((size_t)jcp.ih*jcp.iw*jcp.id) : (i_iw - pad_l) * ic_block + i_ic)); vfmadd231ps(Zmm(i_kw * ic_block_step + i_ic), Zmm(kw * ic_block_step + i_ur % 4), - EVEX_compress_addr(reg_input, i_offset, true)); + EVEX_compress_addr_safe(reg_input, i_offset, reg_long_offt, + true)); } } } @@ -2868,8 +2926,8 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32 ? jcp.tr_iw : jcp.iw; int ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow : jcp.ow; - int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w + jcp.kw - 1 - - (jcp.iw + jcp.l_pad - 1)); + int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); int l_pad = jcp.l_pad; if (jcp.ndims == 5) { @@ -2892,15 +2950,16 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32 input_offset, jcp.typesize_out * i_b_ic * jcp.oc_block, 0, i_b_ic + ic_block_step >= jcp.ic_block); } - add(reg_input, jcp.typesize_in * iw * inp_mul); - add(reg_kernel, jcp.typesize_out * (jcp.kw) * ic_block * oc_block); + add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); + add(reg_kernel, jcp.typesize_out * jcp.kw * ic_block * oc_block); dec(kj); cmp(kj, 0); jg(kh_label, T_NEAR); } if (jcp.ndims == 5) { - add(aux_reg_input, jcp.typesize_in * jcp.ih * iw * inp_mul); + add(aux_reg_input, + jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mul); add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block * oc_block); dec(ki); @@ -2923,7 +2982,7 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32 int ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow : jcp.ow; int r_pad = nstl::max(0, - (ow - 1) * jcp.stride_w + jcp.kw - 1 + (ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); int l_pad = jcp.l_pad; @@ -2943,10 +3002,13 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32 L(ic_block_label); { compute_ic_block_step(ow, l_pad, r_pad, ic_block_step, 0, 0, 0); - int inp_icblk_stride = jcp.is_1stconv ? jcp.ih * jcp.iw * jcp.id + size_t inp_icblk_stride = jcp.is_1stconv + ? (size_t)jcp.ih * jcp.iw * jcp.id : (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni) ? jcp.tr_iw : 1); - add(reg_input, jcp.typesize_in * ic_block_step * inp_icblk_stride); + size_t input_offset + = inp_icblk_stride * jcp.typesize_in * ic_block_step; + safe_add(reg_input, input_offset, reg_long_offt); add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block); add(b_ic, ic_block_step); cmp(b_ic, jcp.ic_block); @@ -2954,10 +3016,13 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32 } if (jcp.is_1stconv) { - sub(reg_input, jcp.typesize_in * jcp.ih*jcp.iw*jcp.id * ic_block); - add(reg_input, jcp.typesize_in * jcp.iw); + size_t input_offset + = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block; + safe_sub(reg_input, input_offset, reg_long_offt); + add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw); } else if (!utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) { - add(reg_input, jcp.typesize_in * (jcp.iw - 1) * ic_block); + add(reg_input, jcp.typesize_in + * ((jcp.dilate_h + 1) * jcp.iw - 1) * ic_block); } add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block); dec(kj); @@ -2965,8 +3030,8 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32 jg(kh_label, T_NEAR); } if (jcp.ndims == 5) { - add(aux_reg_input, jcp.typesize_in * jcp.ih * jcp.iw - * ((jcp.is_1stconv) ? 1 : ic_block)); + add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih + * jcp.iw * (jcp.is_1stconv ? 1 : ic_block)); add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block * oc_block); dec(ki); @@ -2985,8 +3050,8 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32 int oc_block = jcp.oc_block; int ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow : jcp.ow; - int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w + jcp.kw - 1 - - (jcp.iw + jcp.l_pad - 1)); + int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); int l_pad = (jcp.ver == ver_4fma || jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? 0 : jcp.l_pad; @@ -3053,7 +3118,9 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32 ? jcp.ih * jcp.iw * jcp.id : (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni) ? jcp.tr_iw : 1); - add(reg_input, jcp.typesize_in * ic_block_step * inp_icblk_stride); + size_t input_offset + = inp_icblk_stride * jcp.typesize_in * ic_block_step; + safe_add(reg_input, input_offset, reg_long_offt); add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block); add(b_ic, ic_block_step); @@ -3061,10 +3128,13 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32 jl(ic_block_label, T_NEAR); } if (jcp.is_1stconv) { - sub(reg_input, jcp.typesize_in * jcp.ih*jcp.iw*jcp.id * ic_block); - add(reg_input, jcp.typesize_in * jcp.iw); + size_t input_offset + = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block; + safe_sub(reg_input, input_offset, reg_long_offt); + add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw); } else if (!utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) { - add(reg_input, jcp.typesize_in * (jcp.iw - 1) * ic_block); + add(reg_input, jcp.typesize_in + * ((jcp.dilate_h + 1 ) * jcp.iw - 1) * ic_block); } add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block); dec(kj); @@ -3072,15 +3142,14 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32 jg(kh_label, T_NEAR); } if (jcp.ndims == 5) { - add(aux_reg_input, jcp.typesize_in * jcp.ih * jcp.iw * - ((jcp.is_1stconv) ? 1 : ic_block )); + add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih + * jcp.iw * (jcp.is_1stconv ? 1 : ic_block)); add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block * oc_block); dec(ki); cmp(ki, 0); jg(kd_label, T_NEAR); } - } void jit_avx512_common_conv_bwd_weights_kernel_f32 @@ -3179,19 +3248,19 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel() void jit_avx512_common_conv_bwd_weights_kernel_f32 ::compute_oh_loop_common() { - int back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + jcp.kd - 1 - - (jcp.id + jcp.f_pad - 1)); - - int b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - 1 - - (jcp.ih + jcp.t_pad - 1)); + int back_pad = jcp.back_pad; + int b_pad = jcp.b_pad; int t_pad = jcp.t_pad; + bool is_dilated = jcp.dilate_h != 0; + int dilate_h = jcp.dilate_h + 1; int stride_h = jcp.stride_h; - int idp = jcp.id + jcp.f_pad + jcp.back_pad; + int idp = jcp.id + jcp.f_pad + back_pad; const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; int iw = utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni) ? jcp.tr_iw : jcp.iw; Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_tail_label, - oh_bpad_label, oh_bpad_label_end, od_label, od_label_end; + oh_bpad_label, oh_bpad_label_end, od_label, od_label_end, + oh_dilate_label_shift, oh_dilate_label_noshift, oh_dilate_label_end; maybe_zero_kernel(); if (jcp.ndims == 5 && jcp.with_bias) bias_kernel(); @@ -3218,30 +3287,60 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32 xor_(reg_ih_count, reg_ih_count); xor_(reg_oj, reg_oj); if (t_pad > 0) { - mov(reg_kh, jcp.kh <= t_pad + jcp.ih ? jcp.kh - t_pad : jcp.ih); - add(reg_kernel, jcp.typesize_out * t_pad * jcp.kw * jcp.ic_block + const int kh_range = 1 + (jcp.kh - 1) * dilate_h; + const int overflow + = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h)); + const int underflow = div_up(t_pad, dilate_h); + const int initial_inp_ker_overlap = jcp.kh - overflow - underflow; + mov(reg_kh, initial_inp_ker_overlap); + add(reg_kernel, jcp.typesize_out * underflow * jcp.kw * jcp.ic_block * jcp.oc_block); // generate loop to process kernel while it remains within t_pad + ih - if (jcp.kh < t_pad + jcp.ih) { + if (kh_range < t_pad + jcp.ih) { + if (is_dilated) { + const int tail = t_pad % dilate_h; + const int shift = tail == 0 ? 0 : dilate_h - tail; + mov(reg_tmp, shift); + if (tail != 0) + add(reg_input, jcp.typesize_in * shift * iw * inp_mult); + } L(oh_tpad_label); { compute_oh_step_disp(); add(reg_output, jcp.typesize_in * ow * jcp.oc_block); + if (is_dilated) { + inc(reg_tmp); + cmp(reg_tmp, dilate_h); + jl(oh_dilate_label_shift, T_NEAR); + // unshift input as new kernel element enters + sub(reg_input, jcp.typesize_in * (dilate_h - 1) * iw * inp_mult); + xor_(reg_tmp, reg_tmp); + } + // kernel overlap only changes when (t_pad + oj) % dilate_h == 0 sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw * jcp.ic_block * jcp.oc_block); - + add(reg_kh, stride_h); + if (is_dilated) { + jmp(oh_dilate_label_noshift, T_NEAR); + L(oh_dilate_label_shift); + // shift input as old kernel element progresses + add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); + L(oh_dilate_label_noshift); + } inc(reg_oj); add(reg_ih_count, stride_h); - add(reg_kh, stride_h); - // at least jcp.ih cells of kernel ultimately overlap with input - const int final_inp_ker_overlap = nstl::min(jcp.kh, jcp.ih); + // final number of kernel elements that overlap with input + const int final_inp_ker_overlap + = nstl::min(jcp.kh, div_up(jcp.ih, dilate_h)); cmp(reg_kh, final_inp_ker_overlap); jl(oh_tpad_label, T_NEAR); } } // need second loop to process kernel if it is larger than the input - if (jcp.kh >= jcp.ih + (t_pad % stride_h == 0 ? stride_h : + // (does not apply to dilations as they must have unit stride) + if (kh_range >= jcp.ih + (t_pad % stride_h == 0 ? stride_h : t_pad % stride_h)) { + assert(!is_dilated); mov(reg_kh, jcp.ih); L(oh_tpad_tail_label); { compute_oh_step_disp(); @@ -3257,9 +3356,12 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32 } } // correct any excess shifts to kernel and input + // (does not apply to dilations as they must have unit stride, + // kernel must fit inside input, and padding is smaller than input) if (t_pad <= jcp.oh * stride_h) { // kernel has moved beyond padding (adjust for stride effects) if (t_pad % stride_h != 0) { + assert(!is_dilated); int inp_corr = stride_h - t_pad % stride_h; add(reg_kernel, jcp.typesize_out * inp_corr * jcp.kw * jcp.ic_block * jcp.oc_block); @@ -3267,12 +3369,13 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32 } } else { // kernel still overlaps padding (complete reset) + assert(!is_dilated); sub(reg_kernel, jcp.typesize_out * (t_pad - jcp.oh * stride_h) * jcp.kw * jcp.ic_block * jcp.oc_block); } } - cmp(reg_ih_count, jcp.ihp - b_pad - jcp.kh + 1); + cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h); jge(oh_label_end, T_NEAR); cmp(reg_oj, jcp.oh); jge(oh_label, T_NEAR); @@ -3286,7 +3389,7 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32 inc(reg_oj); add(reg_ih_count, stride_h); - cmp(reg_ih_count, jcp.ihp - b_pad - jcp.kh + 1); + cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h); jge(oh_label_end, T_NEAR); cmp(reg_oj, jcp.oh); @@ -3298,17 +3401,29 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32 cmp(reg_oj, jcp.oh); jge(oh_bpad_label_end, T_NEAR); - mov(reg_kh, jcp.ihp - b_pad); - sub(reg_kh, reg_ih_count); + if (is_dilated) { + mov(reg_kh, jcp.kh - 1); // assumes unit stride for dilations + mov(reg_tmp, 0); + } else { + mov(reg_kh, jcp.ihp - b_pad); + sub(reg_kh, reg_ih_count); + } L(oh_bpad_label); { compute_oh_step_disp(); add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); add(reg_output, jcp.typesize_in * ow * jcp.oc_block); - + if (is_dilated) { + inc(reg_tmp); + cmp(reg_tmp, dilate_h); + jl(oh_dilate_label_end, T_NEAR); + xor_(reg_tmp, reg_tmp); + } sub(reg_kh, stride_h); cmp(reg_kh, 0); jle(oh_bpad_label_end, T_NEAR); + if (is_dilated) + L(oh_dilate_label_end); inc(reg_oj); cmp(reg_oj, jcp.oh); @@ -3323,13 +3438,13 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32 pop(reg_output_d); pop(reg_input_d); - add(reg_input_d, jcp.typesize_in *jcp.stride_d*jcp.ih * iw * inp_mult); - add(reg_output_d, jcp.typesize_in * ow * jcp.oh * jcp.oc_block); + add(reg_input_d, jcp.typesize_in * jcp.stride_d * jcp.ih * iw * inp_mult); + add(reg_output_d, jcp.typesize_in * jcp.oh * ow * jcp.oc_block); dec(reg_oi); add(reg_id_count, jcp.stride_d); - cmp(reg_id_count, idp - back_pad - jcp.kd + 1); + cmp(reg_id_count, idp - back_pad - (jcp.kd - 1) * (jcp.dilate_d + 1)); jge(od_label_end, T_NEAR); cmp(reg_oi, 0); @@ -3345,6 +3460,7 @@ bool jit_avx512_common_conv_bwd_weights_kernel_f32 // FIXME: use register mapping from the class declaration bool ok = one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni) && (jcp.ver == ver_4fma || !one_of(1, jcp.kh, jcp.kw)) + && everyone_is(0, jcp.dilate_h, jcp.dilate_w) && everyone_is(1, jcp.stride_h, jcp.stride_w); if (!ok) return false; if (jcp.l_pad != jcp.kw / 2 || jcp.t_pad != jcp.kh / 2) @@ -3901,10 +4017,11 @@ bool jit_avx512_common_conv_bwd_weights_kernel_f32 return true; } -bool jit_avx512_common_conv_bwd_weights_kernel_f32 +bool jit_avx512_common_conv_bwd_weights_kernel_f32 ::flat_4ops_compute() { const auto &j = jcp; - const bool ok = j.ver == ver_4fma && j.is_1stconv; + const bool ok = j.ver == ver_4fma && j.is_1stconv + && everyone_is(0, j.dilate_h, j.dilate_w); if (!ok) return false; Reg64 reg_ptr_tr_src = r8; @@ -4160,15 +4277,24 @@ status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf( jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; jcp.dilate_h = cd.dilates[ndims-4]; jcp.dilate_w = cd.dilates[ndims-3]; - if (jcp.dilate_h != 0 || jcp.dilate_w != 0 || jcp.dilate_d != 0) + + const int kh_range = 1 + (jcp.kh - 1) * (jcp.dilate_h + 1); + bool ok = true + // general condition to simplify dilations + && implication(jcp.dilate_d != 0, jcp.stride_d == 1) + && implication(jcp.dilate_h != 0, jcp.stride_h == 1) + && implication(jcp.dilate_w != 0, jcp.stride_w == 1) + // special condition to simplify dilations in compute_oh_loop_common + && implication(jcp.dilate_h != 0, kh_range <= jcp.ih); + if (!ok) return status::unimplemented; - jcp.r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - - jcp.l_pad); - jcp.b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - - jcp.t_pad); - jcp.back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + jcp.kd - jcp.id - - jcp.f_pad); + jcp.r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + jcp.b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h + + (jcp.kh - 1) * (jcp.dilate_h + 1) - (jcp.ih + jcp.t_pad - 1)); + jcp.back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1)); if ( ndims == 5 ) if (jcp.f_pad != 0 || jcp.back_pad != 0) @@ -4219,9 +4345,10 @@ status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf( /* kernel applicability check wrt boundaries * the conditions are quite general across the kernels we have, * but ideally the check should belong to a specific kernel... */ + const int max_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1) / 2; const bool boundaries_ok = true - && jcp.t_pad <= jcp.kh / 2 - && jcp.b_pad <= jcp.kh / 2; + && jcp.t_pad <= max_pad + && jcp.b_pad <= max_pad; if (!boundaries_ok) return status::unimplemented; @@ -4261,6 +4388,8 @@ status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf( const bool use_4fma = true && ndims == 4 && mayiuse(avx512_mic_4ops) + && mkldnn_thr_syncable() + && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) && everyone_is(0, jcp.l_pad, jcp.r_pad, jcp.t_pad, jcp.b_pad) && jcp.kw <= 28 - jcp.with_bias && jcp.stride_w == 4 @@ -4309,8 +4438,10 @@ status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf( jcp.nb_ic = jcp.ic / jcp.ic_block; jcp.src_fmt = src_d.format(); if ((mayiuse(avx512_mic_4ops) || mayiuse(avx512_core_vnni)) + && mkldnn_thr_syncable() && ndims == 4 && jcp.stride_w == 1 + && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) && ((src_d.data_type() == data_type::s16 && diff_weights_d.data_type() == data_type::s32 && diff_dst_d.data_type() == data_type::s16))) { @@ -4321,7 +4452,9 @@ status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf( src_d.data_type(), diff_weights_d.data_type(), diff_dst_d.data_type())) { jcp.ver = ver_fma; - if (ndims == 4 && mayiuse(avx512_mic_4ops) && jcp.stride_w == 1) { + if (ndims == 4 && mayiuse(avx512_mic_4ops) && jcp.stride_w == 1 && + everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) && + mkldnn_thr_syncable()) { jcp.ver = ver_4fma; } } else { |