diff options
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp')
-rw-r--r-- | inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp | 704 |
1 files changed, 426 insertions, 278 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp index d4bd41b8e..45f516c80 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp @@ -36,6 +36,17 @@ using namespace mkldnn::impl::memory_format; using namespace mkldnn::impl::utils; using namespace Xbyak; +namespace { + // Below scales are applied to source and weights data accordingly + // because this winograd implementation + // transforms source which may increase values up to 4x + // and transforms weights which may increase values up to 9/4x + const float adj_src_scale = 1.f / 4.f; + const float adj_wei_scale = 4.f / 9.f; + // Winograd transforms need ic and oc to be multiples of 16 + const int load_block = 16; +} + /// SRC TRANSFORMS ///////////////////////////////////////////////////////////// struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t: public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS( @@ -60,10 +71,19 @@ struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t: public jit_generator { } void generate(); - Xmm vreg_inp(int i) { + int reg_inp_ind(int i) { assert(i < jcp.alpha * jcp.alpha); - return Xmm(31 - i); + return (31 - i); + } + + Xmm vreg_inp(int i) { + return Xmm(reg_inp_ind(i)); } + + Zmm zmm_inp(int i) { + return Zmm(reg_inp_ind(i)); + } + Xmm vreg_tmp(int i) { assert(i < jcp.alpha * jcp.alpha); return Xmm(15 - i); @@ -93,11 +113,15 @@ struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t: public jit_generator { Reg64 reg_ic_block = r8; int unsign_val_in_wino_domain; + + Reg64 reg_scratch_src_alpha = rdx; + Xmm xmm_src_alpha = Xmm(0); + Zmm zmm_src_alpha = Zmm(0); }; + void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() { Label ic_block_label; - const int load_block = 16; int out_offset = 0, inp_offset = 0; preamble(); @@ -119,20 +143,30 @@ void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() { kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]); } + mov(reg_scratch_src_alpha, float2int(adj_src_scale)); + mov(reg_ic_block, jcp.ic / load_block); L(ic_block_label); { + vmovq(xmm_src_alpha, reg_scratch_src_alpha); + vbroadcastss(zmm_src_alpha, xmm_src_alpha); + for(int y = 0; y < jcp.alpha; y++) { kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(int16_t) * y]); for(int x = 0; x < jcp.alpha; x++) { - vpxord(vreg_inp(y*jcp.alpha + x), vreg_inp(y*jcp.alpha + x), - vreg_inp(y*jcp.alpha + x)); + Zmm zmm_i = zmm_inp(y*jcp.alpha + x); + Xmm vreg_i = vreg_inp(y*jcp.alpha + x); + vpxord(vreg_i, vreg_i, vreg_i); kandw(r_mask, y_mask, x_mask(x)); inp_offset = sizeof(uint8_t) * ((-jcp.t_pad + y) * jcp.iw * jcp.ic + (-jcp.l_pad + x) * jcp.ic); - vmovdqu8(vreg_inp(y*jcp.alpha + x) | r_mask, - EVEX_compress_addr(reg_aux_ptr_src, inp_offset)); + vmovdqu8(vreg_i | r_mask, EVEX_compress_addr(reg_aux_ptr_src, inp_offset)); + vpmovzxbd(zmm_i, vreg_i); // to int32 + vcvtdq2ps(zmm_i, zmm_i); // to fp32 + vmulps(zmm_i, zmm_i, zmm_src_alpha); // *alpha + vcvtps2dq(zmm_i | T_rn_sae, zmm_i); // to int32 + vpmovusdb(vreg_i, zmm_i); // to u8 } } for(int y = 0; y < 4; y++) { @@ -163,8 +197,7 @@ void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() { add(reg_aux_ptr_dst, sizeof(uint8_t) * load_block); } dec(reg_ic_block); - cmp(reg_ic_block, 0); - jg(ic_block_label, T_NEAR); + jnz(ic_block_label, T_NEAR); postamble(); } @@ -204,29 +237,30 @@ struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t: public jit_generator { } Zmm vreg_stg(int id) { // 8 const int id_reg_stg = jcp.alpha * jcp.alpha + id; - assert(id_reg_stg < jcp.alpha * jcp.alpha + 8); + assert(id < 8); return Zmm(31 - id_reg_stg); } Zmm vreg_out(int id) { // 4 const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id; - assert(id_reg_out < jcp.alpha * jcp.alpha + 12); + assert(id < 4); return Zmm(31 - id_reg_out); } Xmm xmm_out(int id) { // 4 const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id; - assert(id_reg_out < jcp.alpha * jcp.alpha + 12); + assert(id < 4); return Xmm(31 - id_reg_out); } Zmm vreg_tmp(int id) { // 2 const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id; - assert(id_reg_tmp < jcp.alpha * jcp.alpha + 14); + assert(id < 2); return Zmm(31 - id_reg_tmp); } Zmm vreg_zero = Zmm(0); Zmm vreg_bias = Zmm(1); Zmm vreg_prev_dst = Zmm(2); - + Zmm zmm_bias_alpha = Zmm(2); + Xmm xmm_bias_alpha = Xmm(2); Opmask y_mask = Opmask(1); Opmask r_mask = Opmask(2); @@ -234,6 +268,9 @@ struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t: public jit_generator { assert(id < 4); return Opmask(3 + id); } + + Reg64 reg_scratch_bias_alpha = r15; + Reg64 reg_ptr_src = r14; Reg64 reg_ptr_dst = r13; @@ -246,9 +283,10 @@ struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t: public jit_generator { Reg64 reg_oc_block = r8; Reg64 reg_ptr_bias = rbx; - Reg64 reg_ptr_scales = rcx; + Reg64 reg_ptr_scales = abi_not_param1; Reg64 reg_ptr_sum_scale = rdx; }; + bool jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::maybe_relu(int position) { using namespace primitive_kind; const auto &p = attr_.post_ops_; @@ -273,11 +311,10 @@ bool jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::maybe_relu(int position) { return false; } + void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() { Label oc_block_label; - const int load_block = 16; - auto loop_body = [=]() { const auto &p = attr_.post_ops_; const int sum_idx = p.find(primitive_kind::sum); @@ -309,6 +346,9 @@ void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() { if (jcp.with_bias) { + vmovq(xmm_bias_alpha, reg_scratch_bias_alpha); + vbroadcastss(zmm_bias_alpha, xmm_bias_alpha); + auto bias_addr = ptr [ reg_ptr_bias ]; switch (jcp.bia_dt) { case data_type::f32: @@ -319,6 +359,7 @@ void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() { } if (jcp.bia_dt != data_type::f32) vcvtdq2ps(vreg_bias, vreg_bias); + vmulps(vreg_bias, vreg_bias, zmm_bias_alpha); // *alpha } for(int y = 0; y < jcp.m; y++) { kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(int16_t) * y ]); @@ -394,6 +435,9 @@ void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() { READ_PARAM(reg_ptr_scales, scales); # undef READ_PARAM + if (jcp.with_bias) + mov(reg_scratch_bias_alpha, float2int(adj_src_scale * adj_wei_scale)); + mov(reg_aux_ptr_src, reg_ptr_src); mov(reg_aux_ptr_dst, reg_ptr_dst); @@ -415,8 +459,7 @@ void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() { add(reg_ptr_bias, sizeof(jcp.typesize_bia) * load_block); } dec(reg_oc_block); - cmp(reg_oc_block, 0); - jg(oc_block_label, T_NEAR); + jnz(oc_block_label, T_NEAR); sub(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block); sub(reg_ptr_bias, oc_blocks * sizeof(jcp.typesize_bia) * load_block); @@ -464,7 +507,8 @@ struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t: public jit_generator { return Zmm(31 - id_reg_out); } Zmm vreg_wei(int i) { - assert(31 - jcp.n2_block * jcp.m_block - i > 2); + assert(31 - jcp.n2_block * jcp.m_block - i + > (jcp.ver == ver_vnni ? 0 : 2)); return Zmm(31 - jcp.n2_block * jcp.m_block - i); } @@ -473,20 +517,20 @@ struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t: public jit_generator { Zmm vreg_tmp = Zmm(2); Reg64 reg_ptr_src = r15; - Reg64 reg_ptr_dst = r14; - Reg64 reg_ptr_wei = r13; - Reg64 reg_ptr_dst_b = r12; - Reg64 reg_aux_dst = r11; + Reg64 reg_aux_dst_b = r13; + Reg64 reg_aux_dst = r12; + Reg64 reg_aux_dst2 = r11; Reg64 reg_aux_wei = r10; - Reg64 reg_aux_dst_b = r9; + Reg64 reg_aux_wei2 = r9; Reg64 reg_aux_src = r8; - Reg64 reg_aux_wei2 = rax; + Reg64 reg_aux_src2 = rax; + Reg64 reg_mb = rbx; + Reg64 reg_nnb = abi_not_param1; Reg64 reg_scratch = rdx; - Reg64 reg_nnb = rcx; Reg64 reg_K = rsi; - }; + bool jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::post_ops_ok( jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) { using namespace primitive_kind; @@ -502,11 +546,11 @@ bool jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::post_ops_ok( switch (p.len_) { case 0: return true; case 1: return true - && implication(jcp.with_relu, p.contain(sum, 0)) - && implication(!jcp.with_relu, is_relu(0) || p.contain(sum, 0)); + && IMPLICATION(jcp.with_relu, p.contain(sum, 0)) + && IMPLICATION(!jcp.with_relu, is_relu(0) || p.contain(sum, 0)); case 2: return true - && implication(jcp.with_relu, p.contain(sum, 0) && is_relu(1)) - && implication(!jcp.with_relu, false + && IMPLICATION(jcp.with_relu, p.contain(sum, 0) && is_relu(1)) + && IMPLICATION(!jcp.with_relu, false || (p.contain(sum, 0) && is_relu(1)) || (p.contain(sum, 1) && is_relu(0))); case 3: return true @@ -517,8 +561,9 @@ bool jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::post_ops_ok( return false; } + void jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::generate() { - Label nnb_loop_label, K_loop_label[2]; + Label nnb_loop_label, K_loop_label, mb_loop_label; auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { if (jcp.ver == ver_vnni) { @@ -534,82 +579,85 @@ void jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::generate() { # define READ_PARAM(reg, field) \ mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) READ_PARAM(reg_ptr_src, src); - READ_PARAM(reg_ptr_dst, dst); - READ_PARAM(reg_ptr_wei, wei); - READ_PARAM(reg_ptr_dst_b, dst_b); + READ_PARAM(reg_aux_dst, dst); + READ_PARAM(reg_aux_wei, wei); + READ_PARAM(reg_aux_dst_b, dst_b); # undef READ_PARAM - xor_(reg_scratch, reg_scratch); - Reg16 _t = reg_scratch.cvt16(); - mov(_t, 0x1); - vpbroadcastw(vreg_one, _t); - - mov(reg_aux_dst, reg_ptr_dst); - mov(reg_aux_wei, reg_ptr_wei); - mov(reg_aux_dst_b, reg_ptr_dst_b); + if (jcp.ver != ver_vnni) { + xor_(reg_scratch, reg_scratch); + Reg16 _t = reg_scratch.cvt16(); + mov(_t, 0x1); + vpbroadcastw(vreg_one, _t); + } if (!jcp.small_mb) { mov(reg_nnb, jcp.n_chunks); L(nnb_loop_label); } - for (int mb = 0; mb < jcp.M / jcp.m_block; mb++) - { - for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { - for (int m = 0; m < jcp.m_block; m++) { - int offset = jcp.typesize_acc * nb2 * jcp.n_block; - vmovups(vreg_out(nb2, m), + mov(reg_aux_dst2, reg_aux_dst); + mov(reg_aux_src, reg_ptr_src); + mov(reg_mb, jcp.M / jcp.m_block); + L(mb_loop_label); + { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + for (int m = 0; m < jcp.m_block; m++) { + int offset = jcp.typesize_acc * nb2 * jcp.n_block; + vmovups(vreg_out(nb2, m), EVEX_compress_addr(reg_aux_dst_b, offset)); - } } - mov(reg_aux_src, reg_ptr_src); - mov(reg_aux_wei2, reg_aux_wei); - mov(reg_K, jcp.k_chunks); - L(K_loop_label[mb]); { - for (int k = 0; k < jcp.k2_block; k += 4) - { - for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { - int wei_offset = jcp.typesize_in * - ((nb2 * jcp.n_block) * jcp.K); - vmovups(vreg_wei(nb2), + } + mov(reg_aux_src2, reg_aux_src); + mov(reg_aux_wei2, reg_aux_wei); + mov(reg_K, jcp.k_chunks); + L(K_loop_label); + { + for (int k = 0; k < jcp.k2_block; k += 4) { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + int wei_offset + = jcp.typesize_in * (nb2 * jcp.n_block * jcp.K); + vmovups(vreg_wei(nb2), EVEX_compress_addr(reg_aux_wei2, wei_offset)); - } - for (int m = 0; m < jcp.m_block; m++) { - int inp_offset = jcp.typesize_in * - (m + mb * jcp.m_block) * jcp.K; - vpbroadcastd(vreg_src, - EVEX_compress_addr(reg_aux_src,inp_offset)); - for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) - compute(vreg_out(nb2, m), vreg_wei(nb2), vreg_src); - } - add(reg_aux_src, jcp.typesize_in * 4); - add(reg_aux_wei2, jcp.typesize_in * 4 * jcp.n_block); } + for (int m = 0; m < jcp.m_block; m++) { + int inp_offset = jcp.typesize_in * m * jcp.K; + vpbroadcastd(vreg_src, + EVEX_compress_addr(reg_aux_src2, inp_offset)); + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) + compute(vreg_out(nb2, m), vreg_wei(nb2), vreg_src); + } + add(reg_aux_src2, jcp.typesize_in * 4); + add(reg_aux_wei2, jcp.typesize_in * 4 * jcp.n_block); } - dec(reg_K); - cmp(reg_K, 0); - jg(K_loop_label[mb], T_NEAR); + } + dec(reg_K); + jnz(K_loop_label, T_NEAR); - for (int m = 0; m < jcp.m_block; m++) { - for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { - int offset = jcp.typesize_acc * - ((mb * jcp.m_block + m) * jcp.N + nb2 * jcp.n_block); - vmovups(EVEX_compress_addr(reg_aux_dst,offset), - vreg_out(nb2, m)); - } + for (int m = 0; m < jcp.m_block; m++) { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + int offset = jcp.typesize_acc * (m * jcp.N + nb2 * jcp.n_block); + vmovups(EVEX_compress_addr(reg_aux_dst2, offset), + vreg_out(nb2, m)); } } + add(reg_aux_src, jcp.typesize_in * jcp.m_block * jcp.K); + add(reg_aux_dst2, jcp.typesize_acc * jcp.m_block * jcp.N); + } + dec(reg_mb); + jnz(mb_loop_label, T_NEAR); + if (!jcp.small_mb) { add(reg_aux_dst, jcp.typesize_acc * jcp.n2_block * jcp.n_block); add(reg_aux_dst_b, jcp.typesize_acc * jcp.n2_block * jcp.n_block); add(reg_aux_wei, jcp.typesize_in * jcp.n2_block * jcp.n_block * jcp.K); dec(reg_nnb); - cmp(reg_nnb, 0); - jg(nnb_loop_label, T_NEAR); + jnz(nnb_loop_label, T_NEAR); } postamble(); } + status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t ::init_conf(jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd, @@ -652,25 +700,27 @@ status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t if (mayiuse(avx512_core_vnni)) jcp.ver = ver_vnni; + // block sizes needed for GEMM kernel jcp.ic_block = 4; jcp.oc_block = 16; bool ok = true - && jcp.kh == 3 && jcp.kw == 3 && jcp.ngroups == 1 + && jcp.oc % load_block == 0 && jcp.ic % load_block == 0 && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0 - && jcp.stride_h == 1 && jcp.stride_w == 1 - && jcp.dilate_h == 0 && jcp.dilate_w == 0 + && everyone_is(3, jcp.kh, jcp.kw) + && everyone_is(1, jcp.stride_h, jcp.stride_w) + && everyone_is(0, jcp.dilate_h, jcp.dilate_w) && jcp.t_pad == jcp.b_pad && jcp.l_pad == jcp.r_pad - && jcp.t_pad < 2 && jcp.t_pad >= 0 - && jcp.l_pad < 2 && jcp.l_pad >= 0; + && one_of(jcp.t_pad, 0, 1) + && one_of(jcp.l_pad, 0, 1); if (!ok) return status::unimplemented; jcp.src_fmt = src_d.format(); jcp.with_bias = cd.bias_desc.format != memory_format::undef; jcp.with_relu = with_relu; jcp.relu_negative_slope = relu_negative_slope; - if (!implication(with_relu, relu_negative_slope == 0.)) + if (!IMPLICATION(with_relu, relu_negative_slope == 0.)) return status::unimplemented; if (!post_ops_ok(jcp, attr)) return status::unimplemented; @@ -692,29 +742,131 @@ status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t jcp.r = 3; jcp.alpha = jcp.m + jcp.r - 1; - jcp.yb = 1; - int opt_val = 14, cur_val = 0; - for (int i = 14; i >= 8; i -= 2) { - cur_val = ((jcp.oh / i) * i + i) - jcp.oh; - if (jcp.oh % i == 0) { - jcp.yb = i; break; - } else if (cur_val < opt_val) { - jcp.yb = i; - opt_val = cur_val; + int aa = jcp.alpha * jcp.alpha; + int nthr = mkldnn_get_max_threads(); + int L1_cap = get_cache_size(1, true); + int L2_cap = get_cache_size(2, true); + // need 1 extra reg for bcast, and 2 tmp regs for non-vnni + int free_regs = jcp.ver == ver_vnni ? 31 : 29; + + auto get_thr_eff = [&](int small_mb, int ix, int iy, int n2_b) { + float thr_eff; + float Z = (float)jcp.ic + jcp.oc; + float Y = (float)jcp.ic * jcp.oc; + if (small_mb == 0) { // outer par + int nblocks = jcp.mb * div_up(jcp.oh, iy) * div_up(jcp.ow, ix); + thr_eff = (float)nblocks / rnd_up(nblocks, nthr); + } else { // inner par + int tranw = iy * ix / jcp.alpha; + int gemmw = aa * (jcp.nb_oc / n2_b); + int tranw_r = rnd_up(tranw, nthr); + int gemmw_r = rnd_up(gemmw, nthr); + thr_eff = (Z * tranw / tranw_r + Y * gemmw / gemmw_r) / (Z + Y); } - } + return thr_eff; + }; - const int nthreads = mkldnn_get_max_threads(); - jcp.xb = 4; - int oh_blocks = (jcp.oh < jcp.yb) ? 1 : (jcp.oh / jcp.yb); - int ow_blocks = (jcp.ow < jcp.xb) ? 1 : (jcp.ow / jcp.xb); + auto get_mem_eff = [&](int small_mb, int ix, int iy, int n2_b) { + float mem_eff, req_mem; + int M = ix * iy / jcp.alpha; + if (small_mb == 0) { // outer parallelization strategy + // memory for wino transforms (other memory has poor reuse) + req_mem = (float)aa * M * (jcp.ic + jcp.typesize_acc * jcp.oc); + mem_eff = req_mem < L1_cap ? 1.f : req_mem < L2_cap ? 0.5f : 0.f; + } else { // inner parallelization strategy + // memory used during gemm + int N = jcp.oc_block * n2_b; + req_mem = (float)jcp.ic * (M + N) + jcp.typesize_acc * M * N; + mem_eff = nstl::min(1.f, L2_cap / req_mem); + // memory used during wino transforms + int M_per_thr = div_up(M, nthr); + req_mem = (float)aa * M_per_thr + * (jcp.ic + jcp.typesize_acc * jcp.oc); + if (req_mem > L2_cap) + mem_eff = 0.1f; + } + return mem_eff; + }; + + auto get_tot_eff = [&](int small_mb, float thr_eff, float work_eff, + float mem_eff, float reg_eff) { + // these coefficients are chosen empirically + float mem_fac = 0.1f, reg_fac = 0.2f; + // normalized overhead relative to memory and register components + float tot_eff = 1.f + mem_fac * mem_eff + reg_fac * reg_eff; + // thread and work components affect all others + tot_eff *= thr_eff * work_eff; + return tot_eff; + }; + + auto find_m_n2_blocks = [&](bool small_mb, int ix, int iy, float work_eff, + int &m_block, int &n2_block, float &tot_eff) { + int M = (ix * iy) / jcp.alpha; + int max_m_block = nstl::min(M, free_regs); + int max_n2_block = nstl::min(jcp.nb_oc, free_regs); + tot_eff = 0.f; + for (int im = max_m_block; im > 0; im--) { + if (M % im) + continue; + for (int in2 = max_n2_block; in2 > 0; in2--) { + int used_regs = (im + 1) * in2; + float mem_eff = get_mem_eff(small_mb, ix, iy, in2); + float reg_eff = (float)(im * in2) / (im + in2); + float thr_eff = get_thr_eff(small_mb, ix, iy, in2); + float cur_tot_eff = get_tot_eff( + small_mb, thr_eff, work_eff, mem_eff, reg_eff); + if (jcp.nb_oc % in2 || used_regs > free_regs + || cur_tot_eff <= tot_eff) + continue; + tot_eff = cur_tot_eff; + m_block = im; + n2_block = in2; + } + } + }; + + /* Selecting xb and yb blocking */ + int min_yb = jcp.m; + int min_xb = jcp.m; + int max_yb = nstl::max(min_yb, rnd_up(jcp.oh, 2)); + int max_xb = nstl::max(min_xb, rnd_up(jcp.ow, 2)); + float best_eff = 0.f; + for (int ix = min_xb; ix <= max_xb; ix += 2) { + assert(rnd_up(jcp.ow, ix) >= jcp.iw - 2); + for (int iy = max_yb; iy >= min_yb; iy -= 2) { + assert(rnd_up(jcp.oh, iy) >= jcp.ih - 2); + + int m_b[2]; + int n2_b[2]; + bool small_mb; + float inner_eff, outer_eff, work_eff; + + int tiled_area = rnd_up(jcp.oh, iy) * rnd_up(jcp.ow, ix); + work_eff = (float)jcp.oh * jcp.ow / tiled_area; + if (best_eff > 0.f && work_eff < 4.f / 9.f) + continue; // no gain from Winograd transformation + + /* outer parallelization */ + find_m_n2_blocks(0, ix, iy, work_eff, m_b[0], n2_b[0], outer_eff); + + /* inner parallelization */ + find_m_n2_blocks(1, ix, iy, work_eff, m_b[1], n2_b[1], inner_eff); + + small_mb = inner_eff > outer_eff; + float eff = small_mb ? inner_eff : outer_eff; + if (eff > best_eff) { + best_eff = eff; + jcp.yb = iy; + jcp.xb = ix; + jcp.m_block = m_b[small_mb]; + jcp.n2_block = n2_b[small_mb]; + jcp.small_mb = small_mb; + } + } + } - const int work_amount = jcp.mb * oh_blocks * ow_blocks; - if (work_amount < nthreads && jcp.ow < 24) { - jcp.small_mb = true; - jcp.xb = (jcp.ow < 9) ? jcp.yb : 4; - } else - jcp.small_mb = false; + assert((jcp.m_block + 1) * jcp.n2_block <= free_regs); + assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0); jcp.inp_stride = jcp.yb * jcp.xb / 4 * jcp.ic; jcp.out_stride = jcp.yb * jcp.xb / 4 * jcp.oc; @@ -725,31 +877,20 @@ status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t jcp.N = jcp.oc; jcp.K = jcp.ic; - jcp.m_block = jcp.xb * jcp.yb / 8; jcp.n_block = jcp.oc_block; jcp.k_block = jcp.ic_block; - int n_nblock = jcp.N / jcp.n_block; - jcp.n2_block = (!(n_nblock % 4)) - ? 4 - : (!(n_nblock % 2)) ? 2 : 1; - const int skx_free_regs = 28; - if (jcp.n2_block * jcp.m_block > (skx_free_regs - jcp.n2_block)) { - jcp.n2_block /= 2; - } - jcp.n_chunks = n_nblock / jcp.n2_block; + jcp.n_chunks = (jcp.N / jcp.n_block) / jcp.n2_block; - int k_nblock = jcp.K / jcp.k_block; - jcp.k2_block = 1; - for (int i = 16; i >= 2; i /= 2) - if (!(k_nblock % i)) { - jcp.k2_block = i; break; - } + // We need jcp.k2_block to be a multiple of jcp.k_block = jcp.ic_block = 4 + // and jcp.K = jcp.ic to be a multiple of jcp.k2_block. Since jcp.ic is + // a multiple of load_block = 16, we just use that for now. + jcp.k2_block = load_block; jcp.k_chunks = jcp.K / jcp.k2_block; const auto &oscales = attr.output_scales_; jcp.is_oc_scale = oscales.mask_ == 1 << 1; - assert(utils::implication(!jcp.is_oc_scale, oscales.mask_ == 0)); + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); /* re-create weights primitive descriptor and set weights wino_blocking */ @@ -767,6 +908,8 @@ status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t wd.oc_block = jcp.oc_block; wd.oc2_block = jcp.n2_block; wd.ic2_block = 1; + wd.adj_scale = adj_wei_scale; + size_t max_size = types::data_type_size(data_type::s8) * jcp.alpha * jcp.alpha * jcp.ic * jcp.oc; max_size += types::data_type_size(data_type::s32) * @@ -797,7 +940,8 @@ _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu, dst_data_type>:: _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *pd, const input_vector &inputs, const output_vector &outputs) : cpu_primitive_t(&conf_, inputs, outputs) - , conf_(*pd) { + , conf_(*pd) + , scratchpad_(nullptr) { const int nthreads = mkldnn_get_max_threads(); kernel_ = new jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t( conf_.jcp_, *conf_.attr()); @@ -806,25 +950,27 @@ _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu, dst_data_type>:: dst_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t( conf_.jcp_, *conf_.attr()); - int wino_size_offset = (conf_.jcp_.yb / 2) * (conf_.jcp_.xb / 2) - + (conf_.jcp_.xb); - size_wino_wei = conf_.jcp_.alpha * conf_.jcp_.alpha * conf_.jcp_.oc - * conf_.jcp_.ic; - size_wino_src = (conf_.jcp_.ic * 16) * (wino_size_offset); - size_wino_dst = (conf_.jcp_.oc * 16) * (wino_size_offset); + const int tilesize = conf_.jcp_.alpha * conf_.jcp_.alpha; + const int numtiles = (conf_.jcp_.yb / 2) * (conf_.jcp_.xb / 2); + const int alltiles = tilesize * numtiles; + size_wino_wei_ = tilesize * conf_.jcp_.oc * conf_.jcp_.ic; + size_wino_src_ = sizeof(src_data_t) * alltiles * conf_.jcp_.ic; + size_wino_src_ = rnd_up(size_wino_src_, PAGE_4K); + size_wino_src_ /= sizeof(src_data_t); + size_wino_dst_ = alltiles * conf_.jcp_.oc; - size_t workspace_size = nthreads - * (sizeof(src_data_t) * size_wino_src - + sizeof(acc_data_t) * size_wino_dst); + size_t workspace_size = (conf_.jcp_.small_mb ? 1 : nthreads) + * (sizeof(src_data_t) * size_wino_src_ + + sizeof(acc_data_t) * size_wino_dst_); - workspace = malloc(workspace_size, 4096); - char *_t = static_cast<char *>(workspace); + scratchpad_ = create_scratchpad(workspace_size); + assert(scratchpad_); // TODO: add proper check and raise exception? - size_t shift = 0; - wino_src_ = (src_data_t *)(_t + shift); + wino_shift_ = (conf_.jcp_.small_mb ? 1 : nthreads) * sizeof(src_data_t) + * size_wino_src_; - shift += nthreads * sizeof(src_data_t) * size_wino_src; - wino_dst_ = (acc_data_t *)(_t + shift); + updated_output_scales_ = conf_.attr()->output_scales_; + updated_output_scales_.scale(1.f / (adj_src_scale * adj_wei_scale)); } template <bool with_relu, data_type_t dst_data_type> @@ -833,8 +979,7 @@ _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu, delete kernel_; delete src_trans_; delete dst_trans_; - - free(workspace); + delete scratchpad_; } template <bool with_relu, data_type_t dst_data_type> @@ -856,30 +1001,32 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu, auto dst = reinterpret_cast<dst_data_t *>(memory(0)); const auto &jcp = kernel_->jcp; - const auto &oscales = conf_.attr()->output_scales_; + const auto &oscales = updated_output_scales_; - wino_wei_ = wei; - dst_bias_ = (const acc_data_t*)(wei + size_wino_wei); + auto wino_wei = wei; + auto dst_bias = (const acc_data_t *)(wei + size_wino_wei_); + auto wino_src_base = (src_data_t *)scratchpad_->get(); + auto wino_dst_base = (acc_data_t *)(scratchpad_->get() + wino_shift_); parallel_nd(jcp.mb, div_up(jcp.oh, jcp.yb), div_up(jcp.ow, jcp.xb), - [&](int mb, int tile_y_b, int tile_x_b) { + [&](int mb, int tile_y_b, int tile_x_b) { int tile_y = tile_y_b * jcp.yb; int tile_x = tile_x_b * jcp.xb; int ithr = mkldnn_get_thread_num(); - auto wino_src = wino_src_ + size_wino_src * ithr; - auto wino_dst = wino_dst_ + size_wino_dst * ithr; - - auto src_trans_p = jit_avx512_core_u8s8s32x_wino_conv_src_trans_t - ::call_params_t(); - auto dst_trans_p = jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t - ::call_params_t(); - auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t - ::call_params_t(); - - { /* transformation of input tensor to winograd domain */ - for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { + auto wino_src = wino_src_base + size_wino_src_ * ithr; + auto wino_dst = wino_dst_base + size_wino_dst_ * ithr; + + auto src_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t(); + auto dst_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t(); + auto gemm_p = + jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::call_params_t(); + + /* transformation of input tensor to winograd domain */ + for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) { unsigned short v_y_masks[4], v_x_masks[4]; @@ -889,19 +1036,20 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu, int v_ys = nstl::max(0, jcp.t_pad - y); int v_ye = nstl::min(jcp.alpha, - nstl::max(0, jcp.ih + jcp.t_pad - y)); + nstl::max(0, jcp.ih + jcp.t_pad - y)); int v_xs = nstl::max(0, jcp.l_pad - x); int v_xe = nstl::min(jcp.alpha, - nstl::max(0, jcp.iw + jcp.l_pad - x)); + nstl::max(0, jcp.iw + jcp.l_pad - x)); - #pragma unroll(4) +#pragma unroll(4) for (int i = 0; i < jcp.alpha; i++) { v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff; v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff; } - auto local_s = src + mb * jcp.ih * jcp.iw * jcp.ic - + y * jcp.iw * jcp.ic + x * jcp.ic; + auto local_s = src + + mb * jcp.ih * jcp.iw * jcp.ic + + y * jcp.iw * jcp.ic + x * jcp.ic; auto local_w = wino_src + m * jcp.ic; src_trans_p.src = local_s; @@ -910,20 +1058,22 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu, src_trans_p.v_x_masks = v_x_masks; src_trans_->ker_(&src_trans_p); - }} - } - { /* gemms */ - for (int tile_ij = 0; tile_ij < 16; tile_ij++) { - gemm_p.src = wino_src + jcp.inp_stride * tile_ij; - gemm_p.dst = wino_dst + jcp.out_stride * tile_ij; - gemm_p.wei = wino_wei_ + jcp.wei_stride * tile_ij; - gemm_p.dst_b = dst_bias_ + jcp.bia_stride * tile_ij; - - kernel_->ker_(&gemm_p); } } - { /* transformation from winograd domain to output tensor */ - for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { + /* gemms */ + for (int tile_ij = 0; tile_ij < 16; tile_ij++) { + // start threads at different GEMMs to help bring weights into LLC + int offset = (tile_ij + ithr) % 16; + gemm_p.src = wino_src + jcp.inp_stride * offset; + gemm_p.dst = wino_dst + jcp.out_stride * offset; + gemm_p.wei = wino_wei + jcp.wei_stride * offset; + gemm_p.dst_b = dst_bias + jcp.bia_stride * offset; + + kernel_->ker_(&gemm_p); + } + + /* transformation from winograd domain to output tensor */ + for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) { unsigned short v_y_masks[2], v_x_masks[2]; @@ -931,13 +1081,14 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu, int x = x_in_block + tile_x; int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); - #pragma unroll(2) +#pragma unroll(2) for (int i = 0; i < jcp.m; i++) { v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0; v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0; } - auto local_d = dst + mb * jcp.oh * jcp.ow * jcp.oc - + y * jcp.ow * jcp.oc + x * jcp.oc; + auto local_d = dst + + mb * jcp.oh * jcp.ow * jcp.oc + + y * jcp.ow * jcp.oc + x * jcp.oc; auto local_w = wino_dst + m * jcp.oc; auto scales = oscales.scales_; @@ -950,7 +1101,7 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu, dst_trans_p.bias = bia; dst_trans_->ker_(&dst_trans_p); - }} + } } }); } @@ -964,113 +1115,110 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu, auto dst = reinterpret_cast<dst_data_t *>(memory(0)); const auto &jcp = kernel_->jcp; - const auto &oscales = conf_.attr()->output_scales_; + const auto &oscales = updated_output_scales_; - wino_wei_ = wei; - dst_bias_ = (const acc_data_t*)(wei + size_wino_wei); + auto wino_wei = wei; + auto dst_bias = (const acc_data_t *)(wei + size_wino_wei_); + auto wino_src = (src_data_t *)scratchpad_->get(); + auto wino_dst = (acc_data_t *)(scratchpad_->get() + wino_shift_); for (int mb = 0; mb < jcp.mb; mb++) { for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) { for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) { - { /* transformation of input tensor to winograd domain */ - - parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), - [&](int y_in_block_b, int x_in_block_b) { - - int y_in_block = y_in_block_b * 2; - int x_in_block = x_in_block_b * 2; - auto src_trans_p = - jit_avx512_core_u8s8s32x_wino_conv_src_trans_t - ::call_params_t(); + /* transformation of input tensor to winograd domain */ + parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), + [&](int y_in_block_b, int x_in_block_b) { + int y_in_block = y_in_block_b * 2; + int x_in_block = x_in_block_b * 2; - unsigned short v_y_masks[4], v_x_masks[4]; + auto src_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t(); - int y = y_in_block + tile_y; - int x = x_in_block + tile_x; - int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); + unsigned short v_y_masks[4], v_x_masks[4]; - int v_ys = nstl::max(0, jcp.t_pad - y); - int v_ye = nstl::min(jcp.alpha, - nstl::max(0, jcp.ih + jcp.t_pad - y)); + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); - int v_xs = nstl::max(0, jcp.l_pad - x); - int v_xe = nstl::min(jcp.alpha, - nstl::max(0, jcp.iw + jcp.l_pad - x)); + int v_ys = nstl::max(0, jcp.t_pad - y); + int v_ye = nstl::min( + jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y)); - #pragma unroll(4) - for (int i = 0; i < jcp.alpha; i++) { - v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff; - v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff; - } - auto local_s = src + mb * jcp.ih * jcp.iw * jcp.ic - + y * jcp.iw * jcp.ic + x * jcp.ic; - auto local_w = wino_src_ + m * jcp.ic; + int v_xs = nstl::max(0, jcp.l_pad - x); + int v_xe = nstl::min( + jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x)); - src_trans_p.src = local_s; - src_trans_p.wino_src = local_w; - src_trans_p.v_y_masks = v_y_masks; - src_trans_p.v_x_masks = v_x_masks; - - src_trans_->ker_(&src_trans_p); - }); - } - { /* gemms */ - parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) { - auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t - ::call_params_t(); - - auto _t_src = wino_src_ + jcp.inp_stride * tile_ij; - auto _t_dst = wino_dst_ + jcp.out_stride * tile_ij; - auto _t_wei = wino_wei_ + jcp.wei_stride * tile_ij; - auto _t_dst_b = dst_bias_ + jcp.bia_stride * tile_ij; - - gemm_p.src = _t_src; - gemm_p.dst = _t_dst + nnb * jcp.n2_block * jcp.n_block; - gemm_p.wei = _t_wei + nnb * jcp.n2_block * jcp.n_block * jcp.K; - gemm_p.dst_b = _t_dst_b + nnb * jcp.n2_block * jcp.n_block; - - kernel_->ker_(&gemm_p); - }); - } - { /* transformation from winograd domain to output tensor */ - parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), - [&](int y_in_block_b, int x_in_block_b) { - int y_in_block = y_in_block_b * 2; - int x_in_block = x_in_block_b * 2; - - auto dst_trans_p = - jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t - ::call_params_t(); - - unsigned short v_y_masks[2], v_x_masks[2]; - - int y = y_in_block + tile_y; - int x = x_in_block + tile_x; - int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); - - #pragma unroll(2) - for (int i = 0; i < jcp.m; i++) { - v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0; - v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0; - } - auto local_d = dst + mb * jcp.oh * jcp.ow * jcp.oc - + y * jcp.ow * jcp.oc + x * jcp.oc; - auto local_w = wino_dst_ + m * jcp.oc; - - auto scales = oscales.scales_; - dst_trans_p.dst = local_d; - dst_trans_p.wino_dst = local_w; - dst_trans_p.v_y_masks = v_y_masks; - dst_trans_p.v_x_masks = v_x_masks; - - dst_trans_p.scales = scales; - dst_trans_p.bias = bia; - - dst_trans_->ker_(&dst_trans_p); - }); - } - }} - } +#pragma unroll(4) + for (int i = 0; i < jcp.alpha; i++) { + v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff; + v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff; + } + auto local_s = src + + mb * jcp.ih * jcp.iw * jcp.ic + + y * jcp.iw * jcp.ic + x * jcp.ic; + auto local_w = wino_src + m * jcp.ic; + + src_trans_p.src = local_s; + src_trans_p.wino_src = local_w; + src_trans_p.v_y_masks = v_y_masks; + src_trans_p.v_x_masks = v_x_masks; + + src_trans_->ker_(&src_trans_p); + }); + + /* gemms */ + parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) { + auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t:: + call_params_t(); + + gemm_p.src = wino_src + jcp.inp_stride * tile_ij; + gemm_p.dst = wino_dst + jcp.out_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block; + gemm_p.wei = wino_wei + jcp.wei_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block * jcp.K; + gemm_p.dst_b = dst_bias + jcp.bia_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block; + + kernel_->ker_(&gemm_p); + }); + + /* transformation from winograd domain to output tensor */ + parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), + [&](int y_in_block_b, int x_in_block_b) { + int y_in_block = y_in_block_b * 2; + int x_in_block = x_in_block_b * 2; + + auto dst_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t(); + + unsigned short v_y_masks[2], v_x_masks[2]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); + +#pragma unroll(2) + for (int i = 0; i < jcp.m; i++) { + v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0; + v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0; + } + auto local_d = dst + + mb * jcp.oh * jcp.ow * jcp.oc + + y * jcp.ow * jcp.oc + x * jcp.oc; + auto local_w = wino_dst + m * jcp.oc; + + auto scales = oscales.scales_; + dst_trans_p.dst = local_d; + dst_trans_p.wino_dst = local_w; + dst_trans_p.v_y_masks = v_y_masks; + dst_trans_p.v_x_masks = v_x_masks; + + dst_trans_p.scales = scales; + dst_trans_p.bias = bia; + + dst_trans_->ker_(&dst_trans_p); + }); + }}} } template struct _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<true, |