diff options
Diffstat (limited to 'compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp')
-rw-r--r-- | compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp index 903d4dcc9..bac033112 100644 --- a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp +++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp @@ -141,6 +141,37 @@ TEST(ReplaceMulAddWithDepthwiseConv, simple) } } +TEST(ReplaceMulAddWithDepthwiseConv, simple_rank4) +{ + SimpleGraph g; + + const uint32_t channel_size = 16; + g.gamma->shape({1, 1, 1, channel_size}); + g.beta->shape({1, 1, 1, channel_size}); + + luci::ReplaceMulAddWithDepthwiseConvPass pass; + while (pass.run(&g.g)) + ; + + auto dwconv = dynamic_cast<luci::CircleDepthwiseConv2D *>(g.output->from()); + EXPECT_NE(nullptr, dwconv); + + auto weights = dynamic_cast<luci::CircleConst *>(dwconv->filter()); + auto bias = dynamic_cast<luci::CircleConst *>(dwconv->bias()); + EXPECT_NE(nullptr, weights); + EXPECT_EQ(4, weights->rank()); + EXPECT_EQ(channel_size, weights->dim(3).value()); + EXPECT_NE(nullptr, bias); + EXPECT_EQ(1, bias->rank()); + EXPECT_EQ(channel_size, bias->dim(0).value()); + + for (int i = 0; i < channel_size; i++) + { + EXPECT_FLOAT_EQ(i, weights->at<loco::DataType::FLOAT32>(i)); + EXPECT_FLOAT_EQ(i, bias->at<loco::DataType::FLOAT32>(i)); + } +} + TEST(ReplaceMulAddWithDepthwiseConv, wrong_op_NEG) { SimpleGraph g; @@ -154,3 +185,18 @@ TEST(ReplaceMulAddWithDepthwiseConv, wrong_op_NEG) EXPECT_EQ(false, changed); } + +TEST(ReplaceMulAddWithDepthwiseConv, rank3_NEG) +{ + SimpleGraph g; + + g.input->shape({4, 4, 16}); + g.mul->shape({4, 4, 16}); + g.add->shape({4, 4, 16}); + g.output->shape({4, 4, 16}); + + luci::ReplaceMulAddWithDepthwiseConvPass pass; + auto changed = pass.run(&g.g); + + EXPECT_EQ(false, changed); +} |