diff options
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/tests/benchdnn/rnn/ref_rnn.cpp')
-rw-r--r-- | inference-engine/thirdparty/mkl-dnn/tests/benchdnn/rnn/ref_rnn.cpp | 130 |
1 files changed, 67 insertions, 63 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/tests/benchdnn/rnn/ref_rnn.cpp b/inference-engine/thirdparty/mkl-dnn/tests/benchdnn/rnn/ref_rnn.cpp index c4cb25094..788cfe160 100644 --- a/inference-engine/thirdparty/mkl-dnn/tests/benchdnn/rnn/ref_rnn.cpp +++ b/inference-engine/thirdparty/mkl-dnn/tests/benchdnn/rnn/ref_rnn.cpp @@ -16,6 +16,8 @@ #include <stdlib.h> +#include "src/common/mkldnn_thread.hpp" + #include "rnn/rnn.hpp" #include "rnn/rnn_aux.hpp" @@ -32,8 +34,7 @@ void lstm_activation(int dic, int n_gates, int batch, // float a[batch][n_gates * wc] float *a) { AOC<float> pa(a, batch, n_gates, dic); -#pragma omp parallel for - for (int ib = 0; ib < batch; ib++) { + mkldnn::impl::parallel_nd(batch, [&](int ib) { for (int ig = 0; ig < 3; ig++) { for (int ih = 0; ih < dic; ih++) { pa(ib, ig, ih) = logistic(pa(ib, ig, ih)); @@ -47,7 +48,7 @@ void lstm_activation(int dic, int n_gates, int batch, print(80, "activation 2 a[%d][%d][%d] = %.7f\n", ib, ig, j, pa(ib, ig, j)); } - } + }); } float activation(activation_t f, float x, bool is_fwd = true) { @@ -69,10 +70,10 @@ void rnn_fwd(activation_t f, int sic, int slc, int dic, int wc, int batch, AOC<const float> bias(bias_, n_gates, dic); AOC<float> gates(gates_, batch, n_gates, dic); - gemm("N", "N", batch, n_gates * dic, slc, src_layer_, wc, weights_layer_, - n_gates * dic, gates_, n_gates * dic, 0.0); - gemm("N", "N", batch, n_gates * dic, sic, src_iter_h_, wc, weights_iter_h_, - n_gates * dic, gates_, n_gates * dic, 1.0); + gemm("C", "N", "N", batch, n_gates * dic, slc, 1.0, src_layer_, wc, + weights_layer_, n_gates * dic, 0.0, gates_, n_gates * dic); + gemm("C", "N", "N", batch, n_gates * dic, sic, 1.0, src_iter_h_, wc, + weights_iter_h_, n_gates * dic, 1.0, gates_, n_gates * dic); for (int i = 0; i < batch; i++) for (int j = 0; j < n_gates; j++) @@ -94,10 +95,10 @@ void gru_fwd(int sic, int slc, int dic, int wc, int batch, int n_gates, AOC<float> gates(gates_, batch, n_gates, dic); AOC<float> h_dst(dst_iter_h_, batch, wc); - gemm("N", "N", batch, n_gates * dic, slc, src_layer_, wc, weights_layer_, - n_gates * dic, gates_, n_gates * dic, 0.0); - gemm("N", "N", batch, (n_gates - 1) * dic, sic, src_iter_h_, wc, weights_iter_h_, - n_gates * dic, gates_, n_gates * dic, 1.0); + gemm("C", "N", "N", batch, n_gates * dic, slc, 1.0, src_layer_, wc, + weights_layer_, n_gates * dic, 0.0, gates_, n_gates * dic); + gemm("C", "N", "N", batch, (n_gates - 1) * dic, sic, 1.0, src_iter_h_, + wc, weights_iter_h_, n_gates * dic, 1.0, gates_, n_gates * dic); for (int i = 0; i < batch; i++) for (int j = 0; j < n_gates - 1; j++) for (int k = 0; k < dic; k++) { @@ -109,8 +110,9 @@ void gru_fwd(int sic, int slc, int dic, int wc, int batch, int n_gates, h_dst(i, k) = src_iter_h(i, k) * gates(i, 1, k); } - gemm("N", "N", batch, dic, sic, dst_iter_h_, wc, &(weights_iter_h(0, 2, 0)), - n_gates * dic, &(gates(0, 2, 0)), n_gates * dic, 1.0); + gemm("C", "N", "N", batch, dic, sic, 1.0, dst_iter_h_, wc, + &(weights_iter_h(0, 2, 0)), n_gates * dic, 1.0, &(gates(0, 2, 0)), + n_gates * dic); for (int i = 0; i < batch; i++) for (int k = 0; k < dic; k++) { @@ -137,11 +139,11 @@ void gru_lbr_fwd(int sic, int slc, int dic, int wc, int batch, int n_gates, AOC<float> h_dst(dst_iter_h_, batch, wc); AOC<float> tmp_ws(ws_local_, batch, n_gates, dic); - gemm("N", "N", batch, n_gates * dic, slc, src_layer_, wc, weights_layer_, - n_gates * dic, gates_, n_gates * dic, 0.0); + gemm("C", "N", "N", batch, n_gates * dic, slc, 1.0, src_layer_, wc, + weights_layer_, n_gates * dic, 0.0, gates_, n_gates * dic); - gemm("N", "N", batch, n_gates * dic, sic, src_iter_h_, wc, weights_iter_h_, - n_gates * dic, ws_local_, n_gates * dic, 0.0); + gemm("C", "N", "N", batch, n_gates * dic, sic, 1.0, src_iter_h_, wc, + weights_iter_h_, n_gates * dic, 0.0, ws_local_, n_gates * dic); for (int i = 0; i < batch; i++) for (int j = 0; j < n_gates - 1; j++) @@ -181,10 +183,10 @@ void lstm_fwd(int sic, int slc, int dic, int wc, int batch, int n_gates, const int oho = 2; const int ohc = 3; - gemm("N", "N", batch, n_gates * dic, slc, src_layer_, wc, weights_layer_, - n_gates * dic, gates_, n_gates * dic, 0.0); - gemm("N", "N", batch, n_gates * dic, sic, src_iter_h_, wc, weights_iter_h_, - n_gates * dic, gates_, n_gates * dic, 1.0); + gemm("C", "N", "N", batch, n_gates * dic, slc, 1.0, src_layer_, wc, + weights_layer_, n_gates * dic, 0.0, gates_, n_gates * dic); + gemm("C", "N", "N", batch, n_gates * dic, sic,1.0, src_iter_h_, wc, + weights_iter_h_, n_gates * dic, 1.0, gates_, n_gates * dic); // add bias for (int i = 0; i < batch; i++) @@ -237,12 +239,13 @@ void copy(int dimc, int dimr, int ld_src, int ld_dst, const float *src_, float *dst_, rnn_action_t action = action_copy) { AOC<const float> src(src_, dimc, ld_src); AOC<float> dst(dst_, dimc, ld_dst); -#pragma omp parallel for - for (int i = 0; i < dimc; i++) + + mkldnn::impl::parallel_nd(dimc, [&](int i) { for (int j = 0; j < dimr; j++) { - dst(i, j) = (action == action_sum) ? dst(i, j) + src(i, j) : - src(i, j); + dst(i, j) = action == action_sum + ? dst(i, j) + src(i, j) : src(i, j); } + }); } /* FIXME: separate copy_init ??? @@ -481,17 +484,17 @@ void rnn_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc, b_gates(b, 0, h) = activation(f, g, false) * dd; } - gemm("T", "N", sic, n_gates * dic, batch, src_iter_, wc, b_gates_, - n_gates * dic, diff_weights_iter_h_, n_gates * dic, 1.0); - gemm("T", "N", slc, n_gates * dic, batch, src_layer_, wc, b_gates_, - n_gates * dic, diff_weights_layer_, n_gates * dic, 1.0); + gemm("C", "T", "N", sic, n_gates * dic, batch, 1.0, src_iter_, wc, b_gates_, + n_gates * dic, 1.0, diff_weights_iter_h_, n_gates * dic); + gemm("C", "T", "N", slc, n_gates * dic, batch, 1.0, src_layer_, wc, b_gates_, + n_gates * dic, 1.0, diff_weights_layer_, n_gates * dic); for (int b = 0; b < batch; ++b) copy(n_gates, dic, dic, dic, &b_gates(b, 0, 0), diff_bias_, action_sum); - gemm("N", "T", batch, slc, n_gates * dic, b_gates_, n_gates * dic, - weights_layer_, n_gates * dic, diff_src_layer_, wc, 0.0); - gemm("N", "T", batch, sic, n_gates * dic, b_gates_, n_gates * dic, - weights_iter_h_, n_gates * dic, diff_src_iter_, wc, 0.0); + gemm("C", "N", "T", batch, slc, n_gates * dic, 1.0, b_gates_, n_gates * dic, + weights_layer_, n_gates * dic, 0.0, diff_src_layer_, wc); + gemm("C", "N", "T", batch, sic, n_gates * dic, 1.0, b_gates_, n_gates * dic, + weights_iter_h_, n_gates * dic, 0.0, diff_src_iter_, wc); } void lstm_bwd(alg_t alg, int sic, int slc, int dic, int wc, int batch, @@ -550,15 +553,15 @@ void lstm_bwd(alg_t alg, int sic, int slc, int dic, int wc, int batch, b_gates(ib, ohc, ih) = dtanhf(hc) * dhc; } - gemm("T", "N", sic, n_gates * dic, batch, src_iter_h_, wc, b_gates_, - n_gates * dic, diff_weights_iter_h_, n_gates * dic, 1.0); - gemm("T", "N", slc, n_gates * dic, batch, src_layer_, wc, b_gates_, - n_gates * dic, diff_weights_layer_, n_gates * dic, 1.0); + gemm("C", "T", "N", sic, n_gates * dic, batch, 1.0, src_iter_h_, wc, b_gates_, + n_gates * dic, 1.0, diff_weights_iter_h_, n_gates * dic); + gemm("C", "T", "N", slc, n_gates * dic, batch, 1.0, src_layer_, wc, b_gates_, + n_gates * dic, 1.0, diff_weights_layer_, n_gates * dic); - gemm("N", "T", batch, sic, n_gates * dic, b_gates_, n_gates * dic, - weights_iter_h_, n_gates * dic, diff_src_iter_h_, wc, 0.0); - gemm("N", "T", batch, slc, n_gates * dic, b_gates_, n_gates * dic, - weights_layer_, n_gates * dic, diff_src_layer_, wc, 0.0); + gemm("C", "N", "T", batch, sic, n_gates * dic, 1.0, b_gates_, n_gates * dic, + weights_iter_h_, n_gates * dic, 0.0, diff_src_iter_h_, wc); + gemm("C", "N", "T", batch, slc, n_gates * dic, 1.0, b_gates_, n_gates * dic, + weights_layer_, n_gates * dic, 0.0, diff_src_layer_, wc); for (int i = 0; i < batch; i++) for (int j = 0; j < n_gates; j++) @@ -611,8 +614,8 @@ void gru_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc, b_gates(ib, ohc, ih) = dtanhf(c) * dc; diff_src_iter(ib, ih) = dh * u; } - gemm("N", "T", batch, slc, dic, &(b_gates(0, 2, 0)), n_gates * dic, - &(weights_layer(0, 2, 0)), n_gates * dic, dhr_, wc, 0.0); + gemm("C", "N", "T", batch, slc, dic, 1.0, &(b_gates(0, 2, 0)), n_gates * dic, + &(weights_layer(0, 2, 0)), n_gates * dic, 0.0, dhr_, wc); for (int ib = 0; ib < batch; ib++) for (int ih = 0; ih < dic; ih++) { @@ -626,19 +629,20 @@ void gru_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc, // dWx += xdu^ | xdr^ | xdc^ // dWh += hdu^ | ddr^ | (h * r)dc^ - gemm("T", "N", sic, (n_gates - 1) * dic, batch, src_iter_, wc, b_gates_, - n_gates * dic, diff_weights_iter_h_, n_gates * dic, 1.0); - gemm("T", "N", sic, dic, batch, hr_, wc, &(b_gates(0, 2, 0)), - n_gates * dic, &(diff_weights_iter_h(0, 2, 0)), n_gates * dic, 1.0); - gemm("T", "N", slc, n_gates * dic, batch, src_layer_, wc, b_gates_, - n_gates * dic, diff_weights_layer_, n_gates * dic, 1.0); + gemm("C", "T", "N", sic, (n_gates - 1) * dic, batch, 1.0, src_iter_, wc, + b_gates_, n_gates * dic, 1.0, diff_weights_iter_h_, n_gates * dic); + gemm("C", "T", "N", sic, dic, batch, 1.0, hr_, wc, &(b_gates(0, 2, 0)), + n_gates * dic, 1.0, &(diff_weights_iter_h(0, 2, 0)), n_gates * dic); + gemm("C", "T", "N", slc, n_gates * dic, batch, 1.0, src_layer_, wc, + b_gates_, n_gates * dic, 1.0, diff_weights_layer_, n_gates * dic); // dx_next = Wxudu^ + Wxrdr^ + Wxcdc^ // dh_next = dh * u + Whudu^ + Whzdz^ + r * Whcdc^ - gemm("N", "T", batch, sic, (n_gates - 1)* dic, b_gates_, n_gates * dic, - weights_iter_h_, n_gates * dic, diff_src_iter_, wc, 1.0); - gemm("N", "T", batch, slc, n_gates * dic, b_gates_, n_gates * dic, - weights_layer_, n_gates * dic, diff_src_layer_, wc, 0.0); + gemm("C", "N", "T", batch, sic, (n_gates - 1)* dic, 1.0, b_gates_, + n_gates * dic, weights_iter_h_, n_gates * dic, 1.0, diff_src_iter_, + wc); + gemm("C", "N", "T", batch, slc, n_gates * dic, 1.0, b_gates_, n_gates * dic, + weights_layer_, n_gates * dic, 0.0, diff_src_layer_, wc); for (int i = 0; i < batch; i++) for (int j = 0; j < n_gates; j++) @@ -677,8 +681,8 @@ void gru_lbr_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc, for (int ih = 0; ih < dic; ih++) Wh_b(ib, ih) = bias(3, ih); - gemm("N", "N", batch, dic, sic, src_iter_, wc, &weights_iter_h(0, 2, 0), - dic, Wh_b_, dic, 1.0); + gemm("C", "N", "N", batch, dic, sic, 1.0, src_iter_, wc, + &weights_iter_h(0, 2, 0), dic, 1.0, Wh_b_, dic); // dc = (1 - u) * dh; dc^ = dtanhf(c) * dc; @@ -708,15 +712,15 @@ void gru_lbr_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc, diff_src_iter(ib, ih) = dh * u; } - gemm("T", "N", sic, n_gates * dic, batch, src_iter_, wc, b_gates_r_, - n_gates * dic, diff_weights_iter_h_, n_gates * dic, 1.0); - gemm("T", "N", slc, n_gates * dic, batch, src_layer_, wc, b_gates_, - n_gates * dic, diff_weights_layer_, n_gates * dic, 1.0); + gemm("C", "T", "N", sic, n_gates * dic, batch, 1.0, src_iter_, wc, b_gates_r_, + n_gates * dic, 1.0, diff_weights_iter_h_, n_gates * dic); + gemm("C", "T", "N", slc, n_gates * dic, batch, 1.0, src_layer_, wc, b_gates_, + n_gates * dic, 1.0, diff_weights_layer_, n_gates * dic); - gemm("N", "T", batch, slc, n_gates * dic, b_gates_, n_gates * dic, - weights_layer_, n_gates * dic, diff_src_layer_, wc, 0.0); - gemm("N", "T", batch, sic, n_gates * dic, b_gates_r_, n_gates * dic, - weights_iter_h_, n_gates * dic, diff_src_iter_, wc, 1.0); + gemm("C", "N", "T", batch, slc, n_gates * dic, 1.0, b_gates_, n_gates * dic, + weights_layer_, n_gates * dic, 0.0, diff_src_layer_, wc); + gemm("C", "N", "T", batch, sic, n_gates * dic, 1.0, b_gates_r_, n_gates * dic, + weights_iter_h_, n_gates * dic, 1.0, diff_src_iter_, wc); for (int i = 0; i < batch; i++) for (int j = 0; j < n_gates; j++) |