summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/RemoveRedundantTranspose.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/RemoveRedundantTranspose.cpp')
-rw-r--r--compiler/luci/pass/src/RemoveRedundantTranspose.cpp127
1 files changed, 127 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/RemoveRedundantTranspose.cpp b/compiler/luci/pass/src/RemoveRedundantTranspose.cpp
new file mode 100644
index 000000000..33cb76520
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveRedundantTranspose.cpp
@@ -0,0 +1,127 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveRedundantTransposePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+/// @brief Return true if first_perm[second_perm[i]] == i
+bool check_perm(const luci::CircleConst *first_perm, const luci::CircleConst *second_perm)
+{
+ assert(first_perm->rank() == 1);
+ assert(second_perm->rank() == 1);
+ assert(second_perm->size<loco::DataType::S32>() == first_perm->size<loco::DataType::S32>());
+ for (int32_t i = 0; i < static_cast<int32_t>(first_perm->size<loco::DataType::S32>()); i++)
+ {
+ if (first_perm->at<loco::DataType::S32>(second_perm->at<loco::DataType::S32>(i)) != i)
+ return false;
+ }
+ return true;
+}
+
+bool remove_consecutive_transpose_function(luci::CircleNode *node)
+{
+ auto target_node = dynamic_cast<luci::CircleTranspose *>(node);
+ if (target_node == nullptr)
+ return false;
+ auto pred_node = dynamic_cast<luci::CircleTranspose *>(target_node->a());
+ if (pred_node == nullptr)
+ return false;
+ if (loco::succs(pred_node).size() != 1)
+ return false;
+
+ auto pred_perm = dynamic_cast<luci::CircleConst *>(target_node->perm());
+ if (pred_perm == nullptr)
+ return false;
+
+ auto main_perm = dynamic_cast<luci::CircleConst *>(pred_node->perm());
+ if (main_perm == nullptr)
+ return false;
+
+ auto main_node = loco::must_cast<luci::CircleNode *>(pred_node->a());
+ if (check_perm(pred_perm, main_perm))
+ {
+ replace(node).with(main_node);
+ }
+ else
+ {
+ auto g = main_perm->graph();
+ auto new_const_node = g->nodes()->create<luci::CircleConst>();
+
+ new_const_node->dtype(loco::DataType::S32);
+ new_const_node->rank(1);
+ new_const_node->dim(0) = main_perm->dim(0);
+ new_const_node->size<loco::DataType::S32>(main_perm->dim(0).value());
+ new_const_node->shape_status(luci::ShapeStatus::VALID);
+ for (uint32_t i = 0; i < main_perm->size<loco::DataType::S32>(); i++)
+ {
+ new_const_node->at<loco::DataType::S32>(i) =
+ pred_perm->at<loco::DataType::S32>(main_perm->at<loco::DataType::S32>(i));
+ }
+ pred_node->perm(new_const_node);
+ replace(node).with(pred_node);
+ }
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+/**
+ * BEFORE
+ * |
+ * [CircleNode] [CircleConst]
+ * (main_node) (main_perm)
+ * \ /
+ * [CircleTranspose] [CircleConst]
+ * (pred_node) (pred_perm)
+ * \ /
+ * [CircleTranspose]
+ * (target_node)
+ * |
+ *
+ * AFTER
+ * <Optional Case>
+ *
+ * | | |
+ * [CircleNode] [CircleConst] |
+ * (main_node) (new_const_node) |
+ * \ / or [CircleNode]
+ * [CircleTranspose] (main_node)
+ * (pred_node) |
+ * | |
+ *
+ */
+bool RemoveRedundantTransposePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ if (remove_consecutive_transpose_function(circle_node))
+ {
+ changed = true;
+ break;
+ }
+ }
+ return changed;
+}
+
+} // namespace luci