/* * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef __NEURUN_GRAPH_PASS_PERMUTATION_ELIMINATION_PASS_H__ #define __NEURUN_GRAPH_PASS_PERMUTATION_ELIMINATION_PASS_H__ #include "OperandPass.h" #include "model/Operand.h" #include "model/OperandIndexSequence.h" namespace neurun { namespace graph { namespace pass { class PermutationEliminationPass : public OperandPass { public: using OperandPass::OperandPass; public: std::string id() override { return "PermutationEliminationPass"; } void callback(const model::OperandIndex &index, model::Operand &object) override; private: /** * @brief Remove Permute operation that permutates input * * Note: This function aslo removes model's input and * sets output of permutation as model's new input * * @param inp_index is the target operand index for the elimination * @param object is the target operand object for the elimination * * @return */ void eliminateInput(const model::OperandIndex &inp_index, model::Operand &object); /** * @brief Remove Permute operation that permutates output of a model * * Note: This function aslo removes model's output and * sets input of permutation as model's new output * * @param out_index is the target operand index for the elimination * @param object is the target operand object for the elimination * * @return */ void eliminateOutput(const model::OperandIndex &out_index, model::Operand &object); /** * @brief Determine if passed operands are permute layer's input and output, that must be * eliminated * * @param inp_index indexes of the input operand to operation * @param out_index indexes of the output operand to operation * @param is_for_model_input checking for model's input or output * * @return if it is permutation layer */ bool isPermuteLayerToEliminate(const model::OperandIndexSequence &inp_indexes, const model::OperandIndexSequence &out_indexes, bool is_for_model_input); }; } // namespace pass } // namespace graph } // namespace neurun #endif // __NEURUN_GRAPH_PASS_PERMUTATION_ELIMINATION_PASS_H__