summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp')
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp704
1 files changed, 426 insertions, 278 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp
index d4bd41b8e..45f516c80 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp
@@ -36,6 +36,17 @@ using namespace mkldnn::impl::memory_format;
using namespace mkldnn::impl::utils;
using namespace Xbyak;
+namespace {
+ // Below scales are applied to source and weights data accordingly
+ // because this winograd implementation
+ // transforms source which may increase values up to 4x
+ // and transforms weights which may increase values up to 9/4x
+ const float adj_src_scale = 1.f / 4.f;
+ const float adj_wei_scale = 4.f / 9.f;
+ // Winograd transforms need ic and oc to be multiples of 16
+ const int load_block = 16;
+}
+
/// SRC TRANSFORMS /////////////////////////////////////////////////////////////
struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t: public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(
@@ -60,10 +71,19 @@ struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t: public jit_generator {
}
void generate();
- Xmm vreg_inp(int i) {
+ int reg_inp_ind(int i) {
assert(i < jcp.alpha * jcp.alpha);
- return Xmm(31 - i);
+ return (31 - i);
+ }
+
+ Xmm vreg_inp(int i) {
+ return Xmm(reg_inp_ind(i));
}
+
+ Zmm zmm_inp(int i) {
+ return Zmm(reg_inp_ind(i));
+ }
+
Xmm vreg_tmp(int i) {
assert(i < jcp.alpha * jcp.alpha);
return Xmm(15 - i);
@@ -93,11 +113,15 @@ struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t: public jit_generator {
Reg64 reg_ic_block = r8;
int unsign_val_in_wino_domain;
+
+ Reg64 reg_scratch_src_alpha = rdx;
+ Xmm xmm_src_alpha = Xmm(0);
+ Zmm zmm_src_alpha = Zmm(0);
};
+
void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() {
Label ic_block_label;
- const int load_block = 16;
int out_offset = 0, inp_offset = 0;
preamble();
@@ -119,20 +143,30 @@ void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() {
kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]);
}
+ mov(reg_scratch_src_alpha, float2int(adj_src_scale));
+
mov(reg_ic_block, jcp.ic / load_block);
L(ic_block_label);
{
+ vmovq(xmm_src_alpha, reg_scratch_src_alpha);
+ vbroadcastss(zmm_src_alpha, xmm_src_alpha);
+
for(int y = 0; y < jcp.alpha; y++) {
kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(int16_t) * y]);
for(int x = 0; x < jcp.alpha; x++) {
- vpxord(vreg_inp(y*jcp.alpha + x), vreg_inp(y*jcp.alpha + x),
- vreg_inp(y*jcp.alpha + x));
+ Zmm zmm_i = zmm_inp(y*jcp.alpha + x);
+ Xmm vreg_i = vreg_inp(y*jcp.alpha + x);
+ vpxord(vreg_i, vreg_i, vreg_i);
kandw(r_mask, y_mask, x_mask(x));
inp_offset = sizeof(uint8_t) *
((-jcp.t_pad + y) * jcp.iw * jcp.ic
+ (-jcp.l_pad + x) * jcp.ic);
- vmovdqu8(vreg_inp(y*jcp.alpha + x) | r_mask,
- EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
+ vmovdqu8(vreg_i | r_mask, EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
+ vpmovzxbd(zmm_i, vreg_i); // to int32
+ vcvtdq2ps(zmm_i, zmm_i); // to fp32
+ vmulps(zmm_i, zmm_i, zmm_src_alpha); // *alpha
+ vcvtps2dq(zmm_i | T_rn_sae, zmm_i); // to int32
+ vpmovusdb(vreg_i, zmm_i); // to u8
}
}
for(int y = 0; y < 4; y++) {
@@ -163,8 +197,7 @@ void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() {
add(reg_aux_ptr_dst, sizeof(uint8_t) * load_block);
}
dec(reg_ic_block);
- cmp(reg_ic_block, 0);
- jg(ic_block_label, T_NEAR);
+ jnz(ic_block_label, T_NEAR);
postamble();
}
@@ -204,29 +237,30 @@ struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t: public jit_generator {
}
Zmm vreg_stg(int id) { // 8
const int id_reg_stg = jcp.alpha * jcp.alpha + id;
- assert(id_reg_stg < jcp.alpha * jcp.alpha + 8);
+ assert(id < 8);
return Zmm(31 - id_reg_stg);
}
Zmm vreg_out(int id) { // 4
const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
- assert(id_reg_out < jcp.alpha * jcp.alpha + 12);
+ assert(id < 4);
return Zmm(31 - id_reg_out);
}
Xmm xmm_out(int id) { // 4
const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
- assert(id_reg_out < jcp.alpha * jcp.alpha + 12);
+ assert(id < 4);
return Xmm(31 - id_reg_out);
}
Zmm vreg_tmp(int id) { // 2
const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id;
- assert(id_reg_tmp < jcp.alpha * jcp.alpha + 14);
+ assert(id < 2);
return Zmm(31 - id_reg_tmp);
}
Zmm vreg_zero = Zmm(0);
Zmm vreg_bias = Zmm(1);
Zmm vreg_prev_dst = Zmm(2);
-
+ Zmm zmm_bias_alpha = Zmm(2);
+ Xmm xmm_bias_alpha = Xmm(2);
Opmask y_mask = Opmask(1);
Opmask r_mask = Opmask(2);
@@ -234,6 +268,9 @@ struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t: public jit_generator {
assert(id < 4);
return Opmask(3 + id);
}
+
+ Reg64 reg_scratch_bias_alpha = r15;
+
Reg64 reg_ptr_src = r14;
Reg64 reg_ptr_dst = r13;
@@ -246,9 +283,10 @@ struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t: public jit_generator {
Reg64 reg_oc_block = r8;
Reg64 reg_ptr_bias = rbx;
- Reg64 reg_ptr_scales = rcx;
+ Reg64 reg_ptr_scales = abi_not_param1;
Reg64 reg_ptr_sum_scale = rdx;
};
+
bool jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::maybe_relu(int position) {
using namespace primitive_kind;
const auto &p = attr_.post_ops_;
@@ -273,11 +311,10 @@ bool jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::maybe_relu(int position) {
return false;
}
+
void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() {
Label oc_block_label;
- const int load_block = 16;
-
auto loop_body = [=]() {
const auto &p = attr_.post_ops_;
const int sum_idx = p.find(primitive_kind::sum);
@@ -309,6 +346,9 @@ void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() {
if (jcp.with_bias) {
+ vmovq(xmm_bias_alpha, reg_scratch_bias_alpha);
+ vbroadcastss(zmm_bias_alpha, xmm_bias_alpha);
+
auto bias_addr = ptr [ reg_ptr_bias ];
switch (jcp.bia_dt) {
case data_type::f32:
@@ -319,6 +359,7 @@ void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() {
}
if (jcp.bia_dt != data_type::f32)
vcvtdq2ps(vreg_bias, vreg_bias);
+ vmulps(vreg_bias, vreg_bias, zmm_bias_alpha); // *alpha
}
for(int y = 0; y < jcp.m; y++) {
kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(int16_t) * y ]);
@@ -394,6 +435,9 @@ void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() {
READ_PARAM(reg_ptr_scales, scales);
# undef READ_PARAM
+ if (jcp.with_bias)
+ mov(reg_scratch_bias_alpha, float2int(adj_src_scale * adj_wei_scale));
+
mov(reg_aux_ptr_src, reg_ptr_src);
mov(reg_aux_ptr_dst, reg_ptr_dst);
@@ -415,8 +459,7 @@ void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() {
add(reg_ptr_bias, sizeof(jcp.typesize_bia) * load_block);
}
dec(reg_oc_block);
- cmp(reg_oc_block, 0);
- jg(oc_block_label, T_NEAR);
+ jnz(oc_block_label, T_NEAR);
sub(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block);
sub(reg_ptr_bias, oc_blocks * sizeof(jcp.typesize_bia) * load_block);
@@ -464,7 +507,8 @@ struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t: public jit_generator {
return Zmm(31 - id_reg_out);
}
Zmm vreg_wei(int i) {
- assert(31 - jcp.n2_block * jcp.m_block - i > 2);
+ assert(31 - jcp.n2_block * jcp.m_block - i
+ > (jcp.ver == ver_vnni ? 0 : 2));
return Zmm(31 - jcp.n2_block * jcp.m_block - i);
}
@@ -473,20 +517,20 @@ struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t: public jit_generator {
Zmm vreg_tmp = Zmm(2);
Reg64 reg_ptr_src = r15;
- Reg64 reg_ptr_dst = r14;
- Reg64 reg_ptr_wei = r13;
- Reg64 reg_ptr_dst_b = r12;
- Reg64 reg_aux_dst = r11;
+ Reg64 reg_aux_dst_b = r13;
+ Reg64 reg_aux_dst = r12;
+ Reg64 reg_aux_dst2 = r11;
Reg64 reg_aux_wei = r10;
- Reg64 reg_aux_dst_b = r9;
+ Reg64 reg_aux_wei2 = r9;
Reg64 reg_aux_src = r8;
- Reg64 reg_aux_wei2 = rax;
+ Reg64 reg_aux_src2 = rax;
+ Reg64 reg_mb = rbx;
+ Reg64 reg_nnb = abi_not_param1;
Reg64 reg_scratch = rdx;
- Reg64 reg_nnb = rcx;
Reg64 reg_K = rsi;
-
};
+
bool jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::post_ops_ok(
jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) {
using namespace primitive_kind;
@@ -502,11 +546,11 @@ bool jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::post_ops_ok(
switch (p.len_) {
case 0: return true;
case 1: return true
- && implication(jcp.with_relu, p.contain(sum, 0))
- && implication(!jcp.with_relu, is_relu(0) || p.contain(sum, 0));
+ && IMPLICATION(jcp.with_relu, p.contain(sum, 0))
+ && IMPLICATION(!jcp.with_relu, is_relu(0) || p.contain(sum, 0));
case 2: return true
- && implication(jcp.with_relu, p.contain(sum, 0) && is_relu(1))
- && implication(!jcp.with_relu, false
+ && IMPLICATION(jcp.with_relu, p.contain(sum, 0) && is_relu(1))
+ && IMPLICATION(!jcp.with_relu, false
|| (p.contain(sum, 0) && is_relu(1))
|| (p.contain(sum, 1) && is_relu(0)));
case 3: return true
@@ -517,8 +561,9 @@ bool jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::post_ops_ok(
return false;
}
+
void jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::generate() {
- Label nnb_loop_label, K_loop_label[2];
+ Label nnb_loop_label, K_loop_label, mb_loop_label;
auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
if (jcp.ver == ver_vnni) {
@@ -534,82 +579,85 @@ void jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::generate() {
# define READ_PARAM(reg, field) \
mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
READ_PARAM(reg_ptr_src, src);
- READ_PARAM(reg_ptr_dst, dst);
- READ_PARAM(reg_ptr_wei, wei);
- READ_PARAM(reg_ptr_dst_b, dst_b);
+ READ_PARAM(reg_aux_dst, dst);
+ READ_PARAM(reg_aux_wei, wei);
+ READ_PARAM(reg_aux_dst_b, dst_b);
# undef READ_PARAM
- xor_(reg_scratch, reg_scratch);
- Reg16 _t = reg_scratch.cvt16();
- mov(_t, 0x1);
- vpbroadcastw(vreg_one, _t);
-
- mov(reg_aux_dst, reg_ptr_dst);
- mov(reg_aux_wei, reg_ptr_wei);
- mov(reg_aux_dst_b, reg_ptr_dst_b);
+ if (jcp.ver != ver_vnni) {
+ xor_(reg_scratch, reg_scratch);
+ Reg16 _t = reg_scratch.cvt16();
+ mov(_t, 0x1);
+ vpbroadcastw(vreg_one, _t);
+ }
if (!jcp.small_mb) {
mov(reg_nnb, jcp.n_chunks);
L(nnb_loop_label);
}
- for (int mb = 0; mb < jcp.M / jcp.m_block; mb++)
- {
- for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
- for (int m = 0; m < jcp.m_block; m++) {
- int offset = jcp.typesize_acc * nb2 * jcp.n_block;
- vmovups(vreg_out(nb2, m),
+ mov(reg_aux_dst2, reg_aux_dst);
+ mov(reg_aux_src, reg_ptr_src);
+ mov(reg_mb, jcp.M / jcp.m_block);
+ L(mb_loop_label);
+ {
+ for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
+ for (int m = 0; m < jcp.m_block; m++) {
+ int offset = jcp.typesize_acc * nb2 * jcp.n_block;
+ vmovups(vreg_out(nb2, m),
EVEX_compress_addr(reg_aux_dst_b, offset));
- }
}
- mov(reg_aux_src, reg_ptr_src);
- mov(reg_aux_wei2, reg_aux_wei);
- mov(reg_K, jcp.k_chunks);
- L(K_loop_label[mb]); {
- for (int k = 0; k < jcp.k2_block; k += 4)
- {
- for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
- int wei_offset = jcp.typesize_in *
- ((nb2 * jcp.n_block) * jcp.K);
- vmovups(vreg_wei(nb2),
+ }
+ mov(reg_aux_src2, reg_aux_src);
+ mov(reg_aux_wei2, reg_aux_wei);
+ mov(reg_K, jcp.k_chunks);
+ L(K_loop_label);
+ {
+ for (int k = 0; k < jcp.k2_block; k += 4) {
+ for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
+ int wei_offset
+ = jcp.typesize_in * (nb2 * jcp.n_block * jcp.K);
+ vmovups(vreg_wei(nb2),
EVEX_compress_addr(reg_aux_wei2, wei_offset));
- }
- for (int m = 0; m < jcp.m_block; m++) {
- int inp_offset = jcp.typesize_in *
- (m + mb * jcp.m_block) * jcp.K;
- vpbroadcastd(vreg_src,
- EVEX_compress_addr(reg_aux_src,inp_offset));
- for (int nb2 = 0; nb2 < jcp.n2_block; nb2++)
- compute(vreg_out(nb2, m), vreg_wei(nb2), vreg_src);
- }
- add(reg_aux_src, jcp.typesize_in * 4);
- add(reg_aux_wei2, jcp.typesize_in * 4 * jcp.n_block);
}
+ for (int m = 0; m < jcp.m_block; m++) {
+ int inp_offset = jcp.typesize_in * m * jcp.K;
+ vpbroadcastd(vreg_src,
+ EVEX_compress_addr(reg_aux_src2, inp_offset));
+ for (int nb2 = 0; nb2 < jcp.n2_block; nb2++)
+ compute(vreg_out(nb2, m), vreg_wei(nb2), vreg_src);
+ }
+ add(reg_aux_src2, jcp.typesize_in * 4);
+ add(reg_aux_wei2, jcp.typesize_in * 4 * jcp.n_block);
}
- dec(reg_K);
- cmp(reg_K, 0);
- jg(K_loop_label[mb], T_NEAR);
+ }
+ dec(reg_K);
+ jnz(K_loop_label, T_NEAR);
- for (int m = 0; m < jcp.m_block; m++) {
- for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
- int offset = jcp.typesize_acc *
- ((mb * jcp.m_block + m) * jcp.N + nb2 * jcp.n_block);
- vmovups(EVEX_compress_addr(reg_aux_dst,offset),
- vreg_out(nb2, m));
- }
+ for (int m = 0; m < jcp.m_block; m++) {
+ for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
+ int offset = jcp.typesize_acc * (m * jcp.N + nb2 * jcp.n_block);
+ vmovups(EVEX_compress_addr(reg_aux_dst2, offset),
+ vreg_out(nb2, m));
}
}
+ add(reg_aux_src, jcp.typesize_in * jcp.m_block * jcp.K);
+ add(reg_aux_dst2, jcp.typesize_acc * jcp.m_block * jcp.N);
+ }
+ dec(reg_mb);
+ jnz(mb_loop_label, T_NEAR);
+
if (!jcp.small_mb) {
add(reg_aux_dst, jcp.typesize_acc * jcp.n2_block * jcp.n_block);
add(reg_aux_dst_b, jcp.typesize_acc * jcp.n2_block * jcp.n_block);
add(reg_aux_wei, jcp.typesize_in * jcp.n2_block * jcp.n_block * jcp.K);
dec(reg_nnb);
- cmp(reg_nnb, 0);
- jg(nnb_loop_label, T_NEAR);
+ jnz(nnb_loop_label, T_NEAR);
}
postamble();
}
+
status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
::init_conf(jit_conv_conf_2x3_wino_t &jcp,
const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
@@ -652,25 +700,27 @@ status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
if (mayiuse(avx512_core_vnni))
jcp.ver = ver_vnni;
+ // block sizes needed for GEMM kernel
jcp.ic_block = 4;
jcp.oc_block = 16;
bool ok = true
- && jcp.kh == 3 && jcp.kw == 3
&& jcp.ngroups == 1
+ && jcp.oc % load_block == 0 && jcp.ic % load_block == 0
&& jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0
- && jcp.stride_h == 1 && jcp.stride_w == 1
- && jcp.dilate_h == 0 && jcp.dilate_w == 0
+ && everyone_is(3, jcp.kh, jcp.kw)
+ && everyone_is(1, jcp.stride_h, jcp.stride_w)
+ && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
&& jcp.t_pad == jcp.b_pad && jcp.l_pad == jcp.r_pad
- && jcp.t_pad < 2 && jcp.t_pad >= 0
- && jcp.l_pad < 2 && jcp.l_pad >= 0;
+ && one_of(jcp.t_pad, 0, 1)
+ && one_of(jcp.l_pad, 0, 1);
if (!ok) return status::unimplemented;
jcp.src_fmt = src_d.format();
jcp.with_bias = cd.bias_desc.format != memory_format::undef;
jcp.with_relu = with_relu;
jcp.relu_negative_slope = relu_negative_slope;
- if (!implication(with_relu, relu_negative_slope == 0.))
+ if (!IMPLICATION(with_relu, relu_negative_slope == 0.))
return status::unimplemented;
if (!post_ops_ok(jcp, attr))
return status::unimplemented;
@@ -692,29 +742,131 @@ status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
jcp.r = 3;
jcp.alpha = jcp.m + jcp.r - 1;
- jcp.yb = 1;
- int opt_val = 14, cur_val = 0;
- for (int i = 14; i >= 8; i -= 2) {
- cur_val = ((jcp.oh / i) * i + i) - jcp.oh;
- if (jcp.oh % i == 0) {
- jcp.yb = i; break;
- } else if (cur_val < opt_val) {
- jcp.yb = i;
- opt_val = cur_val;
+ int aa = jcp.alpha * jcp.alpha;
+ int nthr = mkldnn_get_max_threads();
+ int L1_cap = get_cache_size(1, true);
+ int L2_cap = get_cache_size(2, true);
+ // need 1 extra reg for bcast, and 2 tmp regs for non-vnni
+ int free_regs = jcp.ver == ver_vnni ? 31 : 29;
+
+ auto get_thr_eff = [&](int small_mb, int ix, int iy, int n2_b) {
+ float thr_eff;
+ float Z = (float)jcp.ic + jcp.oc;
+ float Y = (float)jcp.ic * jcp.oc;
+ if (small_mb == 0) { // outer par
+ int nblocks = jcp.mb * div_up(jcp.oh, iy) * div_up(jcp.ow, ix);
+ thr_eff = (float)nblocks / rnd_up(nblocks, nthr);
+ } else { // inner par
+ int tranw = iy * ix / jcp.alpha;
+ int gemmw = aa * (jcp.nb_oc / n2_b);
+ int tranw_r = rnd_up(tranw, nthr);
+ int gemmw_r = rnd_up(gemmw, nthr);
+ thr_eff = (Z * tranw / tranw_r + Y * gemmw / gemmw_r) / (Z + Y);
}
- }
+ return thr_eff;
+ };
- const int nthreads = mkldnn_get_max_threads();
- jcp.xb = 4;
- int oh_blocks = (jcp.oh < jcp.yb) ? 1 : (jcp.oh / jcp.yb);
- int ow_blocks = (jcp.ow < jcp.xb) ? 1 : (jcp.ow / jcp.xb);
+ auto get_mem_eff = [&](int small_mb, int ix, int iy, int n2_b) {
+ float mem_eff, req_mem;
+ int M = ix * iy / jcp.alpha;
+ if (small_mb == 0) { // outer parallelization strategy
+ // memory for wino transforms (other memory has poor reuse)
+ req_mem = (float)aa * M * (jcp.ic + jcp.typesize_acc * jcp.oc);
+ mem_eff = req_mem < L1_cap ? 1.f : req_mem < L2_cap ? 0.5f : 0.f;
+ } else { // inner parallelization strategy
+ // memory used during gemm
+ int N = jcp.oc_block * n2_b;
+ req_mem = (float)jcp.ic * (M + N) + jcp.typesize_acc * M * N;
+ mem_eff = nstl::min(1.f, L2_cap / req_mem);
+ // memory used during wino transforms
+ int M_per_thr = div_up(M, nthr);
+ req_mem = (float)aa * M_per_thr
+ * (jcp.ic + jcp.typesize_acc * jcp.oc);
+ if (req_mem > L2_cap)
+ mem_eff = 0.1f;
+ }
+ return mem_eff;
+ };
+
+ auto get_tot_eff = [&](int small_mb, float thr_eff, float work_eff,
+ float mem_eff, float reg_eff) {
+ // these coefficients are chosen empirically
+ float mem_fac = 0.1f, reg_fac = 0.2f;
+ // normalized overhead relative to memory and register components
+ float tot_eff = 1.f + mem_fac * mem_eff + reg_fac * reg_eff;
+ // thread and work components affect all others
+ tot_eff *= thr_eff * work_eff;
+ return tot_eff;
+ };
+
+ auto find_m_n2_blocks = [&](bool small_mb, int ix, int iy, float work_eff,
+ int &m_block, int &n2_block, float &tot_eff) {
+ int M = (ix * iy) / jcp.alpha;
+ int max_m_block = nstl::min(M, free_regs);
+ int max_n2_block = nstl::min(jcp.nb_oc, free_regs);
+ tot_eff = 0.f;
+ for (int im = max_m_block; im > 0; im--) {
+ if (M % im)
+ continue;
+ for (int in2 = max_n2_block; in2 > 0; in2--) {
+ int used_regs = (im + 1) * in2;
+ float mem_eff = get_mem_eff(small_mb, ix, iy, in2);
+ float reg_eff = (float)(im * in2) / (im + in2);
+ float thr_eff = get_thr_eff(small_mb, ix, iy, in2);
+ float cur_tot_eff = get_tot_eff(
+ small_mb, thr_eff, work_eff, mem_eff, reg_eff);
+ if (jcp.nb_oc % in2 || used_regs > free_regs
+ || cur_tot_eff <= tot_eff)
+ continue;
+ tot_eff = cur_tot_eff;
+ m_block = im;
+ n2_block = in2;
+ }
+ }
+ };
+
+ /* Selecting xb and yb blocking */
+ int min_yb = jcp.m;
+ int min_xb = jcp.m;
+ int max_yb = nstl::max(min_yb, rnd_up(jcp.oh, 2));
+ int max_xb = nstl::max(min_xb, rnd_up(jcp.ow, 2));
+ float best_eff = 0.f;
+ for (int ix = min_xb; ix <= max_xb; ix += 2) {
+ assert(rnd_up(jcp.ow, ix) >= jcp.iw - 2);
+ for (int iy = max_yb; iy >= min_yb; iy -= 2) {
+ assert(rnd_up(jcp.oh, iy) >= jcp.ih - 2);
+
+ int m_b[2];
+ int n2_b[2];
+ bool small_mb;
+ float inner_eff, outer_eff, work_eff;
+
+ int tiled_area = rnd_up(jcp.oh, iy) * rnd_up(jcp.ow, ix);
+ work_eff = (float)jcp.oh * jcp.ow / tiled_area;
+ if (best_eff > 0.f && work_eff < 4.f / 9.f)
+ continue; // no gain from Winograd transformation
+
+ /* outer parallelization */
+ find_m_n2_blocks(0, ix, iy, work_eff, m_b[0], n2_b[0], outer_eff);
+
+ /* inner parallelization */
+ find_m_n2_blocks(1, ix, iy, work_eff, m_b[1], n2_b[1], inner_eff);
+
+ small_mb = inner_eff > outer_eff;
+ float eff = small_mb ? inner_eff : outer_eff;
+ if (eff > best_eff) {
+ best_eff = eff;
+ jcp.yb = iy;
+ jcp.xb = ix;
+ jcp.m_block = m_b[small_mb];
+ jcp.n2_block = n2_b[small_mb];
+ jcp.small_mb = small_mb;
+ }
+ }
+ }
- const int work_amount = jcp.mb * oh_blocks * ow_blocks;
- if (work_amount < nthreads && jcp.ow < 24) {
- jcp.small_mb = true;
- jcp.xb = (jcp.ow < 9) ? jcp.yb : 4;
- } else
- jcp.small_mb = false;
+ assert((jcp.m_block + 1) * jcp.n2_block <= free_regs);
+ assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0);
jcp.inp_stride = jcp.yb * jcp.xb / 4 * jcp.ic;
jcp.out_stride = jcp.yb * jcp.xb / 4 * jcp.oc;
@@ -725,31 +877,20 @@ status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
jcp.N = jcp.oc;
jcp.K = jcp.ic;
- jcp.m_block = jcp.xb * jcp.yb / 8;
jcp.n_block = jcp.oc_block;
jcp.k_block = jcp.ic_block;
- int n_nblock = jcp.N / jcp.n_block;
- jcp.n2_block = (!(n_nblock % 4))
- ? 4
- : (!(n_nblock % 2)) ? 2 : 1;
- const int skx_free_regs = 28;
- if (jcp.n2_block * jcp.m_block > (skx_free_regs - jcp.n2_block)) {
- jcp.n2_block /= 2;
- }
- jcp.n_chunks = n_nblock / jcp.n2_block;
+ jcp.n_chunks = (jcp.N / jcp.n_block) / jcp.n2_block;
- int k_nblock = jcp.K / jcp.k_block;
- jcp.k2_block = 1;
- for (int i = 16; i >= 2; i /= 2)
- if (!(k_nblock % i)) {
- jcp.k2_block = i; break;
- }
+ // We need jcp.k2_block to be a multiple of jcp.k_block = jcp.ic_block = 4
+ // and jcp.K = jcp.ic to be a multiple of jcp.k2_block. Since jcp.ic is
+ // a multiple of load_block = 16, we just use that for now.
+ jcp.k2_block = load_block;
jcp.k_chunks = jcp.K / jcp.k2_block;
const auto &oscales = attr.output_scales_;
jcp.is_oc_scale = oscales.mask_ == 1 << 1;
- assert(utils::implication(!jcp.is_oc_scale, oscales.mask_ == 0));
+ assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
/* re-create weights primitive descriptor
and set weights wino_blocking */
@@ -767,6 +908,8 @@ status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
wd.oc_block = jcp.oc_block;
wd.oc2_block = jcp.n2_block;
wd.ic2_block = 1;
+ wd.adj_scale = adj_wei_scale;
+
size_t max_size = types::data_type_size(data_type::s8) *
jcp.alpha * jcp.alpha * jcp.ic * jcp.oc;
max_size += types::data_type_size(data_type::s32) *
@@ -797,7 +940,8 @@ _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu, dst_data_type>::
_jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *pd,
const input_vector &inputs, const output_vector &outputs)
: cpu_primitive_t(&conf_, inputs, outputs)
- , conf_(*pd) {
+ , conf_(*pd)
+ , scratchpad_(nullptr) {
const int nthreads = mkldnn_get_max_threads();
kernel_ = new jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t(
conf_.jcp_, *conf_.attr());
@@ -806,25 +950,27 @@ _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu, dst_data_type>::
dst_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t(
conf_.jcp_, *conf_.attr());
- int wino_size_offset = (conf_.jcp_.yb / 2) * (conf_.jcp_.xb / 2)
- + (conf_.jcp_.xb);
- size_wino_wei = conf_.jcp_.alpha * conf_.jcp_.alpha * conf_.jcp_.oc
- * conf_.jcp_.ic;
- size_wino_src = (conf_.jcp_.ic * 16) * (wino_size_offset);
- size_wino_dst = (conf_.jcp_.oc * 16) * (wino_size_offset);
+ const int tilesize = conf_.jcp_.alpha * conf_.jcp_.alpha;
+ const int numtiles = (conf_.jcp_.yb / 2) * (conf_.jcp_.xb / 2);
+ const int alltiles = tilesize * numtiles;
+ size_wino_wei_ = tilesize * conf_.jcp_.oc * conf_.jcp_.ic;
+ size_wino_src_ = sizeof(src_data_t) * alltiles * conf_.jcp_.ic;
+ size_wino_src_ = rnd_up(size_wino_src_, PAGE_4K);
+ size_wino_src_ /= sizeof(src_data_t);
+ size_wino_dst_ = alltiles * conf_.jcp_.oc;
- size_t workspace_size = nthreads
- * (sizeof(src_data_t) * size_wino_src
- + sizeof(acc_data_t) * size_wino_dst);
+ size_t workspace_size = (conf_.jcp_.small_mb ? 1 : nthreads)
+ * (sizeof(src_data_t) * size_wino_src_
+ + sizeof(acc_data_t) * size_wino_dst_);
- workspace = malloc(workspace_size, 4096);
- char *_t = static_cast<char *>(workspace);
+ scratchpad_ = create_scratchpad(workspace_size);
+ assert(scratchpad_); // TODO: add proper check and raise exception?
- size_t shift = 0;
- wino_src_ = (src_data_t *)(_t + shift);
+ wino_shift_ = (conf_.jcp_.small_mb ? 1 : nthreads) * sizeof(src_data_t)
+ * size_wino_src_;
- shift += nthreads * sizeof(src_data_t) * size_wino_src;
- wino_dst_ = (acc_data_t *)(_t + shift);
+ updated_output_scales_ = conf_.attr()->output_scales_;
+ updated_output_scales_.scale(1.f / (adj_src_scale * adj_wei_scale));
}
template <bool with_relu, data_type_t dst_data_type>
@@ -833,8 +979,7 @@ _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
delete kernel_;
delete src_trans_;
delete dst_trans_;
-
- free(workspace);
+ delete scratchpad_;
}
template <bool with_relu, data_type_t dst_data_type>
@@ -856,30 +1001,32 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
auto dst = reinterpret_cast<dst_data_t *>(memory(0));
const auto &jcp = kernel_->jcp;
- const auto &oscales = conf_.attr()->output_scales_;
+ const auto &oscales = updated_output_scales_;
- wino_wei_ = wei;
- dst_bias_ = (const acc_data_t*)(wei + size_wino_wei);
+ auto wino_wei = wei;
+ auto dst_bias = (const acc_data_t *)(wei + size_wino_wei_);
+ auto wino_src_base = (src_data_t *)scratchpad_->get();
+ auto wino_dst_base = (acc_data_t *)(scratchpad_->get() + wino_shift_);
parallel_nd(jcp.mb, div_up(jcp.oh, jcp.yb), div_up(jcp.ow, jcp.xb),
- [&](int mb, int tile_y_b, int tile_x_b) {
+ [&](int mb, int tile_y_b, int tile_x_b) {
int tile_y = tile_y_b * jcp.yb;
int tile_x = tile_x_b * jcp.xb;
int ithr = mkldnn_get_thread_num();
- auto wino_src = wino_src_ + size_wino_src * ithr;
- auto wino_dst = wino_dst_ + size_wino_dst * ithr;
-
- auto src_trans_p = jit_avx512_core_u8s8s32x_wino_conv_src_trans_t
- ::call_params_t();
- auto dst_trans_p = jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t
- ::call_params_t();
- auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
- ::call_params_t();
-
- { /* transformation of input tensor to winograd domain */
- for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
+ auto wino_src = wino_src_base + size_wino_src_ * ithr;
+ auto wino_dst = wino_dst_base + size_wino_dst_ * ithr;
+
+ auto src_trans_p =
+ jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t();
+ auto dst_trans_p =
+ jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t();
+ auto gemm_p =
+ jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::call_params_t();
+
+ /* transformation of input tensor to winograd domain */
+ for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) {
unsigned short v_y_masks[4], v_x_masks[4];
@@ -889,19 +1036,20 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
int v_ys = nstl::max(0, jcp.t_pad - y);
int v_ye = nstl::min(jcp.alpha,
- nstl::max(0, jcp.ih + jcp.t_pad - y));
+ nstl::max(0, jcp.ih + jcp.t_pad - y));
int v_xs = nstl::max(0, jcp.l_pad - x);
int v_xe = nstl::min(jcp.alpha,
- nstl::max(0, jcp.iw + jcp.l_pad - x));
+ nstl::max(0, jcp.iw + jcp.l_pad - x));
- #pragma unroll(4)
+#pragma unroll(4)
for (int i = 0; i < jcp.alpha; i++) {
v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff;
v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff;
}
- auto local_s = src + mb * jcp.ih * jcp.iw * jcp.ic
- + y * jcp.iw * jcp.ic + x * jcp.ic;
+ auto local_s = src
+ + mb * jcp.ih * jcp.iw * jcp.ic
+ + y * jcp.iw * jcp.ic + x * jcp.ic;
auto local_w = wino_src + m * jcp.ic;
src_trans_p.src = local_s;
@@ -910,20 +1058,22 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
src_trans_p.v_x_masks = v_x_masks;
src_trans_->ker_(&src_trans_p);
- }}
- }
- { /* gemms */
- for (int tile_ij = 0; tile_ij < 16; tile_ij++) {
- gemm_p.src = wino_src + jcp.inp_stride * tile_ij;
- gemm_p.dst = wino_dst + jcp.out_stride * tile_ij;
- gemm_p.wei = wino_wei_ + jcp.wei_stride * tile_ij;
- gemm_p.dst_b = dst_bias_ + jcp.bia_stride * tile_ij;
-
- kernel_->ker_(&gemm_p);
}
}
- { /* transformation from winograd domain to output tensor */
- for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
+ /* gemms */
+ for (int tile_ij = 0; tile_ij < 16; tile_ij++) {
+ // start threads at different GEMMs to help bring weights into LLC
+ int offset = (tile_ij + ithr) % 16;
+ gemm_p.src = wino_src + jcp.inp_stride * offset;
+ gemm_p.dst = wino_dst + jcp.out_stride * offset;
+ gemm_p.wei = wino_wei + jcp.wei_stride * offset;
+ gemm_p.dst_b = dst_bias + jcp.bia_stride * offset;
+
+ kernel_->ker_(&gemm_p);
+ }
+
+ /* transformation from winograd domain to output tensor */
+ for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) {
unsigned short v_y_masks[2], v_x_masks[2];
@@ -931,13 +1081,14 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
int x = x_in_block + tile_x;
int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
- #pragma unroll(2)
+#pragma unroll(2)
for (int i = 0; i < jcp.m; i++) {
v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0;
v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0;
}
- auto local_d = dst + mb * jcp.oh * jcp.ow * jcp.oc
- + y * jcp.ow * jcp.oc + x * jcp.oc;
+ auto local_d = dst
+ + mb * jcp.oh * jcp.ow * jcp.oc
+ + y * jcp.ow * jcp.oc + x * jcp.oc;
auto local_w = wino_dst + m * jcp.oc;
auto scales = oscales.scales_;
@@ -950,7 +1101,7 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
dst_trans_p.bias = bia;
dst_trans_->ker_(&dst_trans_p);
- }}
+ }
}
});
}
@@ -964,113 +1115,110 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
auto dst = reinterpret_cast<dst_data_t *>(memory(0));
const auto &jcp = kernel_->jcp;
- const auto &oscales = conf_.attr()->output_scales_;
+ const auto &oscales = updated_output_scales_;
- wino_wei_ = wei;
- dst_bias_ = (const acc_data_t*)(wei + size_wino_wei);
+ auto wino_wei = wei;
+ auto dst_bias = (const acc_data_t *)(wei + size_wino_wei_);
+ auto wino_src = (src_data_t *)scratchpad_->get();
+ auto wino_dst = (acc_data_t *)(scratchpad_->get() + wino_shift_);
for (int mb = 0; mb < jcp.mb; mb++) {
for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) {
for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) {
- { /* transformation of input tensor to winograd domain */
-
- parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2),
- [&](int y_in_block_b, int x_in_block_b) {
-
- int y_in_block = y_in_block_b * 2;
- int x_in_block = x_in_block_b * 2;
- auto src_trans_p =
- jit_avx512_core_u8s8s32x_wino_conv_src_trans_t
- ::call_params_t();
+ /* transformation of input tensor to winograd domain */
+ parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2),
+ [&](int y_in_block_b, int x_in_block_b) {
+ int y_in_block = y_in_block_b * 2;
+ int x_in_block = x_in_block_b * 2;
- unsigned short v_y_masks[4], v_x_masks[4];
+ auto src_trans_p =
+ jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t();
- int y = y_in_block + tile_y;
- int x = x_in_block + tile_x;
- int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
+ unsigned short v_y_masks[4], v_x_masks[4];
- int v_ys = nstl::max(0, jcp.t_pad - y);
- int v_ye = nstl::min(jcp.alpha,
- nstl::max(0, jcp.ih + jcp.t_pad - y));
+ int y = y_in_block + tile_y;
+ int x = x_in_block + tile_x;
+ int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
- int v_xs = nstl::max(0, jcp.l_pad - x);
- int v_xe = nstl::min(jcp.alpha,
- nstl::max(0, jcp.iw + jcp.l_pad - x));
+ int v_ys = nstl::max(0, jcp.t_pad - y);
+ int v_ye = nstl::min(
+ jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y));
- #pragma unroll(4)
- for (int i = 0; i < jcp.alpha; i++) {
- v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff;
- v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff;
- }
- auto local_s = src + mb * jcp.ih * jcp.iw * jcp.ic
- + y * jcp.iw * jcp.ic + x * jcp.ic;
- auto local_w = wino_src_ + m * jcp.ic;
+ int v_xs = nstl::max(0, jcp.l_pad - x);
+ int v_xe = nstl::min(
+ jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x));
- src_trans_p.src = local_s;
- src_trans_p.wino_src = local_w;
- src_trans_p.v_y_masks = v_y_masks;
- src_trans_p.v_x_masks = v_x_masks;
-
- src_trans_->ker_(&src_trans_p);
- });
- }
- { /* gemms */
- parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) {
- auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
- ::call_params_t();
-
- auto _t_src = wino_src_ + jcp.inp_stride * tile_ij;
- auto _t_dst = wino_dst_ + jcp.out_stride * tile_ij;
- auto _t_wei = wino_wei_ + jcp.wei_stride * tile_ij;
- auto _t_dst_b = dst_bias_ + jcp.bia_stride * tile_ij;
-
- gemm_p.src = _t_src;
- gemm_p.dst = _t_dst + nnb * jcp.n2_block * jcp.n_block;
- gemm_p.wei = _t_wei + nnb * jcp.n2_block * jcp.n_block * jcp.K;
- gemm_p.dst_b = _t_dst_b + nnb * jcp.n2_block * jcp.n_block;
-
- kernel_->ker_(&gemm_p);
- });
- }
- { /* transformation from winograd domain to output tensor */
- parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2),
- [&](int y_in_block_b, int x_in_block_b) {
- int y_in_block = y_in_block_b * 2;
- int x_in_block = x_in_block_b * 2;
-
- auto dst_trans_p =
- jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t
- ::call_params_t();
-
- unsigned short v_y_masks[2], v_x_masks[2];
-
- int y = y_in_block + tile_y;
- int x = x_in_block + tile_x;
- int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
-
- #pragma unroll(2)
- for (int i = 0; i < jcp.m; i++) {
- v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0;
- v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0;
- }
- auto local_d = dst + mb * jcp.oh * jcp.ow * jcp.oc
- + y * jcp.ow * jcp.oc + x * jcp.oc;
- auto local_w = wino_dst_ + m * jcp.oc;
-
- auto scales = oscales.scales_;
- dst_trans_p.dst = local_d;
- dst_trans_p.wino_dst = local_w;
- dst_trans_p.v_y_masks = v_y_masks;
- dst_trans_p.v_x_masks = v_x_masks;
-
- dst_trans_p.scales = scales;
- dst_trans_p.bias = bia;
-
- dst_trans_->ker_(&dst_trans_p);
- });
- }
- }}
- }
+#pragma unroll(4)
+ for (int i = 0; i < jcp.alpha; i++) {
+ v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff;
+ v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff;
+ }
+ auto local_s = src
+ + mb * jcp.ih * jcp.iw * jcp.ic
+ + y * jcp.iw * jcp.ic + x * jcp.ic;
+ auto local_w = wino_src + m * jcp.ic;
+
+ src_trans_p.src = local_s;
+ src_trans_p.wino_src = local_w;
+ src_trans_p.v_y_masks = v_y_masks;
+ src_trans_p.v_x_masks = v_x_masks;
+
+ src_trans_->ker_(&src_trans_p);
+ });
+
+ /* gemms */
+ parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) {
+ auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::
+ call_params_t();
+
+ gemm_p.src = wino_src + jcp.inp_stride * tile_ij;
+ gemm_p.dst = wino_dst + jcp.out_stride * tile_ij
+ + nnb * jcp.n2_block * jcp.n_block;
+ gemm_p.wei = wino_wei + jcp.wei_stride * tile_ij
+ + nnb * jcp.n2_block * jcp.n_block * jcp.K;
+ gemm_p.dst_b = dst_bias + jcp.bia_stride * tile_ij
+ + nnb * jcp.n2_block * jcp.n_block;
+
+ kernel_->ker_(&gemm_p);
+ });
+
+ /* transformation from winograd domain to output tensor */
+ parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2),
+ [&](int y_in_block_b, int x_in_block_b) {
+ int y_in_block = y_in_block_b * 2;
+ int x_in_block = x_in_block_b * 2;
+
+ auto dst_trans_p =
+ jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t();
+
+ unsigned short v_y_masks[2], v_x_masks[2];
+
+ int y = y_in_block + tile_y;
+ int x = x_in_block + tile_x;
+ int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
+
+#pragma unroll(2)
+ for (int i = 0; i < jcp.m; i++) {
+ v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0;
+ v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0;
+ }
+ auto local_d = dst
+ + mb * jcp.oh * jcp.ow * jcp.oc
+ + y * jcp.ow * jcp.oc + x * jcp.oc;
+ auto local_w = wino_dst + m * jcp.oc;
+
+ auto scales = oscales.scales_;
+ dst_trans_p.dst = local_d;
+ dst_trans_p.wino_dst = local_w;
+ dst_trans_p.v_y_masks = v_y_masks;
+ dst_trans_p.v_x_masks = v_x_masks;
+
+ dst_trans_p.scales = scales;
+ dst_trans_p.bias = bia;
+
+ dst_trans_->ker_(&dst_trans_p);
+ });
+ }}}
}
template struct _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<true,