summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.hpp')
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.hpp27
1 files changed, 19 insertions, 8 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.hpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.hpp
index fe8eedeee..008602a50 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.hpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.hpp
@@ -62,6 +62,11 @@ private:
typesize = sizeof(float),
ker_reg_base_idx = 28,
};
+ enum {
+ no_last_block,
+ last_ic_block,
+ last_sp_block,
+ };
reg64_t reg_inp = r8;
reg64_t reg_ker = r9;
@@ -69,17 +74,19 @@ private:
reg64_t aux_reg_inp = r11;
reg64_t reg_ptr_sum_scale = r11;
reg64_t aux_reg_ker = r12;
- reg64_t reg_acc_s32 = r13;
reg64_t reg_scratch = r14;
- reg64_t reg_kj = rax;
+ reg64_t reg_kj = rax;
reg64_t reg_ptr_scales = rax;
- reg64_t reg_oi = rbx;
+ reg64_t reg_oi = rbx;
reg64_t reg_bias = rdx;
- reg64_t reg_kh = abi_not_param1;
- reg64_t param = abi_param1;
- reg64_t reg_channel = r15;
+ reg64_t reg_kh = abi_not_param1;
+ reg64_t param = abi_param1;
reg64_t reg_tmp = rbp;
reg64_t imm_addr64 = r15;
+ reg64_t reg_oc_blocks = rsi;
+ reg64_t reg_icb = reg_bias;
+
+ Xbyak::Opmask ktail_mask = Xbyak::Opmask(2);
zmm_t zmm_tmp = zmm_t(28);
zmm_t zmm_one = zmm_t(29);
@@ -115,9 +122,13 @@ private:
}
bool maybe_relu(int position);
void prepare_output(int ur_w);
- void store_output(int ur_w);
- void compute_loop(int ur_w, int pad_l, int pad_r);
+ void store_output(int ur_w, int last_oc_block_flag);
+ void compute_ker(int ur_w, int pad_l, int pad_r, int last_ic_block_flag);
+ void compute_loop(
+ int ur_w, int pad_l, int pad_r, bool is_last_spatial_block);
void generate();
+ void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op,
+ bool mask_flag);
};
}