diff options
Diffstat (limited to 'runtimes/neurun/core/src/graph/pass/PermutationInsertionPass.cc')
-rw-r--r-- | runtimes/neurun/core/src/graph/pass/PermutationInsertionPass.cc | 210 |
1 files changed, 210 insertions, 0 deletions
diff --git a/runtimes/neurun/core/src/graph/pass/PermutationInsertionPass.cc b/runtimes/neurun/core/src/graph/pass/PermutationInsertionPass.cc new file mode 100644 index 000000000..0f07b47fe --- /dev/null +++ b/runtimes/neurun/core/src/graph/pass/PermutationInsertionPass.cc @@ -0,0 +1,210 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PermutationInsertionPass.h" + +#include <cassert> +#include <utility> +#include <unordered_map> + +#include "model/Operand.h" +#include "graph/operation/LowerInfo.h" +#include "graph/Graph.h" +#include "backend/IConfig.h" +#include "util/logging.h" +#include "cpp14/memory.h" +#include "model/operation/PermuteNode.h" +#include "graph/operand/Shape4DConvert.h" +#include "compiler/BackendResolver.h" + +namespace neurun +{ +namespace graph +{ +namespace pass +{ + +void PermutationInsertionPass::callback(const model::OperandIndex &index, model::Operand &object) +{ + auto &&operand_li = _graph.getLowerInfo(index); + assert(operand_li); + + // NOTE Later, constants also will have Def + // Ignore constants + if (operand_li->def_factors().size() == 0) + { + return; + } + + std::list<model::OperationIndex> permute_indexes; + + // Build a map for all necessary type of operands + std::unordered_map<operand::PermuteFactor, model::OperandIndex> factor_to_index; + { + assert(operand_li->def_factors().size() == 1); + for (auto factor : operand_li->def_factors()) + { + factor_to_index.emplace(factor, index); + } + + auto insert_set = operand_li->use_factors() - operand_li->def_factors(); + for (auto factor : insert_set) + { + const auto permute_operation_index = insertPermute(index, factor); + permute_indexes.push_back(permute_operation_index); + VERBOSE(PermutationInsertionPass) << "Insert 'Permute' operation for operand " + << index.value() << std::endl; + const auto &permute_operation = _graph.operations().at(permute_operation_index); + const auto permuted_operand_index = permute_operation.getOutputs().at(0); + factor_to_index.emplace(factor, permuted_operand_index); + } + } + + // Update operations' input that uses this operand + { + std::list<model::OperationIndex> remove_list; + + auto uses = object.getUses(); + for (auto use : uses.list()) + { + // If permute operation, ignore it + if (std::find(permute_indexes.begin(), permute_indexes.end(), use) != permute_indexes.end()) + continue; + + auto &operation = _graph.operations().at(use); + assert(_graph.subgraphs().containsOperation(use)); + auto subg_index = _graph.subgraphs().getOperation(use); + auto subg_li = _graph.getLowerInfo(subg_index); + assert(subg_li); + const auto subg_layout = subg_li->layout(); + const backend::Backend *backend = subg_li->backend(); + assert(backend); + auto use_node_inputs = operation.getInputs(); + assert(use_node_inputs.contains(index)); + + auto new_index = factor_to_index.at({backend, subg_layout}); + if (index != new_index) + { + // Update from subgraph + _graph.subgraphs().at(subg_index).replaceInput(index, new_index); + + // Update from operation + operation.replaceInput(index, new_index); + + // Update from operand + remove_list.push_back( + use); // Removal should be done in another loop since we are in the loop + _graph.operands().at(new_index).appendUse(use); + } + } + + for (auto &operation : remove_list) + { + object.removeUse(operation); + } + } +} + +model::OperationIndex +PermutationInsertionPass::insertPermute(const model::OperandIndex &operand_index, + const operand::PermuteFactor &factor) +{ + assert(!_graph.isBuildingPhase()); + + auto &operand = _graph.operands().at(operand_index); + + // Generate output operand and permute operation + auto out_operand_index = _graph.addOperand(operand.shape(), operand.typeInfo()); + // change model output if operand_index is model output index + auto &model_outputs = _graph.getOutputs(); + if (model_outputs.contains(operand_index)) + { + model_outputs.replace(operand_index, out_operand_index); + } + + // Find PermuteNode information + auto input_backend = _graph.getLowerInfo(operand_index)->def_factors().getOnlyElement().backend(); + auto output_backend = factor.backend(); + // NOTE PermuteNode may not have specific layout because the layout of input and output may be + // different. + const auto permute_node_layout = model::Layout::UNKNOWN; + const auto permute_node_backend = backend::BackendManager::instance().getDefault(); + const operand::PermuteFactor permute_node_factor{permute_node_backend, permute_node_layout}; + + // Update LowerInfo of input operand + auto operand_lower_info = _graph.getLowerInfo(operand_index); + operand_lower_info->removeUsePermuteFactor(factor); + operand_lower_info->addUsePermuteFactor(permute_node_factor); + + // Update LowerInfo of output operand + auto out_operand_li = + nnfw::cpp14::make_unique<operand::LowerInfo>(operand::asShape4D(operand.shape())); + + // The input and output factors of all nodes will be the same except PermuteNode. So Tensor's + // allocators allocates memory using only the information of def permutation factor now. + // TODO Change param to permute_node_factor + out_operand_li->addDefPermuteFactor(factor); + out_operand_li->addUsePermuteFactor(factor); + _graph.setLowerInfo(out_operand_index, std::move(out_operand_li)); + + auto input_backend_ctx = _graph.backend_resolver()->getBackendContext(input_backend); + auto output_backend_ctx = _graph.backend_resolver()->getBackendContext(output_backend); + + // Insert permute operation to the graph + const auto input_layout = + _graph.getLowerInfo(operand_index)->def_factors().getOnlyElement().layout(); + const auto output_layout = factor.layout(); + using PermuteNode = model::operation::PermuteNode; + const auto permute_type = [&]() { + if (input_layout == model::Layout::NHWC && output_layout == model::Layout::NCHW) + { + return PermuteNode::Type::NHWC_TO_NCHW; + } + else if (input_layout == model::Layout::NCHW && output_layout == model::Layout::NHWC) + { + return PermuteNode::Type::NCHW_TO_NHWC; + } + else + { + return PermuteNode::Type::COPY; + } + }(); + auto insert_node = nnfw::cpp14::make_unique<PermuteNode>( + operand_index, out_operand_index, input_backend_ctx, output_backend_ctx, permute_type); + + auto node_index = _graph.operations().push(std::move(insert_node)); + const auto &node = _graph.operations().at(node_index); + + // Subgraph + { + auto subg_index = _graph.subgraphs().emplace(node_index, node, permute_node_layout); + auto &subg = _graph.subgraphs().at(subg_index); + subg.setInputs(node.getInputs()); + subg.setOutputs(node.getOutputs()); + _graph.setLowerInfo(subg_index, nnfw::cpp14::make_unique<graph::operation::LowerInfo>( + permute_node_backend, permute_node_layout)); + } + + // Update Use/Def info + { + _graph.operands().at(operand_index).appendUse(node_index); + _graph.operands().at(out_operand_index).appendDef(node_index); + } + return node_index; +} +} // namespace pass +} // namespace graph +} // namespace neurun |