diff options
Diffstat (limited to 'runtime/neurun/core/src/ir/pass/PermutationOperationPass.cc')
-rw-r--r-- | runtime/neurun/core/src/ir/pass/PermutationOperationPass.cc | 230 |
1 files changed, 230 insertions, 0 deletions
diff --git a/runtime/neurun/core/src/ir/pass/PermutationOperationPass.cc b/runtime/neurun/core/src/ir/pass/PermutationOperationPass.cc new file mode 100644 index 000000000..41a1ad903 --- /dev/null +++ b/runtime/neurun/core/src/ir/pass/PermutationOperationPass.cc @@ -0,0 +1,230 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PermutationOperationPass.h" + +#include "backend/Backend.h" +#include "backend/IConfig.h" +#include "ir/Graph.h" + +namespace neurun +{ +namespace ir +{ +namespace pass +{ + +void PermutationOperationPass::callback(const OperationIndex &, Operation &node) +{ + node.accept(*this); +}; + +void PermutationOperationPass::changeToKeepLayout(const Operation &node) +{ + const auto &output_ind = node.getOutputs().at(0); + const auto &output_obj = _graph.operands().at(output_ind); + + assert(output_obj.getDef().size() == 1); + const auto &node_index = output_obj.getDef().list().front(); + const auto &subg_index = _graph.subgraphs().getOperation(node_index); + + const auto frontend_layout = _graph.subgraphs().at(subg_index).getLayout(); + const auto backend_layout = _graph.getLowerInfo(subg_index)->layout(); + + if (frontend_layout == backend_layout) + { + return; + } + + // CPU supports only NHWC now + if (_graph.getLowerInfo(subg_index)->backend()->config()->id() != "cpu") + { + // TODO Change backend of this node + assert(frontend_layout == Layout::NHWC || backend_layout == Layout::UNKNOWN); + } + + // Divide op_seq based on target operation + { + auto &above_subg = _graph.subgraphs().at(subg_index); + + // Create new op_seq and move information from existing op_seq to new op_seq if target + // node is the end of op_seq + auto it = above_subg.begin(); + // Find iterator of target node in op_seq + while ((it++)->index != node_index) + ; + if (it != above_subg.end()) + { + const auto &below_subg_index = + _graph.subgraphs().emplace(it->index, *it->node, above_subg.getLayout()); + auto &below_subg = _graph.subgraphs().at(below_subg_index); + below_subg.setInputs(it->node->getInputs()); + below_subg.setOutputs(it->node->getOutputs()); + + std::vector<OperationIndex> remove_list; + remove_list.emplace_back(it->index); + while (++it != above_subg.end()) + { + below_subg.appendOperation(it->index, *it->node); + below_subg.setOutputs(it->node->getOutputs()); + remove_list.emplace_back(it->index); + } + + above_subg.setOutputs(node.getOutputs()); + for (const auto &index : remove_list) + { + above_subg.remove(index); + } + + const auto subg_li = _graph.getLowerInfo(subg_index); + _graph.setLowerInfo(below_subg_index, nnfw::cpp14::make_unique<operation::LowerInfo>( + subg_li->backend(), subg_li->layout())); + } + } + + // Remove target operation from op_seq and insert the target operation to new op_seq + { + const auto backend = _graph.getLowerInfo(subg_index)->backend(); + + // Remove target operation from subraph + _graph.subgraphs().removeFromSubgraph(node_index); + + if (!_graph.subgraphs().exist(subg_index)) + { + // Remove lowerinfo for op_seq of target operation if the op_seq does not exist + _graph.removeLowerInfo(subg_index); + } + else + { + // Update op_seq of target operation if the op_seq exists + auto &above_subg = _graph.subgraphs().at(subg_index); + const auto last_node = (--above_subg.end())->node; + above_subg.setOutputs(last_node->getOutputs()); + } + + // Create new op_seq and set information to the op_seq + auto new_subg_index = _graph.subgraphs().emplace(node_index, node, frontend_layout); + auto &new_subg = _graph.subgraphs().at(new_subg_index); + new_subg.setInputs(node.getInputs()); + new_subg.setOutputs(node.getOutputs()); + _graph.setLowerInfo(new_subg_index, + nnfw::cpp14::make_unique<operation::LowerInfo>(backend, frontend_layout)); + } + + // Change PermuteFactors of operands of target node + { + const auto &subg_index = _graph.subgraphs().getOperation(node_index); + const auto subg_li = _graph.getLowerInfo(subg_index); + const auto backend = subg_li->backend(); + const operand::PermuteFactor removed_factor{backend, backend_layout}; + const operand::PermuteFactor new_factor{backend, frontend_layout}; + for (const auto &input : node.getInputs()) + { + bool canRemove = true; + for (const auto &use : _graph.operands().at(input).getUses().list()) + { + if (use != node_index) + { + const auto &use_subg_index = _graph.subgraphs().getOperation(use); + auto use_subg_li = _graph.getLowerInfo(use_subg_index); + if (use_subg_li->backend() == backend && use_subg_li->layout() == backend_layout) + { + canRemove = false; + break; + } + } + } + + auto lower_info = _graph.getLowerInfo(input); + if (canRemove) + { + lower_info->removeUsePermuteFactor(removed_factor); + } + lower_info->addUsePermuteFactor(new_factor); + + // Whether if node's input is an input of model or a constant + if (_graph.operands().at(input).getDef().size() == 0) + { + assert(_graph.getInputs().contains(input) || _graph.operands().at(input).isConstant()); + lower_info->removeDefPermuteFactor(removed_factor); + lower_info->addDefPermuteFactor(new_factor); + } + } + + for (const auto &output : node.getOutputs()) + { + auto lower_info = _graph.getLowerInfo(output); + lower_info->removeDefPermuteFactor(removed_factor); + lower_info->addDefPermuteFactor(new_factor); + + // Whether if node's output is an output of model + if (_graph.operands().at(output).getUses().size() == 0) + { + assert(_graph.getOutputs().contains(output)); + lower_info->removeUsePermuteFactor(removed_factor); + lower_info->addUsePermuteFactor(new_factor); + } + } + } +} + +void PermutationOperationPass::visit(const operation::FullyConnected &node) +{ + const auto &input_ind = node.getInputs().at(operation::FullyConnected::Input::INPUT); + const auto &input_obj = _graph.operands().at(input_ind); + const auto &input_shape = input_obj.shape(); + + if (input_shape.rank() == 4) + { + changeToKeepLayout(node); + } +} + +void PermutationOperationPass::visit(const operation::Gather &node) +{ + const auto &input_ind = node.getInputs().at(operation::Gather::Input::INPUT); + const auto &input_obj = _graph.operands().at(input_ind); + const auto &input_shape = input_obj.shape(); + + const auto &output_ind = node.getOutputs().at(0); + const auto &output_obj = _graph.operands().at(output_ind); + const auto &output_shape = output_obj.shape(); + + if (input_shape.rank() >= 4 || output_shape.rank() >= 4) + { + changeToKeepLayout(node); + } +} + +void PermutationOperationPass::visit(const operation::Reshape &node) +{ + const auto &input_ind = node.getInputs().at(operation::Reshape::Input::INPUT); + const auto &input_obj = _graph.operands().at(input_ind); + const auto &input_shape = input_obj.shape(); + + const auto &output_ind = node.getOutputs().at(0); + const auto &output_obj = _graph.operands().at(output_ind); + const auto &output_shape = output_obj.shape(); + + if (input_shape.rank() >= 4 || output_shape.rank() >= 4) + { + changeToKeepLayout(node); + } +} + +} // namespace pass +} // namespace ir +} // namespace neurun |