summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/tests/benchdnn/rnn/ref_rnn.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/tests/benchdnn/rnn/ref_rnn.cpp')
-rw-r--r--inference-engine/thirdparty/mkl-dnn/tests/benchdnn/rnn/ref_rnn.cpp130
1 files changed, 67 insertions, 63 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/tests/benchdnn/rnn/ref_rnn.cpp b/inference-engine/thirdparty/mkl-dnn/tests/benchdnn/rnn/ref_rnn.cpp
index c4cb25094..788cfe160 100644
--- a/inference-engine/thirdparty/mkl-dnn/tests/benchdnn/rnn/ref_rnn.cpp
+++ b/inference-engine/thirdparty/mkl-dnn/tests/benchdnn/rnn/ref_rnn.cpp
@@ -16,6 +16,8 @@
#include <stdlib.h>
+#include "src/common/mkldnn_thread.hpp"
+
#include "rnn/rnn.hpp"
#include "rnn/rnn_aux.hpp"
@@ -32,8 +34,7 @@ void lstm_activation(int dic, int n_gates, int batch,
// float a[batch][n_gates * wc]
float *a) {
AOC<float> pa(a, batch, n_gates, dic);
-#pragma omp parallel for
- for (int ib = 0; ib < batch; ib++) {
+ mkldnn::impl::parallel_nd(batch, [&](int ib) {
for (int ig = 0; ig < 3; ig++) {
for (int ih = 0; ih < dic; ih++) {
pa(ib, ig, ih) = logistic(pa(ib, ig, ih));
@@ -47,7 +48,7 @@ void lstm_activation(int dic, int n_gates, int batch,
print(80, "activation 2 a[%d][%d][%d] = %.7f\n", ib, ig, j,
pa(ib, ig, j));
}
- }
+ });
}
float activation(activation_t f, float x, bool is_fwd = true) {
@@ -69,10 +70,10 @@ void rnn_fwd(activation_t f, int sic, int slc, int dic, int wc, int batch,
AOC<const float> bias(bias_, n_gates, dic);
AOC<float> gates(gates_, batch, n_gates, dic);
- gemm("N", "N", batch, n_gates * dic, slc, src_layer_, wc, weights_layer_,
- n_gates * dic, gates_, n_gates * dic, 0.0);
- gemm("N", "N", batch, n_gates * dic, sic, src_iter_h_, wc, weights_iter_h_,
- n_gates * dic, gates_, n_gates * dic, 1.0);
+ gemm("C", "N", "N", batch, n_gates * dic, slc, 1.0, src_layer_, wc,
+ weights_layer_, n_gates * dic, 0.0, gates_, n_gates * dic);
+ gemm("C", "N", "N", batch, n_gates * dic, sic, 1.0, src_iter_h_, wc,
+ weights_iter_h_, n_gates * dic, 1.0, gates_, n_gates * dic);
for (int i = 0; i < batch; i++)
for (int j = 0; j < n_gates; j++)
@@ -94,10 +95,10 @@ void gru_fwd(int sic, int slc, int dic, int wc, int batch, int n_gates,
AOC<float> gates(gates_, batch, n_gates, dic);
AOC<float> h_dst(dst_iter_h_, batch, wc);
- gemm("N", "N", batch, n_gates * dic, slc, src_layer_, wc, weights_layer_,
- n_gates * dic, gates_, n_gates * dic, 0.0);
- gemm("N", "N", batch, (n_gates - 1) * dic, sic, src_iter_h_, wc, weights_iter_h_,
- n_gates * dic, gates_, n_gates * dic, 1.0);
+ gemm("C", "N", "N", batch, n_gates * dic, slc, 1.0, src_layer_, wc,
+ weights_layer_, n_gates * dic, 0.0, gates_, n_gates * dic);
+ gemm("C", "N", "N", batch, (n_gates - 1) * dic, sic, 1.0, src_iter_h_,
+ wc, weights_iter_h_, n_gates * dic, 1.0, gates_, n_gates * dic);
for (int i = 0; i < batch; i++)
for (int j = 0; j < n_gates - 1; j++)
for (int k = 0; k < dic; k++) {
@@ -109,8 +110,9 @@ void gru_fwd(int sic, int slc, int dic, int wc, int batch, int n_gates,
h_dst(i, k) = src_iter_h(i, k) * gates(i, 1, k);
}
- gemm("N", "N", batch, dic, sic, dst_iter_h_, wc, &(weights_iter_h(0, 2, 0)),
- n_gates * dic, &(gates(0, 2, 0)), n_gates * dic, 1.0);
+ gemm("C", "N", "N", batch, dic, sic, 1.0, dst_iter_h_, wc,
+ &(weights_iter_h(0, 2, 0)), n_gates * dic, 1.0, &(gates(0, 2, 0)),
+ n_gates * dic);
for (int i = 0; i < batch; i++)
for (int k = 0; k < dic; k++) {
@@ -137,11 +139,11 @@ void gru_lbr_fwd(int sic, int slc, int dic, int wc, int batch, int n_gates,
AOC<float> h_dst(dst_iter_h_, batch, wc);
AOC<float> tmp_ws(ws_local_, batch, n_gates, dic);
- gemm("N", "N", batch, n_gates * dic, slc, src_layer_, wc, weights_layer_,
- n_gates * dic, gates_, n_gates * dic, 0.0);
+ gemm("C", "N", "N", batch, n_gates * dic, slc, 1.0, src_layer_, wc,
+ weights_layer_, n_gates * dic, 0.0, gates_, n_gates * dic);
- gemm("N", "N", batch, n_gates * dic, sic, src_iter_h_, wc, weights_iter_h_,
- n_gates * dic, ws_local_, n_gates * dic, 0.0);
+ gemm("C", "N", "N", batch, n_gates * dic, sic, 1.0, src_iter_h_, wc,
+ weights_iter_h_, n_gates * dic, 0.0, ws_local_, n_gates * dic);
for (int i = 0; i < batch; i++)
for (int j = 0; j < n_gates - 1; j++)
@@ -181,10 +183,10 @@ void lstm_fwd(int sic, int slc, int dic, int wc, int batch, int n_gates,
const int oho = 2;
const int ohc = 3;
- gemm("N", "N", batch, n_gates * dic, slc, src_layer_, wc, weights_layer_,
- n_gates * dic, gates_, n_gates * dic, 0.0);
- gemm("N", "N", batch, n_gates * dic, sic, src_iter_h_, wc, weights_iter_h_,
- n_gates * dic, gates_, n_gates * dic, 1.0);
+ gemm("C", "N", "N", batch, n_gates * dic, slc, 1.0, src_layer_, wc,
+ weights_layer_, n_gates * dic, 0.0, gates_, n_gates * dic);
+ gemm("C", "N", "N", batch, n_gates * dic, sic,1.0, src_iter_h_, wc,
+ weights_iter_h_, n_gates * dic, 1.0, gates_, n_gates * dic);
// add bias
for (int i = 0; i < batch; i++)
@@ -237,12 +239,13 @@ void copy(int dimc, int dimr, int ld_src, int ld_dst, const float *src_,
float *dst_, rnn_action_t action = action_copy) {
AOC<const float> src(src_, dimc, ld_src);
AOC<float> dst(dst_, dimc, ld_dst);
-#pragma omp parallel for
- for (int i = 0; i < dimc; i++)
+
+ mkldnn::impl::parallel_nd(dimc, [&](int i) {
for (int j = 0; j < dimr; j++) {
- dst(i, j) = (action == action_sum) ? dst(i, j) + src(i, j) :
- src(i, j);
+ dst(i, j) = action == action_sum
+ ? dst(i, j) + src(i, j) : src(i, j);
}
+ });
}
/* FIXME: separate copy_init ???
@@ -481,17 +484,17 @@ void rnn_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
b_gates(b, 0, h) = activation(f, g, false) * dd;
}
- gemm("T", "N", sic, n_gates * dic, batch, src_iter_, wc, b_gates_,
- n_gates * dic, diff_weights_iter_h_, n_gates * dic, 1.0);
- gemm("T", "N", slc, n_gates * dic, batch, src_layer_, wc, b_gates_,
- n_gates * dic, diff_weights_layer_, n_gates * dic, 1.0);
+ gemm("C", "T", "N", sic, n_gates * dic, batch, 1.0, src_iter_, wc, b_gates_,
+ n_gates * dic, 1.0, diff_weights_iter_h_, n_gates * dic);
+ gemm("C", "T", "N", slc, n_gates * dic, batch, 1.0, src_layer_, wc, b_gates_,
+ n_gates * dic, 1.0, diff_weights_layer_, n_gates * dic);
for (int b = 0; b < batch; ++b)
copy(n_gates, dic, dic, dic, &b_gates(b, 0, 0), diff_bias_, action_sum);
- gemm("N", "T", batch, slc, n_gates * dic, b_gates_, n_gates * dic,
- weights_layer_, n_gates * dic, diff_src_layer_, wc, 0.0);
- gemm("N", "T", batch, sic, n_gates * dic, b_gates_, n_gates * dic,
- weights_iter_h_, n_gates * dic, diff_src_iter_, wc, 0.0);
+ gemm("C", "N", "T", batch, slc, n_gates * dic, 1.0, b_gates_, n_gates * dic,
+ weights_layer_, n_gates * dic, 0.0, diff_src_layer_, wc);
+ gemm("C", "N", "T", batch, sic, n_gates * dic, 1.0, b_gates_, n_gates * dic,
+ weights_iter_h_, n_gates * dic, 0.0, diff_src_iter_, wc);
}
void lstm_bwd(alg_t alg, int sic, int slc, int dic, int wc, int batch,
@@ -550,15 +553,15 @@ void lstm_bwd(alg_t alg, int sic, int slc, int dic, int wc, int batch,
b_gates(ib, ohc, ih) = dtanhf(hc) * dhc;
}
- gemm("T", "N", sic, n_gates * dic, batch, src_iter_h_, wc, b_gates_,
- n_gates * dic, diff_weights_iter_h_, n_gates * dic, 1.0);
- gemm("T", "N", slc, n_gates * dic, batch, src_layer_, wc, b_gates_,
- n_gates * dic, diff_weights_layer_, n_gates * dic, 1.0);
+ gemm("C", "T", "N", sic, n_gates * dic, batch, 1.0, src_iter_h_, wc, b_gates_,
+ n_gates * dic, 1.0, diff_weights_iter_h_, n_gates * dic);
+ gemm("C", "T", "N", slc, n_gates * dic, batch, 1.0, src_layer_, wc, b_gates_,
+ n_gates * dic, 1.0, diff_weights_layer_, n_gates * dic);
- gemm("N", "T", batch, sic, n_gates * dic, b_gates_, n_gates * dic,
- weights_iter_h_, n_gates * dic, diff_src_iter_h_, wc, 0.0);
- gemm("N", "T", batch, slc, n_gates * dic, b_gates_, n_gates * dic,
- weights_layer_, n_gates * dic, diff_src_layer_, wc, 0.0);
+ gemm("C", "N", "T", batch, sic, n_gates * dic, 1.0, b_gates_, n_gates * dic,
+ weights_iter_h_, n_gates * dic, 0.0, diff_src_iter_h_, wc);
+ gemm("C", "N", "T", batch, slc, n_gates * dic, 1.0, b_gates_, n_gates * dic,
+ weights_layer_, n_gates * dic, 0.0, diff_src_layer_, wc);
for (int i = 0; i < batch; i++)
for (int j = 0; j < n_gates; j++)
@@ -611,8 +614,8 @@ void gru_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
b_gates(ib, ohc, ih) = dtanhf(c) * dc;
diff_src_iter(ib, ih) = dh * u;
}
- gemm("N", "T", batch, slc, dic, &(b_gates(0, 2, 0)), n_gates * dic,
- &(weights_layer(0, 2, 0)), n_gates * dic, dhr_, wc, 0.0);
+ gemm("C", "N", "T", batch, slc, dic, 1.0, &(b_gates(0, 2, 0)), n_gates * dic,
+ &(weights_layer(0, 2, 0)), n_gates * dic, 0.0, dhr_, wc);
for (int ib = 0; ib < batch; ib++)
for (int ih = 0; ih < dic; ih++) {
@@ -626,19 +629,20 @@ void gru_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
// dWx += xdu^ | xdr^ | xdc^
// dWh += hdu^ | ddr^ | (h * r)dc^
- gemm("T", "N", sic, (n_gates - 1) * dic, batch, src_iter_, wc, b_gates_,
- n_gates * dic, diff_weights_iter_h_, n_gates * dic, 1.0);
- gemm("T", "N", sic, dic, batch, hr_, wc, &(b_gates(0, 2, 0)),
- n_gates * dic, &(diff_weights_iter_h(0, 2, 0)), n_gates * dic, 1.0);
- gemm("T", "N", slc, n_gates * dic, batch, src_layer_, wc, b_gates_,
- n_gates * dic, diff_weights_layer_, n_gates * dic, 1.0);
+ gemm("C", "T", "N", sic, (n_gates - 1) * dic, batch, 1.0, src_iter_, wc,
+ b_gates_, n_gates * dic, 1.0, diff_weights_iter_h_, n_gates * dic);
+ gemm("C", "T", "N", sic, dic, batch, 1.0, hr_, wc, &(b_gates(0, 2, 0)),
+ n_gates * dic, 1.0, &(diff_weights_iter_h(0, 2, 0)), n_gates * dic);
+ gemm("C", "T", "N", slc, n_gates * dic, batch, 1.0, src_layer_, wc,
+ b_gates_, n_gates * dic, 1.0, diff_weights_layer_, n_gates * dic);
// dx_next = Wxudu^ + Wxrdr^ + Wxcdc^
// dh_next = dh * u + Whudu^ + Whzdz^ + r * Whcdc^
- gemm("N", "T", batch, sic, (n_gates - 1)* dic, b_gates_, n_gates * dic,
- weights_iter_h_, n_gates * dic, diff_src_iter_, wc, 1.0);
- gemm("N", "T", batch, slc, n_gates * dic, b_gates_, n_gates * dic,
- weights_layer_, n_gates * dic, diff_src_layer_, wc, 0.0);
+ gemm("C", "N", "T", batch, sic, (n_gates - 1)* dic, 1.0, b_gates_,
+ n_gates * dic, weights_iter_h_, n_gates * dic, 1.0, diff_src_iter_,
+ wc);
+ gemm("C", "N", "T", batch, slc, n_gates * dic, 1.0, b_gates_, n_gates * dic,
+ weights_layer_, n_gates * dic, 0.0, diff_src_layer_, wc);
for (int i = 0; i < batch; i++)
for (int j = 0; j < n_gates; j++)
@@ -677,8 +681,8 @@ void gru_lbr_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
for (int ih = 0; ih < dic; ih++)
Wh_b(ib, ih) = bias(3, ih);
- gemm("N", "N", batch, dic, sic, src_iter_, wc, &weights_iter_h(0, 2, 0),
- dic, Wh_b_, dic, 1.0);
+ gemm("C", "N", "N", batch, dic, sic, 1.0, src_iter_, wc,
+ &weights_iter_h(0, 2, 0), dic, 1.0, Wh_b_, dic);
// dc = (1 - u) * dh; dc^ = dtanhf(c) * dc;
@@ -708,15 +712,15 @@ void gru_lbr_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
diff_src_iter(ib, ih) = dh * u;
}
- gemm("T", "N", sic, n_gates * dic, batch, src_iter_, wc, b_gates_r_,
- n_gates * dic, diff_weights_iter_h_, n_gates * dic, 1.0);
- gemm("T", "N", slc, n_gates * dic, batch, src_layer_, wc, b_gates_,
- n_gates * dic, diff_weights_layer_, n_gates * dic, 1.0);
+ gemm("C", "T", "N", sic, n_gates * dic, batch, 1.0, src_iter_, wc, b_gates_r_,
+ n_gates * dic, 1.0, diff_weights_iter_h_, n_gates * dic);
+ gemm("C", "T", "N", slc, n_gates * dic, batch, 1.0, src_layer_, wc, b_gates_,
+ n_gates * dic, 1.0, diff_weights_layer_, n_gates * dic);
- gemm("N", "T", batch, slc, n_gates * dic, b_gates_, n_gates * dic,
- weights_layer_, n_gates * dic, diff_src_layer_, wc, 0.0);
- gemm("N", "T", batch, sic, n_gates * dic, b_gates_r_, n_gates * dic,
- weights_iter_h_, n_gates * dic, diff_src_iter_, wc, 1.0);
+ gemm("C", "N", "T", batch, slc, n_gates * dic, 1.0, b_gates_, n_gates * dic,
+ weights_layer_, n_gates * dic, 0.0, diff_src_layer_, wc);
+ gemm("C", "N", "T", batch, sic, n_gates * dic, 1.0, b_gates_r_, n_gates * dic,
+ weights_iter_h_, n_gates * dic, 1.0, diff_src_iter_, wc);
for (int i = 0; i < batch; i++)
for (int j = 0; j < n_gates; j++)