summaryrefslogtreecommitdiff
path: root/runtimes/neurun/core/src/graph/pass/PermutationInsertionPass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'runtimes/neurun/core/src/graph/pass/PermutationInsertionPass.cc')
-rw-r--r--runtimes/neurun/core/src/graph/pass/PermutationInsertionPass.cc210
1 files changed, 210 insertions, 0 deletions
diff --git a/runtimes/neurun/core/src/graph/pass/PermutationInsertionPass.cc b/runtimes/neurun/core/src/graph/pass/PermutationInsertionPass.cc
new file mode 100644
index 000000000..0f07b47fe
--- /dev/null
+++ b/runtimes/neurun/core/src/graph/pass/PermutationInsertionPass.cc
@@ -0,0 +1,210 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PermutationInsertionPass.h"
+
+#include <cassert>
+#include <utility>
+#include <unordered_map>
+
+#include "model/Operand.h"
+#include "graph/operation/LowerInfo.h"
+#include "graph/Graph.h"
+#include "backend/IConfig.h"
+#include "util/logging.h"
+#include "cpp14/memory.h"
+#include "model/operation/PermuteNode.h"
+#include "graph/operand/Shape4DConvert.h"
+#include "compiler/BackendResolver.h"
+
+namespace neurun
+{
+namespace graph
+{
+namespace pass
+{
+
+void PermutationInsertionPass::callback(const model::OperandIndex &index, model::Operand &object)
+{
+ auto &&operand_li = _graph.getLowerInfo(index);
+ assert(operand_li);
+
+ // NOTE Later, constants also will have Def
+ // Ignore constants
+ if (operand_li->def_factors().size() == 0)
+ {
+ return;
+ }
+
+ std::list<model::OperationIndex> permute_indexes;
+
+ // Build a map for all necessary type of operands
+ std::unordered_map<operand::PermuteFactor, model::OperandIndex> factor_to_index;
+ {
+ assert(operand_li->def_factors().size() == 1);
+ for (auto factor : operand_li->def_factors())
+ {
+ factor_to_index.emplace(factor, index);
+ }
+
+ auto insert_set = operand_li->use_factors() - operand_li->def_factors();
+ for (auto factor : insert_set)
+ {
+ const auto permute_operation_index = insertPermute(index, factor);
+ permute_indexes.push_back(permute_operation_index);
+ VERBOSE(PermutationInsertionPass) << "Insert 'Permute' operation for operand "
+ << index.value() << std::endl;
+ const auto &permute_operation = _graph.operations().at(permute_operation_index);
+ const auto permuted_operand_index = permute_operation.getOutputs().at(0);
+ factor_to_index.emplace(factor, permuted_operand_index);
+ }
+ }
+
+ // Update operations' input that uses this operand
+ {
+ std::list<model::OperationIndex> remove_list;
+
+ auto uses = object.getUses();
+ for (auto use : uses.list())
+ {
+ // If permute operation, ignore it
+ if (std::find(permute_indexes.begin(), permute_indexes.end(), use) != permute_indexes.end())
+ continue;
+
+ auto &operation = _graph.operations().at(use);
+ assert(_graph.subgraphs().containsOperation(use));
+ auto subg_index = _graph.subgraphs().getOperation(use);
+ auto subg_li = _graph.getLowerInfo(subg_index);
+ assert(subg_li);
+ const auto subg_layout = subg_li->layout();
+ const backend::Backend *backend = subg_li->backend();
+ assert(backend);
+ auto use_node_inputs = operation.getInputs();
+ assert(use_node_inputs.contains(index));
+
+ auto new_index = factor_to_index.at({backend, subg_layout});
+ if (index != new_index)
+ {
+ // Update from subgraph
+ _graph.subgraphs().at(subg_index).replaceInput(index, new_index);
+
+ // Update from operation
+ operation.replaceInput(index, new_index);
+
+ // Update from operand
+ remove_list.push_back(
+ use); // Removal should be done in another loop since we are in the loop
+ _graph.operands().at(new_index).appendUse(use);
+ }
+ }
+
+ for (auto &operation : remove_list)
+ {
+ object.removeUse(operation);
+ }
+ }
+}
+
+model::OperationIndex
+PermutationInsertionPass::insertPermute(const model::OperandIndex &operand_index,
+ const operand::PermuteFactor &factor)
+{
+ assert(!_graph.isBuildingPhase());
+
+ auto &operand = _graph.operands().at(operand_index);
+
+ // Generate output operand and permute operation
+ auto out_operand_index = _graph.addOperand(operand.shape(), operand.typeInfo());
+ // change model output if operand_index is model output index
+ auto &model_outputs = _graph.getOutputs();
+ if (model_outputs.contains(operand_index))
+ {
+ model_outputs.replace(operand_index, out_operand_index);
+ }
+
+ // Find PermuteNode information
+ auto input_backend = _graph.getLowerInfo(operand_index)->def_factors().getOnlyElement().backend();
+ auto output_backend = factor.backend();
+ // NOTE PermuteNode may not have specific layout because the layout of input and output may be
+ // different.
+ const auto permute_node_layout = model::Layout::UNKNOWN;
+ const auto permute_node_backend = backend::BackendManager::instance().getDefault();
+ const operand::PermuteFactor permute_node_factor{permute_node_backend, permute_node_layout};
+
+ // Update LowerInfo of input operand
+ auto operand_lower_info = _graph.getLowerInfo(operand_index);
+ operand_lower_info->removeUsePermuteFactor(factor);
+ operand_lower_info->addUsePermuteFactor(permute_node_factor);
+
+ // Update LowerInfo of output operand
+ auto out_operand_li =
+ nnfw::cpp14::make_unique<operand::LowerInfo>(operand::asShape4D(operand.shape()));
+
+ // The input and output factors of all nodes will be the same except PermuteNode. So Tensor's
+ // allocators allocates memory using only the information of def permutation factor now.
+ // TODO Change param to permute_node_factor
+ out_operand_li->addDefPermuteFactor(factor);
+ out_operand_li->addUsePermuteFactor(factor);
+ _graph.setLowerInfo(out_operand_index, std::move(out_operand_li));
+
+ auto input_backend_ctx = _graph.backend_resolver()->getBackendContext(input_backend);
+ auto output_backend_ctx = _graph.backend_resolver()->getBackendContext(output_backend);
+
+ // Insert permute operation to the graph
+ const auto input_layout =
+ _graph.getLowerInfo(operand_index)->def_factors().getOnlyElement().layout();
+ const auto output_layout = factor.layout();
+ using PermuteNode = model::operation::PermuteNode;
+ const auto permute_type = [&]() {
+ if (input_layout == model::Layout::NHWC && output_layout == model::Layout::NCHW)
+ {
+ return PermuteNode::Type::NHWC_TO_NCHW;
+ }
+ else if (input_layout == model::Layout::NCHW && output_layout == model::Layout::NHWC)
+ {
+ return PermuteNode::Type::NCHW_TO_NHWC;
+ }
+ else
+ {
+ return PermuteNode::Type::COPY;
+ }
+ }();
+ auto insert_node = nnfw::cpp14::make_unique<PermuteNode>(
+ operand_index, out_operand_index, input_backend_ctx, output_backend_ctx, permute_type);
+
+ auto node_index = _graph.operations().push(std::move(insert_node));
+ const auto &node = _graph.operations().at(node_index);
+
+ // Subgraph
+ {
+ auto subg_index = _graph.subgraphs().emplace(node_index, node, permute_node_layout);
+ auto &subg = _graph.subgraphs().at(subg_index);
+ subg.setInputs(node.getInputs());
+ subg.setOutputs(node.getOutputs());
+ _graph.setLowerInfo(subg_index, nnfw::cpp14::make_unique<graph::operation::LowerInfo>(
+ permute_node_backend, permute_node_layout));
+ }
+
+ // Update Use/Def info
+ {
+ _graph.operands().at(operand_index).appendUse(node_index);
+ _graph.operands().at(out_operand_index).appendDef(node_index);
+ }
+ return node_index;
+}
+} // namespace pass
+} // namespace graph
+} // namespace neurun