summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/ResolveCustomOpSplitVPass.test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/ResolveCustomOpSplitVPass.test.cpp')
-rw-r--r--compiler/luci/pass/src/ResolveCustomOpSplitVPass.test.cpp175
1 files changed, 175 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/ResolveCustomOpSplitVPass.test.cpp b/compiler/luci/pass/src/ResolveCustomOpSplitVPass.test.cpp
new file mode 100644
index 000000000..e7738aadb
--- /dev/null
+++ b/compiler/luci/pass/src/ResolveCustomOpSplitVPass.test.cpp
@@ -0,0 +1,175 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ResolveCustomOpSplitVPass.h"
+
+#include <luci/test/TestIOGraph.h>
+
+#include <luci/IR/CircleNodes.h>
+#include <gtest/gtest.h>
+
+using namespace luci::test;
+
+namespace
+{
+
+/**
+ * graph having Custom operator SplitV
+ *
+ * [Input] [Const] [Const]
+ * \ | /
+ * [Custom(SplitV)]
+ * / | \
+ * [CustomOut] [CustomOut] [CustomOut]
+ * | | |
+ * [Output] [Output] [Output]
+ */
+class SplitVGraphlet
+{
+public:
+ SplitVGraphlet() = default;
+
+public:
+ void init(loco::Graph *g)
+ {
+ // CircleCustom(SplitV)
+ _splitv = g->nodes()->create<luci::CircleCustom>(3, 3);
+ _splitv->custom_code("SplitV");
+ _splitv->shape({1, 2, 2, 192});
+ _splitv->dtype(loco::DataType::FLOAT32);
+ _splitv->name("splitv");
+
+ // CircleConst
+ auto size_splits = g->nodes()->create<luci::CircleConst>();
+ size_splits->dtype(loco::DataType::S64);
+ size_splits->shape({3});
+ size_splits->size<loco::DataType::S64>(3);
+ size_splits->at<loco::DataType::S64>(0) = 32;
+ size_splits->at<loco::DataType::S64>(1) = 32;
+ size_splits->at<loco::DataType::S64>(2) = 128;
+
+ // CircleConst
+ auto split_dim = g->nodes()->create<luci::CircleConst>();
+ split_dim->dtype(loco::DataType::S32);
+ split_dim->rank(0);
+ split_dim->size<loco::DataType::S32>(1);
+ split_dim->scalar<loco::DataType::S32>() = 3;
+
+ _splitv->inputs(1, size_splits);
+ _splitv->inputs(2, split_dim);
+
+ // CircleCustomOut
+ _splitv_out1 = g->nodes()->create<luci::CircleCustomOut>();
+ _splitv_out1->shape({1, 2, 2, 32});
+ _splitv_out1->dtype(loco::DataType::FLOAT32);
+ _splitv_out1->index(0);
+ _splitv_out1->input(_splitv);
+
+ // CircleCustomOut
+ _splitv_out2 = g->nodes()->create<luci::CircleCustomOut>();
+ _splitv_out2->shape({1, 2, 2, 32});
+ _splitv_out2->dtype(loco::DataType::FLOAT32);
+ _splitv_out2->index(1);
+ _splitv_out2->input(_splitv);
+
+ // CircleCustomOut
+ _splitv_out3 = g->nodes()->create<luci::CircleCustomOut>();
+ _splitv_out3->shape({1, 2, 2, 128});
+ _splitv_out3->dtype(loco::DataType::FLOAT32);
+ _splitv_out3->index(2);
+ _splitv_out3->input(_splitv);
+ }
+
+public:
+ luci::CircleCustom *splitv() { return _splitv; }
+
+protected:
+ luci::CircleCustom *_splitv = nullptr;
+ luci::CircleCustomOut *_splitv_out1 = nullptr;
+ luci::CircleCustomOut *_splitv_out2 = nullptr;
+ luci::CircleCustomOut *_splitv_out3 = nullptr;
+};
+
+class SplitVGraph : public TestIGraphlet, public TestOsGraphlet<3>, public SplitVGraphlet
+{
+public:
+ SplitVGraph() = default;
+
+ void init(void)
+ {
+ TestIGraphlet::init(g(), {1, 2, 2, 192});
+ TestOsGraphlet<3>::init(g(), {{1, 2, 2, 32}, {1, 2, 2, 32}, {1, 2, 2, 128}});
+ SplitVGraphlet::init(g());
+
+ // connect graph
+ _splitv->inputs(0, input());
+
+ output(0)->from(_splitv_out1);
+ output(1)->from(_splitv_out2);
+ output(2)->from(_splitv_out3);
+ }
+};
+
+class SplitVGraphTest : public ::testing::Test
+{
+public:
+ SplitVGraph g;
+ luci::ResolveCustomOpSplitVPass pass;
+};
+
+} // namespace
+
+TEST_F(SplitVGraphTest, simple_test)
+{
+ g.init();
+
+ auto ret = pass.run(g.g());
+ EXPECT_EQ(true, ret);
+
+ auto svo_1 = dynamic_cast<luci::CircleSplitVOut *>(g.output(0)->from());
+ EXPECT_NE(nullptr, svo_1);
+ auto svo_2 = dynamic_cast<luci::CircleSplitVOut *>(g.output(1)->from());
+ EXPECT_NE(nullptr, svo_2);
+ auto svo_3 = dynamic_cast<luci::CircleSplitVOut *>(g.output(2)->from());
+ EXPECT_NE(nullptr, svo_3);
+
+ auto sv = dynamic_cast<luci::CircleSplitV *>(svo_1->input());
+ EXPECT_NE(nullptr, sv);
+ sv = dynamic_cast<luci::CircleSplitV *>(svo_2->input());
+ EXPECT_NE(nullptr, sv);
+ sv = dynamic_cast<luci::CircleSplitV *>(svo_3->input());
+ EXPECT_NE(nullptr, sv);
+
+ auto size_splits = loco::must_cast<luci::CircleConst *>(sv->size_splits());
+ EXPECT_EQ(loco::DataType::S32, size_splits->dtype());
+ EXPECT_EQ(32, size_splits->at<loco::DataType::S32>(0));
+ EXPECT_EQ(32, size_splits->at<loco::DataType::S32>(1));
+ EXPECT_EQ(128, size_splits->at<loco::DataType::S32>(2));
+
+ auto split_dim = loco::must_cast<luci::CircleConst *>(sv->split_dim());
+ EXPECT_EQ(loco::DataType::S32, split_dim->dtype());
+ EXPECT_EQ(3, split_dim->scalar<loco::DataType::S32>());
+}
+
+TEST_F(SplitVGraphTest, wrong_op_NEG)
+{
+ g.init();
+
+ g.splitv()->custom_code("AddV2");
+
+ auto ret = pass.run(g.g());
+ EXPECT_EQ(false, ret);
+}