summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp')
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp48
1 files changed, 29 insertions, 19 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp
index 997a4a050..f31e072c1 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp
@@ -21,6 +21,7 @@
#include "cpu_convolution_pd.hpp"
#include "cpu_engine.hpp"
#include "scratchpad.hpp"
+#include "mkldnn_thread.hpp"
#include "jit_avx512_common_conv_winograd_kernel_f32.hpp"
@@ -74,65 +75,65 @@ struct winograd_scratchpad_t {
private:
inline void get_scratchpad_size_(const jit_conv_winograd_conf_t &jcp) {
- nthreads_ = omp_get_max_threads();
+ nthreads_ = mkldnn_get_max_threads();
- U_sz_ = alpha * alpha * jcp.ic * jcp.oc * sizeof(float);
- V_sz_ = alpha * alpha * jcp.mb * jcp.ic
+ U_sz_ = (size_t)alpha * alpha * jcp.ic * jcp.oc * sizeof(float);
+ V_sz_ = (size_t)alpha * alpha * jcp.mb * jcp.ic
* (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding)
* sizeof(float);
- M_sz_ = alpha * alpha * jcp.mb * jcp.oc
+ M_sz_ = (size_t)alpha * alpha * jcp.mb * jcp.oc
* (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding)
* sizeof(float);
switch (jcp.sched_policy) {
case WSCHED_DATA_W_SGD:
- V_sz_ = nthreads_ * alpha * alpha
+ V_sz_ = (size_t)nthreads_ * alpha * alpha
* jcp.nb_tile_block_ur * jcp.tile_block_ur
* jcp.ic * sizeof(float);
- M_sz_ = nthreads_* alpha * alpha
+ M_sz_ = (size_t)nthreads_* alpha * alpha
* jcp.nb_tile_block_ur * jcp.tile_block_ur
* jcp.oc * sizeof(float);
break;
case WSCHED_WEI_SDGt_W:
- U_sz_ = nthreads_ * U_sz_;
- V_sz_ = nthreads_ * alpha * alpha
+ U_sz_ = (size_t)nthreads_ * U_sz_;
+ V_sz_ = (size_t)nthreads_ * alpha * alpha
* (jcp.nb_tile_block_ur * jcp.tile_block_ur
+ jcp.tile_4fma_padding)
* jcp.ic * sizeof(float);
- M_sz_ = nthreads_ * alpha * alpha
+ M_sz_ = (size_t)nthreads_ * alpha * alpha
* (jcp.nb_tile_block_ur * jcp.tile_block_ur
+ jcp.tile_4fma_padding)
* jcp.oc * sizeof(float);
bias_sz_ = nthreads_ * jcp.oc * sizeof(float);
break;
case WSCHED_WEI_SDGtWo:
- U_sz_ = nthreads_ * alpha * alpha
+ U_sz_ = (size_t)nthreads_ * alpha * alpha
* jcp.oc_block * jcp.oc_simd_block * jcp.ic * sizeof(float);
- M_sz_ = nthreads_ * alpha * alpha
+ M_sz_ = (size_t)nthreads_ * alpha * alpha
* (jcp.nb_tile_block_ur * jcp.tile_block_ur
+ jcp.tile_4fma_padding)
* jcp.oc_simd_block * jcp.oc_block * sizeof(float);
bias_sz_ = nthreads_ * jcp.oc * sizeof(float);
break;
case WSCHED_WEI_S_D_Giot_W:
- U_sz_ = (nthreads_ + 1) * alpha * alpha
+ U_sz_ = (size_t)(nthreads_ + 1) * alpha * alpha
* jcp.ic * jcp.oc * sizeof(float);
- V_sz_ = alpha * alpha
+ V_sz_ = (size_t)alpha * alpha
* (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding)
* jcp.ic * jcp.mb * sizeof(float);
- M_sz_ = alpha * alpha
+ M_sz_ = (size_t)alpha * alpha
* (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding)
* jcp.oc * jcp.mb * sizeof(float);
bias_sz_ = nthreads_ * jcp.oc * sizeof(float);
src_transpose_sz_ = jcp.ver == ver_4fma
- ? (nthreads_ * alpha * alpha
+ ? ((size_t)nthreads_ * alpha * alpha
* jcp.tile_4fma
* jcp.ic_simd_block * sizeof(float))
: 0;
break;
case WSCHED_WEI_S_D_G_W:
src_transpose_sz_ = jcp.ver == ver_4fma
- ? (nthreads_ * alpha * alpha
+ ? ((size_t)nthreads_ * alpha * alpha
* jcp.tile_4fma
* jcp.ic_simd_block * sizeof(float))
: 0;
@@ -227,12 +228,17 @@ struct _jit_avx512_common_convolution_winograd_fwd_t
&& utils::one_of(this->cdesc_().prop_kind, forward_training,
forward_inference)
&& this->cdesc_().alg_kind == alg_kind::convolution_winograd
+ && !this->has_zero_dim_memory()
&& utils::everyone_is(data_type::f32,
this->cdesc_().src_desc.data_type,
this->cdesc_().weights_desc.data_type,
this->cdesc_().dst_desc.data_type)
&& utils::implication(this->with_bias(), data_type::f32
- == this->cdesc_().bias_desc.data_type);
+ == this->cdesc_().bias_desc.data_type)
+ && mkldnn_thr_syncable();
+
+ ok = ok && this->dst_pd_.desc()->format == memory_format::nChw16c &&
+ this->src_pd_.desc()->format == memory_format::nChw16c;
if (!ok)
return status::unimplemented;
@@ -321,10 +327,12 @@ struct jit_avx512_common_convolution_winograd_bwd_data_t
bool ok = true && this->set_default_params() == status::success
&& utils::one_of(this->desc()->prop_kind, backward_data)
&& this->desc()->alg_kind == alg_kind::convolution_winograd
+ && !this->has_zero_dim_memory()
&& utils::everyone_is(data_type::f32,
this->desc()->diff_src_desc.data_type,
this->desc()->weights_desc.data_type,
- this->desc()->diff_dst_desc.data_type);
+ this->desc()->diff_dst_desc.data_type)
+ && mkldnn_thr_syncable();
if (!ok)
return status::unimplemented;
@@ -413,10 +421,12 @@ struct jit_avx512_common_convolution_winograd_bwd_weights_t
bool ok = true && this->set_default_params() == status::success
&& utils::one_of(this->desc()->prop_kind, backward_weights)
&& this->desc()->alg_kind == alg_kind::convolution_winograd
+ && !this->has_zero_dim_memory()
&& utils::everyone_is(data_type::f32,
this->desc()->src_desc.data_type,
this->desc()->diff_dst_desc.data_type,
- this->desc()->diff_weights_desc.data_type);
+ this->desc()->diff_weights_desc.data_type)
+ && mkldnn_thr_syncable();
if (!ok)
return status::unimplemented;