summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/BatchNormPatternFinder.test.cpp')
-rw-r--r--compiler/luci/pass/src/BatchNormPatternFinder.test.cpp107
1 files changed, 100 insertions, 7 deletions
diff --git a/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp b/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp
index 08e7fac1c..cc8c5615f 100644
--- a/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp
+++ b/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp
@@ -50,7 +50,7 @@ public:
auto channel_size = *last_it;
_add->shape(shape);
- _add_beta->shape({channel_size});
+ set_beta_shape(channel_size);
_add_beta->size<loco::DataType::FLOAT32>(channel_size);
for (uint32_t i = 0; i < channel_size; i++)
_add_beta->at<loco::DataType::FLOAT32>(i) = i;
@@ -63,10 +63,23 @@ public:
luci::CircleAdd *add() { return _add; }
protected:
+ virtual void set_beta_shape(uint32_t channel) = 0;
+
+protected:
luci::CircleAdd *_add = nullptr;
luci::CircleConst *_add_beta = nullptr;
};
+class AddRank1BetaGraphlet : public AddBetaGraphlet
+{
+ void set_beta_shape(uint32_t channel) final { _add_beta->shape({channel}); }
+};
+
+class AddRank4BetaGraphlet : public AddBetaGraphlet
+{
+ void set_beta_shape(uint32_t channel) final { _add_beta->shape({1, 1, 1, channel}); }
+};
+
/**
* @brief Graphlet with Mul and Const as gamma from BatchNorm
*/
@@ -90,7 +103,7 @@ public:
auto channel_size = *last_it;
_mul->shape(shape);
- _mul_gamma->shape({channel_size});
+ set_gamma_shape(channel_size);
_mul_gamma->size<loco::DataType::FLOAT32>(channel_size);
for (uint32_t i = 0; i < channel_size; i++)
_mul_gamma->at<loco::DataType::FLOAT32>(i) = i;
@@ -103,14 +116,27 @@ public:
luci::CircleMul *mul(void) { return _mul; }
protected:
+ virtual void set_gamma_shape(uint32_t channel) = 0;
+
+protected:
luci::CircleMul *_mul = nullptr;
luci::CircleConst *_mul_gamma = nullptr;
};
+class MulRank1GammaGraphlet : public MulGammaGraphlet
+{
+ void set_gamma_shape(uint32_t channel) final { _mul_gamma->shape({channel}); }
+};
+
+class MulRank4GammaGraphlet : public MulGammaGraphlet
+{
+ void set_gamma_shape(uint32_t channel) final { _mul_gamma->shape({1, 1, 1, channel}); }
+};
+
/**
* @brief Graph of Mul-Add pattern from BatchNorm
*/
-class MulAddGraph : public TestIOGraph, public AddBetaGraphlet, public MulGammaGraphlet
+class MulAddGraph : public TestIOGraph, public AddRank1BetaGraphlet, public MulRank1GammaGraphlet
{
public:
MulAddGraph() = default;
@@ -118,8 +144,30 @@ public:
void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
{
TestIOGraph::init(shape_in, shape_out);
- MulGammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE);
- AddBetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU);
+ MulRank1GammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE);
+ AddRank1BetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU);
+
+ // connect network
+ _mul->x(input());
+ _mul->y(_mul_gamma);
+ _add->x(_mul);
+ _add->y(_add_beta);
+ output()->from(_add);
+ }
+};
+
+class MulAddRank4Graph : public TestIOGraph,
+ public AddRank4BetaGraphlet,
+ public MulRank4GammaGraphlet
+{
+public:
+ MulAddRank4Graph() = default;
+
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ MulRank4GammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE);
+ AddRank4BetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU);
// connect network
_mul->x(input());
@@ -133,7 +181,7 @@ public:
/**
* @brief Graph of Add with Const
*/
-class AddGraph : public TestIOGraph, public AddBetaGraphlet
+class AddGraph : public TestIOGraph, public AddRank1BetaGraphlet
{
public:
AddGraph() = default;
@@ -141,7 +189,24 @@ public:
void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
{
TestIOGraph::init(shape_in, shape_out);
- AddBetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU);
+ AddRank1BetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU);
+
+ // connect network
+ _add->x(input());
+ _add->y(_add_beta);
+ output()->from(_add);
+ }
+};
+
+class AddRank4Graph : public TestIOGraph, public AddRank4BetaGraphlet
+{
+public:
+ AddRank4Graph() = default;
+
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ AddRank4BetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU);
// connect network
_add->x(input());
@@ -160,6 +225,7 @@ public:
protected:
luci::test::MulAddGraph _mag;
+ luci::test::MulAddRank4Graph _mag_r4;
};
class BatchNormPatternFinderAddTest : public ::testing::Test
@@ -169,6 +235,7 @@ public:
protected:
luci::test::AddGraph _ag;
+ luci::test::AddRank4Graph _ag_r4;
};
TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add)
@@ -192,6 +259,19 @@ TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add2)
ASSERT_TRUE(res);
}
+TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add_rank4)
+{
+ _mag_r4.init({1, 16, 16, 4}, {1, 16, 16, 4});
+
+ luci::CircleMul *mul = nullptr;
+ luci::CircleConst *beta = nullptr;
+
+ auto res = luci::is_batchnorm_add(_mag_r4.add(), mul, beta);
+ ASSERT_TRUE(res);
+ ASSERT_NE(nullptr, mul);
+ ASSERT_NE(nullptr, beta);
+}
+
TEST_F(BatchNormPatternFinderAddTest, is_batchnorm_add_NEG)
{
_ag.init({1, 16, 16, 4}, {1, 16, 16, 4});
@@ -215,3 +295,16 @@ TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_mul)
ASSERT_NE(nullptr, pred);
ASSERT_NE(nullptr, gamma);
}
+
+TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_mul_rank4)
+{
+ _mag_r4.init({1, 16, 16, 4}, {1, 16, 16, 4});
+
+ luci::CircleNode *pred = nullptr;
+ luci::CircleConst *gamma = nullptr;
+
+ auto res = luci::is_batchnorm_mul(_mag_r4.mul(), pred, gamma);
+ ASSERT_TRUE(res);
+ ASSERT_NE(nullptr, pred);
+ ASSERT_NE(nullptr, gamma);
+}