From c690d52bdd137ed6a17353aa7af35e8141ece77b Mon Sep 17 00:00:00 2001 From: Chunseok Lee Date: Wed, 7 Sep 2022 19:04:21 +0900 Subject: Imported Upstream version 1.21.0 --- .../luci/pass/src/ConvertNCHWToNHWCPass.test.cpp | 525 ++++++++++++++++++++- 1 file changed, 523 insertions(+), 2 deletions(-) (limited to 'compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp') diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp index dd81d1380..6bb3d3268 100644 --- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp +++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp @@ -16,6 +16,8 @@ #include +#include + #include "luci/Pass/ConvertNCHWToNHWCPass.h" #include "luci/Pass/CircleShapeInferencePass.h" @@ -23,6 +25,8 @@ #include +using namespace luci::test; + namespace { @@ -202,6 +206,173 @@ public: luci::CircleConst *post_shape = nullptr; }; +/** + * Graph with pre-Reshape but no post-Transpose/Reshape. + * + * BEFORE + * [Input] + * | + * [Pre-Reshape] + * | + * [Relu] + * | + * [Output] + * + * AFTER + * [Input] + * | + * [Pre-Reshape] + * | + * [Pre-Transpose] + * | + * [Relu] + * | + * [Post-Transpose] + * | + * [Output] + */ +class NoPostReshapeGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + relu = g.nodes()->create(); + pre_reshape = g.nodes()->create(); + pre_shape = g.nodes()->create(); + + pre_shape->dtype(loco::DataType::S32); + + uint32_t channel_size = 16; + auto in = loco::must_cast(input); + in->shape({1, channel_size, 4, 4}); + pre_shape->shape({4}); + + pre_shape->size(4); + pre_shape->at(0) = 1; + pre_shape->at(1) = 4; + pre_shape->at(2) = 4; + pre_shape->at(3) = channel_size; + + pre_reshape->tensor(input); + pre_reshape->shape(pre_shape); + relu->features(pre_reshape); + + relu->name("Relu"); + pre_reshape->name("pre-reshape"); + + return relu; + } + +public: + luci::CircleRelu *relu = nullptr; + luci::CircleReshape *pre_reshape = nullptr; + luci::CircleConst *pre_shape = nullptr; +}; + +/** + * Graph with two pre-Reshapes + * + * BEFORE + * [Input] + * | + * [Pre-Reshape] + * | + * [Relu] + * | + * [Pre-Reshape] + * | + * [Post-Reshape] + * | + * [Output] + * + * AFTER + * [Input] + * | + * [Pre-Reshape] + * | + * [Pre-Transpose] + * | + * [Relu] + * | + * [Post-Transpose] + * | + * [Pre-Reshape] + * | + * [Post-Reshape] + * | + * [Output] + */ +class ReluNotClosedGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + relu = g.nodes()->create(); + pre_reshape = g.nodes()->create(); + pre_reshape_2 = g.nodes()->create(); + post_reshape = g.nodes()->create(); + pre_shape = g.nodes()->create(); + pre_shape_2 = g.nodes()->create(); + post_shape = g.nodes()->create(); + + pre_shape->dtype(loco::DataType::S32); + pre_shape_2->dtype(loco::DataType::S32); + post_shape->dtype(loco::DataType::S32); + + uint32_t channel_size = 16; + auto in = loco::must_cast(input); + in->shape({1, channel_size, 4, 4}); + pre_shape->shape({4}); + pre_shape_2->shape({4}); + post_shape->shape({4}); + + pre_shape->size(4); + pre_shape->at(0) = 1; + pre_shape->at(1) = 4; + pre_shape->at(2) = 4; + pre_shape->at(3) = channel_size; + + pre_shape_2->size(4); + pre_shape_2->at(0) = 1; + pre_shape_2->at(1) = 4; + pre_shape_2->at(2) = channel_size; + pre_shape_2->at(3) = 4; + + post_shape->size(4); + post_shape->at(0) = 1; + post_shape->at(1) = 4; + post_shape->at(2) = 4; + post_shape->at(3) = channel_size; + + pre_reshape->tensor(input); + pre_reshape->shape(pre_shape); + + relu->features(pre_reshape); + + pre_reshape_2->tensor(relu); + pre_reshape_2->shape(pre_shape_2); + + post_reshape->tensor(pre_reshape_2); + post_reshape->shape(post_shape); + + relu->name("Relu"); + pre_reshape->name("pre-reshape"); + pre_reshape->name("pre-reshape-2"); + post_reshape->name("post-reshape"); + + return post_reshape; + } + +public: + luci::CircleRelu *relu = nullptr; + luci::CircleReshape *pre_reshape = nullptr; + luci::CircleReshape *pre_reshape_2 = nullptr; + luci::CircleReshape *post_reshape = nullptr; + luci::CircleConst *pre_shape = nullptr; + luci::CircleConst *pre_shape_2 = nullptr; + luci::CircleConst *post_shape = nullptr; +}; + class AddScalarGraph final : public SimpleGraph { protected: @@ -312,6 +483,22 @@ public: luci::CircleLogistic *logistic = nullptr; }; +class LogSoftmaxGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + log_softmax = g.nodes()->create(); + log_softmax->logits(input); + log_softmax->name("log_softmax"); + + return log_softmax; + } + +public: + luci::CircleLogSoftmax *log_softmax = nullptr; +}; + class MaximumGraph final : public SimpleGraph { protected: @@ -642,6 +829,51 @@ public: luci::CircleConst *const_value = nullptr; }; +class ReduceMaxGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + rm = g.nodes()->create(); + rindices = g.nodes()->create(); + + rm->dtype(loco::DataType::FLOAT32); + rindices->dtype(loco::DataType::S32); + + rm->shape(_shape); + rindices->shape({static_cast(_axes.size())}); + + rindices->size(_axes.size()); + for (uint32_t i = 0; i < _axes.size(); ++i) + { + rindices->at(i) = _axes[i]; + } + + rm->input(input); + rm->reduction_indices(rindices); + rm->keep_dims(_keep_dims); + + rm->name("reduce_max"); + rindices->name("rindices"); + + return rm; + } + +public: + void keep_dims(bool val) { _keep_dims = val; } + void axes(std::vector val) { _axes = val; } + void shape(std::initializer_list val) { _shape = val; } + +public: + luci::CircleReduceMax *rm = nullptr; + luci::CircleConst *rindices = nullptr; + +private: + bool _keep_dims = true; + std::vector _axes = {2, 3}; + std::initializer_list _shape = {1, 16, 1, 1}; +}; + class ReluGraph final : public SimpleGraph { protected: @@ -690,6 +922,111 @@ public: luci::CircleRsqrt *rsqrt = nullptr; }; +class SoftmaxGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + softmax = g.nodes()->create(); + softmax->logits(input); + softmax->name("softmax"); + + return softmax; + } + +public: + luci::CircleSoftmax *softmax = nullptr; +}; + +class SplitVGraphlet +{ +public: + SplitVGraphlet() = default; + +public: + void init(loco::Graph *g) + { + // CircleCustom(SplitV) + _splitv = g->nodes()->create(); + _splitv->shape({1, 2, 2, 192}); + _splitv->dtype(loco::DataType::FLOAT32); + _splitv->name("splitv"); + + // CircleConst + auto size_splits = g->nodes()->create(); + size_splits->dtype(loco::DataType::S32); + size_splits->shape({3}); + size_splits->size(3); + size_splits->at(0) = 32; + size_splits->at(1) = 32; + size_splits->at(2) = 128; + + // CircleConst + auto split_dim = g->nodes()->create(); + split_dim->dtype(loco::DataType::S32); + split_dim->rank(0); + split_dim->size(1); + split_dim->scalar() = 3; + + _splitv->size_splits(size_splits); + _splitv->split_dim(split_dim); + _splitv->num_split(3); + + // CircleSplitVOut + _splitv_out1 = g->nodes()->create(); + _splitv_out1->shape({1, 2, 2, 32}); + _splitv_out1->dtype(loco::DataType::FLOAT32); + _splitv_out1->index(0); + _splitv_out1->input(_splitv); + _splitv_out1->name("splitv_out1"); + + // CircleSplitVOut + _splitv_out2 = g->nodes()->create(); + _splitv_out2->shape({1, 2, 2, 32}); + _splitv_out2->dtype(loco::DataType::FLOAT32); + _splitv_out2->index(1); + _splitv_out2->input(_splitv); + _splitv_out2->name("splitv_out2"); + + // CircleSplitVOut + _splitv_out3 = g->nodes()->create(); + _splitv_out3->shape({1, 2, 2, 128}); + _splitv_out3->dtype(loco::DataType::FLOAT32); + _splitv_out3->index(2); + _splitv_out3->input(_splitv); + _splitv_out3->name("splitv_out3"); + } + +public: + luci::CircleSplitV *splitv() { return _splitv; } + +protected: + luci::CircleSplitV *_splitv = nullptr; + luci::CircleSplitVOut *_splitv_out1 = nullptr; + luci::CircleSplitVOut *_splitv_out2 = nullptr; + luci::CircleSplitVOut *_splitv_out3 = nullptr; +}; + +class SplitVGraph : public TestIGraphlet, public TestOsGraphlet<3>, public SplitVGraphlet +{ +public: + SplitVGraph() = default; + + void init(void) + { + TestIGraphlet::init(g(), {1, 2, 2, 192}); + TestOsGraphlet<3>::init(g(), {{1, 2, 2, 32}, {1, 2, 2, 32}, {1, 2, 2, 128}}); + SplitVGraphlet::init(g()); + + // connect graph + _splitv->input(input()); + + output(0)->from(_splitv_out1); + output(1)->from(_splitv_out2); + output(2)->from(_splitv_out3); + } +}; + class SquaredDifferenceGraph final : public SimpleGraph { protected: @@ -929,8 +1266,11 @@ TEST(ConvertNCHWToNHWC, AddScalar) auto new_beta = dynamic_cast(g.add->y()); EXPECT_NE(nullptr, new_beta); - EXPECT_EQ(1, new_beta->rank()); + EXPECT_EQ(4, new_beta->rank()); EXPECT_EQ(1, new_beta->dim(0).value()); + EXPECT_EQ(1, new_beta->dim(1).value()); + EXPECT_EQ(1, new_beta->dim(2).value()); + EXPECT_EQ(1, new_beta->dim(3).value()); check_pre_trans(g.output->from()); } @@ -1017,6 +1357,26 @@ TEST(ConvertNCHWToNHWC, Logistic) EXPECT_EQ(16, g.logistic->dim(3).value()); } +TEST(ConvertNCHWToNHWC, LogSoftmax) +{ + LogSoftmaxGraph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.log_softmax->logits()); + + auto log_softmax_succs = loco::succs(g.log_softmax); + EXPECT_EQ(1, log_softmax_succs.size()); + check_post_trans(*log_softmax_succs.begin()); + + // Check log_softmax shape + EXPECT_EQ(1, g.log_softmax->dim(0).value()); + EXPECT_EQ(4, g.log_softmax->dim(1).value()); + EXPECT_EQ(4, g.log_softmax->dim(2).value()); + EXPECT_EQ(16, g.log_softmax->dim(3).value()); +} + TEST(ConvertNCHWToNHWC, Maximum) { MaximumGraph g; @@ -1265,8 +1625,11 @@ TEST(ConvertNCHWToNHWC, MulScalar) auto new_multiplier = dynamic_cast(g.mul->y()); EXPECT_NE(nullptr, new_multiplier); - EXPECT_EQ(1, new_multiplier->rank()); + EXPECT_EQ(4, new_multiplier->rank()); EXPECT_EQ(1, new_multiplier->dim(0).value()); + EXPECT_EQ(1, new_multiplier->dim(1).value()); + EXPECT_EQ(1, new_multiplier->dim(2).value()); + EXPECT_EQ(1, new_multiplier->dim(3).value()); check_pre_trans(g.output->from()); } @@ -1451,6 +1814,85 @@ TEST(ConvertNCHWToNHWC, Preserve_Input_Output) } } +TEST(ConvertNCHWToNHWC, ReduceMax) +{ + ReduceMaxGraph g; + g.init(); + + run_phase(&g.g, false, false); + + check_pre_trans(g.rm->input()); + + auto rm_succs = loco::succs(g.rm); + EXPECT_EQ(1, rm_succs.size()); + check_post_trans(*rm_succs.begin()); + + auto new_rindices = dynamic_cast(g.rm->reduction_indices()); + EXPECT_NE(nullptr, new_rindices); + EXPECT_EQ(1, new_rindices->rank()); + EXPECT_EQ(2, new_rindices->dim(0).value()); + EXPECT_EQ(2, new_rindices->size()); + EXPECT_EQ(1, new_rindices->at(0)); + EXPECT_EQ(2, new_rindices->at(1)); +} + +TEST(ConvertNCHWToNHWC, ReduceMax_keep_dims_false) +{ + struct TC + { + std::vector nchw_ind; + std::vector nhwc_ind; + std::initializer_list shape; + bool needs_transpose = false; + }; + + uint32_t n = 1; + uint32_t c = 16; + uint32_t h = 4; + uint32_t w = 4; + + std::vector test_cases{{{0}, {0}, {c, h, w}, true}, {{1}, {3}, {n, h, w}, false}, + {{2}, {1}, {n, c, w}, true}, {{3}, {2}, {n, c, h}, true}, + {{0, 1}, {0, 3}, {h, w}, false}, {{0, 2}, {0, 1}, {c, w}, true}, + {{0, 3}, {0, 2}, {c, h}, true}, {{1, 2}, {3, 1}, {n, w}, false}, + {{1, 3}, {3, 2}, {n, h}, false}, {{2, 3}, {1, 2}, {n, c}, false}, + {{0, 1, 2}, {0, 3, 1}, {w}, false}}; + + for (auto &tc : test_cases) + { + ReduceMaxGraph g; + g.keep_dims(false); + g.axes(tc.nchw_ind); + g.shape(tc.shape); + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.rm->input()); + + auto rm_succs = loco::succs(g.rm); + EXPECT_EQ(1, rm_succs.size()); + if (tc.needs_transpose) + { + EXPECT_NE(nullptr, dynamic_cast(*rm_succs.begin())); + } + else + { + EXPECT_NE(nullptr, dynamic_cast(*rm_succs.begin())); + } + + auto new_rindices = dynamic_cast(g.rm->reduction_indices()); + EXPECT_NE(nullptr, new_rindices); + EXPECT_EQ(1, new_rindices->rank()); + EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value()); + EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size()); + for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i) + { + EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at(i)); + } + } +} + TEST(ConvertNCHWToNHWC, Relu) { ReluGraph g; @@ -1511,6 +1953,57 @@ TEST(ConvertNCHWToNHWC, Rsqrt) EXPECT_EQ(16, g.rsqrt->dim(3).value()); } +TEST(ConvertNCHWToNHWC, Softmax) +{ + SoftmaxGraph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.softmax->logits()); + + auto softmax_succs = loco::succs(g.softmax); + EXPECT_EQ(1, softmax_succs.size()); + check_post_trans(*softmax_succs.begin()); + + // Check softmax shape + EXPECT_EQ(1, g.softmax->dim(0).value()); + EXPECT_EQ(4, g.softmax->dim(1).value()); + EXPECT_EQ(4, g.softmax->dim(2).value()); + EXPECT_EQ(16, g.softmax->dim(3).value()); +} + +TEST(ConvertNCHWToNHWC, SplitV) +{ + SplitVGraph g; + g.init(); + + run_phase(g.g(), true, true); + + check_pre_trans(g.splitv()->input()); + + auto splitv_succs = loco::succs(g.splitv()); + for (auto svo : loco::succs(g.splitv())) + { + for (auto succ : loco::succs(svo)) + { + check_post_trans(succ); + } + } + + // Check splitv() shape + EXPECT_EQ(1, g.splitv()->dim(0).value()); + EXPECT_EQ(2, g.splitv()->dim(1).value()); + EXPECT_EQ(192, g.splitv()->dim(2).value()); + EXPECT_EQ(2, g.splitv()->dim(3).value()); + + // Check axis + auto axis = dynamic_cast(g.splitv()->split_dim()); + EXPECT_NE(nullptr, axis); + EXPECT_EQ(1, axis->size()); + EXPECT_EQ(2, axis->at(0)); +} + TEST(ConvertNCHWToNHWC, SquaredDifference) { SquaredDifferenceGraph g; @@ -1602,3 +2095,31 @@ TEST(ConvertNCHWToNHWC, SubScalar) check_pre_trans(g.output->from()); } + +TEST(ConvertNCHWToNHWC, Not_Closed_Case1_NEG) +{ + NoPostReshapeGraph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.relu->features()); + + auto relu_succs = loco::succs(g.relu); + EXPECT_EQ(1, relu_succs.size()); + check_post_trans(*relu_succs.begin()); +} + +TEST(ConvertNCHWToNHWC, Not_Closed_Case2_NEG) +{ + ReluNotClosedGraph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.relu->features()); + + auto relu_succs = loco::succs(g.relu); + EXPECT_EQ(1, relu_succs.size()); + check_post_trans(*relu_succs.begin()); +} -- cgit v1.2.3