summaryrefslogtreecommitdiff
path: root/compiler/moco/pass/src/ConstantFoldAdd.test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/moco/pass/src/ConstantFoldAdd.test.cpp')
-rw-r--r--compiler/moco/pass/src/ConstantFoldAdd.test.cpp109
1 files changed, 109 insertions, 0 deletions
diff --git a/compiler/moco/pass/src/ConstantFoldAdd.test.cpp b/compiler/moco/pass/src/ConstantFoldAdd.test.cpp
new file mode 100644
index 000000000..bc9489fbd
--- /dev/null
+++ b/compiler/moco/pass/src/ConstantFoldAdd.test.cpp
@@ -0,0 +1,109 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "moco/Pass/Passes/ConstantFoldAdd.h"
+#include "TestHelper.h"
+
+#include <moco/IR/TFNodes.h>
+#include <loco.h>
+#include <stdex/Memory.h>
+
+#include <gtest/gtest.h>
+
+using namespace moco::test;
+
+namespace
+{
+
+moco::TFConst *const_vector_init(loco::Graph *graph, std::vector<int32_t> values)
+{
+ auto const_node = graph->nodes()->create<moco::TFConst>();
+ auto dim = values.size();
+
+ const_node->dtype(loco::DataType::S32);
+ const_node->rank(1);
+ const_node->dim(0).set(dim);
+
+ const_node->size<loco::DataType::S32>(dim);
+ for (int32_t i = 0; i < dim; ++i)
+ const_node->at<loco::DataType::S32>(i) = values[i];
+
+ return const_node;
+}
+
+} // namespace
+
+TEST(ConstantFoldAdd, basic_vector)
+{
+ loco::Graph graph;
+
+ auto add_node = graph.nodes()->create<moco::TFAdd>();
+ {
+ auto const_from_ss = const_vector_init(&graph, {1, 3, 5});
+ add_node->x(const_from_ss);
+
+ auto const_y = const_vector_init(&graph, {2});
+ add_node->y(const_y);
+ }
+ setup_output_node(&graph, add_node);
+
+ auto pass = stdex::make_unique<moco::ConstantFoldAdd>();
+ bool cont = true;
+ while (cont)
+ {
+ cont = pass->run(&graph);
+ }
+
+ auto ssnode = find_first_node_bytype<moco::TFAdd>(&graph);
+ ASSERT_EQ(ssnode, nullptr);
+
+ auto ssconst = find_first_node_bytype<moco::TFConst>(&graph);
+ ASSERT_NE(ssconst, nullptr);
+ ASSERT_EQ(ssconst->size<loco::DataType::S32>(), 3);
+ ASSERT_EQ(ssconst->at<loco::DataType::S32>(0), 3);
+ ASSERT_EQ(ssconst->at<loco::DataType::S32>(1), 5);
+ ASSERT_EQ(ssconst->at<loco::DataType::S32>(2), 7);
+}
+
+TEST(ConstantFoldAdd, basic_refinedet_1)
+{
+ loco::Graph graph;
+
+ auto add_node = graph.nodes()->create<moco::TFAdd>();
+ {
+ auto const_from_ss = const_vector_init(&graph, {10});
+ add_node->x(const_from_ss);
+
+ auto const_y = const_vector_init(&graph, {0});
+ add_node->y(const_y);
+ }
+ setup_output_node(&graph, add_node);
+
+ auto pass = stdex::make_unique<moco::ConstantFoldAdd>();
+ bool cont = true;
+ while (cont)
+ {
+ cont = pass->run(&graph);
+ }
+
+ auto ssnode = find_first_node_bytype<moco::TFAdd>(&graph);
+ ASSERT_EQ(ssnode, nullptr);
+
+ auto ssconst = find_first_node_bytype<moco::TFConst>(&graph);
+ ASSERT_NE(ssconst, nullptr);
+ ASSERT_EQ(ssconst->size<loco::DataType::S32>(), 1);
+ ASSERT_EQ(ssconst->at<loco::DataType::S32>(0), 10);
+}