diff options
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/nspc_batch_normalization.cpp')
-rw-r--r-- | inference-engine/thirdparty/mkl-dnn/src/cpu/nspc_batch_normalization.cpp | 54 |
1 files changed, 31 insertions, 23 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/nspc_batch_normalization.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/nspc_batch_normalization.cpp index 0c189bdd5..a138cb790 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/nspc_batch_normalization.cpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/nspc_batch_normalization.cpp @@ -40,11 +40,11 @@ nspc_batch_normalization_fwd_t::nspc_batch_normalization_fwd_t(const pd_t *pd, tmp_mean_(nullptr), tmp_variance_(nullptr), conf_(*pd) { if (!conf_.stats_is_src()) { this->stats_reduction_ = (data_t *)malloc( - nstl::max(conf_.C(), 16) * omp_get_max_threads() * sizeof(data_t), 64); - this->tmp_mean_ = (data_t *)malloc(omp_get_max_threads() * + nstl::max(conf_.C(), 16) * mkldnn_get_max_threads() * sizeof(data_t), 64); + this->tmp_mean_ = (data_t *)malloc(mkldnn_get_max_threads() * nstl::max(conf_.C(), 16) * sizeof(data_t), 64); this->tmp_variance_ - = (data_t *)malloc(omp_get_max_threads() * + = (data_t *)malloc(mkldnn_get_max_threads() * nstl::max(conf_.C(), 16) * sizeof(data_t), 64); } } @@ -88,16 +88,15 @@ void nspc_batch_normalization_fwd_t::execute_forward() { const int N = conf_.MB(); const int C = conf_.C(); - int SP = conf_.H() * conf_.W() * conf_.D(); + const int SP = conf_.H() * conf_.W() * conf_.D(); const float eps = conf_.desc()->batch_norm_epsilon; const bool use_scaleshift = conf_.use_scaleshift(); - ; auto maybe_post_op = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; }; -#pragma omp parallel - { - int nthr = omp_get_max_threads(), ithr = omp_get_thread_num(); + + assert(mkldnn_thr_syncable()); + parallel(0, [&](const int ithr, const int nthr) { int N_s = 0, N_e = 0, C_s = 0, C_e = 0; balance211(N, nthr, ithr, N_s, N_e); balance211(C, nthr, ithr, C_s, C_e); @@ -115,14 +114,17 @@ void nspc_batch_normalization_fwd_t::execute_forward() { ws_reduce[C * ithr + c] += src[(size_t)n * SP * C + sp * C + c]; -#pragma omp barrier + mkldnn_thr_barrier(); + for (int c = C_s; c < C_e; c++) { mean[c] = 0; for (int n = 0; n < nthr; n++) mean[c] += ws_reduce[C * n + c]; mean[c] /= SP * N; } -#pragma omp barrier + + mkldnn_thr_barrier(); + for (int c = 0; c < C; c++) { mean_loc[c] = mean[c]; ws_reduce[C * ithr + c] = 0.; @@ -136,14 +138,18 @@ void nspc_batch_normalization_fwd_t::execute_forward() { - mean_loc[c]; ws_reduce[C * ithr + c] += m * m; } -#pragma omp barrier + + mkldnn_thr_barrier(); + for (int c = C_s; c < C_e; c++) { variance[c] = 0; for (int n = 0; n < nthr; n++) variance[c] += ws_reduce[C * n + c]; variance[c] /= SP * N; } -#pragma omp barrier + + mkldnn_thr_barrier(); + for (int c = 0; c < C; c++) variance_loc[c] = variance[c]; } else { @@ -178,16 +184,16 @@ void nspc_batch_normalization_fwd_t::execute_forward() { } } } - } + }); } nspc_batch_normalization_bwd_t::nspc_batch_normalization_bwd_t(const pd_t *pd, const input_vector &inputs, const output_vector &outputs) : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) { this->stats_reduction_ = (data_t *)malloc( - conf_.C() * 2 * omp_get_max_threads() * sizeof(data_t), 64); + conf_.C() * 2 * mkldnn_get_max_threads() * sizeof(data_t), 64); this->tmp_diff_scaleshift_ - = (data_t *)malloc((omp_get_max_threads() + 1) * conf_.C() * 2 * + = (data_t *)malloc((mkldnn_get_max_threads() + 1) * conf_.C() * 2 * sizeof(data_t), 64); } nspc_batch_normalization_bwd_t::~nspc_batch_normalization_bwd_t() { @@ -212,8 +218,7 @@ void nspc_batch_normalization_bwd_t::execute_backward() { const int N = conf_.MB(); const int C = conf_.C(); - int SP = conf_.D() * conf_.H() * conf_.W(); - int nthr = omp_get_max_threads(); + const int SP = conf_.D() * conf_.H() * conf_.W(); data_t *diff_gamma = diff_scaleshift, *diff_beta = diff_scaleshift + C; data_t *ws_reduce = this->stats_reduction_; @@ -221,9 +226,9 @@ void nspc_batch_normalization_bwd_t::execute_backward() { const bool use_scaleshift = conf_.use_scaleshift(); const bool calculate_diff_stats = !conf_.omit_stats(); const bool fuse_bn_relu = conf_.fuse_bn_relu(); -#pragma omp parallel - { - int ithr = omp_get_thread_num(); + + assert(mkldnn_thr_syncable()); + parallel(0, [&](const int ithr, const int nthr) { int N_s = 0, N_e = 0, C_s = 0, C_e = 0; balance211(N, nthr, ithr, N_s, N_e); balance211(C, nthr, ithr, C_s, C_e); @@ -253,7 +258,8 @@ void nspc_batch_normalization_bwd_t::execute_backward() { ws_reduce[C * nthr + C * ithr + c] += dd; } -#pragma omp barrier + mkldnn_thr_barrier(); + for (int c = C_s; c < C_e; c++) { data_t sqrt_variance = static_cast<data_t>(1.0f / sqrtf(variance[c] + eps)); @@ -265,7 +271,9 @@ void nspc_batch_normalization_bwd_t::execute_backward() { } diff_gamma[c] *= sqrt_variance; } -#pragma omp barrier + + mkldnn_thr_barrier(); + for (int c = 0; c < C; c++) { diff_gamma_loc[c] = diff_gamma[c]; diff_beta_loc[c] = diff_beta[c]; @@ -296,7 +304,7 @@ void nspc_batch_normalization_bwd_t::execute_backward() { } } } - } + }); } } } |