summaryrefslogtreecommitdiff
path: root/compiler/nnc/unittests/optimizations/CombineTransposes.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/nnc/unittests/optimizations/CombineTransposes.cpp')
-rw-r--r--compiler/nnc/unittests/optimizations/CombineTransposes.cpp150
1 files changed, 150 insertions, 0 deletions
diff --git a/compiler/nnc/unittests/optimizations/CombineTransposes.cpp b/compiler/nnc/unittests/optimizations/CombineTransposes.cpp
new file mode 100644
index 000000000..8d90bd20a
--- /dev/null
+++ b/compiler/nnc/unittests/optimizations/CombineTransposes.cpp
@@ -0,0 +1,150 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "passes/optimizations/CombineTransposes.h"
+#include "mir/ops/TransposeOp.h"
+#include "mir/ops/ReluOp.h"
+#include "mir/ops/OutputOp.h"
+#include "Util.h"
+#include <gtest/gtest.h>
+
+using namespace std;
+using namespace nnc;
+using namespace mir;
+
+namespace
+{
+
+TEST(OptPass, eliminateTransposesLinear)
+{
+ mir::Graph g;
+ /* Create graph:
+ * [input]
+ * ||
+ * [Transpose 1]
+ * ||
+ * [Transpose 2]
+ * ||
+ * [relu]
+ */
+ mir::TensorType input_type{mir::DataType::FLOAT32, Shape{1, 2, 3}};
+ Operation *input = g.create<ops::InputOp>(input_type);
+ Operation *tr1 = g.create<ops::TransposeOp>(input->getOutput(0), vector<size_t>{1, 0, 2});
+ Operation *tr15 = g.create<ops::TransposeOp>(tr1->getOutput(0), vector<size_t>{1, 0, 2});
+ Operation *tr2 = g.create<ops::TransposeOp>(tr15->getOutput(0), vector<size_t>{1, 0, 2});
+ Operation *relu = g.create<ops::ReluOp>(tr2->getOutput(0));
+
+ // Check that layout is desired
+ std::stringstream ss;
+ DumpVisitor d(ss);
+ CombineTransposes pass;
+ pass.run(&g);
+ g.accept(&d);
+ // Assert only 1 transpose remains
+ ASSERT_EQ("i_0.t_1.r_4.", ss.str());
+}
+
+TEST(OptPass, combineTransposesLinear)
+{
+ mir::Graph g;
+ /* Create graph:
+ * [input]
+ * ||
+ * [Transpose 1]
+ * ||
+ * [Transpose 2]
+ * ||
+ * [relu]
+ */
+ mir::TensorType input_type{mir::DataType::FLOAT32, Shape{1, 2, 3}};
+ Operation *input = g.create<ops::InputOp>(input_type);
+ Operation *tr1 = g.create<ops::TransposeOp>(input->getOutput(0), vector<size_t>{1, 0, 2});
+ Operation *tr2 = g.create<ops::TransposeOp>(tr1->getOutput(0), vector<size_t>{0, 2, 1});
+ Operation *relu = g.create<ops::ReluOp>(tr2->getOutput(0));
+
+ std::stringstream ss;
+ DumpVisitor d(ss);
+ CombineTransposes pass;
+ pass.run(&g);
+ g.accept(&d);
+
+ // Assert transposes are combined
+ ASSERT_EQ("i_0.t_4.r_3.", ss.str());
+ Operation::Use use = g.getInputs()[0]->getOutput(0)->getUses()[0];
+ auto ax_ord_actual = dynamic_cast<ops::TransposeOp *>(use.getNode())->getAxisOrder();
+ auto ax_ord_true = vector<size_t>{1, 2, 0};
+ ASSERT_TRUE(ax_ord_actual == ax_ord_true);
+}
+
+TEST(OptPass, combineTransposesBush)
+{
+ mir::Graph g;
+ /* Create graph:
+ * [input]
+ * ||
+ * [Transpose 1]
+ * // \\
+ *[Transpose 2] [Transpose 3]
+ * \\ //
+ * [Add]
+ */
+ mir::TensorType input_type{mir::DataType::FLOAT32, Shape{1, 2, 3, 2}};
+ Operation *input = g.create<ops::InputOp>(input_type);
+ Operation *tr1 = g.create<ops::TransposeOp>(input->getOutput(0), vector<size_t>{1, 0, 2, 3});
+ Operation *tr2 = g.create<ops::TransposeOp>(tr1->getOutput(0), vector<size_t>{1, 0, 2, 3});
+ Operation *tr3 = g.create<ops::TransposeOp>(tr1->getOutput(0), vector<size_t>{1, 0, 2, 3});
+ Operation *elw = g.create<ops::AddOp>(tr2->getOutput(0), tr3->getOutput(0));
+ std::stringstream ss;
+ DumpVisitor d(ss);
+ CombineTransposes pass;
+ pass.run(&g);
+ g.accept(&d);
+ ASSERT_EQ("i_0.b_4.", ss.str());
+ ASSERT_EQ(elw->getInput(0)->getNode()->getType(), mir::Operation::Type::input);
+ ASSERT_EQ(elw->getInput(1)->getNode()->getType(), mir::Operation::Type::input);
+}
+
+TEST(OptPass, combineTransposesOpOrder)
+{
+ mir::Graph g;
+ /* Create graph:
+ * [input] [input2]
+ * || ||
+ * [Transpose 0] [Transpose1]
+ * || ||
+ * [Transpose 2] [Transpose 3]
+ * \\ //
+ * [Add]
+ */
+ mir::TensorType input_type{mir::DataType::FLOAT32, {1, 2, 3}};
+ Operation *in1 = g.create<ops::InputOp>(input_type);
+ Operation *in2 = g.create<ops::InputOp>(input_type);
+ Operation *tr0 = g.create<ops::TransposeOp>(in1->getOutput(0), vector<size_t>{1, 0, 2});
+ Operation *tr1 = g.create<ops::TransposeOp>(in2->getOutput(0), vector<size_t>{2, 1, 0});
+ Operation *tr2 = g.create<ops::TransposeOp>(tr0->getOutput(0), vector<size_t>{1, 0, 2});
+ Operation *tr3 = g.create<ops::TransposeOp>(tr1->getOutput(0), vector<size_t>{2, 1, 0});
+ Operation *elw = g.create<ops::AddOp>(tr2->getOutput(0), tr3->getOutput(0));
+ g.create<ops::OutputOp>(elw->getOutput(0));
+ int n1 = in1->getId();
+ int n2 = in2->getId();
+ CombineTransposes pass;
+ pass.run(&g);
+ ASSERT_EQ(g.getOutputs()[0]->getInput(0)->getNode()->getType(), mir::Operation::Type::add);
+ // Order is preserved
+ ASSERT_EQ(n1, elw->getInput(0)->getNode()->getId());
+ ASSERT_EQ(n2, elw->getInput(1)->getNode()->getId());
+}
+} // unnamed namespace