summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp')
-rw-r--r--compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp36
1 files changed, 36 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
index c9412fbb1..dd81d1380 100644
--- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
+++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
@@ -264,6 +264,22 @@ public:
luci::CircleConst *input2 = nullptr;
};
+class EluGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ elu = g.nodes()->create<luci::CircleElu>();
+ elu->features(input);
+ elu->name("elu");
+
+ return elu;
+ }
+
+public:
+ luci::CircleElu *elu = nullptr;
+};
+
class LeakyReluGraph final : public SimpleGraph
{
protected:
@@ -941,6 +957,26 @@ TEST(ConvertNCHWToNHWC, Concatenation)
EXPECT_EQ(3, g.concat->axis());
}
+TEST(ConvertNCHWToNHWC, Elu)
+{
+ EluGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.elu->features());
+
+ auto elu_succs = loco::succs(g.elu);
+ EXPECT_EQ(1, elu_succs.size());
+ check_post_trans(*elu_succs.begin());
+
+ // Check elu shape
+ EXPECT_EQ(1, g.elu->dim(0).value());
+ EXPECT_EQ(4, g.elu->dim(1).value());
+ EXPECT_EQ(4, g.elu->dim(2).value());
+ EXPECT_EQ(16, g.elu->dim(3).value());
+}
+
TEST(ConvertNCHWToNHWC, LeakyRelu)
{
LeakyReluGraph g;