diff options
Diffstat (limited to 'compiler/luci/pass/src/BatchNormPatternFinder.test.cpp')
-rw-r--r-- | compiler/luci/pass/src/BatchNormPatternFinder.test.cpp | 107 |
1 files changed, 100 insertions, 7 deletions
diff --git a/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp b/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp index 08e7fac1c..cc8c5615f 100644 --- a/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp +++ b/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp @@ -50,7 +50,7 @@ public: auto channel_size = *last_it; _add->shape(shape); - _add_beta->shape({channel_size}); + set_beta_shape(channel_size); _add_beta->size<loco::DataType::FLOAT32>(channel_size); for (uint32_t i = 0; i < channel_size; i++) _add_beta->at<loco::DataType::FLOAT32>(i) = i; @@ -63,10 +63,23 @@ public: luci::CircleAdd *add() { return _add; } protected: + virtual void set_beta_shape(uint32_t channel) = 0; + +protected: luci::CircleAdd *_add = nullptr; luci::CircleConst *_add_beta = nullptr; }; +class AddRank1BetaGraphlet : public AddBetaGraphlet +{ + void set_beta_shape(uint32_t channel) final { _add_beta->shape({channel}); } +}; + +class AddRank4BetaGraphlet : public AddBetaGraphlet +{ + void set_beta_shape(uint32_t channel) final { _add_beta->shape({1, 1, 1, channel}); } +}; + /** * @brief Graphlet with Mul and Const as gamma from BatchNorm */ @@ -90,7 +103,7 @@ public: auto channel_size = *last_it; _mul->shape(shape); - _mul_gamma->shape({channel_size}); + set_gamma_shape(channel_size); _mul_gamma->size<loco::DataType::FLOAT32>(channel_size); for (uint32_t i = 0; i < channel_size; i++) _mul_gamma->at<loco::DataType::FLOAT32>(i) = i; @@ -103,14 +116,27 @@ public: luci::CircleMul *mul(void) { return _mul; } protected: + virtual void set_gamma_shape(uint32_t channel) = 0; + +protected: luci::CircleMul *_mul = nullptr; luci::CircleConst *_mul_gamma = nullptr; }; +class MulRank1GammaGraphlet : public MulGammaGraphlet +{ + void set_gamma_shape(uint32_t channel) final { _mul_gamma->shape({channel}); } +}; + +class MulRank4GammaGraphlet : public MulGammaGraphlet +{ + void set_gamma_shape(uint32_t channel) final { _mul_gamma->shape({1, 1, 1, channel}); } +}; + /** * @brief Graph of Mul-Add pattern from BatchNorm */ -class MulAddGraph : public TestIOGraph, public AddBetaGraphlet, public MulGammaGraphlet +class MulAddGraph : public TestIOGraph, public AddRank1BetaGraphlet, public MulRank1GammaGraphlet { public: MulAddGraph() = default; @@ -118,8 +144,30 @@ public: void init(const ShapeU32 shape_in, const ShapeU32 shape_out) { TestIOGraph::init(shape_in, shape_out); - MulGammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE); - AddBetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU); + MulRank1GammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE); + AddRank1BetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU); + + // connect network + _mul->x(input()); + _mul->y(_mul_gamma); + _add->x(_mul); + _add->y(_add_beta); + output()->from(_add); + } +}; + +class MulAddRank4Graph : public TestIOGraph, + public AddRank4BetaGraphlet, + public MulRank4GammaGraphlet +{ +public: + MulAddRank4Graph() = default; + + void init(const ShapeU32 shape_in, const ShapeU32 shape_out) + { + TestIOGraph::init(shape_in, shape_out); + MulRank4GammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE); + AddRank4BetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU); // connect network _mul->x(input()); @@ -133,7 +181,7 @@ public: /** * @brief Graph of Add with Const */ -class AddGraph : public TestIOGraph, public AddBetaGraphlet +class AddGraph : public TestIOGraph, public AddRank1BetaGraphlet { public: AddGraph() = default; @@ -141,7 +189,24 @@ public: void init(const ShapeU32 shape_in, const ShapeU32 shape_out) { TestIOGraph::init(shape_in, shape_out); - AddBetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU); + AddRank1BetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU); + + // connect network + _add->x(input()); + _add->y(_add_beta); + output()->from(_add); + } +}; + +class AddRank4Graph : public TestIOGraph, public AddRank4BetaGraphlet +{ +public: + AddRank4Graph() = default; + + void init(const ShapeU32 shape_in, const ShapeU32 shape_out) + { + TestIOGraph::init(shape_in, shape_out); + AddRank4BetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU); // connect network _add->x(input()); @@ -160,6 +225,7 @@ public: protected: luci::test::MulAddGraph _mag; + luci::test::MulAddRank4Graph _mag_r4; }; class BatchNormPatternFinderAddTest : public ::testing::Test @@ -169,6 +235,7 @@ public: protected: luci::test::AddGraph _ag; + luci::test::AddRank4Graph _ag_r4; }; TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add) @@ -192,6 +259,19 @@ TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add2) ASSERT_TRUE(res); } +TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add_rank4) +{ + _mag_r4.init({1, 16, 16, 4}, {1, 16, 16, 4}); + + luci::CircleMul *mul = nullptr; + luci::CircleConst *beta = nullptr; + + auto res = luci::is_batchnorm_add(_mag_r4.add(), mul, beta); + ASSERT_TRUE(res); + ASSERT_NE(nullptr, mul); + ASSERT_NE(nullptr, beta); +} + TEST_F(BatchNormPatternFinderAddTest, is_batchnorm_add_NEG) { _ag.init({1, 16, 16, 4}, {1, 16, 16, 4}); @@ -215,3 +295,16 @@ TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_mul) ASSERT_NE(nullptr, pred); ASSERT_NE(nullptr, gamma); } + +TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_mul_rank4) +{ + _mag_r4.init({1, 16, 16, 4}, {1, 16, 16, 4}); + + luci::CircleNode *pred = nullptr; + luci::CircleConst *gamma = nullptr; + + auto res = luci::is_batchnorm_mul(_mag_r4.mul(), pred, gamma); + ASSERT_TRUE(res); + ASSERT_NE(nullptr, pred); + ASSERT_NE(nullptr, gamma); +} |