summaryrefslogtreecommitdiff
path: root/runtime/neurun/core/src/ir/pass/PermutationOperationPass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/neurun/core/src/ir/pass/PermutationOperationPass.cc')
-rw-r--r--runtime/neurun/core/src/ir/pass/PermutationOperationPass.cc230
1 files changed, 230 insertions, 0 deletions
diff --git a/runtime/neurun/core/src/ir/pass/PermutationOperationPass.cc b/runtime/neurun/core/src/ir/pass/PermutationOperationPass.cc
new file mode 100644
index 000000000..41a1ad903
--- /dev/null
+++ b/runtime/neurun/core/src/ir/pass/PermutationOperationPass.cc
@@ -0,0 +1,230 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PermutationOperationPass.h"
+
+#include "backend/Backend.h"
+#include "backend/IConfig.h"
+#include "ir/Graph.h"
+
+namespace neurun
+{
+namespace ir
+{
+namespace pass
+{
+
+void PermutationOperationPass::callback(const OperationIndex &, Operation &node)
+{
+ node.accept(*this);
+};
+
+void PermutationOperationPass::changeToKeepLayout(const Operation &node)
+{
+ const auto &output_ind = node.getOutputs().at(0);
+ const auto &output_obj = _graph.operands().at(output_ind);
+
+ assert(output_obj.getDef().size() == 1);
+ const auto &node_index = output_obj.getDef().list().front();
+ const auto &subg_index = _graph.subgraphs().getOperation(node_index);
+
+ const auto frontend_layout = _graph.subgraphs().at(subg_index).getLayout();
+ const auto backend_layout = _graph.getLowerInfo(subg_index)->layout();
+
+ if (frontend_layout == backend_layout)
+ {
+ return;
+ }
+
+ // CPU supports only NHWC now
+ if (_graph.getLowerInfo(subg_index)->backend()->config()->id() != "cpu")
+ {
+ // TODO Change backend of this node
+ assert(frontend_layout == Layout::NHWC || backend_layout == Layout::UNKNOWN);
+ }
+
+ // Divide op_seq based on target operation
+ {
+ auto &above_subg = _graph.subgraphs().at(subg_index);
+
+ // Create new op_seq and move information from existing op_seq to new op_seq if target
+ // node is the end of op_seq
+ auto it = above_subg.begin();
+ // Find iterator of target node in op_seq
+ while ((it++)->index != node_index)
+ ;
+ if (it != above_subg.end())
+ {
+ const auto &below_subg_index =
+ _graph.subgraphs().emplace(it->index, *it->node, above_subg.getLayout());
+ auto &below_subg = _graph.subgraphs().at(below_subg_index);
+ below_subg.setInputs(it->node->getInputs());
+ below_subg.setOutputs(it->node->getOutputs());
+
+ std::vector<OperationIndex> remove_list;
+ remove_list.emplace_back(it->index);
+ while (++it != above_subg.end())
+ {
+ below_subg.appendOperation(it->index, *it->node);
+ below_subg.setOutputs(it->node->getOutputs());
+ remove_list.emplace_back(it->index);
+ }
+
+ above_subg.setOutputs(node.getOutputs());
+ for (const auto &index : remove_list)
+ {
+ above_subg.remove(index);
+ }
+
+ const auto subg_li = _graph.getLowerInfo(subg_index);
+ _graph.setLowerInfo(below_subg_index, nnfw::cpp14::make_unique<operation::LowerInfo>(
+ subg_li->backend(), subg_li->layout()));
+ }
+ }
+
+ // Remove target operation from op_seq and insert the target operation to new op_seq
+ {
+ const auto backend = _graph.getLowerInfo(subg_index)->backend();
+
+ // Remove target operation from subraph
+ _graph.subgraphs().removeFromSubgraph(node_index);
+
+ if (!_graph.subgraphs().exist(subg_index))
+ {
+ // Remove lowerinfo for op_seq of target operation if the op_seq does not exist
+ _graph.removeLowerInfo(subg_index);
+ }
+ else
+ {
+ // Update op_seq of target operation if the op_seq exists
+ auto &above_subg = _graph.subgraphs().at(subg_index);
+ const auto last_node = (--above_subg.end())->node;
+ above_subg.setOutputs(last_node->getOutputs());
+ }
+
+ // Create new op_seq and set information to the op_seq
+ auto new_subg_index = _graph.subgraphs().emplace(node_index, node, frontend_layout);
+ auto &new_subg = _graph.subgraphs().at(new_subg_index);
+ new_subg.setInputs(node.getInputs());
+ new_subg.setOutputs(node.getOutputs());
+ _graph.setLowerInfo(new_subg_index,
+ nnfw::cpp14::make_unique<operation::LowerInfo>(backend, frontend_layout));
+ }
+
+ // Change PermuteFactors of operands of target node
+ {
+ const auto &subg_index = _graph.subgraphs().getOperation(node_index);
+ const auto subg_li = _graph.getLowerInfo(subg_index);
+ const auto backend = subg_li->backend();
+ const operand::PermuteFactor removed_factor{backend, backend_layout};
+ const operand::PermuteFactor new_factor{backend, frontend_layout};
+ for (const auto &input : node.getInputs())
+ {
+ bool canRemove = true;
+ for (const auto &use : _graph.operands().at(input).getUses().list())
+ {
+ if (use != node_index)
+ {
+ const auto &use_subg_index = _graph.subgraphs().getOperation(use);
+ auto use_subg_li = _graph.getLowerInfo(use_subg_index);
+ if (use_subg_li->backend() == backend && use_subg_li->layout() == backend_layout)
+ {
+ canRemove = false;
+ break;
+ }
+ }
+ }
+
+ auto lower_info = _graph.getLowerInfo(input);
+ if (canRemove)
+ {
+ lower_info->removeUsePermuteFactor(removed_factor);
+ }
+ lower_info->addUsePermuteFactor(new_factor);
+
+ // Whether if node's input is an input of model or a constant
+ if (_graph.operands().at(input).getDef().size() == 0)
+ {
+ assert(_graph.getInputs().contains(input) || _graph.operands().at(input).isConstant());
+ lower_info->removeDefPermuteFactor(removed_factor);
+ lower_info->addDefPermuteFactor(new_factor);
+ }
+ }
+
+ for (const auto &output : node.getOutputs())
+ {
+ auto lower_info = _graph.getLowerInfo(output);
+ lower_info->removeDefPermuteFactor(removed_factor);
+ lower_info->addDefPermuteFactor(new_factor);
+
+ // Whether if node's output is an output of model
+ if (_graph.operands().at(output).getUses().size() == 0)
+ {
+ assert(_graph.getOutputs().contains(output));
+ lower_info->removeUsePermuteFactor(removed_factor);
+ lower_info->addUsePermuteFactor(new_factor);
+ }
+ }
+ }
+}
+
+void PermutationOperationPass::visit(const operation::FullyConnected &node)
+{
+ const auto &input_ind = node.getInputs().at(operation::FullyConnected::Input::INPUT);
+ const auto &input_obj = _graph.operands().at(input_ind);
+ const auto &input_shape = input_obj.shape();
+
+ if (input_shape.rank() == 4)
+ {
+ changeToKeepLayout(node);
+ }
+}
+
+void PermutationOperationPass::visit(const operation::Gather &node)
+{
+ const auto &input_ind = node.getInputs().at(operation::Gather::Input::INPUT);
+ const auto &input_obj = _graph.operands().at(input_ind);
+ const auto &input_shape = input_obj.shape();
+
+ const auto &output_ind = node.getOutputs().at(0);
+ const auto &output_obj = _graph.operands().at(output_ind);
+ const auto &output_shape = output_obj.shape();
+
+ if (input_shape.rank() >= 4 || output_shape.rank() >= 4)
+ {
+ changeToKeepLayout(node);
+ }
+}
+
+void PermutationOperationPass::visit(const operation::Reshape &node)
+{
+ const auto &input_ind = node.getInputs().at(operation::Reshape::Input::INPUT);
+ const auto &input_obj = _graph.operands().at(input_ind);
+ const auto &input_shape = input_obj.shape();
+
+ const auto &output_ind = node.getOutputs().at(0);
+ const auto &output_obj = _graph.operands().at(output_ind);
+ const auto &output_shape = output_obj.shape();
+
+ if (input_shape.rank() >= 4 || output_shape.rank() >= 4)
+ {
+ changeToKeepLayout(node);
+ }
+}
+
+} // namespace pass
+} // namespace ir
+} // namespace neurun