diff options
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.hpp')
-rw-r--r-- | inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.hpp | 27 |
1 files changed, 19 insertions, 8 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.hpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.hpp index fe8eedeee..008602a50 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.hpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_conv_kernel.hpp @@ -62,6 +62,11 @@ private: typesize = sizeof(float), ker_reg_base_idx = 28, }; + enum { + no_last_block, + last_ic_block, + last_sp_block, + }; reg64_t reg_inp = r8; reg64_t reg_ker = r9; @@ -69,17 +74,19 @@ private: reg64_t aux_reg_inp = r11; reg64_t reg_ptr_sum_scale = r11; reg64_t aux_reg_ker = r12; - reg64_t reg_acc_s32 = r13; reg64_t reg_scratch = r14; - reg64_t reg_kj = rax; + reg64_t reg_kj = rax; reg64_t reg_ptr_scales = rax; - reg64_t reg_oi = rbx; + reg64_t reg_oi = rbx; reg64_t reg_bias = rdx; - reg64_t reg_kh = abi_not_param1; - reg64_t param = abi_param1; - reg64_t reg_channel = r15; + reg64_t reg_kh = abi_not_param1; + reg64_t param = abi_param1; reg64_t reg_tmp = rbp; reg64_t imm_addr64 = r15; + reg64_t reg_oc_blocks = rsi; + reg64_t reg_icb = reg_bias; + + Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); zmm_t zmm_tmp = zmm_t(28); zmm_t zmm_one = zmm_t(29); @@ -115,9 +122,13 @@ private: } bool maybe_relu(int position); void prepare_output(int ur_w); - void store_output(int ur_w); - void compute_loop(int ur_w, int pad_l, int pad_r); + void store_output(int ur_w, int last_oc_block_flag); + void compute_ker(int ur_w, int pad_l, int pad_r, int last_ic_block_flag); + void compute_loop( + int ur_w, int pad_l, int pad_r, bool is_last_spatial_block); void generate(); + void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op, + bool mask_flag); }; } |