summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp')
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp133
1 files changed, 75 insertions, 58 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp
index 389245176..7178e1a0c 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp
@@ -176,21 +176,49 @@ void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter_unrolled(
}
template <cpu_isa_t isa>
-void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_activation(int ur_ch_blocks, int ur_w) {
- if (this->jcp.with_eltwise) {
- inject(eltwise_generator.prepareConstants(jcp.eltwise_alpha, jcp.eltwise_beta));
+void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_postprocess(int ur_ch_blocks, int ur_w) {
+ int repeats = isa == sse42 ? 2 : 1;
+
+ int eltwise_inj_idx = 0;
+ int depthwise_inj_idx = 0;
+ const auto &p = attr_.post_ops_;
+
+ if (p.len_ == 0 && eltwise_injectors.size() == 1) {
+ int start_idx = get_acc_reg(0).getIdx();
+ int end_idx = get_acc_reg(repeats * ur_w * ur_ch_blocks).getIdx();
+
+ eltwise_injectors[0]->compute_vector_range(start_idx, end_idx);
+ }
+
+ for (int i = 0; i < p.len_; i++) {
+ auto& post_op = p.entry_[i];
+ if (post_op.is_eltwise()) {
+ int start_idx = get_acc_reg(0).getIdx();
+ int end_idx = get_acc_reg(repeats * ur_w * ur_ch_blocks).getIdx();
+
+ eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, end_idx);
+ eltwise_inj_idx++;
+ } else if (post_op.is_depthwise()) {
+ mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
+ mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
+
+ add(reg_d_weights, ptr[this->param1 + GET_OFF(oc_off)]);
+ add(reg_d_bias, ptr[this->param1 + GET_OFF(oc_off)]);
- // TODO (dmitrygo): need to find appropriate way to share labels.
- mov(imm_addr64, l_table);
- int repeats = isa == sse42 ? 2 : 1;
- for (int i = 0; i < repeats; i++) {
for (int ch = 0; ch < ur_ch_blocks; ch++) {
- for (int ow = 0; ow < ur_w; ow++) {
- Vmm vmm_dst = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
+ for (int k = 0; k < repeats; k++) {
+ int start_idx = get_acc_reg(k*ur_ch_blocks*ur_w + ur_w * ch).getIdx();
+ int end_idx = get_acc_reg(k*ur_ch_blocks*ur_w + ur_w * ch + ur_w).getIdx();
- inject(eltwise_generator.computeVector(vmm_dst, vmm_dst));
+ depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
+ start_idx, end_idx, reg_d_weights, reg_d_bias);
+
+ add(reg_d_weights, jcp.ch_block / repeats * sizeof(float));
+ add(reg_d_bias, jcp.ch_block / repeats * sizeof(float));
}
}
+
+ depthwise_inj_idx++;
}
}
}
@@ -230,7 +258,7 @@ void jit_uni_dw_conv_fwd_kernel_f32<isa>::loop_body(int ur_ch_blocks) {
load_src(ur_ch_blocks, ur_w);
apply_filter_unrolled(ur_ch_blocks, ur_w);
- apply_activation(ur_ch_blocks, ur_w);
+ apply_postprocess(ur_ch_blocks, ur_w);
store_dst(ur_ch_blocks, ur_w);
add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
@@ -251,7 +279,7 @@ void jit_uni_dw_conv_fwd_kernel_f32<isa>::loop_body(int ur_ch_blocks) {
load_src(ur_ch_blocks, ur_w);
apply_filter(ur_ch_blocks, ur_w);
- apply_activation(ur_ch_blocks, ur_w);
+ apply_postprocess(ur_ch_blocks, ur_w);
store_dst(ur_ch_blocks, ur_w);
add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
@@ -267,18 +295,29 @@ void jit_uni_dw_conv_fwd_kernel_f32<isa>::loop_body(int ur_ch_blocks) {
template <cpu_isa_t isa>
void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate()
{
- nstl::vector<int> shared_vecs;
- shared_vecs.push_back(0);
- shared_vecs.push_back(1);
- shared_vecs.push_back(2);
- shared_vecs.push_back(3);
- if (isa == avx512_common)
- shared_vecs.push_back(31);
-
- nstl::vector<Reg64> shared_regs;
- shared_regs.push_back(imm_addr64);
+ if (jcp.with_eltwise) {
+ eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
+ this, jcp.eltwise_alg, jcp.eltwise_alpha, 0
+ ));
+ }
- eltwise_generator.init(jcp.eltwise_alg, shared_vecs, shared_regs);
+ const auto &p = attr_.post_ops_;
+ for (int i = 0; i < p.len_; i++) {
+ auto &post_op = p.entry_[i];
+ if (post_op.is_eltwise()) {
+ eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
+ this,
+ post_op.eltwise.alg,
+ post_op.eltwise.alpha,
+ post_op.eltwise.beta
+ ));
+ } else if (post_op.is_depthwise()) {
+ depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
+ this,
+ post_op.depthwise.alg
+ ));
+ }
+ }
this->preamble();
@@ -315,11 +354,8 @@ void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate()
this->postamble();
- // TODO (dmitrygo): need to find appropriate way to share labels.
- align(64);
- L(l_table);
- inject(eltwise_generator.prepareTable());
- eltwise_generator.release();
+ for (auto& inj : eltwise_injectors)
+ inj->prepare_table();
}
template <cpu_isa_t isa>
@@ -327,13 +363,20 @@ bool jit_uni_dw_conv_fwd_kernel_f32<isa>::post_ops_ok(
jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
const auto &p = attr.post_ops_;
+ auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
+ auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
switch (p.len_) {
case 0: return true; // no post_ops
- case 1: return !jcp.with_eltwise && (is_eltwise(0) || is_sum(0)); // sum OR relu
- case 2: return !jcp.with_eltwise && (is_sum(0) && is_eltwise(1)); // sum->relu
+ case 1: return true // sum OR eltwise OR deptwise
+ && !jcp.with_eltwise && (is_simple(0) || is_sum(0));
+ case 2: return true // sum->relu OR sum->depthwise OR eltwise->depthwise OR depthwise->depthwise
+ && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) ||
+ (is_simple(0) && is_simple(1)));
+ case 3: return true // sum->eltwise->depthwise OR sum->depthwise->eltwise OR sum->depthwise->depthwise
+ && !jcp.with_eltwise && ((is_sum(0) && is_simple(1) && is_simple(2)));
default: return false;
}
@@ -392,21 +435,11 @@ status_t jit_uni_dw_conv_fwd_kernel_f32<isa>::init_conf(jit_conv_conf_t &jcp,
const auto &p = attr.post_ops_;
jcp.with_sum = p.find(primitive_kind::sum) != -1;
- if (!jcp.with_eltwise) {
- int eltwise_ind = p.find(primitive_kind::eltwise);
- if (eltwise_ind != -1) {
- jcp.with_eltwise = true;
- jcp.eltwise_alg = p.entry_[eltwise_ind].eltwise.alg;
- jcp.eltwise_alpha = p.entry_[eltwise_ind].eltwise.alpha;
- jcp.eltwise_beta = p.entry_[eltwise_ind].eltwise.beta;
- jcp.eltwise_scale = p.entry_[eltwise_ind].eltwise.scale;
- }
- }
bool ok_to_pad_channels = true
&& jcp.oc == jcp.ngroups
&& jcp.ic == jcp.ngroups
- && isa == avx512_common;
+ && one_of(isa, avx512_common, avx2, sse42);
if (ok_to_pad_channels) {
jcp.oc = rnd_up(jcp.oc, simd_w);
jcp.ic = rnd_up(jcp.oc, simd_w);
@@ -437,22 +470,6 @@ status_t jit_uni_dw_conv_fwd_kernel_f32<isa>::init_conf(jit_conv_conf_t &jcp,
if (jcp.nb_ch < jcp.nb_ch_blocking)
jcp.nb_ch_blocking = jcp.nb_ch;
- if (jcp.with_eltwise) {
- int nvecs_elt = jit_uni_eltwise_vector_f32<isa>::sharedVecsCount(jcp.eltwise_alg);
- int nvecs_conv = isa == avx512_common ? 32 - nvecs_elt : 16 - nvecs_elt;
- int isa_mult = isa == sse42 ? 2 : 1;
- while (isa_mult * jcp.ur_w * jcp.nb_ch_blocking > nvecs_conv) {
- if (jcp.nb_ch_blocking <= 1) {
- break;
- }
-
- jcp.nb_ch_blocking -= 1;
- }
-
- if (isa_mult * jcp.ur_w * jcp.nb_ch_blocking > nvecs_conv)
- return status::unimplemented;
- }
-
return status::success;
}
@@ -700,7 +717,7 @@ status_t jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_conf(
bool ok_to_pad_channels = true
&& jcp.oc == jcp.ngroups
&& jcp.ic == jcp.ngroups
- && isa == avx512_common;
+ && one_of(isa, avx512_common, avx2);
if (ok_to_pad_channels) {
jcp.oc = rnd_up(jcp.oc, simd_w);
jcp.ic = rnd_up(jcp.oc, simd_w);