diff options
author | Alexey Suhov <asuhov@users.noreply.github.com> | 2018-11-23 16:19:43 +0300 |
---|---|---|
committer | openvino-pushbot <44090433+openvino-pushbot@users.noreply.github.com> | 2018-11-23 16:19:43 +0300 |
commit | 55a41d7570f78aaea0d6764d157dd7434730d56f (patch) | |
tree | ba022c71609b93d51119bcb25e5ccb8c7147dbd3 /inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp | |
parent | 54eab180361ec09fbd82e2bb62adfeb521275774 (diff) | |
download | dldt-55a41d7570f78aaea0d6764d157dd7434730d56f.tar.gz dldt-55a41d7570f78aaea0d6764d157dd7434730d56f.tar.bz2 dldt-55a41d7570f78aaea0d6764d157dd7434730d56f.zip |
Publishing R4 (#41)
* Publishing R4
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp')
-rw-r--r-- | inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp | 48 |
1 files changed, 29 insertions, 19 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp index 997a4a050..f31e072c1 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp @@ -21,6 +21,7 @@ #include "cpu_convolution_pd.hpp" #include "cpu_engine.hpp" #include "scratchpad.hpp" +#include "mkldnn_thread.hpp" #include "jit_avx512_common_conv_winograd_kernel_f32.hpp" @@ -74,65 +75,65 @@ struct winograd_scratchpad_t { private: inline void get_scratchpad_size_(const jit_conv_winograd_conf_t &jcp) { - nthreads_ = omp_get_max_threads(); + nthreads_ = mkldnn_get_max_threads(); - U_sz_ = alpha * alpha * jcp.ic * jcp.oc * sizeof(float); - V_sz_ = alpha * alpha * jcp.mb * jcp.ic + U_sz_ = (size_t)alpha * alpha * jcp.ic * jcp.oc * sizeof(float); + V_sz_ = (size_t)alpha * alpha * jcp.mb * jcp.ic * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding) * sizeof(float); - M_sz_ = alpha * alpha * jcp.mb * jcp.oc + M_sz_ = (size_t)alpha * alpha * jcp.mb * jcp.oc * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding) * sizeof(float); switch (jcp.sched_policy) { case WSCHED_DATA_W_SGD: - V_sz_ = nthreads_ * alpha * alpha + V_sz_ = (size_t)nthreads_ * alpha * alpha * jcp.nb_tile_block_ur * jcp.tile_block_ur * jcp.ic * sizeof(float); - M_sz_ = nthreads_* alpha * alpha + M_sz_ = (size_t)nthreads_* alpha * alpha * jcp.nb_tile_block_ur * jcp.tile_block_ur * jcp.oc * sizeof(float); break; case WSCHED_WEI_SDGt_W: - U_sz_ = nthreads_ * U_sz_; - V_sz_ = nthreads_ * alpha * alpha + U_sz_ = (size_t)nthreads_ * U_sz_; + V_sz_ = (size_t)nthreads_ * alpha * alpha * (jcp.nb_tile_block_ur * jcp.tile_block_ur + jcp.tile_4fma_padding) * jcp.ic * sizeof(float); - M_sz_ = nthreads_ * alpha * alpha + M_sz_ = (size_t)nthreads_ * alpha * alpha * (jcp.nb_tile_block_ur * jcp.tile_block_ur + jcp.tile_4fma_padding) * jcp.oc * sizeof(float); bias_sz_ = nthreads_ * jcp.oc * sizeof(float); break; case WSCHED_WEI_SDGtWo: - U_sz_ = nthreads_ * alpha * alpha + U_sz_ = (size_t)nthreads_ * alpha * alpha * jcp.oc_block * jcp.oc_simd_block * jcp.ic * sizeof(float); - M_sz_ = nthreads_ * alpha * alpha + M_sz_ = (size_t)nthreads_ * alpha * alpha * (jcp.nb_tile_block_ur * jcp.tile_block_ur + jcp.tile_4fma_padding) * jcp.oc_simd_block * jcp.oc_block * sizeof(float); bias_sz_ = nthreads_ * jcp.oc * sizeof(float); break; case WSCHED_WEI_S_D_Giot_W: - U_sz_ = (nthreads_ + 1) * alpha * alpha + U_sz_ = (size_t)(nthreads_ + 1) * alpha * alpha * jcp.ic * jcp.oc * sizeof(float); - V_sz_ = alpha * alpha + V_sz_ = (size_t)alpha * alpha * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding) * jcp.ic * jcp.mb * sizeof(float); - M_sz_ = alpha * alpha + M_sz_ = (size_t)alpha * alpha * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding) * jcp.oc * jcp.mb * sizeof(float); bias_sz_ = nthreads_ * jcp.oc * sizeof(float); src_transpose_sz_ = jcp.ver == ver_4fma - ? (nthreads_ * alpha * alpha + ? ((size_t)nthreads_ * alpha * alpha * jcp.tile_4fma * jcp.ic_simd_block * sizeof(float)) : 0; break; case WSCHED_WEI_S_D_G_W: src_transpose_sz_ = jcp.ver == ver_4fma - ? (nthreads_ * alpha * alpha + ? ((size_t)nthreads_ * alpha * alpha * jcp.tile_4fma * jcp.ic_simd_block * sizeof(float)) : 0; @@ -227,12 +228,17 @@ struct _jit_avx512_common_convolution_winograd_fwd_t && utils::one_of(this->cdesc_().prop_kind, forward_training, forward_inference) && this->cdesc_().alg_kind == alg_kind::convolution_winograd + && !this->has_zero_dim_memory() && utils::everyone_is(data_type::f32, this->cdesc_().src_desc.data_type, this->cdesc_().weights_desc.data_type, this->cdesc_().dst_desc.data_type) && utils::implication(this->with_bias(), data_type::f32 - == this->cdesc_().bias_desc.data_type); + == this->cdesc_().bias_desc.data_type) + && mkldnn_thr_syncable(); + + ok = ok && this->dst_pd_.desc()->format == memory_format::nChw16c && + this->src_pd_.desc()->format == memory_format::nChw16c; if (!ok) return status::unimplemented; @@ -321,10 +327,12 @@ struct jit_avx512_common_convolution_winograd_bwd_data_t bool ok = true && this->set_default_params() == status::success && utils::one_of(this->desc()->prop_kind, backward_data) && this->desc()->alg_kind == alg_kind::convolution_winograd + && !this->has_zero_dim_memory() && utils::everyone_is(data_type::f32, this->desc()->diff_src_desc.data_type, this->desc()->weights_desc.data_type, - this->desc()->diff_dst_desc.data_type); + this->desc()->diff_dst_desc.data_type) + && mkldnn_thr_syncable(); if (!ok) return status::unimplemented; @@ -413,10 +421,12 @@ struct jit_avx512_common_convolution_winograd_bwd_weights_t bool ok = true && this->set_default_params() == status::success && utils::one_of(this->desc()->prop_kind, backward_weights) && this->desc()->alg_kind == alg_kind::convolution_winograd + && !this->has_zero_dim_memory() && utils::everyone_is(data_type::f32, this->desc()->src_desc.data_type, this->desc()->diff_dst_desc.data_type, - this->desc()->diff_weights_desc.data_type); + this->desc()->diff_weights_desc.data_type) + && mkldnn_thr_syncable(); if (!ok) return status::unimplemented; |