summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/tests/gtests/test_batch_normalization.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/tests/gtests/test_batch_normalization.cpp')
-rw-r--r--inference-engine/thirdparty/mkl-dnn/tests/gtests/test_batch_normalization.cpp95
1 files changed, 73 insertions, 22 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/tests/gtests/test_batch_normalization.cpp b/inference-engine/thirdparty/mkl-dnn/tests/gtests/test_batch_normalization.cpp
index ee052e5d9..48d5bfcc1 100644
--- a/inference-engine/thirdparty/mkl-dnn/tests/gtests/test_batch_normalization.cpp
+++ b/inference-engine/thirdparty/mkl-dnn/tests/gtests/test_batch_normalization.cpp
@@ -38,6 +38,8 @@ struct test_bnrm_params_t {
test_bnrm_sizes_t sizes;
float eps;
int ndims;
+ bool expect_to_fail;
+ mkldnn_status_t expected_status;
};
template <typename data_t>
@@ -45,6 +47,9 @@ void check_bnrm_fwd(const test_bnrm_params_t &p,
const memory &src, const memory &mean, const memory &variance,
const memory &weights, const memory &dst, unsigned flags, prop_kind pk)
{
+ const test_bnrm_sizes_t &bp = p.sizes;
+ if (bp.mb * bp.c * bp.d * bp.h * bp.w == 0) return;
+
const bool use_weights = flags & use_scale_shift;
const bool calculate_stats = !(flags & use_global_stats);
const bool is_training = (pk == prop_kind::forward_training);
@@ -60,14 +65,12 @@ void check_bnrm_fwd(const test_bnrm_params_t &p,
const memory::desc src_d = src.get_primitive_desc().desc();
const memory::desc dst_d = dst.get_primitive_desc().desc();
- test_bnrm_sizes_t bp = p.sizes;
data_t eps = static_cast<data_t>(1.e-4 * bp.mb * bp.d * bp.h * bp.w);
size_t padded_c = src.get_primitive_desc().desc().data.layout_desc
.blocking.padding_dims[1];
-#pragma omp parallel for
- for (int c = 0; c < bp.c; c++) {
+ mkldnn::impl::parallel_nd(bp.c, [&](int c) {
data_t ref_mean = calculate_stats ? data_t(0) : mean_data[c];
data_t ref_variance = calculate_stats ? data_t(0) : variance_data[c];
if (calculate_stats) {
@@ -138,7 +141,7 @@ void check_bnrm_fwd(const test_bnrm_params_t &p,
EXPECT_NEAR((out - ref_dst) / norm_max, 0., eps);
}
}
- }
+ });
}
template <typename data_t>
@@ -147,6 +150,7 @@ void check_bnrm_bwd(const test_bnrm_params_t &p,
const memory &variance, const memory &weights, const memory &diff_src,
const memory &diff_weights, unsigned flags, prop_kind pk)
{
+ const test_bnrm_sizes_t &bp = p.sizes;
const bool use_weights = flags & use_scale_shift;
const bool calculate_diff_stats = !(flags & omit_stats);
@@ -165,13 +169,22 @@ void check_bnrm_bwd(const test_bnrm_params_t &p,
const memory::desc diff_src_d = diff_src.get_primitive_desc().desc();
const memory::desc diff_weights_d = diff_weights.get_primitive_desc().desc();
- test_bnrm_sizes_t bp = p.sizes;
+ if (bp.mb * bp.c * bp.d * bp.h * bp.w == 0) {
+ if (pk == backward) {
+ for (int c = 0; c < bp.c; ++c) {
+ auto dg = diff_weights_data[map_index(diff_weights_d, c)];
+ auto db = diff_weights_data[map_index(diff_weights_d, bp.c + c)];
+ EXPECT_NEAR(dg, 0., 1e-7);
+ EXPECT_NEAR(db, 0., 1e-7);
+ }
+ }
+ return;
+ }
const data_t eps = static_cast<data_t>(1.e-4 * bp.mb * bp.d * bp.h * bp.w);
size_t padded_c = src.get_primitive_desc().desc().data.layout_desc.blocking.padding_dims[1];
-#pragma omp parallel for
- for (int c = 0; c < bp.c; c++) {
+ mkldnn::impl::parallel_nd(bp.c, [&](int c) {
data_t ref_diff_gamma = data_t(0);
data_t ref_diff_beta = data_t(0);
@@ -223,7 +236,7 @@ void check_bnrm_bwd(const test_bnrm_params_t &p,
if (norm_max < eps) norm_max = data_t(1);
EXPECT_NEAR((out_diff_src - ref_diff_src) / norm_max, 0., eps);
}
- }
+ });
}
template <typename data_t>
@@ -249,6 +262,12 @@ private:
protected:
virtual void SetUp() {
p = ::testing::TestWithParam<decltype(p)>::GetParam();
+ catch_expected_failures([=](){Test();}, p.expect_to_fail,
+ p.expected_status);
+ }
+
+ void Test() {
+ p = ::testing::TestWithParam<decltype(p)>::GetParam();
ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
eng.reset(new engine(p.engine_kind, 0));
@@ -448,26 +467,40 @@ TEST_P(bnrm_test_float, TestsBnrm)
#define ENGINE engine::kind::cpu
#define EPS 1e-5f
-#define PARAMS(data, diff, mb, c, h, w, eps) \
- test_bnrm_params_t { ENGINE, \
- EXPAND_FORMATS(data, diff), EXPAND_SIZES_2D(mb, c, h, w), eps, 4 }
+#define PARAMS(data, diff, mb, c, h, w, eps, ef, st) \
+ test_bnrm_params_t { ENGINE, EXPAND_FORMATS(data, diff), \
+ EXPAND_SIZES_2D(mb, c, h, w), eps, 4, ef, st }
-#define PARAMS_3D(data, diff, mb, c, d, h, w, eps) \
- test_bnrm_params_t { ENGINE, \
- EXPAND_FORMATS(data, diff), EXPAND_SIZES_3D(mb, c, d, h, w), eps, 5 }
+#define PARAMS_3D(data, diff, mb, c, d, h, w, eps, ef, st) \
+ test_bnrm_params_t { ENGINE, EXPAND_FORMATS(data, diff), \
+ EXPAND_SIZES_3D(mb, c, d, h, w), eps, 5, ef, st }
-#define PARAMS_N_3D(...) EXPAND_ARGS(PARAMS_3D(ncdhw, ncdhw, __VA_ARGS__))
-#define PARAMS_B16_3D(...) EXPAND_ARGS(PARAMS_3D(nCdhw16c, nCdhw16c, __VA_ARGS__))
-#define PARAMS_N(...) EXPAND_ARGS(PARAMS(nchw, nchw, __VA_ARGS__))
-#define PARAMS_NHWC(...) EXPAND_ARGS(PARAMS(nhwc, nhwc, __VA_ARGS__))
-#define PARAMS_NC(...) EXPAND_ARGS(PARAMS(nc, nc, __VA_ARGS__))
-#define PARAMS_B8(...) EXPAND_ARGS(PARAMS(nChw8c, nChw8c, __VA_ARGS__))
-#define PARAMS_B16(...) EXPAND_ARGS(PARAMS(nChw16c, nChw16c, __VA_ARGS__))
+#define PARAMS_N_3D(...) EXPAND_ARGS(PARAMS_3D(ncdhw, ncdhw, __VA_ARGS__, false, mkldnn_success))
+#define PARAMS_B8_3D(...) EXPAND_ARGS(PARAMS_3D(nCdhw8c, nCdhw8c, __VA_ARGS__, false, mkldnn_success))
+#define PARAMS_B16_3D(...) EXPAND_ARGS(PARAMS_3D(nCdhw16c, nCdhw16c, __VA_ARGS__, false, mkldnn_success))
+#define PARAMS_N(...) EXPAND_ARGS(PARAMS(nchw, nchw, __VA_ARGS__, false, mkldnn_success))
+#define PARAMS_NHWC(...) EXPAND_ARGS(PARAMS(nhwc, nhwc, __VA_ARGS__, false, mkldnn_success))
+#define PARAMS_NC(...) EXPAND_ARGS(PARAMS(nc, nc, __VA_ARGS__, false, mkldnn_success))
+#define PARAMS_B8(...) EXPAND_ARGS(PARAMS(nChw8c, nChw8c, __VA_ARGS__, false, mkldnn_success))
+#define PARAMS_B16(...) EXPAND_ARGS(PARAMS(nChw16c, nChw16c, __VA_ARGS__, false, mkldnn_success))
+#define PARAMS_EF(...) EXPAND_ARGS(PARAMS(nchw, nchw, __VA_ARGS__))
#define INST_TEST_CASE(str, ...) INSTANTIATE_TEST_CASE_P( \
str, bnrm_test_float, ::testing::Values(__VA_ARGS__))
-INST_TEST_CASE(Simple_Blocked_padded,
+INST_TEST_CASE(SimpleZeroDim,
+ PARAMS_N(0, 27, 9, 10, EPS),
+ PARAMS_N(1, 0, 10, 9, EPS),
+ PARAMS_N(4, 20, 0, 12, EPS)
+);
+
+INST_TEST_CASE(SimpleExpectedFails,
+ PARAMS_EF(-1, 27, 9, 10, EPS, true, mkldnn_invalid_arguments),
+ PARAMS_EF(1, -12, 10, 9, EPS, true, mkldnn_invalid_arguments),
+ PARAMS_EF(4, 20, -12, 12, EPS, true, mkldnn_invalid_arguments)
+);
+
+INST_TEST_CASE(Simple_nChw16c_padded,
PARAMS_B16(1, 27, 9, 10, EPS),
PARAMS_B16(1, 12, 10, 9, EPS),
PARAMS_B16(4, 20, 12, 12, EPS),
@@ -481,6 +514,13 @@ INST_TEST_CASE(Simple_nCdhw16c_padded,
PARAMS_B16_3D(2, 27, 10, 8, 4, EPS)
);
+INST_TEST_CASE(Simple_nChw8c_padded,
+ PARAMS_B8(1, 27, 9, 10, EPS),
+ PARAMS_B8(1, 12, 10, 9, EPS),
+ PARAMS_B8(4, 20, 12, 12, EPS),
+ PARAMS_B8(4, 7, 16, 16, EPS)
+);
+
INST_TEST_CASE(Simple_nCdhw16c,
PARAMS_B16_3D(2, 32, 4, 4, 4, EPS),
@@ -493,6 +533,17 @@ INST_TEST_CASE(Simple_nCdhw16c,
PARAMS_B16_3D(2, 32, 10, 8, 4, EPS)
);
+INST_TEST_CASE(Simple_nCdhw8c,
+ PARAMS_B8_3D(2, 32, 4, 4, 4, EPS),
+ PARAMS_B8_3D(2, 32, 4, 4, 4, EPS),
+ PARAMS_B8_3D(2, 32, 8, 8, 8, EPS),
+ PARAMS_B8_3D(2, 32, 8, 8, 8, EPS),
+ PARAMS_B8_3D(2, 32, 16, 8, 20, EPS),
+ PARAMS_B8_3D(2, 32, 16, 8, 20, EPS),
+ PARAMS_B8_3D(2, 32, 10, 8, 4, EPS),
+ PARAMS_B8_3D(2, 32, 10, 8, 4, EPS)
+);
+
INST_TEST_CASE(Simple_NC,
PARAMS_NC(2, 8, 1, 1, EPS),
PARAMS_NC(2, 10, 1, 1, EPS),