diff options
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/src/cpu/ref_rnn.cpp')
-rw-r--r-- | inference-engine/thirdparty/mkl-dnn/src/cpu/ref_rnn.cpp | 337 |
1 files changed, 151 insertions, 186 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/ref_rnn.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/ref_rnn.cpp index a3eeb6cd1..c1f0612da 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/ref_rnn.cpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/ref_rnn.cpp @@ -33,6 +33,7 @@ #include "mkldnn_thread.hpp" #include "mkldnn_traits.hpp" #include "type_helpers.hpp" +#include "gemm/gemm.hpp" #include "ref_rnn.hpp" @@ -79,14 +80,13 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::forward>::rnn_elemwise) { AOC<float, 3> ws_gates(ws_gates_, batch, n_gates, dic); AOC<const float, 2> bias(bias_, n_gates, dic); AOC<float, 3> states_t_l(states_t_l_, n_states, batch, wic); -#pragma omp parallel for - for (int i = 0; i < batch; i++) { + parallel_nd(batch, [&](int i) { for (int j = 0; j < dic; j++) { const float h = activation_func(0, ws_gates(i, 0, j) + bias(0, j), 0, 0); ws_gates(i, 0, j) = states_t_l(0, i, j) = h; } - } + }); } template <> @@ -96,15 +96,14 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::backward>::rnn_elemwise) { diff_states_tp1_l_, n_states + 1, batch, wic); AOC<float, 3> diff_states_t_lp1( diff_states_t_lp1_, n_states + 1, batch, wic); -#pragma omp parallel for - for (int i = 0; i < batch; ++i) { + parallel_nd(batch, [&](int i) { for (int j = 0; j < dic; ++j) { const float dH = diff_states_t_lp1(n_states, i, j) + diff_states_tp1_l(0, i, j); auto g = ws_gates(i, 0, j); ws_gates(i, 0, j) = activation_func(dH, g, 0, 0); } - } + }); } template <> @@ -114,9 +113,11 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::forward>::lstm_elemwise) { AOC<float, 3> states_t_l(states_t_l_, n_states, batch, wic); AOC<float, 3> states_tm1_l(states_tm1_l_, n_states, batch, wic); -#pragma omp parallel for - for (int i = 0; i < batch; i++) { + parallel_nd(batch, [&](int i) { +// WA. Loss of correctnes in case of simd loop unrolling with icc 18 +#if !defined(__INTEL_COMPILER) PRAGMA_OMP_SIMD() +#endif for (int j = 0; j < dic; j++) { ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j)); ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j)); @@ -128,7 +129,7 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::forward>::lstm_elemwise) { states_t_l(0, i, j) = ws_gates(i, 2, j) * tanh_fwd(tmp); states_t_l(1, i, j) = tmp; } - } + }); } template <> @@ -145,8 +146,7 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::backward>::lstm_elemwise) { auto one_m_square = [](float a) -> float { return 1.0f - a * a; }; -#pragma omp parallel for - for (int i = 0; i < batch; i++) { + parallel_nd(batch, [&](int i) { PRAGMA_OMP_SIMD() for (int j = 0; j < dic; j++) { float Ct = states_t_l(1, i, j); @@ -173,7 +173,7 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::backward>::lstm_elemwise) { ws_gates(i, 2, j) = dG2; ws_gates(i, 3, j) = dG3; } - } + }); } template <prop_kind_t aprop> @@ -197,25 +197,31 @@ gemm_sig(_ref_rnn_common_t<aprop>::packed_gemm) { template <prop_kind_t aprop> gemm_sig(_ref_rnn_common_t<aprop>::gemm) { - cblas_sgemm(CblasColMajor, CblasNoTrans, - is_B_trans ? CblasTrans : CblasNoTrans, m, n, k, 1.0f, a_, - strideA_m, b_, is_B_trans ? strideB_n : strideB_k, beta, c_, - strideC_m); + float alpha = 1.f; + extended_sgemm("N", is_B_trans ? "T" : "N", &m, &n, &k, &alpha, + a_, &strideA_m, b_, is_B_trans ? &strideB_n : &strideB_k, &beta, + c_, &strideC_m); } template <prop_kind_t aprop> void _ref_rnn_common_t<aprop>::gates_reduction(int n_gates, int dic, int batch, const float *ws_gates_, float *diff_bias_) { -#if (_OPENMP >= 201307) + auto body = [&](int i, int k) { + for (int j = 0; j < batch; j++) + diff_bias_[i * dic + k] += ws_gates_[(j * n_gates + i) * dic + k]; + }; + + // @todo block k on simd-width +#if (_OPENMP >= 201307) \ + /* icc 17.0 has a problem with simd collapse */ \ + && !((defined __INTEL_COMPILER) && (__INTEL_COMPILER == 1700)) #pragma omp parallel for simd collapse(2) -#else -#pragma omp parallel for collapse(2) ///@todo block k on simd-width -#endif for (int i = 0; i < n_gates; i++) for (int k = 0; k < dic; k++) - for (int j = 0; j < batch; j++) - diff_bias_[i * dic + k] - += ws_gates_[(j * n_gates + i) * dic + k]; + body(i, k); +#else + parallel_nd(n_gates, dic, body); +#endif } /// @todo template this function on fwd or bwd, if the overhead /// to pass argument for empty function is too big @@ -276,15 +282,14 @@ cell_execution_sig(_ref_rnn_common_t<prop_kind::forward>::cell_execution_gru) { ws_gates_, false, 1.0f); // 3. activation zt and rt + elemwise multiplication rt,ht-1 -#pragma omp parallel for - for (int i = 0; i < batch; i++) { + parallel_nd(batch, [&](int i) { PRAGMA_OMP_SIMD() for (int j = 0; j < dic; j++) { ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j)); ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j)); states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 1, j); } - } + }); // 4. gemm Wh[2],h~t (this->*gemm_state_func)(dic, batch, sic, n_gates * dic, sic, @@ -292,15 +297,14 @@ cell_execution_sig(_ref_rnn_common_t<prop_kind::forward>::cell_execution_gru) { &(ws_gates(0, 2, 0)), false, 1.0f); // 5. activation h~t + calculate ht -#pragma omp parallel for - for (int i = 0; i < batch; i++) { + parallel_nd(batch, [&](int i) { PRAGMA_OMP_SIMD() for (int j = 0; j < dic; j++) { ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j)); states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j) + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j); } - } + }); } template <> @@ -312,8 +316,7 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::forward>::gru_lbr_elemwise) { AOC<float, 2> states_t_l(states_t_l_, batch, wic); AOC<float, 2> states_tm1_l(states_tm1_l_, batch, wic); AOC<float, 3> ws_gemm_state(ws_cell_, batch, n_gates, dic); -#pragma omp parallel for - for (int i = 0; i < batch; i++) { + parallel_nd(batch, [&](int i) { PRAGMA_OMP_SIMD() for (int j = 0; j < dic; j++) { float Wh_b = ws_gemm_state(i, 2, j) + bias(3, j); @@ -327,7 +330,7 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::forward>::gru_lbr_elemwise) { (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j); if (is_training) ws_Wh_b(i, j) = Wh_b; } - } + }); } template <> @@ -359,8 +362,7 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::backward>::gru_lbr_elemwise) { // dG0 = (dht - G2) * dht * (1 - G0) * G0 // dG1 = (W*h + b) * dG2 * (1 - G1) * G1 // dG2 = (1 - G0) * dht * (1 - G2*G2) -#pragma omp parallel for - for (int i = 0; i < batch; i++) { + parallel_nd(batch, [&](int i) { PRAGMA_OMP_SIMD() for (int j = 0; j < dic; j++) { float h = states_tm1_l(i, j); @@ -379,7 +381,7 @@ elemwise_sig(_ref_rnn_common_t<prop_kind::backward>::gru_lbr_elemwise) { ws_gates(i, 0, j) = ws_gates_r(i, 0, j) = dG0; ws_gates(i, 1, j) = ws_gates_r(i, 1, j) = dG1; } - } + }); } template <> @@ -412,12 +414,11 @@ cell_execution_sig(_ref_rnn_common_t<prop_kind::backward>::cell_execution_gru_lb // db4 += e * (r * dG2) gates_reduction(n_gates, dic, batch, ws_gates_, diff_bias_); -#pragma omp parallel for - for (int j = 0; j < dic; j++) { + parallel_nd(dic, [&](int j) { for (int i = 0; i < batch; i++) { diff_bias_[3 * dic + j] += ws_gates_r(i, 2, j); } - } + }); } template <> @@ -440,8 +441,7 @@ cell_execution_sig(_ref_rnn_common_t<prop_kind::backward>::cell_execution_gru) { // dG2^ = dh * (1 - G0) * (1 - G2^2) // dG0^ = dh * (ht-1 - G2) * u * (1 - G0) // dht-1 (part) = dh * G0 -#pragma omp parallel for - for (int i = 0; i < batch; i++) { + parallel_nd(batch, [&](int i) { PRAGMA_OMP_SIMD() for (int j = 0; j < dic; j++) { float h = states_tm1_l(i, j); @@ -456,7 +456,7 @@ cell_execution_sig(_ref_rnn_common_t<prop_kind::backward>::cell_execution_gru) { ws_gates(i, 0, j) = dG0; ws_gates(i, 2, j) = dG2; } - } + }); //2. calculate intermediate d(hG1) //d(hG1) = dG2 * W2h^t @@ -468,8 +468,7 @@ cell_execution_sig(_ref_rnn_common_t<prop_kind::backward>::cell_execution_gru) { //dG1^ = d(hG1) * h * G1 * (1 - G1) //dht-1 (part) += d(hG1) * G1 //h * G1 (required for dWh) -#pragma omp parallel for - for (int i = 0; i < batch; i++) { + parallel_nd(batch, [&](int i) { PRAGMA_OMP_SIMD() for (int j = 0; j < dic; j++) { float h = states_tm1_l(i, j); @@ -478,7 +477,7 @@ cell_execution_sig(_ref_rnn_common_t<prop_kind::backward>::cell_execution_gru) { ws_gates(i, 1, j) = dhG1(i, j) * logistic_bwd(h, G1); hG1(i, j) = G1 * h; } - } + }); //4. calculate diff weights //dWx += [dG0 dG1 dG2] * [x] @@ -634,7 +633,7 @@ struct reversed_indexer : wavefront_indexer { default: return -1; } } - + private: original_indexer wd; }; @@ -668,7 +667,7 @@ grid_execution_sig(_ref_rnn_common_t<aprop>::wavefront_execution){// (int dic, i : (((exec_dir == b2t_l2r) || (exec_dir == b2t_r2l)) //nlayer is maxdim, we look for b2t ? (wavefront_indexer) wi_b2t_l2r_maxdim(n_layer) : (wavefront_indexer) wi_t2b_r2l_maxdim(n_layer)); - + wavefront_indexer wi_mindim = (!is_niter_maxdim) ? (((exec_dir == b2t_l2r) || (exec_dir == t2b_l2r)) //niter is mindim, we look for l2r ? (wavefront_indexer) wi_b2t_l2r_mindim(n_iter) @@ -676,7 +675,7 @@ grid_execution_sig(_ref_rnn_common_t<aprop>::wavefront_execution){// (int dic, i : (((exec_dir == b2t_l2r) || (exec_dir == b2t_r2l)) //nlayer is mindim, we look for b2t ? (wavefront_indexer) wi_b2t_l2r_mindim(n_layer) : (wavefront_indexer) wi_t2b_r2l_mindim(n_layer)); - + // auto get_offset = [=](wavefront_loop_index idx, int i, int j){ // int dim_min = wi_mindim.get(idx, i,j); // int dim_max = wi_maxdim.get(idx, i,j); @@ -734,8 +733,7 @@ void _ref_rnn_common_t<prop_kind::forward>::copy_init_layer(bool lr, bool rl, ws_states_, n_direction, n_iter + 1, n_states, batch, wic); auto xt_d = memory_desc_wrapper(conf_.src_pd(0)); -#pragma omp parallel for - for (int it = 0; it < n_iter; it++) { + parallel_nd(n_iter, [&](int it) { auto xxt = xt_ + xt_d.blk_off(it); if (lr) for (int b = 0; b < batch; b++) @@ -746,7 +744,7 @@ void _ref_rnn_common_t<prop_kind::forward>::copy_init_layer(bool lr, bool rl, for (int c = 0; c < slc; c++) ws_states(n_direction - 1, n_iter - it, 0, b, c) = *(xxt + b * slc + c); - } + }); } template <> @@ -761,47 +759,38 @@ void _ref_rnn_common_t<prop_kind::backward>::copy_init_layer(bool lr, bool rl, switch (conf_.direction()) { case mkldnn_bidirectional_concat: -#pragma omp parallel for collapse(2) - for (int it = 0; it < n_iter; it++) { - for (int b = 0; b < batch; b++) { - auto diff_dst_layer_x - = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); - for (int s = 0; s < dlc; s++) { - ws_diff_states(n_layer, 0, it, n_states, b, s) - = diff_dst_layer_x[s]; - ws_diff_states(n_layer, 1, it, n_states, b, s) - = diff_dst_layer_x[dic + s]; - } + parallel_nd(n_iter, batch, [&](int it, int b) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < dlc; s++) { + ws_diff_states(n_layer, 0, it, n_states, b, s) + = diff_dst_layer_x[s]; + ws_diff_states(n_layer, 1, it, n_states, b, s) + = diff_dst_layer_x[dic + s]; } - } + }); break; case mkldnn_bidirectional_sum: -#pragma omp parallel for collapse(2) - for (int it = 0; it < n_iter; it++) { - for (int b = 0; b < batch; b++) { - auto diff_dst_layer_x - = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); - for (int s = 0; s < dic; s++) { - ws_diff_states(n_layer, 0, it, n_states, b, s) - = diff_dst_layer_x[s]; - ws_diff_states(n_layer, 1, it, n_states, b, s) - = diff_dst_layer_x[s]; - } + parallel_nd(n_iter, batch, [&](int it, int b) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < dic; s++) { + ws_diff_states(n_layer, 0, it, n_states, b, s) + = diff_dst_layer_x[s]; + ws_diff_states(n_layer, 1, it, n_states, b, s) + = diff_dst_layer_x[s]; } - } + }); break; default: // assumes default is always unidirectional -#pragma omp parallel for collapse(2) - for (int it = 0; it < n_iter; it++) { - for (int b = 0; b < batch; b++) { - auto diff_dst_layer_x - = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); - for (int s = 0; s < dic; s++) { - ws_diff_states(n_layer, 0, it, n_states, b, s) - = diff_dst_layer_x[s]; - } + parallel_nd(n_iter, batch, [&](int it, int b) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < dic; s++) { + ws_diff_states(n_layer, 0, it, n_states, b, s) + = diff_dst_layer_x[s]; } - } + }); break; } } @@ -815,25 +804,21 @@ void _ref_rnn_common_t<prop_kind::forward>::copy_init_iter(int n_layer, n_states, batch, wic); auto firstit_states_d = memory_desc_wrapper(conf_.src_pd(1)); if (firstit_states_) { -#pragma omp parallel for collapse(2) - for (int lay = 0; lay < n_layer; lay++) - for (int dir = 0; dir < n_direction; dir++) - for (int state = 0; state < n_states; state++) - for (int b = 0; b < batch; ++b) { - array_copy(&(ws_states(lay + 1, dir, 0, state, b, 0)), - firstit_states_ - + firstit_states_d.blk_off( - lay, dir, state, b), - sic); - } + parallel_nd(n_layer, n_direction, [&](int lay, int dir) { + for (int state = 0; state < n_states; state++) + for (int b = 0; b < batch; ++b) { + array_copy(&(ws_states(lay + 1, dir, 0, state, b, 0)), + firstit_states_ + firstit_states_d.blk_off( + lay, dir, state, b), sic); + } + }); } else { -#pragma omp parallel for collapse(2) - for (int lay = 0; lay < n_layer; lay++) - for (int dir = 0; dir < n_direction; dir++) - for (int state = 0; state < n_states; state++) - for (int i = 0; i < batch; i++) - for (int j = 0; j < sic; j++) - ws_states(lay + 1, dir, 0, state, i, j) = 0.0f; + parallel_nd(n_layer, n_direction, [&](int lay, int dir) { + for (int state = 0; state < n_states; state++) + for (int i = 0; i < batch; i++) + for (int j = 0; j < sic; j++) + ws_states(lay + 1, dir, 0, state, i, j) = 0.0f; + }); } } @@ -846,27 +831,18 @@ void _ref_rnn_common_t<prop_kind::backward>::copy_init_iter(int n_layer, n_iter + 1, n_states + 1, batch, wic); auto diff_dst_iter_d = memory_desc_wrapper(conf_.diff_dst_pd(1)); if (diff_dst_iter_) { -#pragma omp parallel for collapse(4) - for (int lay = 0; lay < n_layer; lay++) - for (int dir = 0; dir < n_direction; dir++) - for (int state = 0; state < n_states; state++) - for (int b = 0; b < batch; ++b) { - array_copy(&(ws_diff_states( - lay, dir, n_iter, state, b, 0)), - diff_dst_iter_ - + diff_dst_iter_d.blk_off( - lay, dir, state, b), - dic); - } + parallel_nd(n_layer, n_direction, n_states, batch, + [&](int lay, int dir, int state, int b) { + array_copy(&(ws_diff_states(lay, dir, n_iter, state, b, 0)), + diff_dst_iter_ + diff_dst_iter_d.blk_off(lay, dir, state, b), + dic); + }); } else { -#pragma omp parallel for collapse(4) - for (int lay = 0; lay < n_layer; lay++) - for (int dir = 0; dir < n_direction; dir++) - for (int state = 0; state < n_states; state++) - for (int i = 0; i < batch; i++) - for (int j = 0; j < dic; j++) - ws_diff_states(lay, dir, n_iter, state, i, j) - = 0.0f; + parallel_nd(n_layer, n_direction, n_states, batch, + [&](int lay, int dir, int state, int i) { + for (int j = 0; j < dic; j++) + ws_diff_states(lay, dir, n_iter, state, i, j) = 0.0f; + }); } } @@ -880,30 +856,28 @@ void _ref_rnn_common_t<prop_kind::forward>::copy_res_layer(bool lr, bool rl, auto dst_layer_d = memory_desc_wrapper(conf_.dst_pd(0)); AOC<const float, 6> ws_states(ws_states_, n_layer + 1, n_direction, n_iter + 1, n_states, batch, wic); -#pragma omp parallel for collapse(2) - for (int it = 0; it < n_iter; it++) { - for (int b = 0; b < batch; b++) { - int dir = 0; - if (lr) { - for (int s = 0; s < dic; s++) + + parallel_nd(n_iter, batch, [&](int it, int b) { + int dir = 0; + if (lr) { + for (int s = 0; s < dic; s++) + dst_layer_[dst_layer_d.blk_off(it, b, dir * dic + s)] + = ws_states(n_layer, dir, it + 1, 0, b, s); + dir = 1; + } + if (rl) { + for (int s = 0; s < dic; s++) + switch (direction) { + case mkldnn_bidirectional_sum: + dst_layer_[dst_layer_d.blk_off(it, b, s)] += ws_states( + n_layer, dir, n_iter - it, 0, b, s); + break; + default: dst_layer_[dst_layer_d.blk_off(it, b, dir * dic + s)] - = ws_states(n_layer, dir, it + 1, 0, b, s); - dir = 1; - } - if (rl) { - for (int s = 0; s < dic; s++) - switch (direction) { - case mkldnn_bidirectional_sum: - dst_layer_[dst_layer_d.blk_off(it, b, s)] += ws_states( - n_layer, dir, n_iter - it, 0, b, s); - break; - default: - dst_layer_[dst_layer_d.blk_off(it, b, dir * dic + s)] - = ws_states(n_layer, dir, n_iter - it, 0, b, s); - } - } + = ws_states(n_layer, dir, n_iter - it, 0, b, s); + } } - } + }); } template <> @@ -916,26 +890,24 @@ void _ref_rnn_common_t<prop_kind::backward>::copy_res_layer(bool lr, bool rl, auto diff_src_layer_d = memory_desc_wrapper(conf_.diff_src_pd(0)); AOC<const float, 6> ws_diff_states(ws_diff_states_, n_layer + 1, n_direction, n_iter + 1, n_states + 1, batch, wic); -#pragma omp parallel for collapse(2) - for (int it = 0; it < n_iter; it++) { - for (int b = 0; b < batch; b++) { - int dir = 0; - for (int s = 0; s < slc; s++) { - float *dst_addr = diff_src_layer_ - + diff_src_layer_d.blk_off( - (direction - == mkldnn_unidirectional_right2left) ? - n_iter - 1 - it : - it, - b, dir * slc + s); - float res = ws_diff_states(0, 0, it, n_states, b, s); - if (n_direction - 1) - res += ws_diff_states( - 0, 1, n_iter - 1 - it, n_states, b, s); - dst_addr[0] = res; - } + + parallel_nd(n_iter, batch, [&](int it, int b) { + int dir = 0; + for (int s = 0; s < slc; s++) { + float *dst_addr = diff_src_layer_ + + diff_src_layer_d.blk_off( + (direction + == mkldnn_unidirectional_right2left) ? + n_iter - 1 - it : + it, + b, dir * slc + s); + float res = ws_diff_states(0, 0, it, n_states, b, s); + if (n_direction - 1) + res += ws_diff_states( + 0, 1, n_iter - 1 - it, n_states, b, s); + dst_addr[0] = res; } - } + }); } template <> @@ -947,17 +919,13 @@ void _ref_rnn_common_t<prop_kind::forward>::copy_res_iter(int n_layer, AOC<const float, 6> ws_states(ws_states_, n_layer + 1, n_direction, n_iter + 1, n_states, batch, wic); if (dst_iter_) { -#pragma omp parallel for collapse(4) - for (int lay = 0; lay < n_layer; lay++) { - for (int dir = 0; dir < n_direction; dir++) - for (int state = 0; state < n_states; state++) - for (int b = 0; b < batch; b++) - for (int s = 0; s < dic; s++) { - dst_iter_[dst_iter_d.blk_off(lay, dir, state, b, s)] - = ws_states( - lay + 1, dir, n_iter, state, b, s); - } - } + parallel_nd(n_layer, n_direction, n_states, batch, + [&](int lay, int dir, int state, int b) { + for (int s = 0; s < dic; s++) { + dst_iter_[dst_iter_d.blk_off(lay, dir, state, b, s)] + = ws_states(lay + 1, dir, n_iter, state, b, s); + } + }); } } @@ -970,17 +938,14 @@ void _ref_rnn_common_t<prop_kind::backward>::copy_res_iter(int n_layer, AOC<const float, 6> ws_diff_states(ws_diff_states_, n_layer + 1, n_direction, n_iter + 1, n_states + 1, batch, wic); if (diff_src_iter_) { -#pragma omp parallel for collapse(4) - for (int lay = 0; lay < n_layer; lay++) { - for (int dir = 0; dir < n_direction; dir++) - for (int state = 0; state < n_states; state++) - for (int b = 0; b < batch; b++) - for (int s = 0; s < sic; s++) { - diff_src_iter_[diff_src_iter_d.blk_off( - lay, dir, state, b, s)] - = ws_diff_states(lay, dir, 0, state, b, s); - } - } + parallel_nd(n_layer, n_direction, n_states, batch, + [&](int lay, int dir, int state, int b) { + for (int s = 0; s < sic; s++) { + diff_src_iter_[diff_src_iter_d.blk_off( + lay, dir, state, b, s)] + = ws_diff_states(lay, dir, 0, state, b, s); + } + }); } } @@ -1127,7 +1092,7 @@ void _ref_rnn_common_t<aprop>::execute_() { auto diff_dst_layer = is_fwd ? nullptr : reinterpret_cast<const float *>(this->input_memory(input_idx++)); - auto diff_dst_iter = is_fwd || !conf_.with_src_iter() ? + auto diff_dst_iter = is_fwd || !conf_.with_dst_iter() ? nullptr : reinterpret_cast<const float *>(this->input_memory(input_idx++)); |