summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp')
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp615
1 files changed, 374 insertions, 241 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp
index d16cd1ac9..80206ceb7 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp
@@ -74,14 +74,15 @@ void jit_avx512_common_conv_fwd_kernel::prepare_output(int ur_w)
for (int j = 0; j < ur_w; j++) {
Zmm zmm = zmm_out(j, k);
vpxord(zmm, zmm, zmm);
- int aux_output_offset = get_output_offset(j, k);
- mic_prefetcht1(EVEX_compress_addr(reg_out_prf, aux_output_offset));
+ size_t aux_output_offset = get_output_offset(j, k);
+ mic_prefetcht1(EVEX_compress_addr_safe(reg_out_prf,
+ aux_output_offset, reg_out_long_offt));
}
}
void jit_avx512_common_conv_fwd_kernel::store_output(int ur_w)
{
- Label no_update_label, store_label, relu_label;
+ Label no_update_label, store_label, postproc_label;
mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
if (jcp.with_bias) {
@@ -96,15 +97,16 @@ void jit_avx512_common_conv_fwd_kernel::store_output(int ur_w)
for (int k = 0; k < jcp.nb_oc_blocking; k++)
for (int j = 0; j < ur_w; j++) {
Zmm zmm = zmm_out(j, k);
- int aux_output_offset = get_output_offset(j, k);
- vadd(zmm, reg_out, aux_output_offset);
+ size_t aux_output_offset = get_output_offset(j, k);
+ vadd(zmm,
+ make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt));
}
if (!jcp.with_sum) {
- jmp(relu_label, T_NEAR);
+ jmp(postproc_label, T_NEAR);
} else {
cmp(reg_channel, 0);
- jne(relu_label, T_NEAR);
+ jne(postproc_label, T_NEAR);
}
L(no_update_label);
@@ -113,26 +115,51 @@ void jit_avx512_common_conv_fwd_kernel::store_output(int ur_w)
int bias_offset = jcp.typesize_out * k * jcp.oc_block;
for (int j = 0; j < ur_w; j++) {
Zmm zmm = zmm_out(j, k);
- vadd(zmm, reg_bias, bias_offset);
+ vadd(zmm, EVEX_compress_addr(reg_bias, bias_offset));
}
mic_prefetcht1(EVEX_compress_addr(reg_bias, bias_offset + 64));
}
}
- L(relu_label);
- if (this->jcp.with_eltwise) {
- cmp(reg_channel, jcp.nb_ic - 1);
- jl(store_label, T_NEAR);
+ L(postproc_label);
- inject(eltwise_generator.prepareConstants(jcp.eltwise_alpha, jcp.eltwise_beta));
+ cmp(reg_channel, jcp.nb_ic - 1);
+ jl(store_label, T_NEAR);
- // TODO (dmitrygo): need to find appropriate way to share labels.
- mov(imm_addr64, l_table);
- for (int k = 0; k < jcp.nb_oc_blocking; k++) {
- for (int j = 0; j < ur_w; j++) {
- Zmm zmm_reg_out = zmm_out(j, k);
- inject(eltwise_generator.computeVector(zmm_reg_out, zmm_reg_out));
+ int eltwise_inj_idx = 0;
+ int depthwise_inj_idx = 0;
+ const auto &p = attr_.post_ops_;
+
+ if (p.len_ == 0 && eltwise_injectors.size() == 1) {
+ for (int k = 0; k < jcp.nb_oc_blocking; k++)
+ eltwise_injectors[0]->compute_vector_range(
+ k*jcp.ur_w, k*jcp.ur_w + ur_w);
+ }
+
+ for (int i = 0; i < p.len_; i++) {
+ auto& post_op = p.entry_[i];
+ if (post_op.is_eltwise()) {
+ for (int k = 0; k < jcp.nb_oc_blocking; k++)
+ eltwise_injectors[eltwise_inj_idx]->compute_vector_range(
+ k*jcp.ur_w, k*jcp.ur_w + ur_w);
+
+ eltwise_inj_idx++;
+ } else if (post_op.is_depthwise()) {
+ mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
+ mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
+
+ add(reg_d_weights, ptr[this->param1 + GET_OFF(oc_off)]);
+ add(reg_d_bias, ptr[this->param1 + GET_OFF(oc_off)]);
+
+ for (int k = 0; k < jcp.nb_oc_blocking; k++) {
+ depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
+ k*jcp.ur_w, k*jcp.ur_w + ur_w, reg_d_weights, reg_d_bias);
+
+ add(reg_d_weights, jcp.oc_block * sizeof(float));
+ add(reg_d_bias, jcp.oc_block * sizeof(float));
}
+
+ depthwise_inj_idx++;
}
}
@@ -140,10 +167,12 @@ void jit_avx512_common_conv_fwd_kernel::store_output(int ur_w)
for (int k = 0; k < jcp.nb_oc_blocking; k++)
for (int j = 0; j < ur_w; j++) {
Zmm zmm = zmm_out(j, k);
- int aux_output_offset
- = typesize * (k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block;
- vmovups(EVEX_compress_addr(reg_out, aux_output_offset), zmm);
- mic_prefetcht0(EVEX_compress_addr(reg_out_prf, aux_output_offset));
+ size_t aux_output_offset = (size_t)typesize *
+ ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block;
+ vmovups(EVEX_compress_addr_safe(reg_out, aux_output_offset,
+ reg_out_long_offt), zmm);
+ mic_prefetcht0(EVEX_compress_addr_safe(reg_out_prf,
+ aux_output_offset, reg_out_long_offt));
}
}
@@ -160,13 +189,19 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w,
Label kh_label, kd_label, skip_kd_loop;
+ prepare_output(ur_w);
+
if (jcp.ndims == 4) {
mov(aux_reg_inp, reg_inp);
mov(aux_reg_ker, reg_ker);
mov(aux_reg_inp_prf, reg_inp_prf);
}
- prepare_output(ur_w);
+ size_t max_input_offset = (size_t)jcp.typesize_in
+ * ((size_t)(kw * (jcp.dilate_w + 1) + ur_w * stride_w - pad_l)
+ + (size_t)ic_block * iw * ih * jcp.id);
+ assert(reg_inp_prf == reg_long_offt);
+ if (max_input_offset > INT_MAX) push(reg_inp_prf);
if (jcp.ndims == 5) {
push(reg_out_prf);
@@ -177,7 +212,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w,
mov(aux_reg_inp_d, reg_inp);
mov(aux_reg_inp_d_prf, reg_inp_prf);
- if ((jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) {
+ if ((jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
cmp(reg_ki, 0);
je(skip_kd_loop, T_NEAR);
}
@@ -185,7 +220,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w,
}
mov(reg_kj, reg_kh);
Label skip_kh_loop;
- if ((jcp.kh - 1) * (jcp.dilate_h + 1) < jcp.t_pad) {
+ if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
cmp(reg_kj, 0);
je(skip_kh_loop, T_NEAR);
}
@@ -214,11 +249,12 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w,
int j_end = get_ow_end(ur_w, ki, pad_r);
for (int j = j_start, prf_count=0; j < j_end; j++) {
- int aux_input_offset = jcp.typesize_in
- * ((ki * (jcp.dilate_w + 1) + j * stride_w - pad_l)
- + ic * iw * ih * jcp.id);
+ size_t aux_input_offset = (size_t)jcp.typesize_in
+ * ((size_t)(ki * (jcp.dilate_w + 1) + j * stride_w
+ - pad_l) + (size_t)ic * iw * ih * jcp.id);
v4fmaddps(zmm_out(j, 0), zmm_ker(0),
- EVEX_compress_addr(aux_reg_inp, aux_input_offset));
+ EVEX_compress_addr_safe(aux_reg_inp, aux_input_offset,
+ reg_long_offt));
if (ki + prf_count < kw && prf_count < 4
&& ((ki < 2 && j % 4) || j % 2)) {
int aux_ker_offset = jcp.typesize_in
@@ -230,13 +266,13 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w,
}
if (ki == 0
&& j % (64 / (stride_w * jcp.typesize_in)) == 0) {
- mic_prefetcht0(EVEX_compress_addr(aux_reg_inp_prf,
- aux_input_offset));
+ mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp_prf,
+ aux_input_offset, reg_long_offt));
}
if (ki == 1
&& j % (64 / (stride_w * jcp.typesize_in)) == 0) {
- mic_prefetcht0(EVEX_compress_addr(aux_reg_inp,
- aux_input_offset+jcp.typesize_in * iw));
+ mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp,
+ aux_input_offset+jcp.typesize_in * iw, reg_long_offt));
}
}
}
@@ -266,6 +302,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w,
}
store_output(ur_w);
+ if (max_input_offset > INT_MAX) pop(reg_inp_prf);
}
void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w,
@@ -287,13 +324,6 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w,
assert(jcp.oc % jcp.nb_oc_blocking == 0);
- if (jcp.ndims == 4) {
- mov(aux_reg_inp, reg_inp);
- mov(aux_reg_ker, reg_ker);
- mov(aux_reg_ker_prf, reg_ker_prf);
- mov(aux_reg_inp_prf, reg_inp_prf);
- }
-
auto kernel_offset = [=](int ocb, int ic, int ki) {
int blk_idx = ocb * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd + ki;
int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
@@ -322,6 +352,13 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w,
prepare_output(ur_w);
+ if (jcp.ndims == 4) {
+ mov(aux_reg_inp, reg_inp);
+ mov(aux_reg_ker, reg_ker);
+ mov(aux_reg_ker_prf, reg_ker_prf);
+ mov(aux_reg_inp_prf, reg_inp_prf);
+ }
+
if (jcp.ndims == 5) {
push(reg_out_prf);
push(reg_out);
@@ -332,7 +369,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w,
mov(aux_reg_inp_d_prf, reg_inp_prf);
mov(aux_reg_ker_d_prf, reg_ker_prf);
- if ((jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) {
+ if ((jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
cmp(reg_ki, 0);
je(skip_kd_loop, T_NEAR);
}
@@ -342,7 +379,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w,
mov(reg_kj, reg_kh);
}
Label skip_kh_loop;
- if ((jcp.kh - 1) * (jcp.dilate_h + 1) < jcp.t_pad) {
+ if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
cmp(reg_kj, 0);
je(skip_kh_loop, T_NEAR);
}
@@ -548,6 +585,11 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w,
mov(aux_reg_ker_prf, reg_ker_prf);
}
+ size_t max_input_offset = (size_t)jcp.typesize_in * ic_block * iw * ih * id;
+ assert(reg_inp_prf == reg_long_offt);
+ if (max_input_offset > INT_MAX) push(reg_inp_prf);
+
+
if (jcp.ndims == 5) {
push(reg_out_prf);
push(reg_out);
@@ -558,7 +600,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w,
mov(aux_reg_inp_d_prf, reg_inp_prf);
mov(aux_reg_ker_d_prf, reg_ker_prf);
- if ((jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) {
+ if ((jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
cmp(reg_ki, 0);
je(skip_kd_loop, T_NEAR);
}
@@ -568,7 +610,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w,
mov(reg_kj, reg_kh);
}
Label skip_kh_loop;
- if ((jcp.kh - 1) * (jcp.dilate_h + 1) < jcp.t_pad) {
+ if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
cmp(reg_kj, 0);
je(skip_kh_loop, T_NEAR);
}
@@ -609,10 +651,10 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w,
int j_start = get_ow_start(ki, pad_l);
int j_end = get_ow_end(ur_w, ki, pad_r);
for (int j = j_start; j < j_end; j++) {
- int aux_input_offset = get_input_offset(ki, ic, j, pad_l);
- vfmadd231ps(zmm_out(j, 0), zmm_kernel,
- EVEX_compress_addr(
- aux_reg_inp, aux_input_offset, true));
+ size_t aux_input_offset = get_input_offset(ki, ic, j, pad_l);
+ auto addr = EVEX_compress_addr_safe(aux_reg_inp,
+ aux_input_offset, reg_long_offt, true);
+ vfmadd231ps(zmm_out(j, 0), zmm_kernel, addr);
int fma_idx = step * ur_w + j;
int prf_slot_idx = fma_idx / prf_inst_spacing;
if (fma_idx % prf_inst_spacing == prf_inst_trigger) {
@@ -627,8 +669,8 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w,
} else if (prf_inp) {
int inp_prf_idx = prf_slot_idx - ker_prfs;
if (inp_prf_idx < num_inp_prfs) {
- int inp_prf_stride = nstl::max(kw, stride_w);
- int inp_prf_offset;
+ size_t inp_prf_stride = nstl::max(kw, stride_w);
+ size_t inp_prf_offset;
if (!jcp.is_1stconv) {
inp_prf_offset
= ic_block * jcp.typesize_in
@@ -636,17 +678,18 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w,
* inp_prf_stride
+ (inp_prf_idx % kw));
} else {
- int ic_prf_stride
- = jcp.typesize_in * iw * ih * id;
- int iw_prf_stride
+ size_t ic_prf_stride =
+ (size_t)jcp.typesize_in * iw * ih * id;
+ size_t iw_prf_stride
= jcp.typesize_in * simd_w;
inp_prf_offset = ((inp_prf_idx / ic_block)
* iw_prf_stride
+ (inp_prf_idx % ic_block)
* ic_prf_stride);
}
- mic_prefetcht0(EVEX_compress_addr(
- aux_reg_inp_prf, inp_prf_offset));
+ mic_prefetcht0(EVEX_compress_addr_safe(
+ aux_reg_inp_prf, inp_prf_offset,
+ reg_long_offt));
}
}
}
@@ -686,7 +729,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w,
pop(reg_out);
pop(reg_out_prf);
}
-
+ if (max_input_offset > INT_MAX) pop(reg_inp_prf);
store_output(ur_w);
}
@@ -707,18 +750,19 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w,
auto input_offset = [=](int oi, int ic, int ki) {
- return jcp.typesize_in
- * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) * inp_mul
- + ic * (!jcp.is_1stconv ? 1 : jcp.iw * jcp.ih * jcp.id));
+ return (size_t)jcp.typesize_in
+ * ((size_t)(ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l)
+ * inp_mul + (size_t)ic
+ * (!jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id));
};
+ prepare_output(ur_w);
+
if (jcp.ndims == 4) {
mov(aux_reg_inp, reg_inp);
mov(aux_reg_ker, reg_ker);
}
- prepare_output(ur_w);
-
if (jcp.ndims == 5) {
push(reg_out);
@@ -726,7 +770,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w,
mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
mov(aux_reg_inp_d, reg_inp);
- if ((jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) {
+ if ((jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
cmp(reg_ki, 0);
je(skip_kd_loop, T_NEAR);
}
@@ -735,7 +779,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w,
} else {
mov(reg_kj, reg_kh);
}
- if ((jcp.kh - 1) * (jcp.dilate_h + 1) < jcp.t_pad) {
+ if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
cmp(reg_kj, 0);
je(skip_kh_loop, T_NEAR);
}
@@ -753,9 +797,10 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w,
for (int ic = 0; ic < ic_block; ic++) {
if (jcp.kernel_kind == expl_bcast) {
for (int jj = jj_start; jj < jj_end; jj++) {
- int aux_input_offset = input_offset(jj, ic, ki);
+ size_t aux_input_offset = input_offset(jj, ic, ki);
vbroadcastss(zmm_inp(jj, nb_oc_block),
- ptr[aux_reg_inp + aux_input_offset]);
+ EVEX_compress_addr_safe(aux_reg_inp,
+ aux_input_offset, reg_long_offt));
}
}
for (int ii = 0; ii < nb_oc_block; ii++) {
@@ -769,10 +814,12 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w,
if (jcp.kernel_kind == expl_bcast)
vfmadd231ps(zmm_out(jj, ii),
zmm_inp(jj, nb_oc_block), zmm_wei);
- else
+ else {
+ size_t aux_input_offset = input_offset(jj, ic, ki);
vfmadd231ps(zmm_out(jj, ii), zmm_wei,
- EVEX_compress_addr(aux_reg_inp,
- input_offset(jj, ic, ki), true));
+ EVEX_compress_addr_safe(aux_reg_inp,
+ aux_input_offset, reg_long_offt, true));
+ }
}
}
}
@@ -813,6 +860,13 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni(
const int shift_input_ptr
= jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block;
+ size_t max_input_offset = (size_t)jcp.typesize_in
+ * jcp.ic_block * jcp.iw * jcp.ih * jcp.id;
+ assert(reg_inp_prf == reg_long_offt);
+ if (max_input_offset > INT_MAX) push(reg_inp_prf);
+
+ prepare_output(ur_w);
+
if (jcp.ndims == 4) {
mov(aux_reg_inp, reg_inp);
mov(aux_reg_ker, reg_ker);
@@ -820,8 +874,6 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni(
mov(aux_reg_inp_prf, reg_inp_prf);
}
- prepare_output(ur_w);
-
Label skip_kh_loop, skip_kd_loop;
if (jcp.ndims == 5) {
@@ -834,7 +886,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni(
mov(aux_reg_inp_d_prf, reg_inp_prf);
mov(aux_reg_ker_d_prf, reg_ker_prf);
- if (jcp.kd <= jcp.f_pad) {
+ if ((jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
cmp(reg_ki, 0);
je(skip_kd_loop, T_NEAR);
}
@@ -843,7 +895,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni(
} else {
mov(reg_kj, reg_kh);
}
- if ((jcp.kh - 1) * (jcp.dilate_h + 1) < jcp.t_pad) {
+ if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
cmp(reg_kj, 0);
je(skip_kh_loop, T_NEAR);
}
@@ -861,9 +913,10 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni(
for (int ic = 0; ic < jcp.ic_block / 2; ic += channel_inc) {
if (jcp.kernel_kind == expl_bcast) {
for (int oi = ow_start; oi < ow_end; oi++) {
- int input_offset = get_input_offset(ki, ic, oi, pad_l);
+ size_t input_offset = get_input_offset(ki, ic, oi, pad_l);
vpbroadcastd(zmm_inp(oi, jcp.nb_oc_blocking),
- ptr[aux_reg_inp + input_offset]);
+ EVEX_compress_addr_safe(aux_reg_inp, input_offset,
+ reg_long_offt));
}
}
for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
@@ -881,14 +934,14 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni(
}
}
for (int oi = ow_start, prf_count = 0; oi < ow_end; oi++) {
- int input_offset = get_input_offset(ki, ic, oi, pad_l);
-
+ size_t input_offset = get_input_offset(ki, ic, oi, pad_l);
if (jcp.kernel_kind == expl_bcast) {
vpdpwssd(zmm_out(oi, kk), zmm_wei,
zmm_inp(oi, jcp.nb_oc_blocking));
} else {
vpXdpwssd(zmm_out(oi, kk), Zmm(ker_reg_base_idx),
- aux_reg_inp, input_offset);
+ EVEX_compress_addr_safe(aux_reg_inp, input_offset,
+ reg_long_offt, jcp.ver != ver_4vnni));
}
if ((oi % 2) && (prf_count < ker_load_number)) {
int kernel_offset = get_kernel_offset(
@@ -897,12 +950,12 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni(
kernel_offset));
}
if (!(oi % 2) && ki == 0 && ic == 0 && kk == 0) {
- mic_prefetcht1(EVEX_compress_addr(aux_reg_inp_prf,
- input_offset));
+ mic_prefetcht1(EVEX_compress_addr_safe(
+ aux_reg_inp_prf, input_offset, reg_long_offt));
}
if (!(oi % 2) && ki == 1 && ic == 0 && kk == 0) {
- mic_prefetcht0(EVEX_compress_addr(aux_reg_inp,
- input_offset + shift_input_ptr));
+ mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp,
+ input_offset + shift_input_ptr, reg_long_offt));
}
}
}
@@ -936,7 +989,7 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni(
pop(reg_out);
pop(reg_out_prf);
}
-
+ if (max_input_offset > INT_MAX) pop(reg_inp_prf);
store_output(ur_w);
}
@@ -967,6 +1020,30 @@ void jit_avx512_common_conv_fwd_kernel::compute_loop(int ur_w,
void jit_avx512_common_conv_fwd_kernel::generate()
{
+ if (jcp.with_eltwise) {
+ eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>(
+ this, jcp.eltwise_alg, jcp.eltwise_alpha, 0
+ ));
+ }
+
+ const auto &p = attr_.post_ops_;
+ for (int i = 0; i < p.len_; i++) {
+ auto &post_op = p.entry_[i];
+ if (post_op.is_eltwise()) {
+ eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>(
+ this,
+ post_op.eltwise.alg,
+ post_op.eltwise.alpha,
+ post_op.eltwise.beta
+ ));
+ } else if (post_op.is_depthwise()) {
+ depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx512_common>(
+ this,
+ post_op.depthwise.alg
+ ));
+ }
+ }
+
int iw = jcp.iw;
int ow = jcp.ow;
int kw = jcp.kw;
@@ -975,25 +1052,11 @@ void jit_avx512_common_conv_fwd_kernel::generate()
int ur_w_tail = jcp.ur_w_tail;
int dilate_w = jcp.dilate_w + 1;
int stride_w = jcp.stride_w;
- int ic_block = jcp.ic_block;
- int oc_block = jcp.oc_block;
- int inp_mult = !jcp.is_1stconv ? ic_block : 1;
+ int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
int inp_shift_pad = jcp.typesize_in * (ur_w * stride_w - l_pad) * inp_mult;
- int inp_shift = jcp.typesize_in * (ur_w * stride_w * inp_mult);
- int out_shift = jcp.typesize_out * (ur_w * oc_block);
-
- nstl::vector<int> shared_vecs;
- shared_vecs.push_back(27);
- shared_vecs.push_back(28);
- shared_vecs.push_back(29);
- shared_vecs.push_back(30);
- shared_vecs.push_back(31);
-
- nstl::vector<Reg64> shared_regs;
- shared_regs.push_back(imm_addr64);
-
- eltwise_generator.init(jcp.eltwise_alg, shared_vecs, shared_regs);
+ int inp_shift = jcp.typesize_in * ur_w * stride_w * inp_mult;
+ int out_shift = jcp.typesize_out * ur_w * jcp.oc_block;
preamble();
mov(reg_inp, ptr[param1 + GET_OFF(src)]);
@@ -1069,11 +1132,8 @@ void jit_avx512_common_conv_fwd_kernel::generate()
postamble();
- // TODO (dmitrygo): need to find appropriate way to share labels.
- align(64);
- L(l_table);
- inject(eltwise_generator.prepareTable());
- eltwise_generator.release();
+ for (auto& inj : eltwise_injectors)
+ inj->prepare_table();
}
bool jit_avx512_common_conv_fwd_kernel::post_ops_ok(
@@ -1081,16 +1141,22 @@ bool jit_avx512_common_conv_fwd_kernel::post_ops_ok(
const auto &p = attr.post_ops_;
auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+ auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
+ auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
switch (p.len_) {
case 0: return true; // no post_ops
case 1:
- return true // sum OR relu
- && !jcp.with_eltwise && (is_eltwise(0) || is_sum(0));
+ return true // sum OR eltwise OR depthwise
+ && !jcp.with_eltwise && (is_simple(0) || is_sum(0));
case 2:
return true // sum->relu
- && !jcp.with_eltwise && (is_sum(0) && is_eltwise(1));
+ && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) ||
+ (is_simple(0) && is_simple(1)));
+ case 3:
+ return true // sum->relu
+ && !jcp.with_eltwise && (is_sum(0) && is_simple(1) && is_simple(2));
default: return false;
}
@@ -1145,12 +1211,18 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
jcp.stride_w = cd.strides[ndims-3];
jcp.src_fmt = src_d.format();
jcp.with_eltwise = with_relu;
+ jcp.eltwise_alg = mkldnn_eltwise_relu;
jcp.eltwise_alpha = relu_negative_slope;
jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
jcp.dilate_h = cd.dilates[ndims-4];
jcp.dilate_w = cd.dilates[ndims-3];
+ jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
+ - (jcp.ih + jcp.t_pad - 1);
+ jcp.back_pad = (jcp.od - 1) * jcp.stride_d
+ + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1);
+
jcp.is_1stconv = is_1stconv(jcp);
jcp.oc_block = simd_w;
@@ -1165,34 +1237,17 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
}
-
-
- jcp.with_eltwise = with_relu;
- jcp.eltwise_alg = mkldnn_eltwise_relu;
- jcp.eltwise_alpha = relu_negative_slope;
+ bool args_ok = true
+ && jcp.oc % jcp.oc_block == 0
+ && jcp.ic % jcp.ic_block == 0;
+ if (!args_ok)
+ return status::unimplemented;
if (!post_ops_ok(jcp, attr))
return status::unimplemented;
const auto &p = attr.post_ops_;
jcp.with_sum = p.find(primitive_kind::sum) != -1;
- if (!jcp.with_eltwise) {
- int eltwise_ind = p.find(primitive_kind::eltwise);
- if (eltwise_ind != -1) {
- jcp.with_eltwise = true;
- jcp.eltwise_alg = p.entry_[eltwise_ind].eltwise.alg;
- jcp.eltwise_alpha = p.entry_[eltwise_ind].eltwise.alpha;
- jcp.eltwise_beta = p.entry_[eltwise_ind].eltwise.beta;
- jcp.eltwise_scale = p.entry_[eltwise_ind].eltwise.scale;
- }
- }
-
- if (jcp.with_eltwise) {
- int nvecs_elt = jit_uni_eltwise_vector_f32<avx512_common>::sharedVecsCount(jcp.eltwise_alg);
- int elt_regs = 32 - nvecs_elt;
-
- regs = nstl::min(regs, elt_regs);
- }
jcp.is_1stconv = is_1stconv(jcp);
if (jcp.ic % simd_w != 0 && !jcp.is_1stconv)
@@ -1368,7 +1423,8 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
if (one_of(jcp.ver, ver_4vnni, ver_4fma) && !jcp.is_1stconv) {
if (jcp.kw == 3 && jcp.kh == 3 && jcp.ow == 7 && jcp.oh == 7) {
- jcp.nb_oc_blocking = 2;
+ if (jcp.nb_oc % 2 == 0)
+ jcp.nb_oc_blocking = 2;
} else {
for (int i = jcp.nb_oc; i > 0; i--)
if (i * jcp.ur_w <= regs && jcp.nb_oc % i == 0) {
@@ -1398,6 +1454,7 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
jcp.kernel_kind = embd_bcast;
unsigned int inp_size = jcp.mb * (jcp.ih / jcp.stride_h)
* (jcp.iw / jcp.stride_w) * jcp.ic;
+ if (inp_size == 0) inp_size = 1;
unsigned int wei_size = jcp.ic * jcp.oc * jcp.kh * jcp.kw;
// Estimate whether we need to limit the number of threads
@@ -1466,9 +1523,7 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
jcp.ur_w_tail = jcp.ow % jcp.ur_w;
- bool args_ok = true
- && jcp.oc % jcp.oc_block == 0
- && jcp.ic % jcp.ic_block == 0
+ args_ok = true
&& jcp.l_pad <= jcp.ur_w
&& jcp.ic <= src_d.blocking_desc().padding_dims[1]
&& jcp.oc <= dst_d.blocking_desc().padding_dims[1]
@@ -1491,11 +1546,12 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
if (jcp.ver == ver_4fma) {
for (int divf = 2, temp_nb = jcp.nb_ic_L2; divf <= jcp.nb_ic;
divf++) {
- int l2_src = jcp.iw * jcp.ic_block * jcp.ih * temp_nb * jcp.id;
- int l2_dst = jcp.ow * jcp.oc_block * jcp.nb_oc_blocking
+ size_t l2_src
+ = (size_t)jcp.iw * jcp.ic_block * jcp.ih * temp_nb * jcp.id;
+ size_t l2_dst = (size_t)jcp.ow * jcp.oc_block * jcp.nb_oc_blocking
* jcp.oh * jcp.od;
- int l2_filt = jcp.kw * jcp.oc_block * jcp.ic_block * jcp.kh
- * jcp.nb_oc_blocking * temp_nb * jcp.kd;
+ size_t l2_filt = (size_t)jcp.kw * jcp.oc_block * jcp.ic_block
+ * jcp.kh * jcp.nb_oc_blocking * temp_nb * jcp.kd;
if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) {
if (jcp.kh == 3 && jcp.oh == 7) {
jcp.nb_ic_L2 = 1;
@@ -1510,10 +1566,6 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(
}
}
- // TODO (dmitrygo): we need at least 5 vector registers to fuse eltwise. Need to adapt unrolling scheme
- if (jcp.with_eltwise && (jcp.nb_oc_blocking * jcp.ur_w) > 27)
- return status::unimplemented;
-
return status::success;
}
@@ -1523,9 +1575,11 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::prepare_output(int ur_w)
for (int j = 0; j < ur_w; j++) {
Zmm zmm = zmm_out(j, k);
vpxord(zmm, zmm, zmm);
- int aux_src_offset
- = typesize * (k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block;
- mic_prefetcht1(EVEX_compress_addr(reg_src_prf, aux_src_offset));
+ size_t aux_src_offset
+ = (size_t)typesize * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j)
+ * jcp.ic_block;
+ mic_prefetcht1(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset,
+ reg_long_offt));
}
}
}
@@ -1540,9 +1594,10 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::store_output(int ur_w)
for (int k = 0; k < jcp.nb_ic_blocking; k++) {
for (int j = 0; j < ur_w; j++) {
Zmm zmm = zmm_out(j, k);
- int aux_src_offset
- = typesize * (k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block;
- vadd(zmm, reg_src, aux_src_offset);
+ size_t aux_src_offset = (size_t)typesize
+ * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block;
+ vadd(zmm, EVEX_compress_addr_safe(reg_src, aux_src_offset,
+ reg_long_offt));
}
}
@@ -1550,10 +1605,12 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::store_output(int ur_w)
for (int k = 0; k < jcp.nb_ic_blocking; k++) {
for (int j = 0; j < ur_w; j++) {
Zmm zmm = zmm_out(j, k);
- int aux_src_offset
- = typesize * (k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block;
- vmovups(EVEX_compress_addr(reg_src, aux_src_offset), zmm);
- mic_prefetcht0(EVEX_compress_addr(reg_src_prf, aux_src_offset));
+ size_t aux_src_offset = (size_t)typesize
+ * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block;
+ vmovups(EVEX_compress_addr_safe(reg_src, aux_src_offset,
+ reg_long_offt), zmm);
+ mic_prefetcht0(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset,
+ reg_long_offt));
}
}
}
@@ -1574,13 +1631,6 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma(
bool check_last_kh = (jcp.kh > 3);
- if (jcp.ndims == 4) {
- mov(aux_reg_dst, reg_dst);
- mov(aux_reg_ker, reg_ker);
- mov(aux_reg_dst_prf, reg_dst_prf);
- mov(aux_reg_ker_prf, reg_ker_prf);
- }
-
auto kernel_offset = [=](int icb, int oc, int ki) {
int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki;
int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
@@ -1606,6 +1656,13 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma(
prepare_output(ur_w);
+ if (jcp.ndims == 4) {
+ mov(aux_reg_dst, reg_dst);
+ mov(aux_reg_ker, reg_ker);
+ mov(aux_reg_dst_prf, reg_dst_prf);
+ mov(aux_reg_ker_prf, reg_ker_prf);
+ }
+
if (jcp.ndims == 5) {
push(reg_src_prf);
push(reg_src);
@@ -1764,11 +1821,6 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_vnni(
const int ker_load_number = jcp.ver == ver_4vnni ? 4 : 1;
Label kh_label;
- mov(aux_reg_dst, reg_dst);
- mov(aux_reg_ker, reg_ker);
- mov(aux_reg_dst_prf, reg_dst_prf);
- mov(aux_reg_ker_prf, reg_ker_prf);
-
auto kernel_offset = [=](int icb, int oc, int ki) {
int blk_idx = icb * jcp.kh * jcp.kw + ki;
int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
@@ -1778,6 +1830,11 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_vnni(
prepare_output(ur_w);
+ mov(aux_reg_dst, reg_dst);
+ mov(aux_reg_ker, reg_ker);
+ mov(aux_reg_dst_prf, reg_dst_prf);
+ mov(aux_reg_ker_prf, reg_ker_prf);
+
mov(reg_kj, reg_kh);
L(kh_label); {
for (int ki = 0; ki < kw; ki++) {
@@ -2047,13 +2104,13 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core(
return typesize * (blk_offset + oc_offset);
};
+ prepare_output(ur_w);
+
if (jcp.ndims == 4) {
mov(aux_reg_dst, reg_dst);
mov(aux_reg_ker, reg_ker);
}
- prepare_output(ur_w);
-
if (jcp.ndims == 5) {
push(reg_src_prf);
push(reg_src);
@@ -2067,10 +2124,8 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core(
} else {
mov(reg_kj, reg_kh);
}
- if ((jcp.kh - 1) * (jcp.dilate_h + 1) < jcp.t_pad) {
- cmp(reg_kj, 0);
- je(skip_kh_loop, T_NEAR);
- }
+ cmp(reg_kj, 0);
+ je(skip_kh_loop, T_NEAR);
if (jcp.ndims == 5) {
mov(aux_reg_dst, aux_reg_dst_d);
@@ -2095,7 +2150,7 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core(
if (jj_end - jj_start > 0)
vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker,
aux_kernel_offset));
- for (int jj = jj_start; jj < jj_end; jj++)
+ for (int jj = jj_start; jj < jj_end; jj += stride_w)
if (jcp.kernel_kind == expl_bcast)
vfmadd231ps(zmm_out(jj, ii),
zmm_inp(jj, nb_ic_block), zmm_wei);
@@ -2496,11 +2551,11 @@ status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(
if (jcp.ver == ver_4fma && (jcp.kh < 5 && jcp.kw < 5)) {
for (int divf = 2, temp_nb = jcp.nb_oc_L2; divf <= jcp.nb_oc;
divf++) {
- int l2_src = jcp.iw * jcp.ic_block * jcp.nb_ic_blocking * jcp.ih
+ size_t l2_src = jcp.iw * jcp.ic_block * jcp.nb_ic_blocking * jcp.ih
* jcp.id;
- int l2_dst = jcp.ow * jcp.oc_block * temp_nb * jcp.oh * jcp.od;
- int l2_filt = jcp.kw * jcp.oc_block * jcp.ic_block * jcp.kh * jcp.kd
- * jcp.nb_ic_blocking * temp_nb;
+ size_t l2_dst = jcp.ow * jcp.oc_block * temp_nb * jcp.oh * jcp.od;
+ size_t l2_filt = jcp.kw * jcp.oc_block * jcp.ic_block * jcp.kh
+ * jcp.kd * jcp.nb_ic_blocking * temp_nb;
if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) {
if (jcp.kh == 3 && jcp.ih == 7) {
jcp.nb_oc_L2 = 1;
@@ -2535,7 +2590,8 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32::od_step_comeback_pointers()
int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
int iw = (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni))
? jcp.tr_iw : jcp.iw;
- sub(aux_reg_input, jcp.typesize_in * jcp.ih * iw * inp_mult);
+ sub(aux_reg_input,
+ jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mult);
sub(aux_reg_kernel,
jcp.typesize_out * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block);
dec(kj);
@@ -2552,7 +2608,7 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers()
int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
int iw = (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni))
? jcp.tr_iw : jcp.iw;
- sub(reg_input, jcp.typesize_in * iw * inp_mult);
+ sub(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mult);
sub(reg_kernel,
jcp.typesize_out * jcp.kw * jcp.ic_block * jcp.oc_block);
dec(kj);
@@ -2596,19 +2652,21 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_fma(
+ output_offset));
for (int i_kw = 0; i_kw < kw; i_kw++) {
- int i_iw = i_ur * jcp.stride_w + i_kw;
- if (i_iw - pad_l < 0 || i_iw > (ur_w - 1) * jcp.stride_w + kw - 1
- - pad_r) continue;
+ int i_iw = i_ur * jcp.stride_w + i_kw * (jcp.dilate_w + 1);
+ if (i_iw - pad_l < 0 || i_iw > (ur_w - 1) * jcp.stride_w +
+ (kw - 1) * (jcp.dilate_w + 1) - pad_r) continue;
for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
- const int i_offset = input_offset
- + typesize * (jcp.ver == ver_4fma
+ const size_t i_offset = (size_t)input_offset
+ + (size_t)typesize * (jcp.ver == ver_4fma
? (i_iw - pad_l + i_ic * jcp.tr_iw)
: (jcp.is_1stconv
- ? (i_iw - pad_l) + i_ic * (jcp.ih*jcp.iw*jcp.id)
+ ? (i_iw - pad_l) + (size_t)i_ic
+ * ((size_t)jcp.ih*jcp.iw*jcp.id)
: (i_iw - pad_l) * ic_block + i_ic));
vfmadd231ps(Zmm(i_kw * ic_block_step + i_ic),
Zmm(kw * ic_block_step + i_ur % 4),
- EVEX_compress_addr(reg_input, i_offset, true));
+ EVEX_compress_addr_safe(reg_input, i_offset, reg_long_offt,
+ true));
}
}
}
@@ -2868,8 +2926,8 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32
? jcp.tr_iw : jcp.iw;
int ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow : jcp.ow;
- int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w + jcp.kw - 1
- - (jcp.iw + jcp.l_pad - 1));
+ int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
+ + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
int l_pad = jcp.l_pad;
if (jcp.ndims == 5) {
@@ -2892,15 +2950,16 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32
input_offset, jcp.typesize_out * i_b_ic * jcp.oc_block, 0,
i_b_ic + ic_block_step >= jcp.ic_block);
}
- add(reg_input, jcp.typesize_in * iw * inp_mul);
- add(reg_kernel, jcp.typesize_out * (jcp.kw) * ic_block * oc_block);
+ add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
+ add(reg_kernel, jcp.typesize_out * jcp.kw * ic_block * oc_block);
dec(kj);
cmp(kj, 0);
jg(kh_label, T_NEAR);
}
if (jcp.ndims == 5) {
- add(aux_reg_input, jcp.typesize_in * jcp.ih * iw * inp_mul);
+ add(aux_reg_input,
+ jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mul);
add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block
* oc_block);
dec(ki);
@@ -2923,7 +2982,7 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32
int ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow : jcp.ow;
int r_pad = nstl::max(0,
- (ow - 1) * jcp.stride_w + jcp.kw - 1
+ (ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
- (jcp.iw + jcp.l_pad - 1));
int l_pad = jcp.l_pad;
@@ -2943,10 +3002,13 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32
L(ic_block_label); {
compute_ic_block_step(ow, l_pad, r_pad, ic_block_step,
0, 0, 0);
- int inp_icblk_stride = jcp.is_1stconv ? jcp.ih * jcp.iw * jcp.id
+ size_t inp_icblk_stride = jcp.is_1stconv
+ ? (size_t)jcp.ih * jcp.iw * jcp.id
: (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
? jcp.tr_iw : 1);
- add(reg_input, jcp.typesize_in * ic_block_step * inp_icblk_stride);
+ size_t input_offset
+ = inp_icblk_stride * jcp.typesize_in * ic_block_step;
+ safe_add(reg_input, input_offset, reg_long_offt);
add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block);
add(b_ic, ic_block_step);
cmp(b_ic, jcp.ic_block);
@@ -2954,10 +3016,13 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32
}
if (jcp.is_1stconv) {
- sub(reg_input, jcp.typesize_in * jcp.ih*jcp.iw*jcp.id * ic_block);
- add(reg_input, jcp.typesize_in * jcp.iw);
+ size_t input_offset
+ = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block;
+ safe_sub(reg_input, input_offset, reg_long_offt);
+ add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw);
} else if (!utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) {
- add(reg_input, jcp.typesize_in * (jcp.iw - 1) * ic_block);
+ add(reg_input, jcp.typesize_in
+ * ((jcp.dilate_h + 1) * jcp.iw - 1) * ic_block);
}
add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block);
dec(kj);
@@ -2965,8 +3030,8 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32
jg(kh_label, T_NEAR);
}
if (jcp.ndims == 5) {
- add(aux_reg_input, jcp.typesize_in * jcp.ih * jcp.iw
- * ((jcp.is_1stconv) ? 1 : ic_block));
+ add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih
+ * jcp.iw * (jcp.is_1stconv ? 1 : ic_block));
add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block
* oc_block);
dec(ki);
@@ -2985,8 +3050,8 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32
int oc_block = jcp.oc_block;
int ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow : jcp.ow;
- int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w + jcp.kw - 1
- - (jcp.iw + jcp.l_pad - 1));
+ int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
+ + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
int l_pad = (jcp.ver == ver_4fma || jcp.ver == ver_4vnni
|| jcp.ver == ver_vnni) ? 0 : jcp.l_pad;
@@ -3053,7 +3118,9 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32
? jcp.ih * jcp.iw * jcp.id
: (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
? jcp.tr_iw : 1);
- add(reg_input, jcp.typesize_in * ic_block_step * inp_icblk_stride);
+ size_t input_offset
+ = inp_icblk_stride * jcp.typesize_in * ic_block_step;
+ safe_add(reg_input, input_offset, reg_long_offt);
add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block);
add(b_ic, ic_block_step);
@@ -3061,10 +3128,13 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32
jl(ic_block_label, T_NEAR);
}
if (jcp.is_1stconv) {
- sub(reg_input, jcp.typesize_in * jcp.ih*jcp.iw*jcp.id * ic_block);
- add(reg_input, jcp.typesize_in * jcp.iw);
+ size_t input_offset
+ = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block;
+ safe_sub(reg_input, input_offset, reg_long_offt);
+ add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw);
} else if (!utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) {
- add(reg_input, jcp.typesize_in * (jcp.iw - 1) * ic_block);
+ add(reg_input, jcp.typesize_in
+ * ((jcp.dilate_h + 1 ) * jcp.iw - 1) * ic_block);
}
add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block);
dec(kj);
@@ -3072,15 +3142,14 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32
jg(kh_label, T_NEAR);
}
if (jcp.ndims == 5) {
- add(aux_reg_input, jcp.typesize_in * jcp.ih * jcp.iw *
- ((jcp.is_1stconv) ? 1 : ic_block ));
+ add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih
+ * jcp.iw * (jcp.is_1stconv ? 1 : ic_block));
add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block
* oc_block);
dec(ki);
cmp(ki, 0);
jg(kd_label, T_NEAR);
}
-
}
void jit_avx512_common_conv_bwd_weights_kernel_f32
@@ -3179,19 +3248,19 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel()
void jit_avx512_common_conv_bwd_weights_kernel_f32
::compute_oh_loop_common()
{
- int back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + jcp.kd - 1
- - (jcp.id + jcp.f_pad - 1));
-
- int b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - 1
- - (jcp.ih + jcp.t_pad - 1));
+ int back_pad = jcp.back_pad;
+ int b_pad = jcp.b_pad;
int t_pad = jcp.t_pad;
+ bool is_dilated = jcp.dilate_h != 0;
+ int dilate_h = jcp.dilate_h + 1;
int stride_h = jcp.stride_h;
- int idp = jcp.id + jcp.f_pad + jcp.back_pad;
+ int idp = jcp.id + jcp.f_pad + back_pad;
const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
int iw = utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni) ? jcp.tr_iw
: jcp.iw;
Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_tail_label,
- oh_bpad_label, oh_bpad_label_end, od_label, od_label_end;
+ oh_bpad_label, oh_bpad_label_end, od_label, od_label_end,
+ oh_dilate_label_shift, oh_dilate_label_noshift, oh_dilate_label_end;
maybe_zero_kernel();
if (jcp.ndims == 5 && jcp.with_bias) bias_kernel();
@@ -3218,30 +3287,60 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32
xor_(reg_ih_count, reg_ih_count);
xor_(reg_oj, reg_oj);
if (t_pad > 0) {
- mov(reg_kh, jcp.kh <= t_pad + jcp.ih ? jcp.kh - t_pad : jcp.ih);
- add(reg_kernel, jcp.typesize_out * t_pad * jcp.kw * jcp.ic_block
+ const int kh_range = 1 + (jcp.kh - 1) * dilate_h;
+ const int overflow
+ = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h));
+ const int underflow = div_up(t_pad, dilate_h);
+ const int initial_inp_ker_overlap = jcp.kh - overflow - underflow;
+ mov(reg_kh, initial_inp_ker_overlap);
+ add(reg_kernel, jcp.typesize_out * underflow * jcp.kw * jcp.ic_block
* jcp.oc_block);
// generate loop to process kernel while it remains within t_pad + ih
- if (jcp.kh < t_pad + jcp.ih) {
+ if (kh_range < t_pad + jcp.ih) {
+ if (is_dilated) {
+ const int tail = t_pad % dilate_h;
+ const int shift = tail == 0 ? 0 : dilate_h - tail;
+ mov(reg_tmp, shift);
+ if (tail != 0)
+ add(reg_input, jcp.typesize_in * shift * iw * inp_mult);
+ }
L(oh_tpad_label); {
compute_oh_step_disp();
add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
+ if (is_dilated) {
+ inc(reg_tmp);
+ cmp(reg_tmp, dilate_h);
+ jl(oh_dilate_label_shift, T_NEAR);
+ // unshift input as new kernel element enters
+ sub(reg_input, jcp.typesize_in * (dilate_h - 1) * iw * inp_mult);
+ xor_(reg_tmp, reg_tmp);
+ }
+ // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw
* jcp.ic_block * jcp.oc_block);
-
+ add(reg_kh, stride_h);
+ if (is_dilated) {
+ jmp(oh_dilate_label_noshift, T_NEAR);
+ L(oh_dilate_label_shift);
+ // shift input as old kernel element progresses
+ add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
+ L(oh_dilate_label_noshift);
+ }
inc(reg_oj);
add(reg_ih_count, stride_h);
- add(reg_kh, stride_h);
- // at least jcp.ih cells of kernel ultimately overlap with input
- const int final_inp_ker_overlap = nstl::min(jcp.kh, jcp.ih);
+ // final number of kernel elements that overlap with input
+ const int final_inp_ker_overlap
+ = nstl::min(jcp.kh, div_up(jcp.ih, dilate_h));
cmp(reg_kh, final_inp_ker_overlap);
jl(oh_tpad_label, T_NEAR);
}
}
// need second loop to process kernel if it is larger than the input
- if (jcp.kh >= jcp.ih + (t_pad % stride_h == 0 ? stride_h :
+ // (does not apply to dilations as they must have unit stride)
+ if (kh_range >= jcp.ih + (t_pad % stride_h == 0 ? stride_h :
t_pad % stride_h)) {
+ assert(!is_dilated);
mov(reg_kh, jcp.ih);
L(oh_tpad_tail_label); {
compute_oh_step_disp();
@@ -3257,9 +3356,12 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32
}
}
// correct any excess shifts to kernel and input
+ // (does not apply to dilations as they must have unit stride,
+ // kernel must fit inside input, and padding is smaller than input)
if (t_pad <= jcp.oh * stride_h) {
// kernel has moved beyond padding (adjust for stride effects)
if (t_pad % stride_h != 0) {
+ assert(!is_dilated);
int inp_corr = stride_h - t_pad % stride_h;
add(reg_kernel, jcp.typesize_out * inp_corr * jcp.kw
* jcp.ic_block * jcp.oc_block);
@@ -3267,12 +3369,13 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32
}
} else {
// kernel still overlaps padding (complete reset)
+ assert(!is_dilated);
sub(reg_kernel, jcp.typesize_out * (t_pad - jcp.oh * stride_h)
* jcp.kw * jcp.ic_block * jcp.oc_block);
}
}
- cmp(reg_ih_count, jcp.ihp - b_pad - jcp.kh + 1);
+ cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h);
jge(oh_label_end, T_NEAR);
cmp(reg_oj, jcp.oh);
jge(oh_label, T_NEAR);
@@ -3286,7 +3389,7 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32
inc(reg_oj);
add(reg_ih_count, stride_h);
- cmp(reg_ih_count, jcp.ihp - b_pad - jcp.kh + 1);
+ cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h);
jge(oh_label_end, T_NEAR);
cmp(reg_oj, jcp.oh);
@@ -3298,17 +3401,29 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32
cmp(reg_oj, jcp.oh);
jge(oh_bpad_label_end, T_NEAR);
- mov(reg_kh, jcp.ihp - b_pad);
- sub(reg_kh, reg_ih_count);
+ if (is_dilated) {
+ mov(reg_kh, jcp.kh - 1); // assumes unit stride for dilations
+ mov(reg_tmp, 0);
+ } else {
+ mov(reg_kh, jcp.ihp - b_pad);
+ sub(reg_kh, reg_ih_count);
+ }
L(oh_bpad_label);
{
compute_oh_step_disp();
add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
-
+ if (is_dilated) {
+ inc(reg_tmp);
+ cmp(reg_tmp, dilate_h);
+ jl(oh_dilate_label_end, T_NEAR);
+ xor_(reg_tmp, reg_tmp);
+ }
sub(reg_kh, stride_h);
cmp(reg_kh, 0);
jle(oh_bpad_label_end, T_NEAR);
+ if (is_dilated)
+ L(oh_dilate_label_end);
inc(reg_oj);
cmp(reg_oj, jcp.oh);
@@ -3323,13 +3438,13 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32
pop(reg_output_d);
pop(reg_input_d);
- add(reg_input_d, jcp.typesize_in *jcp.stride_d*jcp.ih * iw * inp_mult);
- add(reg_output_d, jcp.typesize_in * ow * jcp.oh * jcp.oc_block);
+ add(reg_input_d, jcp.typesize_in * jcp.stride_d * jcp.ih * iw * inp_mult);
+ add(reg_output_d, jcp.typesize_in * jcp.oh * ow * jcp.oc_block);
dec(reg_oi);
add(reg_id_count, jcp.stride_d);
- cmp(reg_id_count, idp - back_pad - jcp.kd + 1);
+ cmp(reg_id_count, idp - back_pad - (jcp.kd - 1) * (jcp.dilate_d + 1));
jge(od_label_end, T_NEAR);
cmp(reg_oi, 0);
@@ -3345,6 +3460,7 @@ bool jit_avx512_common_conv_bwd_weights_kernel_f32
// FIXME: use register mapping from the class declaration
bool ok = one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
&& (jcp.ver == ver_4fma || !one_of(1, jcp.kh, jcp.kw))
+ && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
&& everyone_is(1, jcp.stride_h, jcp.stride_w);
if (!ok) return false;
if (jcp.l_pad != jcp.kw / 2 || jcp.t_pad != jcp.kh / 2)
@@ -3901,10 +4017,11 @@ bool jit_avx512_common_conv_bwd_weights_kernel_f32
return true;
}
-bool jit_avx512_common_conv_bwd_weights_kernel_f32
+bool jit_avx512_common_conv_bwd_weights_kernel_f32
::flat_4ops_compute() {
const auto &j = jcp;
- const bool ok = j.ver == ver_4fma && j.is_1stconv;
+ const bool ok = j.ver == ver_4fma && j.is_1stconv
+ && everyone_is(0, j.dilate_h, j.dilate_w);
if (!ok) return false;
Reg64 reg_ptr_tr_src = r8;
@@ -4160,15 +4277,24 @@ status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
jcp.dilate_h = cd.dilates[ndims-4];
jcp.dilate_w = cd.dilates[ndims-3];
- if (jcp.dilate_h != 0 || jcp.dilate_w != 0 || jcp.dilate_d != 0)
+
+ const int kh_range = 1 + (jcp.kh - 1) * (jcp.dilate_h + 1);
+ bool ok = true
+ // general condition to simplify dilations
+ && implication(jcp.dilate_d != 0, jcp.stride_d == 1)
+ && implication(jcp.dilate_h != 0, jcp.stride_h == 1)
+ && implication(jcp.dilate_w != 0, jcp.stride_w == 1)
+ // special condition to simplify dilations in compute_oh_loop_common
+ && implication(jcp.dilate_h != 0, kh_range <= jcp.ih);
+ if (!ok)
return status::unimplemented;
- jcp.r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw
- - jcp.l_pad);
- jcp.b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih
- - jcp.t_pad);
- jcp.back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + jcp.kd - jcp.id
- - jcp.f_pad);
+ jcp.r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w
+ + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
+ jcp.b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h
+ + (jcp.kh - 1) * (jcp.dilate_h + 1) - (jcp.ih + jcp.t_pad - 1));
+ jcp.back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d
+ + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1));
if ( ndims == 5 )
if (jcp.f_pad != 0 || jcp.back_pad != 0)
@@ -4219,9 +4345,10 @@ status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
/* kernel applicability check wrt boundaries
* the conditions are quite general across the kernels we have,
* but ideally the check should belong to a specific kernel... */
+ const int max_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1) / 2;
const bool boundaries_ok = true
- && jcp.t_pad <= jcp.kh / 2
- && jcp.b_pad <= jcp.kh / 2;
+ && jcp.t_pad <= max_pad
+ && jcp.b_pad <= max_pad;
if (!boundaries_ok)
return status::unimplemented;
@@ -4261,6 +4388,8 @@ status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
const bool use_4fma = true
&& ndims == 4
&& mayiuse(avx512_mic_4ops)
+ && mkldnn_thr_syncable()
+ && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w)
&& everyone_is(0, jcp.l_pad, jcp.r_pad, jcp.t_pad, jcp.b_pad)
&& jcp.kw <= 28 - jcp.with_bias
&& jcp.stride_w == 4
@@ -4309,8 +4438,10 @@ status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
jcp.nb_ic = jcp.ic / jcp.ic_block;
jcp.src_fmt = src_d.format();
if ((mayiuse(avx512_mic_4ops) || mayiuse(avx512_core_vnni))
+ && mkldnn_thr_syncable()
&& ndims == 4
&& jcp.stride_w == 1
+ && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w)
&& ((src_d.data_type() == data_type::s16
&& diff_weights_d.data_type() == data_type::s32
&& diff_dst_d.data_type() == data_type::s16))) {
@@ -4321,7 +4452,9 @@ status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
src_d.data_type(), diff_weights_d.data_type(),
diff_dst_d.data_type())) {
jcp.ver = ver_fma;
- if (ndims == 4 && mayiuse(avx512_mic_4ops) && jcp.stride_w == 1) {
+ if (ndims == 4 && mayiuse(avx512_mic_4ops) && jcp.stride_w == 1 &&
+ everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) &&
+ mkldnn_thr_syncable()) {
jcp.ver = ver_4fma;
}
} else {