diff options
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp')
-rw-r--r-- | inference-engine/thirdparty/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp | 133 |
1 files changed, 75 insertions, 58 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp index 389245176..7178e1a0c 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp @@ -176,21 +176,49 @@ void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter_unrolled( } template <cpu_isa_t isa> -void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_activation(int ur_ch_blocks, int ur_w) { - if (this->jcp.with_eltwise) { - inject(eltwise_generator.prepareConstants(jcp.eltwise_alpha, jcp.eltwise_beta)); +void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_postprocess(int ur_ch_blocks, int ur_w) { + int repeats = isa == sse42 ? 2 : 1; + + int eltwise_inj_idx = 0; + int depthwise_inj_idx = 0; + const auto &p = attr_.post_ops_; + + if (p.len_ == 0 && eltwise_injectors.size() == 1) { + int start_idx = get_acc_reg(0).getIdx(); + int end_idx = get_acc_reg(repeats * ur_w * ur_ch_blocks).getIdx(); + + eltwise_injectors[0]->compute_vector_range(start_idx, end_idx); + } + + for (int i = 0; i < p.len_; i++) { + auto& post_op = p.entry_[i]; + if (post_op.is_eltwise()) { + int start_idx = get_acc_reg(0).getIdx(); + int end_idx = get_acc_reg(repeats * ur_w * ur_ch_blocks).getIdx(); + + eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, end_idx); + eltwise_inj_idx++; + } else if (post_op.is_depthwise()) { + mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data)); + mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data)); + + add(reg_d_weights, ptr[this->param1 + GET_OFF(oc_off)]); + add(reg_d_bias, ptr[this->param1 + GET_OFF(oc_off)]); - // TODO (dmitrygo): need to find appropriate way to share labels. - mov(imm_addr64, l_table); - int repeats = isa == sse42 ? 2 : 1; - for (int i = 0; i < repeats; i++) { for (int ch = 0; ch < ur_ch_blocks; ch++) { - for (int ow = 0; ow < ur_w; ow++) { - Vmm vmm_dst = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow); + for (int k = 0; k < repeats; k++) { + int start_idx = get_acc_reg(k*ur_ch_blocks*ur_w + ur_w * ch).getIdx(); + int end_idx = get_acc_reg(k*ur_ch_blocks*ur_w + ur_w * ch + ur_w).getIdx(); - inject(eltwise_generator.computeVector(vmm_dst, vmm_dst)); + depthwise_injectors[depthwise_inj_idx]->compute_vector_range( + start_idx, end_idx, reg_d_weights, reg_d_bias); + + add(reg_d_weights, jcp.ch_block / repeats * sizeof(float)); + add(reg_d_bias, jcp.ch_block / repeats * sizeof(float)); } } + + depthwise_inj_idx++; } } } @@ -230,7 +258,7 @@ void jit_uni_dw_conv_fwd_kernel_f32<isa>::loop_body(int ur_ch_blocks) { load_src(ur_ch_blocks, ur_w); apply_filter_unrolled(ur_ch_blocks, ur_w); - apply_activation(ur_ch_blocks, ur_w); + apply_postprocess(ur_ch_blocks, ur_w); store_dst(ur_ch_blocks, ur_w); add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); @@ -251,7 +279,7 @@ void jit_uni_dw_conv_fwd_kernel_f32<isa>::loop_body(int ur_ch_blocks) { load_src(ur_ch_blocks, ur_w); apply_filter(ur_ch_blocks, ur_w); - apply_activation(ur_ch_blocks, ur_w); + apply_postprocess(ur_ch_blocks, ur_w); store_dst(ur_ch_blocks, ur_w); add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); @@ -267,18 +295,29 @@ void jit_uni_dw_conv_fwd_kernel_f32<isa>::loop_body(int ur_ch_blocks) { template <cpu_isa_t isa> void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate() { - nstl::vector<int> shared_vecs; - shared_vecs.push_back(0); - shared_vecs.push_back(1); - shared_vecs.push_back(2); - shared_vecs.push_back(3); - if (isa == avx512_common) - shared_vecs.push_back(31); - - nstl::vector<Reg64> shared_regs; - shared_regs.push_back(imm_addr64); + if (jcp.with_eltwise) { + eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>( + this, jcp.eltwise_alg, jcp.eltwise_alpha, 0 + )); + } - eltwise_generator.init(jcp.eltwise_alg, shared_vecs, shared_regs); + const auto &p = attr_.post_ops_; + for (int i = 0; i < p.len_; i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_eltwise()) { + eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>( + this, + post_op.eltwise.alg, + post_op.eltwise.alpha, + post_op.eltwise.beta + )); + } else if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>( + this, + post_op.depthwise.alg + )); + } + } this->preamble(); @@ -315,11 +354,8 @@ void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate() this->postamble(); - // TODO (dmitrygo): need to find appropriate way to share labels. - align(64); - L(l_table); - inject(eltwise_generator.prepareTable()); - eltwise_generator.release(); + for (auto& inj : eltwise_injectors) + inj->prepare_table(); } template <cpu_isa_t isa> @@ -327,13 +363,20 @@ bool jit_uni_dw_conv_fwd_kernel_f32<isa>::post_ops_ok( jit_conv_conf_t &jcp, const primitive_attr_t &attr) { const auto &p = attr.post_ops_; + auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); }; auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); }; switch (p.len_) { case 0: return true; // no post_ops - case 1: return !jcp.with_eltwise && (is_eltwise(0) || is_sum(0)); // sum OR relu - case 2: return !jcp.with_eltwise && (is_sum(0) && is_eltwise(1)); // sum->relu + case 1: return true // sum OR eltwise OR deptwise + && !jcp.with_eltwise && (is_simple(0) || is_sum(0)); + case 2: return true // sum->relu OR sum->depthwise OR eltwise->depthwise OR depthwise->depthwise + && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) || + (is_simple(0) && is_simple(1))); + case 3: return true // sum->eltwise->depthwise OR sum->depthwise->eltwise OR sum->depthwise->depthwise + && !jcp.with_eltwise && ((is_sum(0) && is_simple(1) && is_simple(2))); default: return false; } @@ -392,21 +435,11 @@ status_t jit_uni_dw_conv_fwd_kernel_f32<isa>::init_conf(jit_conv_conf_t &jcp, const auto &p = attr.post_ops_; jcp.with_sum = p.find(primitive_kind::sum) != -1; - if (!jcp.with_eltwise) { - int eltwise_ind = p.find(primitive_kind::eltwise); - if (eltwise_ind != -1) { - jcp.with_eltwise = true; - jcp.eltwise_alg = p.entry_[eltwise_ind].eltwise.alg; - jcp.eltwise_alpha = p.entry_[eltwise_ind].eltwise.alpha; - jcp.eltwise_beta = p.entry_[eltwise_ind].eltwise.beta; - jcp.eltwise_scale = p.entry_[eltwise_ind].eltwise.scale; - } - } bool ok_to_pad_channels = true && jcp.oc == jcp.ngroups && jcp.ic == jcp.ngroups - && isa == avx512_common; + && one_of(isa, avx512_common, avx2, sse42); if (ok_to_pad_channels) { jcp.oc = rnd_up(jcp.oc, simd_w); jcp.ic = rnd_up(jcp.oc, simd_w); @@ -437,22 +470,6 @@ status_t jit_uni_dw_conv_fwd_kernel_f32<isa>::init_conf(jit_conv_conf_t &jcp, if (jcp.nb_ch < jcp.nb_ch_blocking) jcp.nb_ch_blocking = jcp.nb_ch; - if (jcp.with_eltwise) { - int nvecs_elt = jit_uni_eltwise_vector_f32<isa>::sharedVecsCount(jcp.eltwise_alg); - int nvecs_conv = isa == avx512_common ? 32 - nvecs_elt : 16 - nvecs_elt; - int isa_mult = isa == sse42 ? 2 : 1; - while (isa_mult * jcp.ur_w * jcp.nb_ch_blocking > nvecs_conv) { - if (jcp.nb_ch_blocking <= 1) { - break; - } - - jcp.nb_ch_blocking -= 1; - } - - if (isa_mult * jcp.ur_w * jcp.nb_ch_blocking > nvecs_conv) - return status::unimplemented; - } - return status::success; } @@ -700,7 +717,7 @@ status_t jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_conf( bool ok_to_pad_channels = true && jcp.oc == jcp.ngroups && jcp.ic == jcp.ngroups - && isa == avx512_common; + && one_of(isa, avx512_common, avx2); if (ok_to_pad_channels) { jcp.oc = rnd_up(jcp.oc, simd_w); jcp.ic = rnd_up(jcp.oc, simd_w); |