diff options
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/tests/gtests/test_batch_normalization.cpp')
-rw-r--r-- | inference-engine/thirdparty/mkl-dnn/tests/gtests/test_batch_normalization.cpp | 95 |
1 files changed, 73 insertions, 22 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/tests/gtests/test_batch_normalization.cpp b/inference-engine/thirdparty/mkl-dnn/tests/gtests/test_batch_normalization.cpp index ee052e5d9..48d5bfcc1 100644 --- a/inference-engine/thirdparty/mkl-dnn/tests/gtests/test_batch_normalization.cpp +++ b/inference-engine/thirdparty/mkl-dnn/tests/gtests/test_batch_normalization.cpp @@ -38,6 +38,8 @@ struct test_bnrm_params_t { test_bnrm_sizes_t sizes; float eps; int ndims; + bool expect_to_fail; + mkldnn_status_t expected_status; }; template <typename data_t> @@ -45,6 +47,9 @@ void check_bnrm_fwd(const test_bnrm_params_t &p, const memory &src, const memory &mean, const memory &variance, const memory &weights, const memory &dst, unsigned flags, prop_kind pk) { + const test_bnrm_sizes_t &bp = p.sizes; + if (bp.mb * bp.c * bp.d * bp.h * bp.w == 0) return; + const bool use_weights = flags & use_scale_shift; const bool calculate_stats = !(flags & use_global_stats); const bool is_training = (pk == prop_kind::forward_training); @@ -60,14 +65,12 @@ void check_bnrm_fwd(const test_bnrm_params_t &p, const memory::desc src_d = src.get_primitive_desc().desc(); const memory::desc dst_d = dst.get_primitive_desc().desc(); - test_bnrm_sizes_t bp = p.sizes; data_t eps = static_cast<data_t>(1.e-4 * bp.mb * bp.d * bp.h * bp.w); size_t padded_c = src.get_primitive_desc().desc().data.layout_desc .blocking.padding_dims[1]; -#pragma omp parallel for - for (int c = 0; c < bp.c; c++) { + mkldnn::impl::parallel_nd(bp.c, [&](int c) { data_t ref_mean = calculate_stats ? data_t(0) : mean_data[c]; data_t ref_variance = calculate_stats ? data_t(0) : variance_data[c]; if (calculate_stats) { @@ -138,7 +141,7 @@ void check_bnrm_fwd(const test_bnrm_params_t &p, EXPECT_NEAR((out - ref_dst) / norm_max, 0., eps); } } - } + }); } template <typename data_t> @@ -147,6 +150,7 @@ void check_bnrm_bwd(const test_bnrm_params_t &p, const memory &variance, const memory &weights, const memory &diff_src, const memory &diff_weights, unsigned flags, prop_kind pk) { + const test_bnrm_sizes_t &bp = p.sizes; const bool use_weights = flags & use_scale_shift; const bool calculate_diff_stats = !(flags & omit_stats); @@ -165,13 +169,22 @@ void check_bnrm_bwd(const test_bnrm_params_t &p, const memory::desc diff_src_d = diff_src.get_primitive_desc().desc(); const memory::desc diff_weights_d = diff_weights.get_primitive_desc().desc(); - test_bnrm_sizes_t bp = p.sizes; + if (bp.mb * bp.c * bp.d * bp.h * bp.w == 0) { + if (pk == backward) { + for (int c = 0; c < bp.c; ++c) { + auto dg = diff_weights_data[map_index(diff_weights_d, c)]; + auto db = diff_weights_data[map_index(diff_weights_d, bp.c + c)]; + EXPECT_NEAR(dg, 0., 1e-7); + EXPECT_NEAR(db, 0., 1e-7); + } + } + return; + } const data_t eps = static_cast<data_t>(1.e-4 * bp.mb * bp.d * bp.h * bp.w); size_t padded_c = src.get_primitive_desc().desc().data.layout_desc.blocking.padding_dims[1]; -#pragma omp parallel for - for (int c = 0; c < bp.c; c++) { + mkldnn::impl::parallel_nd(bp.c, [&](int c) { data_t ref_diff_gamma = data_t(0); data_t ref_diff_beta = data_t(0); @@ -223,7 +236,7 @@ void check_bnrm_bwd(const test_bnrm_params_t &p, if (norm_max < eps) norm_max = data_t(1); EXPECT_NEAR((out_diff_src - ref_diff_src) / norm_max, 0., eps); } - } + }); } template <typename data_t> @@ -249,6 +262,12 @@ private: protected: virtual void SetUp() { p = ::testing::TestWithParam<decltype(p)>::GetParam(); + catch_expected_failures([=](){Test();}, p.expect_to_fail, + p.expected_status); + } + + void Test() { + p = ::testing::TestWithParam<decltype(p)>::GetParam(); ASSERT_TRUE(p.engine_kind == engine::kind::cpu); eng.reset(new engine(p.engine_kind, 0)); @@ -448,26 +467,40 @@ TEST_P(bnrm_test_float, TestsBnrm) #define ENGINE engine::kind::cpu #define EPS 1e-5f -#define PARAMS(data, diff, mb, c, h, w, eps) \ - test_bnrm_params_t { ENGINE, \ - EXPAND_FORMATS(data, diff), EXPAND_SIZES_2D(mb, c, h, w), eps, 4 } +#define PARAMS(data, diff, mb, c, h, w, eps, ef, st) \ + test_bnrm_params_t { ENGINE, EXPAND_FORMATS(data, diff), \ + EXPAND_SIZES_2D(mb, c, h, w), eps, 4, ef, st } -#define PARAMS_3D(data, diff, mb, c, d, h, w, eps) \ - test_bnrm_params_t { ENGINE, \ - EXPAND_FORMATS(data, diff), EXPAND_SIZES_3D(mb, c, d, h, w), eps, 5 } +#define PARAMS_3D(data, diff, mb, c, d, h, w, eps, ef, st) \ + test_bnrm_params_t { ENGINE, EXPAND_FORMATS(data, diff), \ + EXPAND_SIZES_3D(mb, c, d, h, w), eps, 5, ef, st } -#define PARAMS_N_3D(...) EXPAND_ARGS(PARAMS_3D(ncdhw, ncdhw, __VA_ARGS__)) -#define PARAMS_B16_3D(...) EXPAND_ARGS(PARAMS_3D(nCdhw16c, nCdhw16c, __VA_ARGS__)) -#define PARAMS_N(...) EXPAND_ARGS(PARAMS(nchw, nchw, __VA_ARGS__)) -#define PARAMS_NHWC(...) EXPAND_ARGS(PARAMS(nhwc, nhwc, __VA_ARGS__)) -#define PARAMS_NC(...) EXPAND_ARGS(PARAMS(nc, nc, __VA_ARGS__)) -#define PARAMS_B8(...) EXPAND_ARGS(PARAMS(nChw8c, nChw8c, __VA_ARGS__)) -#define PARAMS_B16(...) EXPAND_ARGS(PARAMS(nChw16c, nChw16c, __VA_ARGS__)) +#define PARAMS_N_3D(...) EXPAND_ARGS(PARAMS_3D(ncdhw, ncdhw, __VA_ARGS__, false, mkldnn_success)) +#define PARAMS_B8_3D(...) EXPAND_ARGS(PARAMS_3D(nCdhw8c, nCdhw8c, __VA_ARGS__, false, mkldnn_success)) +#define PARAMS_B16_3D(...) EXPAND_ARGS(PARAMS_3D(nCdhw16c, nCdhw16c, __VA_ARGS__, false, mkldnn_success)) +#define PARAMS_N(...) EXPAND_ARGS(PARAMS(nchw, nchw, __VA_ARGS__, false, mkldnn_success)) +#define PARAMS_NHWC(...) EXPAND_ARGS(PARAMS(nhwc, nhwc, __VA_ARGS__, false, mkldnn_success)) +#define PARAMS_NC(...) EXPAND_ARGS(PARAMS(nc, nc, __VA_ARGS__, false, mkldnn_success)) +#define PARAMS_B8(...) EXPAND_ARGS(PARAMS(nChw8c, nChw8c, __VA_ARGS__, false, mkldnn_success)) +#define PARAMS_B16(...) EXPAND_ARGS(PARAMS(nChw16c, nChw16c, __VA_ARGS__, false, mkldnn_success)) +#define PARAMS_EF(...) EXPAND_ARGS(PARAMS(nchw, nchw, __VA_ARGS__)) #define INST_TEST_CASE(str, ...) INSTANTIATE_TEST_CASE_P( \ str, bnrm_test_float, ::testing::Values(__VA_ARGS__)) -INST_TEST_CASE(Simple_Blocked_padded, +INST_TEST_CASE(SimpleZeroDim, + PARAMS_N(0, 27, 9, 10, EPS), + PARAMS_N(1, 0, 10, 9, EPS), + PARAMS_N(4, 20, 0, 12, EPS) +); + +INST_TEST_CASE(SimpleExpectedFails, + PARAMS_EF(-1, 27, 9, 10, EPS, true, mkldnn_invalid_arguments), + PARAMS_EF(1, -12, 10, 9, EPS, true, mkldnn_invalid_arguments), + PARAMS_EF(4, 20, -12, 12, EPS, true, mkldnn_invalid_arguments) +); + +INST_TEST_CASE(Simple_nChw16c_padded, PARAMS_B16(1, 27, 9, 10, EPS), PARAMS_B16(1, 12, 10, 9, EPS), PARAMS_B16(4, 20, 12, 12, EPS), @@ -481,6 +514,13 @@ INST_TEST_CASE(Simple_nCdhw16c_padded, PARAMS_B16_3D(2, 27, 10, 8, 4, EPS) ); +INST_TEST_CASE(Simple_nChw8c_padded, + PARAMS_B8(1, 27, 9, 10, EPS), + PARAMS_B8(1, 12, 10, 9, EPS), + PARAMS_B8(4, 20, 12, 12, EPS), + PARAMS_B8(4, 7, 16, 16, EPS) +); + INST_TEST_CASE(Simple_nCdhw16c, PARAMS_B16_3D(2, 32, 4, 4, 4, EPS), @@ -493,6 +533,17 @@ INST_TEST_CASE(Simple_nCdhw16c, PARAMS_B16_3D(2, 32, 10, 8, 4, EPS) ); +INST_TEST_CASE(Simple_nCdhw8c, + PARAMS_B8_3D(2, 32, 4, 4, 4, EPS), + PARAMS_B8_3D(2, 32, 4, 4, 4, EPS), + PARAMS_B8_3D(2, 32, 8, 8, 8, EPS), + PARAMS_B8_3D(2, 32, 8, 8, 8, EPS), + PARAMS_B8_3D(2, 32, 16, 8, 20, EPS), + PARAMS_B8_3D(2, 32, 16, 8, 20, EPS), + PARAMS_B8_3D(2, 32, 10, 8, 4, EPS), + PARAMS_B8_3D(2, 32, 10, 8, 4, EPS) +); + INST_TEST_CASE(Simple_NC, PARAMS_NC(2, 8, 1, 1, EPS), PARAMS_NC(2, 10, 1, 1, EPS), |