summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
diff options
context:
space:
mode:
authorChunseok Lee <chunseok.lee@samsung.com>2022-09-07 19:04:21 +0900
committerChunseok Lee <chunseok.lee@samsung.com>2022-09-07 19:04:21 +0900
commitc690d52bdd137ed6a17353aa7af35e8141ece77b (patch)
treedbb7dd99133132dfbffcb8c9e9af4f1ffc2f4808 /compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
parent3ad689f0803519e343c36d5700646e86059df961 (diff)
downloadnnfw-c690d52bdd137ed6a17353aa7af35e8141ece77b.tar.gz
nnfw-c690d52bdd137ed6a17353aa7af35e8141ece77b.tar.bz2
nnfw-c690d52bdd137ed6a17353aa7af35e8141ece77b.zip
Diffstat (limited to 'compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp')
-rw-r--r--compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp525
1 files changed, 523 insertions, 2 deletions
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
index dd81d1380..6bb3d3268 100644
--- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
+++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
@@ -16,6 +16,8 @@
#include <logo/Phase.h>
+#include <luci/test/TestIOGraph.h>
+
#include "luci/Pass/ConvertNCHWToNHWCPass.h"
#include "luci/Pass/CircleShapeInferencePass.h"
@@ -23,6 +25,8 @@
#include <gtest/gtest.h>
+using namespace luci::test;
+
namespace
{
@@ -202,6 +206,173 @@ public:
luci::CircleConst *post_shape = nullptr;
};
+/**
+ * Graph with pre-Reshape but no post-Transpose/Reshape.
+ *
+ * BEFORE
+ * [Input]
+ * |
+ * [Pre-Reshape]
+ * |
+ * [Relu]
+ * |
+ * [Output]
+ *
+ * AFTER
+ * [Input]
+ * |
+ * [Pre-Reshape]
+ * |
+ * [Pre-Transpose]
+ * |
+ * [Relu]
+ * |
+ * [Post-Transpose]
+ * |
+ * [Output]
+ */
+class NoPostReshapeGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ relu = g.nodes()->create<luci::CircleRelu>();
+ pre_reshape = g.nodes()->create<luci::CircleReshape>();
+ pre_shape = g.nodes()->create<luci::CircleConst>();
+
+ pre_shape->dtype(loco::DataType::S32);
+
+ uint32_t channel_size = 16;
+ auto in = loco::must_cast<luci::CircleNode *>(input);
+ in->shape({1, channel_size, 4, 4});
+ pre_shape->shape({4});
+
+ pre_shape->size<loco::DataType::S32>(4);
+ pre_shape->at<loco::DataType::S32>(0) = 1;
+ pre_shape->at<loco::DataType::S32>(1) = 4;
+ pre_shape->at<loco::DataType::S32>(2) = 4;
+ pre_shape->at<loco::DataType::S32>(3) = channel_size;
+
+ pre_reshape->tensor(input);
+ pre_reshape->shape(pre_shape);
+ relu->features(pre_reshape);
+
+ relu->name("Relu");
+ pre_reshape->name("pre-reshape");
+
+ return relu;
+ }
+
+public:
+ luci::CircleRelu *relu = nullptr;
+ luci::CircleReshape *pre_reshape = nullptr;
+ luci::CircleConst *pre_shape = nullptr;
+};
+
+/**
+ * Graph with two pre-Reshapes
+ *
+ * BEFORE
+ * [Input]
+ * |
+ * [Pre-Reshape]
+ * |
+ * [Relu]
+ * |
+ * [Pre-Reshape]
+ * |
+ * [Post-Reshape]
+ * |
+ * [Output]
+ *
+ * AFTER
+ * [Input]
+ * |
+ * [Pre-Reshape]
+ * |
+ * [Pre-Transpose]
+ * |
+ * [Relu]
+ * |
+ * [Post-Transpose]
+ * |
+ * [Pre-Reshape]
+ * |
+ * [Post-Reshape]
+ * |
+ * [Output]
+ */
+class ReluNotClosedGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ relu = g.nodes()->create<luci::CircleRelu>();
+ pre_reshape = g.nodes()->create<luci::CircleReshape>();
+ pre_reshape_2 = g.nodes()->create<luci::CircleReshape>();
+ post_reshape = g.nodes()->create<luci::CircleReshape>();
+ pre_shape = g.nodes()->create<luci::CircleConst>();
+ pre_shape_2 = g.nodes()->create<luci::CircleConst>();
+ post_shape = g.nodes()->create<luci::CircleConst>();
+
+ pre_shape->dtype(loco::DataType::S32);
+ pre_shape_2->dtype(loco::DataType::S32);
+ post_shape->dtype(loco::DataType::S32);
+
+ uint32_t channel_size = 16;
+ auto in = loco::must_cast<luci::CircleNode *>(input);
+ in->shape({1, channel_size, 4, 4});
+ pre_shape->shape({4});
+ pre_shape_2->shape({4});
+ post_shape->shape({4});
+
+ pre_shape->size<loco::DataType::S32>(4);
+ pre_shape->at<loco::DataType::S32>(0) = 1;
+ pre_shape->at<loco::DataType::S32>(1) = 4;
+ pre_shape->at<loco::DataType::S32>(2) = 4;
+ pre_shape->at<loco::DataType::S32>(3) = channel_size;
+
+ pre_shape_2->size<loco::DataType::S32>(4);
+ pre_shape_2->at<loco::DataType::S32>(0) = 1;
+ pre_shape_2->at<loco::DataType::S32>(1) = 4;
+ pre_shape_2->at<loco::DataType::S32>(2) = channel_size;
+ pre_shape_2->at<loco::DataType::S32>(3) = 4;
+
+ post_shape->size<loco::DataType::S32>(4);
+ post_shape->at<loco::DataType::S32>(0) = 1;
+ post_shape->at<loco::DataType::S32>(1) = 4;
+ post_shape->at<loco::DataType::S32>(2) = 4;
+ post_shape->at<loco::DataType::S32>(3) = channel_size;
+
+ pre_reshape->tensor(input);
+ pre_reshape->shape(pre_shape);
+
+ relu->features(pre_reshape);
+
+ pre_reshape_2->tensor(relu);
+ pre_reshape_2->shape(pre_shape_2);
+
+ post_reshape->tensor(pre_reshape_2);
+ post_reshape->shape(post_shape);
+
+ relu->name("Relu");
+ pre_reshape->name("pre-reshape");
+ pre_reshape->name("pre-reshape-2");
+ post_reshape->name("post-reshape");
+
+ return post_reshape;
+ }
+
+public:
+ luci::CircleRelu *relu = nullptr;
+ luci::CircleReshape *pre_reshape = nullptr;
+ luci::CircleReshape *pre_reshape_2 = nullptr;
+ luci::CircleReshape *post_reshape = nullptr;
+ luci::CircleConst *pre_shape = nullptr;
+ luci::CircleConst *pre_shape_2 = nullptr;
+ luci::CircleConst *post_shape = nullptr;
+};
+
class AddScalarGraph final : public SimpleGraph
{
protected:
@@ -312,6 +483,22 @@ public:
luci::CircleLogistic *logistic = nullptr;
};
+class LogSoftmaxGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ log_softmax = g.nodes()->create<luci::CircleLogSoftmax>();
+ log_softmax->logits(input);
+ log_softmax->name("log_softmax");
+
+ return log_softmax;
+ }
+
+public:
+ luci::CircleLogSoftmax *log_softmax = nullptr;
+};
+
class MaximumGraph final : public SimpleGraph
{
protected:
@@ -642,6 +829,51 @@ public:
luci::CircleConst *const_value = nullptr;
};
+class ReduceMaxGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ rm = g.nodes()->create<luci::CircleReduceMax>();
+ rindices = g.nodes()->create<luci::CircleConst>();
+
+ rm->dtype(loco::DataType::FLOAT32);
+ rindices->dtype(loco::DataType::S32);
+
+ rm->shape(_shape);
+ rindices->shape({static_cast<uint32_t>(_axes.size())});
+
+ rindices->size<loco::DataType::S32>(_axes.size());
+ for (uint32_t i = 0; i < _axes.size(); ++i)
+ {
+ rindices->at<loco::DataType::S32>(i) = _axes[i];
+ }
+
+ rm->input(input);
+ rm->reduction_indices(rindices);
+ rm->keep_dims(_keep_dims);
+
+ rm->name("reduce_max");
+ rindices->name("rindices");
+
+ return rm;
+ }
+
+public:
+ void keep_dims(bool val) { _keep_dims = val; }
+ void axes(std::vector<int32_t> val) { _axes = val; }
+ void shape(std::initializer_list<uint32_t> val) { _shape = val; }
+
+public:
+ luci::CircleReduceMax *rm = nullptr;
+ luci::CircleConst *rindices = nullptr;
+
+private:
+ bool _keep_dims = true;
+ std::vector<int32_t> _axes = {2, 3};
+ std::initializer_list<uint32_t> _shape = {1, 16, 1, 1};
+};
+
class ReluGraph final : public SimpleGraph
{
protected:
@@ -690,6 +922,111 @@ public:
luci::CircleRsqrt *rsqrt = nullptr;
};
+class SoftmaxGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ softmax = g.nodes()->create<luci::CircleSoftmax>();
+ softmax->logits(input);
+ softmax->name("softmax");
+
+ return softmax;
+ }
+
+public:
+ luci::CircleSoftmax *softmax = nullptr;
+};
+
+class SplitVGraphlet
+{
+public:
+ SplitVGraphlet() = default;
+
+public:
+ void init(loco::Graph *g)
+ {
+ // CircleCustom(SplitV)
+ _splitv = g->nodes()->create<luci::CircleSplitV>();
+ _splitv->shape({1, 2, 2, 192});
+ _splitv->dtype(loco::DataType::FLOAT32);
+ _splitv->name("splitv");
+
+ // CircleConst
+ auto size_splits = g->nodes()->create<luci::CircleConst>();
+ size_splits->dtype(loco::DataType::S32);
+ size_splits->shape({3});
+ size_splits->size<loco::DataType::S32>(3);
+ size_splits->at<loco::DataType::S32>(0) = 32;
+ size_splits->at<loco::DataType::S32>(1) = 32;
+ size_splits->at<loco::DataType::S32>(2) = 128;
+
+ // CircleConst
+ auto split_dim = g->nodes()->create<luci::CircleConst>();
+ split_dim->dtype(loco::DataType::S32);
+ split_dim->rank(0);
+ split_dim->size<loco::DataType::S32>(1);
+ split_dim->scalar<loco::DataType::S32>() = 3;
+
+ _splitv->size_splits(size_splits);
+ _splitv->split_dim(split_dim);
+ _splitv->num_split(3);
+
+ // CircleSplitVOut
+ _splitv_out1 = g->nodes()->create<luci::CircleSplitVOut>();
+ _splitv_out1->shape({1, 2, 2, 32});
+ _splitv_out1->dtype(loco::DataType::FLOAT32);
+ _splitv_out1->index(0);
+ _splitv_out1->input(_splitv);
+ _splitv_out1->name("splitv_out1");
+
+ // CircleSplitVOut
+ _splitv_out2 = g->nodes()->create<luci::CircleSplitVOut>();
+ _splitv_out2->shape({1, 2, 2, 32});
+ _splitv_out2->dtype(loco::DataType::FLOAT32);
+ _splitv_out2->index(1);
+ _splitv_out2->input(_splitv);
+ _splitv_out2->name("splitv_out2");
+
+ // CircleSplitVOut
+ _splitv_out3 = g->nodes()->create<luci::CircleSplitVOut>();
+ _splitv_out3->shape({1, 2, 2, 128});
+ _splitv_out3->dtype(loco::DataType::FLOAT32);
+ _splitv_out3->index(2);
+ _splitv_out3->input(_splitv);
+ _splitv_out3->name("splitv_out3");
+ }
+
+public:
+ luci::CircleSplitV *splitv() { return _splitv; }
+
+protected:
+ luci::CircleSplitV *_splitv = nullptr;
+ luci::CircleSplitVOut *_splitv_out1 = nullptr;
+ luci::CircleSplitVOut *_splitv_out2 = nullptr;
+ luci::CircleSplitVOut *_splitv_out3 = nullptr;
+};
+
+class SplitVGraph : public TestIGraphlet, public TestOsGraphlet<3>, public SplitVGraphlet
+{
+public:
+ SplitVGraph() = default;
+
+ void init(void)
+ {
+ TestIGraphlet::init(g(), {1, 2, 2, 192});
+ TestOsGraphlet<3>::init(g(), {{1, 2, 2, 32}, {1, 2, 2, 32}, {1, 2, 2, 128}});
+ SplitVGraphlet::init(g());
+
+ // connect graph
+ _splitv->input(input());
+
+ output(0)->from(_splitv_out1);
+ output(1)->from(_splitv_out2);
+ output(2)->from(_splitv_out3);
+ }
+};
+
class SquaredDifferenceGraph final : public SimpleGraph
{
protected:
@@ -929,8 +1266,11 @@ TEST(ConvertNCHWToNHWC, AddScalar)
auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
EXPECT_NE(nullptr, new_beta);
- EXPECT_EQ(1, new_beta->rank());
+ EXPECT_EQ(4, new_beta->rank());
EXPECT_EQ(1, new_beta->dim(0).value());
+ EXPECT_EQ(1, new_beta->dim(1).value());
+ EXPECT_EQ(1, new_beta->dim(2).value());
+ EXPECT_EQ(1, new_beta->dim(3).value());
check_pre_trans(g.output->from());
}
@@ -1017,6 +1357,26 @@ TEST(ConvertNCHWToNHWC, Logistic)
EXPECT_EQ(16, g.logistic->dim(3).value());
}
+TEST(ConvertNCHWToNHWC, LogSoftmax)
+{
+ LogSoftmaxGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.log_softmax->logits());
+
+ auto log_softmax_succs = loco::succs(g.log_softmax);
+ EXPECT_EQ(1, log_softmax_succs.size());
+ check_post_trans(*log_softmax_succs.begin());
+
+ // Check log_softmax shape
+ EXPECT_EQ(1, g.log_softmax->dim(0).value());
+ EXPECT_EQ(4, g.log_softmax->dim(1).value());
+ EXPECT_EQ(4, g.log_softmax->dim(2).value());
+ EXPECT_EQ(16, g.log_softmax->dim(3).value());
+}
+
TEST(ConvertNCHWToNHWC, Maximum)
{
MaximumGraph g;
@@ -1265,8 +1625,11 @@ TEST(ConvertNCHWToNHWC, MulScalar)
auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
EXPECT_NE(nullptr, new_multiplier);
- EXPECT_EQ(1, new_multiplier->rank());
+ EXPECT_EQ(4, new_multiplier->rank());
EXPECT_EQ(1, new_multiplier->dim(0).value());
+ EXPECT_EQ(1, new_multiplier->dim(1).value());
+ EXPECT_EQ(1, new_multiplier->dim(2).value());
+ EXPECT_EQ(1, new_multiplier->dim(3).value());
check_pre_trans(g.output->from());
}
@@ -1451,6 +1814,85 @@ TEST(ConvertNCHWToNHWC, Preserve_Input_Output)
}
}
+TEST(ConvertNCHWToNHWC, ReduceMax)
+{
+ ReduceMaxGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ check_pre_trans(g.rm->input());
+
+ auto rm_succs = loco::succs(g.rm);
+ EXPECT_EQ(1, rm_succs.size());
+ check_post_trans(*rm_succs.begin());
+
+ auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
+ EXPECT_NE(nullptr, new_rindices);
+ EXPECT_EQ(1, new_rindices->rank());
+ EXPECT_EQ(2, new_rindices->dim(0).value());
+ EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
+ EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
+ EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
+}
+
+TEST(ConvertNCHWToNHWC, ReduceMax_keep_dims_false)
+{
+ struct TC
+ {
+ std::vector<int32_t> nchw_ind;
+ std::vector<int32_t> nhwc_ind;
+ std::initializer_list<uint32_t> shape;
+ bool needs_transpose = false;
+ };
+
+ uint32_t n = 1;
+ uint32_t c = 16;
+ uint32_t h = 4;
+ uint32_t w = 4;
+
+ std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true}, {{1}, {3}, {n, h, w}, false},
+ {{2}, {1}, {n, c, w}, true}, {{3}, {2}, {n, c, h}, true},
+ {{0, 1}, {0, 3}, {h, w}, false}, {{0, 2}, {0, 1}, {c, w}, true},
+ {{0, 3}, {0, 2}, {c, h}, true}, {{1, 2}, {3, 1}, {n, w}, false},
+ {{1, 3}, {3, 2}, {n, h}, false}, {{2, 3}, {1, 2}, {n, c}, false},
+ {{0, 1, 2}, {0, 3, 1}, {w}, false}};
+
+ for (auto &tc : test_cases)
+ {
+ ReduceMaxGraph g;
+ g.keep_dims(false);
+ g.axes(tc.nchw_ind);
+ g.shape(tc.shape);
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.rm->input());
+
+ auto rm_succs = loco::succs(g.rm);
+ EXPECT_EQ(1, rm_succs.size());
+ if (tc.needs_transpose)
+ {
+ EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*rm_succs.begin()));
+ }
+ else
+ {
+ EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*rm_succs.begin()));
+ }
+
+ auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
+ EXPECT_NE(nullptr, new_rindices);
+ EXPECT_EQ(1, new_rindices->rank());
+ EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
+ EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
+ for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
+ {
+ EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
+ }
+ }
+}
+
TEST(ConvertNCHWToNHWC, Relu)
{
ReluGraph g;
@@ -1511,6 +1953,57 @@ TEST(ConvertNCHWToNHWC, Rsqrt)
EXPECT_EQ(16, g.rsqrt->dim(3).value());
}
+TEST(ConvertNCHWToNHWC, Softmax)
+{
+ SoftmaxGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.softmax->logits());
+
+ auto softmax_succs = loco::succs(g.softmax);
+ EXPECT_EQ(1, softmax_succs.size());
+ check_post_trans(*softmax_succs.begin());
+
+ // Check softmax shape
+ EXPECT_EQ(1, g.softmax->dim(0).value());
+ EXPECT_EQ(4, g.softmax->dim(1).value());
+ EXPECT_EQ(4, g.softmax->dim(2).value());
+ EXPECT_EQ(16, g.softmax->dim(3).value());
+}
+
+TEST(ConvertNCHWToNHWC, SplitV)
+{
+ SplitVGraph g;
+ g.init();
+
+ run_phase(g.g(), true, true);
+
+ check_pre_trans(g.splitv()->input());
+
+ auto splitv_succs = loco::succs(g.splitv());
+ for (auto svo : loco::succs(g.splitv()))
+ {
+ for (auto succ : loco::succs(svo))
+ {
+ check_post_trans(succ);
+ }
+ }
+
+ // Check splitv() shape
+ EXPECT_EQ(1, g.splitv()->dim(0).value());
+ EXPECT_EQ(2, g.splitv()->dim(1).value());
+ EXPECT_EQ(192, g.splitv()->dim(2).value());
+ EXPECT_EQ(2, g.splitv()->dim(3).value());
+
+ // Check axis
+ auto axis = dynamic_cast<luci::CircleConst *>(g.splitv()->split_dim());
+ EXPECT_NE(nullptr, axis);
+ EXPECT_EQ(1, axis->size<loco::DataType::S32>());
+ EXPECT_EQ(2, axis->at<loco::DataType::S32>(0));
+}
+
TEST(ConvertNCHWToNHWC, SquaredDifference)
{
SquaredDifferenceGraph g;
@@ -1602,3 +2095,31 @@ TEST(ConvertNCHWToNHWC, SubScalar)
check_pre_trans(g.output->from());
}
+
+TEST(ConvertNCHWToNHWC, Not_Closed_Case1_NEG)
+{
+ NoPostReshapeGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.relu->features());
+
+ auto relu_succs = loco::succs(g.relu);
+ EXPECT_EQ(1, relu_succs.size());
+ check_post_trans(*relu_succs.begin());
+}
+
+TEST(ConvertNCHWToNHWC, Not_Closed_Case2_NEG)
+{
+ ReluNotClosedGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.relu->features());
+
+ auto relu_succs = loco::succs(g.relu);
+ EXPECT_EQ(1, relu_succs.size());
+ check_post_trans(*relu_succs.begin());
+}