diff options
Diffstat (limited to 'runtime/neurun/core/src/ir/pass/PermutationEliminationPass.cc')
-rw-r--r-- | runtime/neurun/core/src/ir/pass/PermutationEliminationPass.cc | 195 |
1 files changed, 195 insertions, 0 deletions
diff --git a/runtime/neurun/core/src/ir/pass/PermutationEliminationPass.cc b/runtime/neurun/core/src/ir/pass/PermutationEliminationPass.cc new file mode 100644 index 000000000..71f3d7e82 --- /dev/null +++ b/runtime/neurun/core/src/ir/pass/PermutationEliminationPass.cc @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PermutationEliminationPass.h" + +#include "ir/Operand.h" +#include "ir/operand/LowerInfo.h" +#include "ir/Graph.h" +#include "backend/IConfig.h" +#include "util/logging.h" +#include "compiler/BackendResolver.h" + +namespace neurun +{ +namespace ir +{ +namespace pass +{ +void PermutationEliminationPass::callback(const OperandIndex &inp_index, Operand &object) +{ + if (_graph.getInputs().contains(inp_index)) + { + eliminateInput(inp_index, object); + } + else if (_graph.getOutputs().contains(inp_index)) + { + eliminateOutput(inp_index, object); + } +} + +void PermutationEliminationPass::eliminateInput(const OperandIndex &inp_index, Operand &object) +{ + auto &model_inputs = _graph.getInputs(); + + // get uses of the model's given input + auto uses = object.getUses(); + + // input must be used just by permutation + if (uses.size() != 1) + { + return; + } + + for (auto input_use : uses.list()) + { + auto &perm_operation = _graph.operations().at(input_use); + auto perm_inputs = perm_operation.getInputs(); + + auto perm_outputs = perm_operation.getOutputs(); + + if (!isPermuteLayerToEliminate(perm_inputs, perm_outputs, true)) + { + return; + } + + assert(perm_inputs.at(0) == inp_index); + + VERBOSE(PermutationEliminationPass::EliminateInput) << "remove NHWC_TO_NCHW permutation\n"; + + // set model's new input, which was output of permutation + model_inputs.replace(inp_index, perm_outputs.at(0)); + + // remove model's input, which is also input of permutation + _graph.removeOperand(inp_index); + + // remove permutation operation + assert(_graph.subgraphs().containsOperation(input_use)); + auto subg_idx = _graph.subgraphs().getOperation(input_use); + _graph.subgraphs().remove(subg_idx); + _graph.operations().remove(input_use); + + VERBOSE(PermutationEliminationPass::EliminateInput) + << inp_index.value() << " is model's input and is removed. New input is " + << perm_outputs.at(0).value() << "\n" + << input_use.value() << " is removed permutation operation\n"; + } +} + +void PermutationEliminationPass::eliminateOutput(const OperandIndex &out_index, Operand &object) +{ + auto &model_outputs = _graph.getOutputs(); + + // get defs of the model's given output + auto defs = object.getDef(); + + // output must use just permutation + if (defs.size() != 1) + { + return; + } + + for (auto output_def : defs.list()) + { + auto &perm_operation = _graph.operations().at(output_def); + auto perm_outputs = perm_operation.getOutputs(); + + auto perm_inputs = perm_operation.getInputs(); + if (!isPermuteLayerToEliminate(perm_inputs, perm_outputs, false)) + { + return; + } + + assert(perm_outputs.at(0) == out_index); + + VERBOSE(PermutationEliminationPass::EliminateOutput) << "remove NCHW_TO_NHWC permutation\n"; + + // Update operations' output that is used by permute operand + for (auto perm_input_index : perm_inputs) + { + auto &perm_input_operand = _graph.operands().at(perm_input_index); + perm_input_operand.removeUse(output_def); + } + + // set model's new output, which was input of permutation + model_outputs.replace(out_index, perm_inputs.at(0)); + + // remove model's output, which is also output of permutation + _graph.removeOperand(out_index); + + // remove permutation operation + assert(_graph.subgraphs().containsOperation(output_def)); + auto subg_idx = _graph.subgraphs().getOperation(output_def); + _graph.subgraphs().remove(subg_idx); + _graph.operations().remove(output_def); + + VERBOSE(PermutationEliminationPass::EliminateOutput) + << out_index.value() << " is model's output and is removed. New output is " + << perm_inputs.at(0).value() << "\n" + << output_def.value() << " is removed permutation operation\n"; + } +} + +bool PermutationEliminationPass::isPermuteLayerToEliminate(const OperandIndexSequence &inp_indexes, + const OperandIndexSequence &out_indexes, + bool is_for_model_input) +{ + auto input_def_factors = _graph.getLowerInfo(inp_indexes.at(0))->def_factors(); + auto output_def_factors = _graph.getLowerInfo(out_indexes.at(0))->def_factors(); + + auto input_layout = input_def_factors.getOnlyElement().layout(); + auto output_layout = output_def_factors.getOnlyElement().layout(); + + if (input_def_factors.size() != 1 || output_def_factors.size() != 1) + { + return false; + } + + // all operands' factor must be the same + for (auto index : inp_indexes) + { + auto op_factor_set = _graph.getLowerInfo(index)->def_factors(); + if (op_factor_set.size() != 1 || + input_layout != _graph.getLowerInfo(index)->def_factors().getOnlyElement().layout()) + { + return false; + } + } + // all operands' factor must be the same + for (auto index : out_indexes) + { + auto op_factor_set = _graph.getLowerInfo(index)->def_factors(); + if (op_factor_set.size() != 1 || + output_layout != _graph.getLowerInfo(index)->def_factors().getOnlyElement().layout()) + { + return false; + } + } + + if (is_for_model_input) + { + // check if this is NHWC_TO_NCHW permutation: must have single input, which is model's input + return (inp_indexes.size() == 1 && input_layout == Layout::NHWC && + output_layout == Layout::NCHW); + } + + // check if this is NCHW_TO_NHWC permutation: must have single output, which is model's output + return (out_indexes.size() == 1 && input_layout == Layout::NCHW && output_layout == Layout::NHWC); +} + +} // namespace pass +} // namespace ir +} // namespace neurun |