/* * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "SubTensorAnalyzer.h" #include #include "cpp14/memory.h" #include "model/OperandIndexSequence.h" #include "util/logging.h" #include "util/Coordinates.h" namespace neurun { namespace compiler { void SubTensorAnalyzer::visit(const model::operation::ConcatNode &node) { // If operator is concat (or other operators related with subsumption), fill subsumption info // TODO: if one tensor is subset of many parents or model input // Solution 1. Handle 1st parent only, ignore others (need to invert for other children) // Solution 2. Insert copy operation for other parents int32_t axis_raw = node.param().axis; auto &output_index = node.getOutputs().at(0); auto &inputs = node.getInputs(); int32_t axis_point = 0; const auto rank = _ctx.at(output_index).shape().rank(); int32_t axis = axis_raw < 0 ? (axis_raw + rank) : axis_raw; assert(rank > axis); // NOTE Not support multiple parent tensor yet for (auto &input_index : inputs) { if (_ctx.at(input_index).parent_info() != nullptr) { return; } } for (auto &input_index : inputs) { auto input_shape = _ctx.at(input_index).shape(); assert(rank == input_shape.rank()); neurun::util::Coordinates coordinate_info{}; for (int i = 0; i < rank; i++) { coordinate_info.set(i, 0); } coordinate_info.set(axis, axis_point); std::unique_ptr parentInfo = nnfw::cpp14::make_unique(output_index, coordinate_info); _ctx.at(input_index).parent_info(std::move(parentInfo)); axis_point += input_shape.dim(axis); } } } // namespace compiler } // namespace neurun