diff options
Diffstat (limited to 'compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp')
-rw-r--r-- | compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp index c9412fbb1..dd81d1380 100644 --- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp +++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp @@ -264,6 +264,22 @@ public: luci::CircleConst *input2 = nullptr; }; +class EluGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + elu = g.nodes()->create<luci::CircleElu>(); + elu->features(input); + elu->name("elu"); + + return elu; + } + +public: + luci::CircleElu *elu = nullptr; +}; + class LeakyReluGraph final : public SimpleGraph { protected: @@ -941,6 +957,26 @@ TEST(ConvertNCHWToNHWC, Concatenation) EXPECT_EQ(3, g.concat->axis()); } +TEST(ConvertNCHWToNHWC, Elu) +{ + EluGraph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.elu->features()); + + auto elu_succs = loco::succs(g.elu); + EXPECT_EQ(1, elu_succs.size()); + check_post_trans(*elu_succs.begin()); + + // Check elu shape + EXPECT_EQ(1, g.elu->dim(0).value()); + EXPECT_EQ(4, g.elu->dim(1).value()); + EXPECT_EQ(4, g.elu->dim(2).value()); + EXPECT_EQ(16, g.elu->dim(3).value()); +} + TEST(ConvertNCHWToNHWC, LeakyRelu) { LeakyReluGraph g; |