summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/src/cpu/ref_rnn.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/ref_rnn.cpp')
-rw-r--r--inference-engine/thirdparty/mkl-dnn/src/cpu/ref_rnn.cpp337
1 files changed, 151 insertions, 186 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/ref_rnn.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/ref_rnn.cpp
index a3eeb6cd1..c1f0612da 100644
--- a/inference-engine/thirdparty/mkl-dnn/src/cpu/ref_rnn.cpp
+++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/ref_rnn.cpp
@@ -33,6 +33,7 @@
#include "mkldnn_thread.hpp"
#include "mkldnn_traits.hpp"
#include "type_helpers.hpp"
+#include "gemm/gemm.hpp"
#include "ref_rnn.hpp"
@@ -79,14 +80,13 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::forward>::rnn_elemwise) {
AOC<float, 3> ws_gates(ws_gates_, batch, n_gates, dic);
AOC<const float, 2> bias(bias_, n_gates, dic);
AOC<float, 3> states_t_l(states_t_l_, n_states, batch, wic);
-#pragma omp parallel for
- for (int i = 0; i < batch; i++) {
+ parallel_nd(batch, [&](int i) {
for (int j = 0; j < dic; j++) {
const float h
= activation_func(0, ws_gates(i, 0, j) + bias(0, j), 0, 0);
ws_gates(i, 0, j) = states_t_l(0, i, j) = h;
}
- }
+ });
}
template <>
@@ -96,15 +96,14 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::backward>::rnn_elemwise) {
diff_states_tp1_l_, n_states + 1, batch, wic);
AOC<float, 3> diff_states_t_lp1(
diff_states_t_lp1_, n_states + 1, batch, wic);
-#pragma omp parallel for
- for (int i = 0; i < batch; ++i) {
+ parallel_nd(batch, [&](int i) {
for (int j = 0; j < dic; ++j) {
const float dH = diff_states_t_lp1(n_states, i, j)
+ diff_states_tp1_l(0, i, j);
auto g = ws_gates(i, 0, j);
ws_gates(i, 0, j) = activation_func(dH, g, 0, 0);
}
- }
+ });
}
template <>
@@ -114,9 +113,11 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::forward>::lstm_elemwise) {
AOC<float, 3> states_t_l(states_t_l_, n_states, batch, wic);
AOC<float, 3> states_tm1_l(states_tm1_l_, n_states, batch, wic);
-#pragma omp parallel for
- for (int i = 0; i < batch; i++) {
+ parallel_nd(batch, [&](int i) {
+// WA. Loss of correctnes in case of simd loop unrolling with icc 18
+#if !defined(__INTEL_COMPILER)
PRAGMA_OMP_SIMD()
+#endif
for (int j = 0; j < dic; j++) {
ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j));
ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j));
@@ -128,7 +129,7 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::forward>::lstm_elemwise) {
states_t_l(0, i, j) = ws_gates(i, 2, j) * tanh_fwd(tmp);
states_t_l(1, i, j) = tmp;
}
- }
+ });
}
template <>
@@ -145,8 +146,7 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::backward>::lstm_elemwise) {
auto one_m_square = [](float a) -> float { return 1.0f - a * a; };
-#pragma omp parallel for
- for (int i = 0; i < batch; i++) {
+ parallel_nd(batch, [&](int i) {
PRAGMA_OMP_SIMD()
for (int j = 0; j < dic; j++) {
float Ct = states_t_l(1, i, j);
@@ -173,7 +173,7 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::backward>::lstm_elemwise) {
ws_gates(i, 2, j) = dG2;
ws_gates(i, 3, j) = dG3;
}
- }
+ });
}
template <prop_kind_t aprop>
@@ -197,25 +197,31 @@ gemm_sig(_ref_rnn_common_t<aprop>::packed_gemm) {
template <prop_kind_t aprop>
gemm_sig(_ref_rnn_common_t<aprop>::gemm) {
- cblas_sgemm(CblasColMajor, CblasNoTrans,
- is_B_trans ? CblasTrans : CblasNoTrans, m, n, k, 1.0f, a_,
- strideA_m, b_, is_B_trans ? strideB_n : strideB_k, beta, c_,
- strideC_m);
+ float alpha = 1.f;
+ extended_sgemm("N", is_B_trans ? "T" : "N", &m, &n, &k, &alpha,
+ a_, &strideA_m, b_, is_B_trans ? &strideB_n : &strideB_k, &beta,
+ c_, &strideC_m);
}
template <prop_kind_t aprop>
void _ref_rnn_common_t<aprop>::gates_reduction(int n_gates, int dic, int batch,
const float *ws_gates_, float *diff_bias_) {
-#if (_OPENMP >= 201307)
+ auto body = [&](int i, int k) {
+ for (int j = 0; j < batch; j++)
+ diff_bias_[i * dic + k] += ws_gates_[(j * n_gates + i) * dic + k];
+ };
+
+ // @todo block k on simd-width
+#if (_OPENMP >= 201307) \
+ /* icc 17.0 has a problem with simd collapse */ \
+ && !((defined __INTEL_COMPILER) && (__INTEL_COMPILER == 1700))
#pragma omp parallel for simd collapse(2)
-#else
-#pragma omp parallel for collapse(2) ///@todo block k on simd-width
-#endif
for (int i = 0; i < n_gates; i++)
for (int k = 0; k < dic; k++)
- for (int j = 0; j < batch; j++)
- diff_bias_[i * dic + k]
- += ws_gates_[(j * n_gates + i) * dic + k];
+ body(i, k);
+#else
+ parallel_nd(n_gates, dic, body);
+#endif
}
/// @todo template this function on fwd or bwd, if the overhead
/// to pass argument for empty function is too big
@@ -276,15 +282,14 @@ cell_execution_sig(_ref_rnn_common_t<prop_kind::forward>::cell_execution_gru) {
ws_gates_, false, 1.0f);
// 3. activation zt and rt + elemwise multiplication rt,ht-1
-#pragma omp parallel for
- for (int i = 0; i < batch; i++) {
+ parallel_nd(batch, [&](int i) {
PRAGMA_OMP_SIMD()
for (int j = 0; j < dic; j++) {
ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j));
ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j));
states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 1, j);
}
- }
+ });
// 4. gemm Wh[2],h~t
(this->*gemm_state_func)(dic, batch, sic, n_gates * dic, sic,
@@ -292,15 +297,14 @@ cell_execution_sig(_ref_rnn_common_t<prop_kind::forward>::cell_execution_gru) {
&(ws_gates(0, 2, 0)), false, 1.0f);
// 5. activation h~t + calculate ht
-#pragma omp parallel for
- for (int i = 0; i < batch; i++) {
+ parallel_nd(batch, [&](int i) {
PRAGMA_OMP_SIMD()
for (int j = 0; j < dic; j++) {
ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j));
states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j) +
(1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j);
}
- }
+ });
}
template <>
@@ -312,8 +316,7 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::forward>::gru_lbr_elemwise) {
AOC<float, 2> states_t_l(states_t_l_, batch, wic);
AOC<float, 2> states_tm1_l(states_tm1_l_, batch, wic);
AOC<float, 3> ws_gemm_state(ws_cell_, batch, n_gates, dic);
-#pragma omp parallel for
- for (int i = 0; i < batch; i++) {
+ parallel_nd(batch, [&](int i) {
PRAGMA_OMP_SIMD()
for (int j = 0; j < dic; j++) {
float Wh_b = ws_gemm_state(i, 2, j) + bias(3, j);
@@ -327,7 +330,7 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::forward>::gru_lbr_elemwise) {
(1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j);
if (is_training) ws_Wh_b(i, j) = Wh_b;
}
- }
+ });
}
template <>
@@ -359,8 +362,7 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::backward>::gru_lbr_elemwise) {
// dG0 = (dht - G2) * dht * (1 - G0) * G0
// dG1 = (W*h + b) * dG2 * (1 - G1) * G1
// dG2 = (1 - G0) * dht * (1 - G2*G2)
-#pragma omp parallel for
- for (int i = 0; i < batch; i++) {
+ parallel_nd(batch, [&](int i) {
PRAGMA_OMP_SIMD()
for (int j = 0; j < dic; j++) {
float h = states_tm1_l(i, j);
@@ -379,7 +381,7 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::backward>::gru_lbr_elemwise) {
ws_gates(i, 0, j) = ws_gates_r(i, 0, j) = dG0;
ws_gates(i, 1, j) = ws_gates_r(i, 1, j) = dG1;
}
- }
+ });
}
template <>
@@ -412,12 +414,11 @@ cell_execution_sig(_ref_rnn_common_t<prop_kind::backward>::cell_execution_gru_lb
// db4 += e * (r * dG2)
gates_reduction(n_gates, dic, batch, ws_gates_, diff_bias_);
-#pragma omp parallel for
- for (int j = 0; j < dic; j++) {
+ parallel_nd(dic, [&](int j) {
for (int i = 0; i < batch; i++) {
diff_bias_[3 * dic + j] += ws_gates_r(i, 2, j);
}
- }
+ });
}
template <>
@@ -440,8 +441,7 @@ cell_execution_sig(_ref_rnn_common_t<prop_kind::backward>::cell_execution_gru) {
// dG2^ = dh * (1 - G0) * (1 - G2^2)
// dG0^ = dh * (ht-1 - G2) * u * (1 - G0)
// dht-1 (part) = dh * G0
-#pragma omp parallel for
- for (int i = 0; i < batch; i++) {
+ parallel_nd(batch, [&](int i) {
PRAGMA_OMP_SIMD()
for (int j = 0; j < dic; j++) {
float h = states_tm1_l(i, j);
@@ -456,7 +456,7 @@ cell_execution_sig(_ref_rnn_common_t<prop_kind::backward>::cell_execution_gru) {
ws_gates(i, 0, j) = dG0;
ws_gates(i, 2, j) = dG2;
}
- }
+ });
//2. calculate intermediate d(hG1)
//d(hG1) = dG2 * W2h^t
@@ -468,8 +468,7 @@ cell_execution_sig(_ref_rnn_common_t<prop_kind::backward>::cell_execution_gru) {
//dG1^ = d(hG1) * h * G1 * (1 - G1)
//dht-1 (part) += d(hG1) * G1
//h * G1 (required for dWh)
-#pragma omp parallel for
- for (int i = 0; i < batch; i++) {
+ parallel_nd(batch, [&](int i) {
PRAGMA_OMP_SIMD()
for (int j = 0; j < dic; j++) {
float h = states_tm1_l(i, j);
@@ -478,7 +477,7 @@ cell_execution_sig(_ref_rnn_common_t<prop_kind::backward>::cell_execution_gru) {
ws_gates(i, 1, j) = dhG1(i, j) * logistic_bwd(h, G1);
hG1(i, j) = G1 * h;
}
- }
+ });
//4. calculate diff weights
//dWx += [dG0 dG1 dG2] * [x]
@@ -634,7 +633,7 @@ struct reversed_indexer : wavefront_indexer {
default: return -1;
}
}
-
+
private:
original_indexer wd;
};
@@ -668,7 +667,7 @@ grid_execution_sig(_ref_rnn_common_t<aprop>::wavefront_execution){// (int dic, i
: (((exec_dir == b2t_l2r) || (exec_dir == b2t_r2l)) //nlayer is maxdim, we look for b2t
? (wavefront_indexer) wi_b2t_l2r_maxdim(n_layer)
: (wavefront_indexer) wi_t2b_r2l_maxdim(n_layer));
-
+
wavefront_indexer wi_mindim = (!is_niter_maxdim)
? (((exec_dir == b2t_l2r) || (exec_dir == t2b_l2r)) //niter is mindim, we look for l2r
? (wavefront_indexer) wi_b2t_l2r_mindim(n_iter)
@@ -676,7 +675,7 @@ grid_execution_sig(_ref_rnn_common_t<aprop>::wavefront_execution){// (int dic, i
: (((exec_dir == b2t_l2r) || (exec_dir == b2t_r2l)) //nlayer is mindim, we look for b2t
? (wavefront_indexer) wi_b2t_l2r_mindim(n_layer)
: (wavefront_indexer) wi_t2b_r2l_mindim(n_layer));
-
+
// auto get_offset = [=](wavefront_loop_index idx, int i, int j){
// int dim_min = wi_mindim.get(idx, i,j);
// int dim_max = wi_maxdim.get(idx, i,j);
@@ -734,8 +733,7 @@ void _ref_rnn_common_t<prop_kind::forward>::copy_init_layer(bool lr, bool rl,
ws_states_, n_direction, n_iter + 1, n_states, batch, wic);
auto xt_d = memory_desc_wrapper(conf_.src_pd(0));
-#pragma omp parallel for
- for (int it = 0; it < n_iter; it++) {
+ parallel_nd(n_iter, [&](int it) {
auto xxt = xt_ + xt_d.blk_off(it);
if (lr)
for (int b = 0; b < batch; b++)
@@ -746,7 +744,7 @@ void _ref_rnn_common_t<prop_kind::forward>::copy_init_layer(bool lr, bool rl,
for (int c = 0; c < slc; c++)
ws_states(n_direction - 1, n_iter - it, 0, b, c)
= *(xxt + b * slc + c);
- }
+ });
}
template <>
@@ -761,47 +759,38 @@ void _ref_rnn_common_t<prop_kind::backward>::copy_init_layer(bool lr, bool rl,
switch (conf_.direction()) {
case mkldnn_bidirectional_concat:
-#pragma omp parallel for collapse(2)
- for (int it = 0; it < n_iter; it++) {
- for (int b = 0; b < batch; b++) {
- auto diff_dst_layer_x
- = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
- for (int s = 0; s < dlc; s++) {
- ws_diff_states(n_layer, 0, it, n_states, b, s)
- = diff_dst_layer_x[s];
- ws_diff_states(n_layer, 1, it, n_states, b, s)
- = diff_dst_layer_x[dic + s];
- }
+ parallel_nd(n_iter, batch, [&](int it, int b) {
+ auto diff_dst_layer_x
+ = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
+ for (int s = 0; s < dlc; s++) {
+ ws_diff_states(n_layer, 0, it, n_states, b, s)
+ = diff_dst_layer_x[s];
+ ws_diff_states(n_layer, 1, it, n_states, b, s)
+ = diff_dst_layer_x[dic + s];
}
- }
+ });
break;
case mkldnn_bidirectional_sum:
-#pragma omp parallel for collapse(2)
- for (int it = 0; it < n_iter; it++) {
- for (int b = 0; b < batch; b++) {
- auto diff_dst_layer_x
- = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
- for (int s = 0; s < dic; s++) {
- ws_diff_states(n_layer, 0, it, n_states, b, s)
- = diff_dst_layer_x[s];
- ws_diff_states(n_layer, 1, it, n_states, b, s)
- = diff_dst_layer_x[s];
- }
+ parallel_nd(n_iter, batch, [&](int it, int b) {
+ auto diff_dst_layer_x
+ = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
+ for (int s = 0; s < dic; s++) {
+ ws_diff_states(n_layer, 0, it, n_states, b, s)
+ = diff_dst_layer_x[s];
+ ws_diff_states(n_layer, 1, it, n_states, b, s)
+ = diff_dst_layer_x[s];
}
- }
+ });
break;
default: // assumes default is always unidirectional
-#pragma omp parallel for collapse(2)
- for (int it = 0; it < n_iter; it++) {
- for (int b = 0; b < batch; b++) {
- auto diff_dst_layer_x
- = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
- for (int s = 0; s < dic; s++) {
- ws_diff_states(n_layer, 0, it, n_states, b, s)
- = diff_dst_layer_x[s];
- }
+ parallel_nd(n_iter, batch, [&](int it, int b) {
+ auto diff_dst_layer_x
+ = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
+ for (int s = 0; s < dic; s++) {
+ ws_diff_states(n_layer, 0, it, n_states, b, s)
+ = diff_dst_layer_x[s];
}
- }
+ });
break;
}
}
@@ -815,25 +804,21 @@ void _ref_rnn_common_t<prop_kind::forward>::copy_init_iter(int n_layer,
n_states, batch, wic);
auto firstit_states_d = memory_desc_wrapper(conf_.src_pd(1));
if (firstit_states_) {
-#pragma omp parallel for collapse(2)
- for (int lay = 0; lay < n_layer; lay++)
- for (int dir = 0; dir < n_direction; dir++)
- for (int state = 0; state < n_states; state++)
- for (int b = 0; b < batch; ++b) {
- array_copy(&(ws_states(lay + 1, dir, 0, state, b, 0)),
- firstit_states_
- + firstit_states_d.blk_off(
- lay, dir, state, b),
- sic);
- }
+ parallel_nd(n_layer, n_direction, [&](int lay, int dir) {
+ for (int state = 0; state < n_states; state++)
+ for (int b = 0; b < batch; ++b) {
+ array_copy(&(ws_states(lay + 1, dir, 0, state, b, 0)),
+ firstit_states_ + firstit_states_d.blk_off(
+ lay, dir, state, b), sic);
+ }
+ });
} else {
-#pragma omp parallel for collapse(2)
- for (int lay = 0; lay < n_layer; lay++)
- for (int dir = 0; dir < n_direction; dir++)
- for (int state = 0; state < n_states; state++)
- for (int i = 0; i < batch; i++)
- for (int j = 0; j < sic; j++)
- ws_states(lay + 1, dir, 0, state, i, j) = 0.0f;
+ parallel_nd(n_layer, n_direction, [&](int lay, int dir) {
+ for (int state = 0; state < n_states; state++)
+ for (int i = 0; i < batch; i++)
+ for (int j = 0; j < sic; j++)
+ ws_states(lay + 1, dir, 0, state, i, j) = 0.0f;
+ });
}
}
@@ -846,27 +831,18 @@ void _ref_rnn_common_t<prop_kind::backward>::copy_init_iter(int n_layer,
n_iter + 1, n_states + 1, batch, wic);
auto diff_dst_iter_d = memory_desc_wrapper(conf_.diff_dst_pd(1));
if (diff_dst_iter_) {
-#pragma omp parallel for collapse(4)
- for (int lay = 0; lay < n_layer; lay++)
- for (int dir = 0; dir < n_direction; dir++)
- for (int state = 0; state < n_states; state++)
- for (int b = 0; b < batch; ++b) {
- array_copy(&(ws_diff_states(
- lay, dir, n_iter, state, b, 0)),
- diff_dst_iter_
- + diff_dst_iter_d.blk_off(
- lay, dir, state, b),
- dic);
- }
+ parallel_nd(n_layer, n_direction, n_states, batch,
+ [&](int lay, int dir, int state, int b) {
+ array_copy(&(ws_diff_states(lay, dir, n_iter, state, b, 0)),
+ diff_dst_iter_ + diff_dst_iter_d.blk_off(lay, dir, state, b),
+ dic);
+ });
} else {
-#pragma omp parallel for collapse(4)
- for (int lay = 0; lay < n_layer; lay++)
- for (int dir = 0; dir < n_direction; dir++)
- for (int state = 0; state < n_states; state++)
- for (int i = 0; i < batch; i++)
- for (int j = 0; j < dic; j++)
- ws_diff_states(lay, dir, n_iter, state, i, j)
- = 0.0f;
+ parallel_nd(n_layer, n_direction, n_states, batch,
+ [&](int lay, int dir, int state, int i) {
+ for (int j = 0; j < dic; j++)
+ ws_diff_states(lay, dir, n_iter, state, i, j) = 0.0f;
+ });
}
}
@@ -880,30 +856,28 @@ void _ref_rnn_common_t<prop_kind::forward>::copy_res_layer(bool lr, bool rl,
auto dst_layer_d = memory_desc_wrapper(conf_.dst_pd(0));
AOC<const float, 6> ws_states(ws_states_, n_layer + 1, n_direction,
n_iter + 1, n_states, batch, wic);
-#pragma omp parallel for collapse(2)
- for (int it = 0; it < n_iter; it++) {
- for (int b = 0; b < batch; b++) {
- int dir = 0;
- if (lr) {
- for (int s = 0; s < dic; s++)
+
+ parallel_nd(n_iter, batch, [&](int it, int b) {
+ int dir = 0;
+ if (lr) {
+ for (int s = 0; s < dic; s++)
+ dst_layer_[dst_layer_d.blk_off(it, b, dir * dic + s)]
+ = ws_states(n_layer, dir, it + 1, 0, b, s);
+ dir = 1;
+ }
+ if (rl) {
+ for (int s = 0; s < dic; s++)
+ switch (direction) {
+ case mkldnn_bidirectional_sum:
+ dst_layer_[dst_layer_d.blk_off(it, b, s)] += ws_states(
+ n_layer, dir, n_iter - it, 0, b, s);
+ break;
+ default:
dst_layer_[dst_layer_d.blk_off(it, b, dir * dic + s)]
- = ws_states(n_layer, dir, it + 1, 0, b, s);
- dir = 1;
- }
- if (rl) {
- for (int s = 0; s < dic; s++)
- switch (direction) {
- case mkldnn_bidirectional_sum:
- dst_layer_[dst_layer_d.blk_off(it, b, s)] += ws_states(
- n_layer, dir, n_iter - it, 0, b, s);
- break;
- default:
- dst_layer_[dst_layer_d.blk_off(it, b, dir * dic + s)]
- = ws_states(n_layer, dir, n_iter - it, 0, b, s);
- }
- }
+ = ws_states(n_layer, dir, n_iter - it, 0, b, s);
+ }
}
- }
+ });
}
template <>
@@ -916,26 +890,24 @@ void _ref_rnn_common_t<prop_kind::backward>::copy_res_layer(bool lr, bool rl,
auto diff_src_layer_d = memory_desc_wrapper(conf_.diff_src_pd(0));
AOC<const float, 6> ws_diff_states(ws_diff_states_, n_layer + 1,
n_direction, n_iter + 1, n_states + 1, batch, wic);
-#pragma omp parallel for collapse(2)
- for (int it = 0; it < n_iter; it++) {
- for (int b = 0; b < batch; b++) {
- int dir = 0;
- for (int s = 0; s < slc; s++) {
- float *dst_addr = diff_src_layer_
- + diff_src_layer_d.blk_off(
- (direction
- == mkldnn_unidirectional_right2left) ?
- n_iter - 1 - it :
- it,
- b, dir * slc + s);
- float res = ws_diff_states(0, 0, it, n_states, b, s);
- if (n_direction - 1)
- res += ws_diff_states(
- 0, 1, n_iter - 1 - it, n_states, b, s);
- dst_addr[0] = res;
- }
+
+ parallel_nd(n_iter, batch, [&](int it, int b) {
+ int dir = 0;
+ for (int s = 0; s < slc; s++) {
+ float *dst_addr = diff_src_layer_
+ + diff_src_layer_d.blk_off(
+ (direction
+ == mkldnn_unidirectional_right2left) ?
+ n_iter - 1 - it :
+ it,
+ b, dir * slc + s);
+ float res = ws_diff_states(0, 0, it, n_states, b, s);
+ if (n_direction - 1)
+ res += ws_diff_states(
+ 0, 1, n_iter - 1 - it, n_states, b, s);
+ dst_addr[0] = res;
}
- }
+ });
}
template <>
@@ -947,17 +919,13 @@ void _ref_rnn_common_t<prop_kind::forward>::copy_res_iter(int n_layer,
AOC<const float, 6> ws_states(ws_states_, n_layer + 1, n_direction,
n_iter + 1, n_states, batch, wic);
if (dst_iter_) {
-#pragma omp parallel for collapse(4)
- for (int lay = 0; lay < n_layer; lay++) {
- for (int dir = 0; dir < n_direction; dir++)
- for (int state = 0; state < n_states; state++)
- for (int b = 0; b < batch; b++)
- for (int s = 0; s < dic; s++) {
- dst_iter_[dst_iter_d.blk_off(lay, dir, state, b, s)]
- = ws_states(
- lay + 1, dir, n_iter, state, b, s);
- }
- }
+ parallel_nd(n_layer, n_direction, n_states, batch,
+ [&](int lay, int dir, int state, int b) {
+ for (int s = 0; s < dic; s++) {
+ dst_iter_[dst_iter_d.blk_off(lay, dir, state, b, s)]
+ = ws_states(lay + 1, dir, n_iter, state, b, s);
+ }
+ });
}
}
@@ -970,17 +938,14 @@ void _ref_rnn_common_t<prop_kind::backward>::copy_res_iter(int n_layer,
AOC<const float, 6> ws_diff_states(ws_diff_states_, n_layer + 1,
n_direction, n_iter + 1, n_states + 1, batch, wic);
if (diff_src_iter_) {
-#pragma omp parallel for collapse(4)
- for (int lay = 0; lay < n_layer; lay++) {
- for (int dir = 0; dir < n_direction; dir++)
- for (int state = 0; state < n_states; state++)
- for (int b = 0; b < batch; b++)
- for (int s = 0; s < sic; s++) {
- diff_src_iter_[diff_src_iter_d.blk_off(
- lay, dir, state, b, s)]
- = ws_diff_states(lay, dir, 0, state, b, s);
- }
- }
+ parallel_nd(n_layer, n_direction, n_states, batch,
+ [&](int lay, int dir, int state, int b) {
+ for (int s = 0; s < sic; s++) {
+ diff_src_iter_[diff_src_iter_d.blk_off(
+ lay, dir, state, b, s)]
+ = ws_diff_states(lay, dir, 0, state, b, s);
+ }
+ });
}
}
@@ -1127,7 +1092,7 @@ void _ref_rnn_common_t<aprop>::execute_() {
auto diff_dst_layer = is_fwd ?
nullptr :
reinterpret_cast<const float *>(this->input_memory(input_idx++));
- auto diff_dst_iter = is_fwd || !conf_.with_src_iter() ?
+ auto diff_dst_iter = is_fwd || !conf_.with_dst_iter() ?
nullptr :
reinterpret_cast<const float *>(this->input_memory(input_idx++));