summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
diff options
context:
space:
mode:
authorChunseok Lee <chunseok.lee@samsung.com>2021-10-19 11:32:46 +0900
committerChunseok Lee <chunseok.lee@samsung.com>2021-10-19 11:32:46 +0900
commit33ae5d70a1ed85d215c1293ed63afbf3517b07d5 (patch)
tree9f1ace0f4760a8f7903ef15e2e92f1d1401e4b1e /compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
parentf4cf19e579a19c5346ccb2aad55bfd251065e447 (diff)
downloadnnfw-33ae5d70a1ed85d215c1293ed63afbf3517b07d5.tar.gz
nnfw-33ae5d70a1ed85d215c1293ed63afbf3517b07d5.tar.bz2
nnfw-33ae5d70a1ed85d215c1293ed63afbf3517b07d5.zip
Diffstat (limited to 'compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp')
-rw-r--r--compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp111
1 files changed, 111 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
index d844246f8..c9412fbb1 100644
--- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
+++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
@@ -130,6 +130,19 @@ protected:
}
public:
+ void update_const_shape_to_nchw(void)
+ {
+ uint32_t channel_size = 16;
+ beta->shape({1, channel_size, 4, 4});
+
+ beta->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
+ for (uint32_t i = 0; i < channel_size; i++)
+ {
+ beta->at<loco::DataType::FLOAT32>(i) = i;
+ }
+ }
+
+public:
luci::CircleAdd *add = nullptr;
luci::CircleConst *beta = nullptr;
};
@@ -421,6 +434,19 @@ protected:
}
public:
+ void update_const_shape_to_nchw(void)
+ {
+ uint32_t channel_size = 16;
+ multiplier->shape({1, channel_size, 4, 4});
+
+ multiplier->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
+ for (uint32_t i = 0; i < channel_size; i++)
+ {
+ multiplier->at<loco::DataType::FLOAT32>(i) = i;
+ }
+ }
+
+public:
luci::CircleMul *mul = nullptr;
luci::CircleConst *multiplier = nullptr;
};
@@ -696,6 +722,19 @@ protected:
}
public:
+ void update_const_shape_to_nchw(void)
+ {
+ uint32_t channel_size = 16;
+ beta->shape({1, channel_size, 4, 4});
+
+ beta->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
+ for (uint32_t i = 0; i < channel_size; i++)
+ {
+ beta->at<loco::DataType::FLOAT32>(i) = i;
+ }
+ }
+
+public:
luci::CircleSub *sub = nullptr;
luci::CircleConst *beta = nullptr;
};
@@ -815,6 +854,30 @@ TEST(ConvertNCHWToNHWC, Add)
check_pre_trans(g.output->from());
}
+TEST(ConvertNCHWToNHWC, Add_NCHW_const)
+{
+ AddGraph g;
+ g.init();
+ g.update_const_shape_to_nchw();
+
+ run_phase(&g.g, false, false);
+
+ check_pre_trans(g.add->x());
+
+ auto add_succs = loco::succs(g.add);
+ EXPECT_EQ(1, add_succs.size());
+ check_post_trans(*add_succs.begin());
+
+ uint32_t channel_size = 16;
+ auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
+ EXPECT_NE(nullptr, new_beta);
+ EXPECT_EQ(4, new_beta->rank());
+ EXPECT_EQ(1, new_beta->dim(0).value());
+ EXPECT_EQ(4, new_beta->dim(1).value());
+ EXPECT_EQ(4, new_beta->dim(2).value());
+ EXPECT_EQ(channel_size, new_beta->dim(3).value());
+}
+
TEST(ConvertNCHWToNHWC, NHWC_Relu)
{
// Relu is already NHWC, so it should not be converted
@@ -1123,6 +1186,30 @@ TEST(ConvertNCHWToNHWC, Mul)
check_pre_trans(g.output->from());
}
+TEST(ConvertNCHWToNHWC, Mul_NCHW_const)
+{
+ MulGraph g;
+ g.init();
+ g.update_const_shape_to_nchw();
+
+ run_phase(&g.g, false, false);
+
+ check_pre_trans(g.mul->x());
+
+ auto mul_succs = loco::succs(g.mul);
+ EXPECT_EQ(1, mul_succs.size());
+ check_post_trans(*mul_succs.begin());
+
+ uint32_t channel_size = 16;
+ auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
+ EXPECT_NE(nullptr, new_multiplier);
+ EXPECT_EQ(4, new_multiplier->rank());
+ EXPECT_EQ(1, new_multiplier->dim(0).value());
+ EXPECT_EQ(4, new_multiplier->dim(1).value());
+ EXPECT_EQ(4, new_multiplier->dim(2).value());
+ EXPECT_EQ(channel_size, new_multiplier->dim(3).value());
+}
+
TEST(ConvertNCHWToNHWC, MulScalar)
{
MulScalarGraph g;
@@ -1432,6 +1519,30 @@ TEST(ConvertNCHWToNHWC, Sub)
check_pre_trans(g.output->from());
}
+TEST(ConvertNCHWToNHWC, Sub_NCHW_const)
+{
+ SubGraph g;
+ g.init();
+ g.update_const_shape_to_nchw();
+
+ run_phase(&g.g, false, false);
+
+ check_pre_trans(g.sub->x());
+
+ auto sub_succs = loco::succs(g.sub);
+ EXPECT_EQ(1, sub_succs.size());
+ check_post_trans(*sub_succs.begin());
+
+ uint32_t channel_size = 16;
+ auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->y());
+ EXPECT_NE(nullptr, new_beta);
+ EXPECT_EQ(4, new_beta->rank());
+ EXPECT_EQ(1, new_beta->dim(0).value());
+ EXPECT_EQ(4, new_beta->dim(1).value());
+ EXPECT_EQ(4, new_beta->dim(2).value());
+ EXPECT_EQ(channel_size, new_beta->dim(3).value());
+}
+
TEST(ConvertNCHWToNHWC, SubScalar)
{
SubScalarGraph g;