/* * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "ConstantInitializer.h" namespace neurun { namespace backend { namespace acl_neon { ConstantInitializer::ConstantInitializer(const model::Operands &operands, const std::shared_ptr &tensor_builder) : _operands{operands}, _tensor_builder{tensor_builder} { // DO NOTHING } void ConstantInitializer::run() { for (const auto &it : _init_map) { const auto &ind = it.first; const auto &fn = it.second; const auto &model_obj = _operands.at(ind); auto tensor_obj = _tensor_builder->wrapTensor(ind); fn(model_obj, *tensor_obj); } _init_map.clear(); } void ConstantInitializer::visit(const model::operation::Conv2DNode &node) { const auto &kernel_index = node.getInputs().at(model::operation::Conv2DNode::KERNEL); const auto &kernel_obj = _operands.at(kernel_index); registerPermuteInitializer(kernel_index, kernel_obj); const auto &bias_index = node.getInputs().at(model::operation::Conv2DNode::BIAS); const auto &bias_obj = _operands.at(bias_index); registerCopyInitializer(bias_index, bias_obj); } void ConstantInitializer::visit(const model::operation::DepthwiseConv2DNode &node) { const auto &kernel_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::KERNEL); const auto &kernel_obj = _operands.at(kernel_index); registerPermuteInitializer(kernel_index, kernel_obj); const auto &bias_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::BIAS); const auto &bias_obj = _operands.at(bias_index); registerCopyInitializer(bias_index, bias_obj); } void ConstantInitializer::visit(const model::operation::FullyConnectedNode &node) { const auto &weight_index = node.getInputs().at(model::operation::FullyConnectedNode::WEIGHT); const auto &weight_obj = _operands.at(weight_index); registerCopyInitializer(weight_index, weight_obj); const auto &bias_index = node.getInputs().at(model::operation::FullyConnectedNode::BIAS); const auto &bias_obj = _operands.at(bias_index); registerCopyInitializer(bias_index, bias_obj); } void ConstantInitializer::visit(const model::operation::LSTMNode &node) { const auto &input_to_input_weights_index = node.getInputs().at(model::operation::LSTMNode::INPUT_TO_INPUT_WEIGHTS); const auto &input_to_input_weights_obj = _operands.at(input_to_input_weights_index); registerCopyInitializer(input_to_input_weights_index, input_to_input_weights_obj); const auto &input_to_forget_weights_index = node.getInputs().at(model::operation::LSTMNode::INPUT_TO_FORGET_WEIGHTS); const auto &input_to_forget_weights_obj = _operands.at(input_to_forget_weights_index); registerCopyInitializer(input_to_forget_weights_index, input_to_forget_weights_obj); const auto &input_to_cell_weights_index = node.getInputs().at(model::operation::LSTMNode::INPUT_TO_CELL_WEIGHTS); const auto &input_to_cell_weights_obj = _operands.at(input_to_cell_weights_index); registerCopyInitializer(input_to_cell_weights_index, input_to_cell_weights_obj); const auto &input_to_output_weights_index = node.getInputs().at(model::operation::LSTMNode::INPUT_TO_OUTPUT_WEIGHTS); const auto &input_to_output_weights_obj = _operands.at(input_to_output_weights_index); registerCopyInitializer(input_to_output_weights_index, input_to_output_weights_obj); const auto &recurrent_to_input_weights_index = node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_INPUT_WEIGHTS); const auto &recurrent_to_input_weights_obj = _operands.at(recurrent_to_input_weights_index); registerCopyInitializer(recurrent_to_input_weights_index, recurrent_to_input_weights_obj); const auto &recurrent_to_forget_weights_index = node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_FORGET_WEIGHTS); const auto &recurrent_to_forget_weights_obj = _operands.at(recurrent_to_forget_weights_index); registerCopyInitializer(recurrent_to_forget_weights_index, recurrent_to_forget_weights_obj); const auto &recurrent_to_cell_weights_index = node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_CELL_WEIGHTS); const auto &recurrent_to_cell_weights_obj = _operands.at(recurrent_to_cell_weights_index); registerCopyInitializer(recurrent_to_cell_weights_index, recurrent_to_cell_weights_obj); const auto &recurrent_to_output_weights_index = node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_OUTPUT_WEIGHTS); const auto &recurrent_to_output_weights_obj = _operands.at(recurrent_to_output_weights_index); registerCopyInitializer(recurrent_to_output_weights_index, recurrent_to_output_weights_obj); const auto &cell_to_input_weights_index = node.getInputs().at(model::operation::LSTMNode::CELL_TO_INPUT_WEIGHTS); const auto &cell_to_input_weights_obj = _operands.at(cell_to_input_weights_index); registerCopyInitializer(cell_to_input_weights_index, cell_to_input_weights_obj); const auto &cell_to_forget_weights_index = node.getInputs().at(model::operation::LSTMNode::CELL_TO_FORGET_WEIGHTS); const auto &cell_to_forget_weights_obj = _operands.at(cell_to_forget_weights_index); registerCopyInitializer(cell_to_forget_weights_index, cell_to_forget_weights_obj); const auto &cell_to_output_weights_index = node.getInputs().at(model::operation::LSTMNode::CELL_TO_OUTPUT_WEIGHTS); const auto &cell_to_output_weights_obj = _operands.at(cell_to_output_weights_index); registerCopyInitializer(cell_to_output_weights_index, cell_to_output_weights_obj); const auto &input_gate_bias_index = node.getInputs().at(model::operation::LSTMNode::INPUT_GATE_BIAS); const auto &input_gate_bias_obj = _operands.at(input_gate_bias_index); registerCopyInitializer(input_gate_bias_index, input_gate_bias_obj); const auto &forget_gate_bias_index = node.getInputs().at(model::operation::LSTMNode::FORGET_GATE_BIAS); const auto &forget_gate_bias_obj = _operands.at(forget_gate_bias_index); registerCopyInitializer(forget_gate_bias_index, forget_gate_bias_obj); const auto &output_gate_bias_index = node.getInputs().at(model::operation::LSTMNode::OUTPUT_GATE_BIAS); const auto &output_gate_bias_obj = _operands.at(output_gate_bias_index); registerCopyInitializer(output_gate_bias_index, output_gate_bias_obj); const auto &projection_weights_index = node.getInputs().at(model::operation::LSTMNode::PROJECTION_WEIGHTS); const auto &projection_weights_obj = _operands.at(projection_weights_index); registerCopyInitializer(projection_weights_index, projection_weights_obj); const auto &projection_bias_index = node.getInputs().at(model::operation::LSTMNode::PROJECTION_BIAS); const auto &projection_bias_obj = _operands.at(projection_bias_index); registerCopyInitializer(projection_bias_index, projection_bias_obj); } void ConstantInitializer::visit(const model::operation::RNNNode &node) { const auto &weights_index = node.getInputs().at(model::operation::RNNNode::WEIGHTS); const auto &weights_obj = _operands.at(weights_index); registerCopyInitializer(weights_index, weights_obj); const auto &recurrent_weights_index = node.getInputs().at(model::operation::RNNNode::RECURRENT_WEIGHTS); const auto &recurrent_weights_obj = _operands.at(recurrent_weights_index); registerCopyInitializer(recurrent_weights_index, recurrent_weights_obj); const auto &bias_index = node.getInputs().at(model::operation::RNNNode::BIAS); const auto &bias_obj = _operands.at(bias_index); registerCopyInitializer(bias_index, bias_obj); } void ConstantInitializer::visit(const model::operation::TransposeConvNode &node) { const auto &kernel_index = node.getInputs().at(model::operation::TransposeConvNode::KERNEL); const auto &kernel_obj = _operands.at(kernel_index); registerPermuteInitializer(kernel_index, kernel_obj); } } // namespace acl_neon } // namespace backend } // namespace neurun