diff options
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/tests/gtests/test_softmax_backward.cpp')
-rw-r--r-- | inference-engine/thirdparty/mkl-dnn/tests/gtests/test_softmax_backward.cpp | 42 |
1 files changed, 22 insertions, 20 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/tests/gtests/test_softmax_backward.cpp b/inference-engine/thirdparty/mkl-dnn/tests/gtests/test_softmax_backward.cpp index 98125df3e..06d0a1da3 100644 --- a/inference-engine/thirdparty/mkl-dnn/tests/gtests/test_softmax_backward.cpp +++ b/inference-engine/thirdparty/mkl-dnn/tests/gtests/test_softmax_backward.cpp @@ -51,10 +51,7 @@ void check_softmax_bwd(memory& dst, memory& diff_dst, memory &diff_src, int axis int IN = 1; for (int d = axis + 1; d < ndims; ++d) IN *= diff_dst_pd.data.dims[d]; -# pragma omp parallel for collapse(2) - for (int ou = 0; ou < OU; ++ou) - for (int in = 0; in < IN; ++in) - { + mkldnn::impl::parallel_nd(OU, IN, [&](int ou, int in) { const int idx_start = ou * C * IN + in; float sbr = 0.0; @@ -70,7 +67,7 @@ void check_softmax_bwd(memory& dst, memory& diff_dst, memory &diff_src, int axis diff_src_ref_ptr[off_dd] = dst_ptr[off_d] * (diff_dst_ptr[off_dd] - sbr); } - } + }); // Actual check for (int i=0; i < OU*C*IN; ++i) @@ -121,14 +118,6 @@ protected: auto diff_src = memory(diff_mem_prim_desc, src_diff.get()); auto diff_dst = memory(diff_mem_prim_desc, dst_diff.get()); - // Fill the softmax forward input - fill_data<data_t>(data_mem_prim_desc.get_size(), - (data_t *)src.get_data_handle(), data_t(0), data_t(1)); - - // Fill the softmax backward diffs eg. data diff that comes from upper primitive/layer - fill_data<data_t>(diff_mem_prim_desc.get_size(), - (data_t *)diff_dst.get_data_handle(), data_t(0), data_t(1)); - // Create softmax backward descriptor // before forward so its exceptions can be tested auto softmax_desc @@ -146,13 +135,23 @@ protected: = softmax_backward::primitive_desc(softmax_desc, eng, softmax_fwd_pdesc); auto softmax_bwd = softmax_backward(softmax_prim_desc, dst, diff_dst, diff_src); - std::vector<primitive> pipeline; - pipeline.push_back(softmax); - pipeline.push_back(softmax_bwd); - auto s = stream(stream::kind::lazy); - s.submit(pipeline).wait(); + auto test_with_given_fill = [&](data_t mean, data_t var) { + // Fill the softmax forward input + fill_data<data_t>(data_mem_prim_desc.get_size(), + (data_t *)src.get_data_handle(), mean, var); + + // Fill the softmax backward diffs + // eg. data diff that comes from upper primitive/layer + fill_data<data_t>(diff_mem_prim_desc.get_size(), + (data_t *)diff_dst.get_data_handle(), data_t(0), data_t(1)); + + stream(stream::kind::lazy).submit({softmax, softmax_bwd}).wait(); + check_softmax_bwd<data_t>(dst, diff_dst, diff_src, p.axis); + }; - check_softmax_bwd<data_t>(dst, diff_dst, diff_src, p.axis); + test_with_given_fill(-200, 1); + test_with_given_fill( 0, 1); + test_with_given_fill( 200, 1); } }; @@ -162,7 +161,10 @@ using softmax_bwd_test_params_float = softmax_test_params<float>; TEST_P(softmax_backward_test_float, TestsSoftmax) { } INSTANTIATE_TEST_CASE_P(TestSoftmaxBackward, softmax_backward_test_float, ::testing::Values( - softmax_bwd_test_params_float{ engine::kind::cpu, memory::format::nchw, memory::format::nchw, {2, 0, 128, 256}, 0, true, mkldnn_invalid_arguments}, + softmax_bwd_test_params_float{ engine::kind::cpu, memory::format::nchw, memory::format::nchw, {2, -2, 128, 256}, 0, true, mkldnn_invalid_arguments}, + softmax_bwd_test_params_float{ engine::kind::cpu, memory::format::nchw, memory::format::nchw, {2, 19, 128, 256}, 5, true, mkldnn_invalid_arguments}, + softmax_bwd_test_params_float{ engine::kind::cpu, memory::format::nchw, memory::format::nchw, {2, 0, 5, 5}, 0}, + softmax_bwd_test_params_float{ engine::kind::cpu, memory::format::nchw, memory::format::nchw, {2, 0, 5, 5}, 1}, softmax_bwd_test_params_float{ engine::kind::cpu, memory::format::nchw, memory::format::nchw, {2, 19, 128, 256}, 0}, softmax_bwd_test_params_float{ engine::kind::cpu, memory::format::nchw, memory::format::nchw, {2, 19, 128, 256}, 2}, softmax_bwd_test_params_float{ engine::kind::cpu, memory::format::nchw, memory::format::nchw, {2, 19, 128, 256}, 3}, |