summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp')
-rw-r--r--compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp46
1 files changed, 46 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
index 903d4dcc9..bac033112 100644
--- a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
+++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
@@ -141,6 +141,37 @@ TEST(ReplaceMulAddWithDepthwiseConv, simple)
}
}
+TEST(ReplaceMulAddWithDepthwiseConv, simple_rank4)
+{
+ SimpleGraph g;
+
+ const uint32_t channel_size = 16;
+ g.gamma->shape({1, 1, 1, channel_size});
+ g.beta->shape({1, 1, 1, channel_size});
+
+ luci::ReplaceMulAddWithDepthwiseConvPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ auto dwconv = dynamic_cast<luci::CircleDepthwiseConv2D *>(g.output->from());
+ EXPECT_NE(nullptr, dwconv);
+
+ auto weights = dynamic_cast<luci::CircleConst *>(dwconv->filter());
+ auto bias = dynamic_cast<luci::CircleConst *>(dwconv->bias());
+ EXPECT_NE(nullptr, weights);
+ EXPECT_EQ(4, weights->rank());
+ EXPECT_EQ(channel_size, weights->dim(3).value());
+ EXPECT_NE(nullptr, bias);
+ EXPECT_EQ(1, bias->rank());
+ EXPECT_EQ(channel_size, bias->dim(0).value());
+
+ for (int i = 0; i < channel_size; i++)
+ {
+ EXPECT_FLOAT_EQ(i, weights->at<loco::DataType::FLOAT32>(i));
+ EXPECT_FLOAT_EQ(i, bias->at<loco::DataType::FLOAT32>(i));
+ }
+}
+
TEST(ReplaceMulAddWithDepthwiseConv, wrong_op_NEG)
{
SimpleGraph g;
@@ -154,3 +185,18 @@ TEST(ReplaceMulAddWithDepthwiseConv, wrong_op_NEG)
EXPECT_EQ(false, changed);
}
+
+TEST(ReplaceMulAddWithDepthwiseConv, rank3_NEG)
+{
+ SimpleGraph g;
+
+ g.input->shape({4, 4, 16});
+ g.mul->shape({4, 4, 16});
+ g.add->shape({4, 4, 16});
+ g.output->shape({4, 4, 16});
+
+ luci::ReplaceMulAddWithDepthwiseConvPass pass;
+ auto changed = pass.run(&g.g);
+
+ EXPECT_EQ(false, changed);
+}