summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/src/cpu/nspc_batch_normalization.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/nspc_batch_normalization.cpp')
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/nspc_batch_normalization.cpp54
1 files changed, 31 insertions, 23 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/nspc_batch_normalization.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/nspc_batch_normalization.cpp
index 0c189bdd5..a138cb790 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/nspc_batch_normalization.cpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/nspc_batch_normalization.cpp
@@ -40,11 +40,11 @@ nspc_batch_normalization_fwd_t::nspc_batch_normalization_fwd_t(const pd_t *pd,
tmp_mean_(nullptr), tmp_variance_(nullptr), conf_(*pd) {
if (!conf_.stats_is_src()) {
this->stats_reduction_ = (data_t *)malloc(
- nstl::max(conf_.C(), 16) * omp_get_max_threads() * sizeof(data_t), 64);
- this->tmp_mean_ = (data_t *)malloc(omp_get_max_threads() *
+ nstl::max(conf_.C(), 16) * mkldnn_get_max_threads() * sizeof(data_t), 64);
+ this->tmp_mean_ = (data_t *)malloc(mkldnn_get_max_threads() *
nstl::max(conf_.C(), 16) * sizeof(data_t), 64);
this->tmp_variance_
- = (data_t *)malloc(omp_get_max_threads() *
+ = (data_t *)malloc(mkldnn_get_max_threads() *
nstl::max(conf_.C(), 16) * sizeof(data_t), 64);
}
}
@@ -88,16 +88,15 @@ void nspc_batch_normalization_fwd_t::execute_forward() {
const int N = conf_.MB();
const int C = conf_.C();
- int SP = conf_.H() * conf_.W() * conf_.D();
+ const int SP = conf_.H() * conf_.W() * conf_.D();
const float eps = conf_.desc()->batch_norm_epsilon;
const bool use_scaleshift = conf_.use_scaleshift();
- ;
auto maybe_post_op
= [&](data_t res) { return (with_relu && res < 0) ? 0 : res; };
-#pragma omp parallel
- {
- int nthr = omp_get_max_threads(), ithr = omp_get_thread_num();
+
+ assert(mkldnn_thr_syncable());
+ parallel(0, [&](const int ithr, const int nthr) {
int N_s = 0, N_e = 0, C_s = 0, C_e = 0;
balance211(N, nthr, ithr, N_s, N_e);
balance211(C, nthr, ithr, C_s, C_e);
@@ -115,14 +114,17 @@ void nspc_batch_normalization_fwd_t::execute_forward() {
ws_reduce[C * ithr + c] += src[(size_t)n * SP * C
+ sp * C + c];
-#pragma omp barrier
+ mkldnn_thr_barrier();
+
for (int c = C_s; c < C_e; c++) {
mean[c] = 0;
for (int n = 0; n < nthr; n++)
mean[c] += ws_reduce[C * n + c];
mean[c] /= SP * N;
}
-#pragma omp barrier
+
+ mkldnn_thr_barrier();
+
for (int c = 0; c < C; c++) {
mean_loc[c] = mean[c];
ws_reduce[C * ithr + c] = 0.;
@@ -136,14 +138,18 @@ void nspc_batch_normalization_fwd_t::execute_forward() {
- mean_loc[c];
ws_reduce[C * ithr + c] += m * m;
}
-#pragma omp barrier
+
+ mkldnn_thr_barrier();
+
for (int c = C_s; c < C_e; c++) {
variance[c] = 0;
for (int n = 0; n < nthr; n++)
variance[c] += ws_reduce[C * n + c];
variance[c] /= SP * N;
}
-#pragma omp barrier
+
+ mkldnn_thr_barrier();
+
for (int c = 0; c < C; c++)
variance_loc[c] = variance[c];
} else {
@@ -178,16 +184,16 @@ void nspc_batch_normalization_fwd_t::execute_forward() {
}
}
}
- }
+ });
}
nspc_batch_normalization_bwd_t::nspc_batch_normalization_bwd_t(const pd_t *pd,
const input_vector &inputs, const output_vector &outputs)
: cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {
this->stats_reduction_ = (data_t *)malloc(
- conf_.C() * 2 * omp_get_max_threads() * sizeof(data_t), 64);
+ conf_.C() * 2 * mkldnn_get_max_threads() * sizeof(data_t), 64);
this->tmp_diff_scaleshift_
- = (data_t *)malloc((omp_get_max_threads() + 1) * conf_.C() * 2 *
+ = (data_t *)malloc((mkldnn_get_max_threads() + 1) * conf_.C() * 2 *
sizeof(data_t), 64);
}
nspc_batch_normalization_bwd_t::~nspc_batch_normalization_bwd_t() {
@@ -212,8 +218,7 @@ void nspc_batch_normalization_bwd_t::execute_backward() {
const int N = conf_.MB();
const int C = conf_.C();
- int SP = conf_.D() * conf_.H() * conf_.W();
- int nthr = omp_get_max_threads();
+ const int SP = conf_.D() * conf_.H() * conf_.W();
data_t *diff_gamma = diff_scaleshift, *diff_beta = diff_scaleshift + C;
data_t *ws_reduce = this->stats_reduction_;
@@ -221,9 +226,9 @@ void nspc_batch_normalization_bwd_t::execute_backward() {
const bool use_scaleshift = conf_.use_scaleshift();
const bool calculate_diff_stats = !conf_.omit_stats();
const bool fuse_bn_relu = conf_.fuse_bn_relu();
-#pragma omp parallel
- {
- int ithr = omp_get_thread_num();
+
+ assert(mkldnn_thr_syncable());
+ parallel(0, [&](const int ithr, const int nthr) {
int N_s = 0, N_e = 0, C_s = 0, C_e = 0;
balance211(N, nthr, ithr, N_s, N_e);
balance211(C, nthr, ithr, C_s, C_e);
@@ -253,7 +258,8 @@ void nspc_batch_normalization_bwd_t::execute_backward() {
ws_reduce[C * nthr + C * ithr + c] += dd;
}
-#pragma omp barrier
+ mkldnn_thr_barrier();
+
for (int c = C_s; c < C_e; c++) {
data_t sqrt_variance
= static_cast<data_t>(1.0f / sqrtf(variance[c] + eps));
@@ -265,7 +271,9 @@ void nspc_batch_normalization_bwd_t::execute_backward() {
}
diff_gamma[c] *= sqrt_variance;
}
-#pragma omp barrier
+
+ mkldnn_thr_barrier();
+
for (int c = 0; c < C; c++) {
diff_gamma_loc[c] = diff_gamma[c];
diff_beta_loc[c] = diff_beta[c];
@@ -296,7 +304,7 @@ void nspc_batch_normalization_bwd_t::execute_backward() {
}
}
}
- }
+ });
}
}
}