diff options
author | Chunseok Lee <chunseok.lee@samsung.com> | 2021-08-23 13:25:15 +0900 |
---|---|---|
committer | Chunseok Lee <chunseok.lee@samsung.com> | 2021-08-23 13:25:15 +0900 |
commit | f4cf19e579a19c5346ccb2aad55bfd251065e447 (patch) | |
tree | 5d436b11f89be0e8a8289ea82b773da6402c1add /compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp | |
parent | 589bb1db6db6784efe21b3fbbfbfdb79aaa5f14e (diff) | |
download | nnfw-f4cf19e579a19c5346ccb2aad55bfd251065e447.tar.gz nnfw-f4cf19e579a19c5346ccb2aad55bfd251065e447.tar.bz2 nnfw-f4cf19e579a19c5346ccb2aad55bfd251065e447.zip |
Imported Upstream version 1.17.0upstream/1.17.0submit/tizen/20210823.054833submit/tizen/20210823.045832submit/tizen/20210823.044411accepted/tizen/unified/20210823.124210
Diffstat (limited to 'compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp')
-rw-r--r-- | compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp | 821 |
1 files changed, 821 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp index 831d5f89a..d844246f8 100644 --- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp +++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp @@ -134,6 +134,93 @@ public: luci::CircleConst *beta = nullptr; }; +class NHWCReluGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + relu = g.nodes()->create<luci::CircleRelu>(); + pre_reshape = g.nodes()->create<luci::CircleReshape>(); + post_reshape = g.nodes()->create<luci::CircleReshape>(); + pre_shape = g.nodes()->create<luci::CircleConst>(); + post_shape = g.nodes()->create<luci::CircleConst>(); + + pre_shape->dtype(loco::DataType::S32); + post_shape->dtype(loco::DataType::S32); + + uint32_t channel_size = 16; + auto in = loco::must_cast<luci::CircleNode *>(input); + in->shape({1, channel_size, 4, 4}); + pre_shape->shape({4}); + post_shape->shape({4}); + + pre_shape->size<loco::DataType::S32>(4); + pre_shape->at<loco::DataType::S32>(0) = 1; + pre_shape->at<loco::DataType::S32>(1) = 4; + pre_shape->at<loco::DataType::S32>(2) = 4; + pre_shape->at<loco::DataType::S32>(3) = channel_size; + + post_shape->size<loco::DataType::S32>(4); + post_shape->at<loco::DataType::S32>(0) = 1; + post_shape->at<loco::DataType::S32>(1) = channel_size; + post_shape->at<loco::DataType::S32>(2) = 4; + post_shape->at<loco::DataType::S32>(3) = 4; + + pre_reshape->tensor(input); + pre_reshape->shape(pre_shape); + + relu->features(pre_reshape); + + post_reshape->tensor(relu); + post_reshape->shape(post_shape); + + relu->name("Relu"); + pre_reshape->name("pre-reshape"); + post_reshape->name("post-reshape"); + + return post_reshape; + } + +public: + luci::CircleRelu *relu = nullptr; + luci::CircleReshape *pre_reshape = nullptr; + luci::CircleReshape *post_reshape = nullptr; + luci::CircleConst *pre_shape = nullptr; + luci::CircleConst *post_shape = nullptr; +}; + +class AddScalarGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + add = g.nodes()->create<luci::CircleAdd>(); + beta = g.nodes()->create<luci::CircleConst>(); + + add->dtype(loco::DataType::FLOAT32); + beta->dtype(loco::DataType::FLOAT32); + + uint32_t channel_size = 16; + add->shape({1, channel_size, 4, 4}); + beta->shape({1}); + + beta->size<loco::DataType::FLOAT32>(1); + beta->at<loco::DataType::FLOAT32>(0) = 3.14; + + add->x(input); + add->y(beta); + + add->name("add"); + beta->name("beta"); + + return add; + } + +public: + luci::CircleAdd *add = nullptr; + luci::CircleConst *beta = nullptr; +}; + class ConcatenationGraph final : public SimpleGraph { protected: @@ -180,6 +267,129 @@ public: luci::CircleLeakyRelu *leakyrelu = nullptr; }; +class LogisticGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + logistic = g.nodes()->create<luci::CircleLogistic>(); + logistic->x(input); + logistic->name("logistic"); + + return logistic; + } + +public: + luci::CircleLogistic *logistic = nullptr; +}; + +class MaximumGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + max = g.nodes()->create<luci::CircleMaximum>(); + limit = g.nodes()->create<luci::CircleConst>(); + + max->dtype(loco::DataType::FLOAT32); + limit->dtype(loco::DataType::FLOAT32); + + max->shape({1, 16, 4, 4}); + limit->shape({}); + + limit->size<loco::DataType::FLOAT32>(1); + limit->at<loco::DataType::FLOAT32>(0) = 100; + + max->x(input); + max->y(limit); + + max->name("max"); + limit->name("limit"); + + return max; + } + +public: + luci::CircleMaximum *max = nullptr; + luci::CircleConst *limit = nullptr; +}; + +class MeanGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + mean = g.nodes()->create<luci::CircleMean>(); + rindices = g.nodes()->create<luci::CircleConst>(); + + mean->dtype(loco::DataType::FLOAT32); + rindices->dtype(loco::DataType::S32); + + mean->shape(_shape); + rindices->shape({static_cast<uint32_t>(_axes.size())}); + + rindices->size<loco::DataType::S32>(_axes.size()); + for (uint32_t i = 0; i < _axes.size(); ++i) + { + rindices->at<loco::DataType::S32>(i) = _axes[i]; + } + + mean->input(input); + mean->reduction_indices(rindices); + mean->keep_dims(_keep_dims); + + mean->name("mean"); + rindices->name("rindices"); + + return mean; + } + +public: + void keep_dims(bool val) { _keep_dims = val; } + void axes(std::vector<int32_t> val) { _axes = val; } + void shape(std::initializer_list<uint32_t> val) { _shape = val; } + +public: + luci::CircleMean *mean = nullptr; + luci::CircleConst *rindices = nullptr; + +private: + bool _keep_dims = true; + std::vector<int32_t> _axes = {2, 3}; + std::initializer_list<uint32_t> _shape = {1, 16, 1, 1}; +}; + +class MinimumGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + min = g.nodes()->create<luci::CircleMinimum>(); + limit = g.nodes()->create<luci::CircleConst>(); + + min->dtype(loco::DataType::FLOAT32); + limit->dtype(loco::DataType::FLOAT32); + + min->shape({1, 16, 4, 4}); + limit->shape({}); + + limit->size<loco::DataType::FLOAT32>(1); + limit->at<loco::DataType::FLOAT32>(0) = 100; + + min->x(input); + min->y(limit); + + min->name("min"); + limit->name("limit"); + + return min; + } + +public: + luci::CircleMinimum *min = nullptr; + luci::CircleConst *limit = nullptr; +}; + class MulGraph final : public SimpleGraph { protected: @@ -215,6 +425,62 @@ public: luci::CircleConst *multiplier = nullptr; }; +class MulScalarGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + mul = g.nodes()->create<luci::CircleMul>(); + multiplier = g.nodes()->create<luci::CircleConst>(); + + mul->dtype(loco::DataType::FLOAT32); + multiplier->dtype(loco::DataType::FLOAT32); + + uint32_t channel_size = 16; + mul->shape({1, channel_size, 4, 4}); + multiplier->shape({1}); + + multiplier->size<loco::DataType::FLOAT32>(1); + multiplier->at<loco::DataType::FLOAT32>(0) = 2; + + mul->x(input); + mul->y(multiplier); + + mul->name("mul"); + multiplier->name("multiplier"); + + return mul; + } + +public: + luci::CircleMul *mul = nullptr; + luci::CircleConst *multiplier = nullptr; +}; + +class MulBothNormGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + mul = g.nodes()->create<luci::CircleMul>(); + + mul->dtype(loco::DataType::FLOAT32); + + uint32_t channel_size = 16; + mul->shape({1, channel_size, 4, 4}); + + mul->x(input); + mul->y(input); + + mul->name("mul"); + + return mul; + } + +public: + luci::CircleMul *mul = nullptr; +}; + class NegGraph final : public SimpleGraph { protected: @@ -278,6 +544,62 @@ public: luci::CircleConst *paddings = nullptr; }; +class PadV2Graph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + pad = g.nodes()->create<luci::CirclePadV2>(); + paddings = g.nodes()->create<luci::CircleConst>(); + const_value = g.nodes()->create<luci::CircleConst>(); + + pad->dtype(loco::DataType::FLOAT32); + paddings->dtype(loco::DataType::S32); + const_value->dtype(loco::DataType::FLOAT32); + + uint32_t channel_size = 16; + pad->shape({1, channel_size, 4, 4}); + paddings->shape({4, 2}); + const_value->shape({1}); + + // paddings data (NCHW) + // [[0,0], [0,0], [1,1], [2,2]] + paddings->size<loco::DataType::S32>(8); + for (uint32_t dim = 0; dim < 4; dim++) + { + for (uint32_t i = 0; i < 2; i++) + { + int32_t data = 0; + + if (dim == 2) + data = 1; + else if (dim == 3) + data = 2; + + paddings->at<loco::DataType::S32>(dim * 2 + i) = data; + } + } + + const_value->size<loco::DataType::FLOAT32>(1); + const_value->at<loco::DataType::FLOAT32>(0) = -3.4; + + pad->input(input); + pad->paddings(paddings); + pad->constant_values(paddings); + + pad->name("padV2"); + paddings->name("paddings"); + const_value->name("constant_values"); + + return pad; + } + +public: + luci::CirclePadV2 *pad = nullptr; + luci::CircleConst *paddings = nullptr; + luci::CircleConst *const_value = nullptr; +}; + class ReluGraph final : public SimpleGraph { protected: @@ -310,6 +632,106 @@ public: luci::CircleRelu6 *relu6 = nullptr; }; +class RsqrtGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + rsqrt = g.nodes()->create<luci::CircleRsqrt>(); + rsqrt->x(input); + rsqrt->name("rsqrt"); + + return rsqrt; + } + +public: + luci::CircleRsqrt *rsqrt = nullptr; +}; + +class SquaredDifferenceGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + sqdiff = g.nodes()->create<luci::CircleSquaredDifference>(); + sqdiff->x(input); + sqdiff->y(input); + sqdiff->name("sqdiff"); + + return sqdiff; + } + +public: + luci::CircleSquaredDifference *sqdiff = nullptr; +}; + +class SubGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + sub = g.nodes()->create<luci::CircleSub>(); + beta = g.nodes()->create<luci::CircleConst>(); + + sub->dtype(loco::DataType::FLOAT32); + beta->dtype(loco::DataType::FLOAT32); + + uint32_t channel_size = 16; + sub->shape({1, channel_size, 4, 4}); + beta->shape({1, channel_size, 1, 1}); + + beta->size<loco::DataType::FLOAT32>(channel_size); + for (uint32_t i = 0; i < channel_size; i++) + { + beta->at<loco::DataType::FLOAT32>(i) = i; + } + + sub->x(input); + sub->y(beta); + + sub->name("sub"); + beta->name("beta"); + + return sub; + } + +public: + luci::CircleSub *sub = nullptr; + luci::CircleConst *beta = nullptr; +}; + +class SubScalarGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + sub = g.nodes()->create<luci::CircleSub>(); + beta = g.nodes()->create<luci::CircleConst>(); + + sub->dtype(loco::DataType::FLOAT32); + beta->dtype(loco::DataType::FLOAT32); + + uint32_t channel_size = 16; + sub->shape({1, channel_size, 4, 4}); + beta->shape({1}); + + beta->size<loco::DataType::FLOAT32>(1); + beta->at<loco::DataType::FLOAT32>(0) = 5; + + sub->x(beta); + sub->y(input); + + sub->name("sub"); + beta->name("beta"); + + return sub; + } + +public: + luci::CircleSub *sub = nullptr; + luci::CircleConst *beta = nullptr; +}; + void check_pre_trans(loco::Node *node) { auto pre_trans = dynamic_cast<luci::CircleTranspose *>(node); @@ -393,6 +815,47 @@ TEST(ConvertNCHWToNHWC, Add) check_pre_trans(g.output->from()); } +TEST(ConvertNCHWToNHWC, NHWC_Relu) +{ + // Relu is already NHWC, so it should not be converted + // i.e., the graph is not changed + NHWCReluGraph g; + g.init(); + + run_phase(&g.g, false, false); + + EXPECT_EQ(g.pre_reshape, g.relu->features()); + + auto relu_succs = loco::succs(g.relu); + EXPECT_EQ(1, relu_succs.size()); + EXPECT_EQ(g.post_reshape, *relu_succs.begin()); +} + +TEST(ConvertNCHWToNHWC, AddScalar) +{ + AddScalarGraph g; + g.init(); + + run_phase(&g.g, false, false); + + auto input_succs = loco::succs(g.input); + EXPECT_EQ(1, input_succs.size()); + check_post_trans(*input_succs.begin()); + + check_pre_trans(g.add->x()); + + auto add_succs = loco::succs(g.add); + EXPECT_EQ(1, add_succs.size()); + check_post_trans(*add_succs.begin()); + + auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y()); + EXPECT_NE(nullptr, new_beta); + EXPECT_EQ(1, new_beta->rank()); + EXPECT_EQ(1, new_beta->dim(0).value()); + + check_pre_trans(g.output->from()); +} + TEST(ConvertNCHWToNHWC, Concatenation) { ConcatenationGraph g; @@ -435,6 +898,202 @@ TEST(ConvertNCHWToNHWC, LeakyRelu) EXPECT_EQ(16, g.leakyrelu->dim(3).value()); } +TEST(ConvertNCHWToNHWC, Logistic) +{ + LogisticGraph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.logistic->x()); + + auto logistic_succs = loco::succs(g.logistic); + EXPECT_EQ(1, logistic_succs.size()); + check_post_trans(*logistic_succs.begin()); + + // Check logistic shape + EXPECT_EQ(1, g.logistic->dim(0).value()); + EXPECT_EQ(4, g.logistic->dim(1).value()); + EXPECT_EQ(4, g.logistic->dim(2).value()); + EXPECT_EQ(16, g.logistic->dim(3).value()); +} + +TEST(ConvertNCHWToNHWC, Maximum) +{ + MaximumGraph g; + g.init(); + + run_phase(&g.g, false, false); + + auto input_succs = loco::succs(g.input); + EXPECT_EQ(1, input_succs.size()); + check_post_trans(*input_succs.begin()); + + check_pre_trans(g.max->x()); + + auto max_succs = loco::succs(g.max); + EXPECT_EQ(1, max_succs.size()); + check_post_trans(*max_succs.begin()); + + check_pre_trans(g.output->from()); +} + +TEST(ConvertNCHWToNHWC, Mean) +{ + MeanGraph g; + g.init(); + + run_phase(&g.g, false, false); + + check_pre_trans(g.mean->input()); + + auto mean_succs = loco::succs(g.mean); + EXPECT_EQ(1, mean_succs.size()); + check_post_trans(*mean_succs.begin()); + + auto new_rindices = dynamic_cast<luci::CircleConst *>(g.mean->reduction_indices()); + EXPECT_NE(nullptr, new_rindices); + EXPECT_EQ(1, new_rindices->rank()); + EXPECT_EQ(2, new_rindices->dim(0).value()); + EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>()); + EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0)); + EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1)); +} + +TEST(ConvertNCHWToNHWC, Mean_keep_dims_false) +{ + struct TC + { + std::vector<int32_t> nchw_ind; + std::vector<int32_t> nhwc_ind; + std::initializer_list<uint32_t> shape; + bool needs_transpose = false; + }; + + uint32_t n = 1; + uint32_t c = 16; + uint32_t h = 4; + uint32_t w = 4; + + std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true}, {{1}, {3}, {n, h, w}, false}, + {{2}, {1}, {n, c, w}, true}, {{3}, {2}, {n, c, h}, true}, + {{0, 1}, {0, 3}, {h, w}, false}, {{0, 2}, {0, 1}, {c, w}, true}, + {{0, 3}, {0, 2}, {c, h}, true}, {{1, 2}, {3, 1}, {n, w}, false}, + {{1, 3}, {3, 2}, {n, h}, false}, {{2, 3}, {1, 2}, {n, c}, false}, + {{0, 1, 2}, {0, 3, 1}, {w}, false}}; + + for (auto &tc : test_cases) + { + MeanGraph g; + g.keep_dims(false); + g.axes(tc.nchw_ind); + g.shape(tc.shape); + g.init(); + + run_phase(&g.g, false, true); + + check_pre_trans(g.mean->input()); + + auto mean_succs = loco::succs(g.mean); + EXPECT_EQ(1, mean_succs.size()); + if (tc.needs_transpose) + { + EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*mean_succs.begin())); + } + else + { + EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*mean_succs.begin())); + } + + auto new_rindices = dynamic_cast<luci::CircleConst *>(g.mean->reduction_indices()); + EXPECT_NE(nullptr, new_rindices); + EXPECT_EQ(1, new_rindices->rank()); + EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value()); + EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>()); + for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i) + { + EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i)); + } + } +} + +TEST(ConvertNCHWToNHWC, ConvertNCHWToNHWC_Mean_keep_dims_false_NEG) +{ + loco::Graph g; + auto input = g.nodes()->create<luci::CircleInput>(); + auto output = g.nodes()->create<luci::CircleOutput>(); + input->name("input"); + output->name("output"); + + auto graph_input = g.inputs()->create(); + input->index(graph_input->index()); + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + graph_input->dtype(loco::DataType::FLOAT32); + input->dtype(loco::DataType::FLOAT32); + output->dtype(loco::DataType::FLOAT32); + graph_output->dtype(loco::DataType::FLOAT32); + + uint32_t channel_size = 16; + graph_input->shape({channel_size, 4, 4}); + input->shape({channel_size, 4, 4}); + output->shape({channel_size}); + graph_output->shape({channel_size}); + + auto mean = g.nodes()->create<luci::CircleMean>(); + auto rindices = g.nodes()->create<luci::CircleConst>(); + + mean->dtype(loco::DataType::FLOAT32); + rindices->dtype(loco::DataType::S32); + + mean->shape({channel_size}); + rindices->shape({2}); + + rindices->size<loco::DataType::S32>(2); + rindices->at<loco::DataType::S32>(0) = 1; + rindices->at<loco::DataType::S32>(1) = 2; + + mean->input(input); + mean->reduction_indices(rindices); + mean->keep_dims(false); + + mean->name("mean"); + rindices->name("rindices"); + + output->from(mean); + + run_phase(&g, true, true); + + auto new_rindices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices()); + EXPECT_NE(nullptr, new_rindices); + EXPECT_EQ(1, new_rindices->rank()); + EXPECT_EQ(2, new_rindices->dim(0).value()); + EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>()); + EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0)); + EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1)); +} + +TEST(ConvertNCHWToNHWC, Minimum) +{ + MinimumGraph g; + g.init(); + + run_phase(&g.g, false, false); + + auto input_succs = loco::succs(g.input); + EXPECT_EQ(1, input_succs.size()); + check_post_trans(*input_succs.begin()); + + check_pre_trans(g.min->x()); + + auto min_succs = loco::succs(g.min); + EXPECT_EQ(1, min_succs.size()); + check_post_trans(*min_succs.begin()); + + check_pre_trans(g.output->from()); +} + TEST(ConvertNCHWToNHWC, Mul) { MulGraph g; @@ -464,6 +1123,52 @@ TEST(ConvertNCHWToNHWC, Mul) check_pre_trans(g.output->from()); } +TEST(ConvertNCHWToNHWC, MulScalar) +{ + MulScalarGraph g; + g.init(); + + run_phase(&g.g, false, false); + + auto input_succs = loco::succs(g.input); + EXPECT_EQ(1, input_succs.size()); + check_post_trans(*input_succs.begin()); + + check_pre_trans(g.mul->x()); + + auto mul_succs = loco::succs(g.mul); + EXPECT_EQ(1, mul_succs.size()); + check_post_trans(*mul_succs.begin()); + + auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y()); + EXPECT_NE(nullptr, new_multiplier); + EXPECT_EQ(1, new_multiplier->rank()); + EXPECT_EQ(1, new_multiplier->dim(0).value()); + + check_pre_trans(g.output->from()); +} + +TEST(ConvertNCHWToNHWC, MulBothNorm) +{ + MulBothNormGraph g; + g.init(); + + run_phase(&g.g, false, false); + + auto input_succs = loco::succs(g.input); + EXPECT_EQ(1, input_succs.size()); + check_post_trans(*input_succs.begin()); + + check_pre_trans(g.mul->x()); + check_pre_trans(g.mul->y()); + + auto mul_succs = loco::succs(g.mul); + EXPECT_EQ(1, mul_succs.size()); + check_post_trans(*mul_succs.begin()); + + check_pre_trans(g.output->from()); +} + TEST(ConvertNCHWToNHWC, Neg) { NegGraph g; @@ -518,6 +1223,34 @@ TEST(ConvertNCHWToNHWC, Pad) check_pre_trans(g.output->from()); } +TEST(ConvertNCHWToNHWC, PadV2) +{ + PadV2Graph g; + g.init(); + + run_phase(&g.g, false, false); + + check_pre_trans(g.pad->input()); + + auto pad_succs = loco::succs(g.pad); + EXPECT_EQ(1, pad_succs.size()); + check_post_trans(*pad_succs.begin()); + + auto new_paddings = dynamic_cast<luci::CircleConst *>(g.pad->paddings()); + EXPECT_NE(nullptr, new_paddings); + EXPECT_EQ(2, new_paddings->rank()); + EXPECT_EQ(4, new_paddings->dim(0).value()); + EXPECT_EQ(2, new_paddings->dim(1).value()); + EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(0)); + EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(1)); + EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(2)); + EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(3)); + EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(4)); + EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(5)); + EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(6)); + EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(7)); +} + TEST(ConvertNCHWToNHWC, Unknown_Shape_NEG) { AddGraph g; @@ -634,3 +1367,91 @@ TEST(ConvertNCHWToNHWC, Relu6) EXPECT_EQ(4, g.relu6->dim(2).value()); EXPECT_EQ(16, g.relu6->dim(3).value()); } + +TEST(ConvertNCHWToNHWC, Rsqrt) +{ + RsqrtGraph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.rsqrt->x()); + + auto rsqrt_succs = loco::succs(g.rsqrt); + EXPECT_EQ(1, rsqrt_succs.size()); + check_post_trans(*rsqrt_succs.begin()); + + // Check rsqrt shape + EXPECT_EQ(1, g.rsqrt->dim(0).value()); + EXPECT_EQ(4, g.rsqrt->dim(1).value()); + EXPECT_EQ(4, g.rsqrt->dim(2).value()); + EXPECT_EQ(16, g.rsqrt->dim(3).value()); +} + +TEST(ConvertNCHWToNHWC, SquaredDifference) +{ + SquaredDifferenceGraph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.sqdiff->x()); + check_pre_trans(g.sqdiff->y()); + + auto sqdiff_succs = loco::succs(g.sqdiff); + EXPECT_EQ(1, sqdiff_succs.size()); + check_post_trans(*sqdiff_succs.begin()); +} + +TEST(ConvertNCHWToNHWC, Sub) +{ + SubGraph g; + g.init(); + + run_phase(&g.g, false, false); + + auto input_succs = loco::succs(g.input); + EXPECT_EQ(1, input_succs.size()); + check_post_trans(*input_succs.begin()); + + check_pre_trans(g.sub->x()); + + auto add_succs = loco::succs(g.sub); + EXPECT_EQ(1, add_succs.size()); + check_post_trans(*add_succs.begin()); + + uint32_t channel_size = 16; + auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->y()); + EXPECT_NE(nullptr, new_beta); + EXPECT_EQ(4, new_beta->rank()); + EXPECT_EQ(1, new_beta->dim(0).value()); + EXPECT_EQ(1, new_beta->dim(1).value()); + EXPECT_EQ(1, new_beta->dim(2).value()); + EXPECT_EQ(channel_size, new_beta->dim(3).value()); + + check_pre_trans(g.output->from()); +} + +TEST(ConvertNCHWToNHWC, SubScalar) +{ + SubScalarGraph g; + g.init(); + + run_phase(&g.g, false, false); + + auto input_succs = loco::succs(g.input); + EXPECT_EQ(1, input_succs.size()); + check_post_trans(*input_succs.begin()); + + check_pre_trans(g.sub->y()); + + auto add_succs = loco::succs(g.sub); + EXPECT_EQ(1, add_succs.size()); + check_post_trans(*add_succs.begin()); + + auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->x()); + EXPECT_NE(nullptr, new_beta); + EXPECT_EQ(1, new_beta->rank()); + + check_pre_trans(g.output->from()); +} |