summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
diff options
context:
space:
mode:
authorChunseok Lee <chunseok.lee@samsung.com>2021-08-23 13:25:15 +0900
committerChunseok Lee <chunseok.lee@samsung.com>2021-08-23 13:25:15 +0900
commitf4cf19e579a19c5346ccb2aad55bfd251065e447 (patch)
tree5d436b11f89be0e8a8289ea82b773da6402c1add /compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
parent589bb1db6db6784efe21b3fbbfbfdb79aaa5f14e (diff)
downloadnnfw-f4cf19e579a19c5346ccb2aad55bfd251065e447.tar.gz
nnfw-f4cf19e579a19c5346ccb2aad55bfd251065e447.tar.bz2
nnfw-f4cf19e579a19c5346ccb2aad55bfd251065e447.zip
Diffstat (limited to 'compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp')
-rw-r--r--compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp821
1 files changed, 821 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
index 831d5f89a..d844246f8 100644
--- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
+++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
@@ -134,6 +134,93 @@ public:
luci::CircleConst *beta = nullptr;
};
+class NHWCReluGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ relu = g.nodes()->create<luci::CircleRelu>();
+ pre_reshape = g.nodes()->create<luci::CircleReshape>();
+ post_reshape = g.nodes()->create<luci::CircleReshape>();
+ pre_shape = g.nodes()->create<luci::CircleConst>();
+ post_shape = g.nodes()->create<luci::CircleConst>();
+
+ pre_shape->dtype(loco::DataType::S32);
+ post_shape->dtype(loco::DataType::S32);
+
+ uint32_t channel_size = 16;
+ auto in = loco::must_cast<luci::CircleNode *>(input);
+ in->shape({1, channel_size, 4, 4});
+ pre_shape->shape({4});
+ post_shape->shape({4});
+
+ pre_shape->size<loco::DataType::S32>(4);
+ pre_shape->at<loco::DataType::S32>(0) = 1;
+ pre_shape->at<loco::DataType::S32>(1) = 4;
+ pre_shape->at<loco::DataType::S32>(2) = 4;
+ pre_shape->at<loco::DataType::S32>(3) = channel_size;
+
+ post_shape->size<loco::DataType::S32>(4);
+ post_shape->at<loco::DataType::S32>(0) = 1;
+ post_shape->at<loco::DataType::S32>(1) = channel_size;
+ post_shape->at<loco::DataType::S32>(2) = 4;
+ post_shape->at<loco::DataType::S32>(3) = 4;
+
+ pre_reshape->tensor(input);
+ pre_reshape->shape(pre_shape);
+
+ relu->features(pre_reshape);
+
+ post_reshape->tensor(relu);
+ post_reshape->shape(post_shape);
+
+ relu->name("Relu");
+ pre_reshape->name("pre-reshape");
+ post_reshape->name("post-reshape");
+
+ return post_reshape;
+ }
+
+public:
+ luci::CircleRelu *relu = nullptr;
+ luci::CircleReshape *pre_reshape = nullptr;
+ luci::CircleReshape *post_reshape = nullptr;
+ luci::CircleConst *pre_shape = nullptr;
+ luci::CircleConst *post_shape = nullptr;
+};
+
+class AddScalarGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ add = g.nodes()->create<luci::CircleAdd>();
+ beta = g.nodes()->create<luci::CircleConst>();
+
+ add->dtype(loco::DataType::FLOAT32);
+ beta->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ add->shape({1, channel_size, 4, 4});
+ beta->shape({1});
+
+ beta->size<loco::DataType::FLOAT32>(1);
+ beta->at<loco::DataType::FLOAT32>(0) = 3.14;
+
+ add->x(input);
+ add->y(beta);
+
+ add->name("add");
+ beta->name("beta");
+
+ return add;
+ }
+
+public:
+ luci::CircleAdd *add = nullptr;
+ luci::CircleConst *beta = nullptr;
+};
+
class ConcatenationGraph final : public SimpleGraph
{
protected:
@@ -180,6 +267,129 @@ public:
luci::CircleLeakyRelu *leakyrelu = nullptr;
};
+class LogisticGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ logistic = g.nodes()->create<luci::CircleLogistic>();
+ logistic->x(input);
+ logistic->name("logistic");
+
+ return logistic;
+ }
+
+public:
+ luci::CircleLogistic *logistic = nullptr;
+};
+
+class MaximumGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ max = g.nodes()->create<luci::CircleMaximum>();
+ limit = g.nodes()->create<luci::CircleConst>();
+
+ max->dtype(loco::DataType::FLOAT32);
+ limit->dtype(loco::DataType::FLOAT32);
+
+ max->shape({1, 16, 4, 4});
+ limit->shape({});
+
+ limit->size<loco::DataType::FLOAT32>(1);
+ limit->at<loco::DataType::FLOAT32>(0) = 100;
+
+ max->x(input);
+ max->y(limit);
+
+ max->name("max");
+ limit->name("limit");
+
+ return max;
+ }
+
+public:
+ luci::CircleMaximum *max = nullptr;
+ luci::CircleConst *limit = nullptr;
+};
+
+class MeanGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ mean = g.nodes()->create<luci::CircleMean>();
+ rindices = g.nodes()->create<luci::CircleConst>();
+
+ mean->dtype(loco::DataType::FLOAT32);
+ rindices->dtype(loco::DataType::S32);
+
+ mean->shape(_shape);
+ rindices->shape({static_cast<uint32_t>(_axes.size())});
+
+ rindices->size<loco::DataType::S32>(_axes.size());
+ for (uint32_t i = 0; i < _axes.size(); ++i)
+ {
+ rindices->at<loco::DataType::S32>(i) = _axes[i];
+ }
+
+ mean->input(input);
+ mean->reduction_indices(rindices);
+ mean->keep_dims(_keep_dims);
+
+ mean->name("mean");
+ rindices->name("rindices");
+
+ return mean;
+ }
+
+public:
+ void keep_dims(bool val) { _keep_dims = val; }
+ void axes(std::vector<int32_t> val) { _axes = val; }
+ void shape(std::initializer_list<uint32_t> val) { _shape = val; }
+
+public:
+ luci::CircleMean *mean = nullptr;
+ luci::CircleConst *rindices = nullptr;
+
+private:
+ bool _keep_dims = true;
+ std::vector<int32_t> _axes = {2, 3};
+ std::initializer_list<uint32_t> _shape = {1, 16, 1, 1};
+};
+
+class MinimumGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ min = g.nodes()->create<luci::CircleMinimum>();
+ limit = g.nodes()->create<luci::CircleConst>();
+
+ min->dtype(loco::DataType::FLOAT32);
+ limit->dtype(loco::DataType::FLOAT32);
+
+ min->shape({1, 16, 4, 4});
+ limit->shape({});
+
+ limit->size<loco::DataType::FLOAT32>(1);
+ limit->at<loco::DataType::FLOAT32>(0) = 100;
+
+ min->x(input);
+ min->y(limit);
+
+ min->name("min");
+ limit->name("limit");
+
+ return min;
+ }
+
+public:
+ luci::CircleMinimum *min = nullptr;
+ luci::CircleConst *limit = nullptr;
+};
+
class MulGraph final : public SimpleGraph
{
protected:
@@ -215,6 +425,62 @@ public:
luci::CircleConst *multiplier = nullptr;
};
+class MulScalarGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ mul = g.nodes()->create<luci::CircleMul>();
+ multiplier = g.nodes()->create<luci::CircleConst>();
+
+ mul->dtype(loco::DataType::FLOAT32);
+ multiplier->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ mul->shape({1, channel_size, 4, 4});
+ multiplier->shape({1});
+
+ multiplier->size<loco::DataType::FLOAT32>(1);
+ multiplier->at<loco::DataType::FLOAT32>(0) = 2;
+
+ mul->x(input);
+ mul->y(multiplier);
+
+ mul->name("mul");
+ multiplier->name("multiplier");
+
+ return mul;
+ }
+
+public:
+ luci::CircleMul *mul = nullptr;
+ luci::CircleConst *multiplier = nullptr;
+};
+
+class MulBothNormGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ mul = g.nodes()->create<luci::CircleMul>();
+
+ mul->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ mul->shape({1, channel_size, 4, 4});
+
+ mul->x(input);
+ mul->y(input);
+
+ mul->name("mul");
+
+ return mul;
+ }
+
+public:
+ luci::CircleMul *mul = nullptr;
+};
+
class NegGraph final : public SimpleGraph
{
protected:
@@ -278,6 +544,62 @@ public:
luci::CircleConst *paddings = nullptr;
};
+class PadV2Graph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ pad = g.nodes()->create<luci::CirclePadV2>();
+ paddings = g.nodes()->create<luci::CircleConst>();
+ const_value = g.nodes()->create<luci::CircleConst>();
+
+ pad->dtype(loco::DataType::FLOAT32);
+ paddings->dtype(loco::DataType::S32);
+ const_value->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ pad->shape({1, channel_size, 4, 4});
+ paddings->shape({4, 2});
+ const_value->shape({1});
+
+ // paddings data (NCHW)
+ // [[0,0], [0,0], [1,1], [2,2]]
+ paddings->size<loco::DataType::S32>(8);
+ for (uint32_t dim = 0; dim < 4; dim++)
+ {
+ for (uint32_t i = 0; i < 2; i++)
+ {
+ int32_t data = 0;
+
+ if (dim == 2)
+ data = 1;
+ else if (dim == 3)
+ data = 2;
+
+ paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
+ }
+ }
+
+ const_value->size<loco::DataType::FLOAT32>(1);
+ const_value->at<loco::DataType::FLOAT32>(0) = -3.4;
+
+ pad->input(input);
+ pad->paddings(paddings);
+ pad->constant_values(paddings);
+
+ pad->name("padV2");
+ paddings->name("paddings");
+ const_value->name("constant_values");
+
+ return pad;
+ }
+
+public:
+ luci::CirclePadV2 *pad = nullptr;
+ luci::CircleConst *paddings = nullptr;
+ luci::CircleConst *const_value = nullptr;
+};
+
class ReluGraph final : public SimpleGraph
{
protected:
@@ -310,6 +632,106 @@ public:
luci::CircleRelu6 *relu6 = nullptr;
};
+class RsqrtGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ rsqrt = g.nodes()->create<luci::CircleRsqrt>();
+ rsqrt->x(input);
+ rsqrt->name("rsqrt");
+
+ return rsqrt;
+ }
+
+public:
+ luci::CircleRsqrt *rsqrt = nullptr;
+};
+
+class SquaredDifferenceGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ sqdiff = g.nodes()->create<luci::CircleSquaredDifference>();
+ sqdiff->x(input);
+ sqdiff->y(input);
+ sqdiff->name("sqdiff");
+
+ return sqdiff;
+ }
+
+public:
+ luci::CircleSquaredDifference *sqdiff = nullptr;
+};
+
+class SubGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ sub = g.nodes()->create<luci::CircleSub>();
+ beta = g.nodes()->create<luci::CircleConst>();
+
+ sub->dtype(loco::DataType::FLOAT32);
+ beta->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ sub->shape({1, channel_size, 4, 4});
+ beta->shape({1, channel_size, 1, 1});
+
+ beta->size<loco::DataType::FLOAT32>(channel_size);
+ for (uint32_t i = 0; i < channel_size; i++)
+ {
+ beta->at<loco::DataType::FLOAT32>(i) = i;
+ }
+
+ sub->x(input);
+ sub->y(beta);
+
+ sub->name("sub");
+ beta->name("beta");
+
+ return sub;
+ }
+
+public:
+ luci::CircleSub *sub = nullptr;
+ luci::CircleConst *beta = nullptr;
+};
+
+class SubScalarGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ sub = g.nodes()->create<luci::CircleSub>();
+ beta = g.nodes()->create<luci::CircleConst>();
+
+ sub->dtype(loco::DataType::FLOAT32);
+ beta->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ sub->shape({1, channel_size, 4, 4});
+ beta->shape({1});
+
+ beta->size<loco::DataType::FLOAT32>(1);
+ beta->at<loco::DataType::FLOAT32>(0) = 5;
+
+ sub->x(beta);
+ sub->y(input);
+
+ sub->name("sub");
+ beta->name("beta");
+
+ return sub;
+ }
+
+public:
+ luci::CircleSub *sub = nullptr;
+ luci::CircleConst *beta = nullptr;
+};
+
void check_pre_trans(loco::Node *node)
{
auto pre_trans = dynamic_cast<luci::CircleTranspose *>(node);
@@ -393,6 +815,47 @@ TEST(ConvertNCHWToNHWC, Add)
check_pre_trans(g.output->from());
}
+TEST(ConvertNCHWToNHWC, NHWC_Relu)
+{
+ // Relu is already NHWC, so it should not be converted
+ // i.e., the graph is not changed
+ NHWCReluGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ EXPECT_EQ(g.pre_reshape, g.relu->features());
+
+ auto relu_succs = loco::succs(g.relu);
+ EXPECT_EQ(1, relu_succs.size());
+ EXPECT_EQ(g.post_reshape, *relu_succs.begin());
+}
+
+TEST(ConvertNCHWToNHWC, AddScalar)
+{
+ AddScalarGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ auto input_succs = loco::succs(g.input);
+ EXPECT_EQ(1, input_succs.size());
+ check_post_trans(*input_succs.begin());
+
+ check_pre_trans(g.add->x());
+
+ auto add_succs = loco::succs(g.add);
+ EXPECT_EQ(1, add_succs.size());
+ check_post_trans(*add_succs.begin());
+
+ auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
+ EXPECT_NE(nullptr, new_beta);
+ EXPECT_EQ(1, new_beta->rank());
+ EXPECT_EQ(1, new_beta->dim(0).value());
+
+ check_pre_trans(g.output->from());
+}
+
TEST(ConvertNCHWToNHWC, Concatenation)
{
ConcatenationGraph g;
@@ -435,6 +898,202 @@ TEST(ConvertNCHWToNHWC, LeakyRelu)
EXPECT_EQ(16, g.leakyrelu->dim(3).value());
}
+TEST(ConvertNCHWToNHWC, Logistic)
+{
+ LogisticGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.logistic->x());
+
+ auto logistic_succs = loco::succs(g.logistic);
+ EXPECT_EQ(1, logistic_succs.size());
+ check_post_trans(*logistic_succs.begin());
+
+ // Check logistic shape
+ EXPECT_EQ(1, g.logistic->dim(0).value());
+ EXPECT_EQ(4, g.logistic->dim(1).value());
+ EXPECT_EQ(4, g.logistic->dim(2).value());
+ EXPECT_EQ(16, g.logistic->dim(3).value());
+}
+
+TEST(ConvertNCHWToNHWC, Maximum)
+{
+ MaximumGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ auto input_succs = loco::succs(g.input);
+ EXPECT_EQ(1, input_succs.size());
+ check_post_trans(*input_succs.begin());
+
+ check_pre_trans(g.max->x());
+
+ auto max_succs = loco::succs(g.max);
+ EXPECT_EQ(1, max_succs.size());
+ check_post_trans(*max_succs.begin());
+
+ check_pre_trans(g.output->from());
+}
+
+TEST(ConvertNCHWToNHWC, Mean)
+{
+ MeanGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ check_pre_trans(g.mean->input());
+
+ auto mean_succs = loco::succs(g.mean);
+ EXPECT_EQ(1, mean_succs.size());
+ check_post_trans(*mean_succs.begin());
+
+ auto new_rindices = dynamic_cast<luci::CircleConst *>(g.mean->reduction_indices());
+ EXPECT_NE(nullptr, new_rindices);
+ EXPECT_EQ(1, new_rindices->rank());
+ EXPECT_EQ(2, new_rindices->dim(0).value());
+ EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
+ EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
+ EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
+}
+
+TEST(ConvertNCHWToNHWC, Mean_keep_dims_false)
+{
+ struct TC
+ {
+ std::vector<int32_t> nchw_ind;
+ std::vector<int32_t> nhwc_ind;
+ std::initializer_list<uint32_t> shape;
+ bool needs_transpose = false;
+ };
+
+ uint32_t n = 1;
+ uint32_t c = 16;
+ uint32_t h = 4;
+ uint32_t w = 4;
+
+ std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true}, {{1}, {3}, {n, h, w}, false},
+ {{2}, {1}, {n, c, w}, true}, {{3}, {2}, {n, c, h}, true},
+ {{0, 1}, {0, 3}, {h, w}, false}, {{0, 2}, {0, 1}, {c, w}, true},
+ {{0, 3}, {0, 2}, {c, h}, true}, {{1, 2}, {3, 1}, {n, w}, false},
+ {{1, 3}, {3, 2}, {n, h}, false}, {{2, 3}, {1, 2}, {n, c}, false},
+ {{0, 1, 2}, {0, 3, 1}, {w}, false}};
+
+ for (auto &tc : test_cases)
+ {
+ MeanGraph g;
+ g.keep_dims(false);
+ g.axes(tc.nchw_ind);
+ g.shape(tc.shape);
+ g.init();
+
+ run_phase(&g.g, false, true);
+
+ check_pre_trans(g.mean->input());
+
+ auto mean_succs = loco::succs(g.mean);
+ EXPECT_EQ(1, mean_succs.size());
+ if (tc.needs_transpose)
+ {
+ EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*mean_succs.begin()));
+ }
+ else
+ {
+ EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*mean_succs.begin()));
+ }
+
+ auto new_rindices = dynamic_cast<luci::CircleConst *>(g.mean->reduction_indices());
+ EXPECT_NE(nullptr, new_rindices);
+ EXPECT_EQ(1, new_rindices->rank());
+ EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
+ EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
+ for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
+ {
+ EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
+ }
+ }
+}
+
+TEST(ConvertNCHWToNHWC, ConvertNCHWToNHWC_Mean_keep_dims_false_NEG)
+{
+ loco::Graph g;
+ auto input = g.nodes()->create<luci::CircleInput>();
+ auto output = g.nodes()->create<luci::CircleOutput>();
+ input->name("input");
+ output->name("output");
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ graph_input->dtype(loco::DataType::FLOAT32);
+ input->dtype(loco::DataType::FLOAT32);
+ output->dtype(loco::DataType::FLOAT32);
+ graph_output->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ graph_input->shape({channel_size, 4, 4});
+ input->shape({channel_size, 4, 4});
+ output->shape({channel_size});
+ graph_output->shape({channel_size});
+
+ auto mean = g.nodes()->create<luci::CircleMean>();
+ auto rindices = g.nodes()->create<luci::CircleConst>();
+
+ mean->dtype(loco::DataType::FLOAT32);
+ rindices->dtype(loco::DataType::S32);
+
+ mean->shape({channel_size});
+ rindices->shape({2});
+
+ rindices->size<loco::DataType::S32>(2);
+ rindices->at<loco::DataType::S32>(0) = 1;
+ rindices->at<loco::DataType::S32>(1) = 2;
+
+ mean->input(input);
+ mean->reduction_indices(rindices);
+ mean->keep_dims(false);
+
+ mean->name("mean");
+ rindices->name("rindices");
+
+ output->from(mean);
+
+ run_phase(&g, true, true);
+
+ auto new_rindices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
+ EXPECT_NE(nullptr, new_rindices);
+ EXPECT_EQ(1, new_rindices->rank());
+ EXPECT_EQ(2, new_rindices->dim(0).value());
+ EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
+ EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
+ EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
+}
+
+TEST(ConvertNCHWToNHWC, Minimum)
+{
+ MinimumGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ auto input_succs = loco::succs(g.input);
+ EXPECT_EQ(1, input_succs.size());
+ check_post_trans(*input_succs.begin());
+
+ check_pre_trans(g.min->x());
+
+ auto min_succs = loco::succs(g.min);
+ EXPECT_EQ(1, min_succs.size());
+ check_post_trans(*min_succs.begin());
+
+ check_pre_trans(g.output->from());
+}
+
TEST(ConvertNCHWToNHWC, Mul)
{
MulGraph g;
@@ -464,6 +1123,52 @@ TEST(ConvertNCHWToNHWC, Mul)
check_pre_trans(g.output->from());
}
+TEST(ConvertNCHWToNHWC, MulScalar)
+{
+ MulScalarGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ auto input_succs = loco::succs(g.input);
+ EXPECT_EQ(1, input_succs.size());
+ check_post_trans(*input_succs.begin());
+
+ check_pre_trans(g.mul->x());
+
+ auto mul_succs = loco::succs(g.mul);
+ EXPECT_EQ(1, mul_succs.size());
+ check_post_trans(*mul_succs.begin());
+
+ auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
+ EXPECT_NE(nullptr, new_multiplier);
+ EXPECT_EQ(1, new_multiplier->rank());
+ EXPECT_EQ(1, new_multiplier->dim(0).value());
+
+ check_pre_trans(g.output->from());
+}
+
+TEST(ConvertNCHWToNHWC, MulBothNorm)
+{
+ MulBothNormGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ auto input_succs = loco::succs(g.input);
+ EXPECT_EQ(1, input_succs.size());
+ check_post_trans(*input_succs.begin());
+
+ check_pre_trans(g.mul->x());
+ check_pre_trans(g.mul->y());
+
+ auto mul_succs = loco::succs(g.mul);
+ EXPECT_EQ(1, mul_succs.size());
+ check_post_trans(*mul_succs.begin());
+
+ check_pre_trans(g.output->from());
+}
+
TEST(ConvertNCHWToNHWC, Neg)
{
NegGraph g;
@@ -518,6 +1223,34 @@ TEST(ConvertNCHWToNHWC, Pad)
check_pre_trans(g.output->from());
}
+TEST(ConvertNCHWToNHWC, PadV2)
+{
+ PadV2Graph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ check_pre_trans(g.pad->input());
+
+ auto pad_succs = loco::succs(g.pad);
+ EXPECT_EQ(1, pad_succs.size());
+ check_post_trans(*pad_succs.begin());
+
+ auto new_paddings = dynamic_cast<luci::CircleConst *>(g.pad->paddings());
+ EXPECT_NE(nullptr, new_paddings);
+ EXPECT_EQ(2, new_paddings->rank());
+ EXPECT_EQ(4, new_paddings->dim(0).value());
+ EXPECT_EQ(2, new_paddings->dim(1).value());
+ EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(0));
+ EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(1));
+ EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(2));
+ EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(3));
+ EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(4));
+ EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(5));
+ EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(6));
+ EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(7));
+}
+
TEST(ConvertNCHWToNHWC, Unknown_Shape_NEG)
{
AddGraph g;
@@ -634,3 +1367,91 @@ TEST(ConvertNCHWToNHWC, Relu6)
EXPECT_EQ(4, g.relu6->dim(2).value());
EXPECT_EQ(16, g.relu6->dim(3).value());
}
+
+TEST(ConvertNCHWToNHWC, Rsqrt)
+{
+ RsqrtGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.rsqrt->x());
+
+ auto rsqrt_succs = loco::succs(g.rsqrt);
+ EXPECT_EQ(1, rsqrt_succs.size());
+ check_post_trans(*rsqrt_succs.begin());
+
+ // Check rsqrt shape
+ EXPECT_EQ(1, g.rsqrt->dim(0).value());
+ EXPECT_EQ(4, g.rsqrt->dim(1).value());
+ EXPECT_EQ(4, g.rsqrt->dim(2).value());
+ EXPECT_EQ(16, g.rsqrt->dim(3).value());
+}
+
+TEST(ConvertNCHWToNHWC, SquaredDifference)
+{
+ SquaredDifferenceGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.sqdiff->x());
+ check_pre_trans(g.sqdiff->y());
+
+ auto sqdiff_succs = loco::succs(g.sqdiff);
+ EXPECT_EQ(1, sqdiff_succs.size());
+ check_post_trans(*sqdiff_succs.begin());
+}
+
+TEST(ConvertNCHWToNHWC, Sub)
+{
+ SubGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ auto input_succs = loco::succs(g.input);
+ EXPECT_EQ(1, input_succs.size());
+ check_post_trans(*input_succs.begin());
+
+ check_pre_trans(g.sub->x());
+
+ auto add_succs = loco::succs(g.sub);
+ EXPECT_EQ(1, add_succs.size());
+ check_post_trans(*add_succs.begin());
+
+ uint32_t channel_size = 16;
+ auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->y());
+ EXPECT_NE(nullptr, new_beta);
+ EXPECT_EQ(4, new_beta->rank());
+ EXPECT_EQ(1, new_beta->dim(0).value());
+ EXPECT_EQ(1, new_beta->dim(1).value());
+ EXPECT_EQ(1, new_beta->dim(2).value());
+ EXPECT_EQ(channel_size, new_beta->dim(3).value());
+
+ check_pre_trans(g.output->from());
+}
+
+TEST(ConvertNCHWToNHWC, SubScalar)
+{
+ SubScalarGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ auto input_succs = loco::succs(g.input);
+ EXPECT_EQ(1, input_succs.size());
+ check_post_trans(*input_succs.begin());
+
+ check_pre_trans(g.sub->y());
+
+ auto add_succs = loco::succs(g.sub);
+ EXPECT_EQ(1, add_succs.size());
+ check_post_trans(*add_succs.begin());
+
+ auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->x());
+ EXPECT_NE(nullptr, new_beta);
+ EXPECT_EQ(1, new_beta->rank());
+
+ check_pre_trans(g.output->from());
+}