diff options
Diffstat (limited to 'runtime/onert/backend/cpu/ConstantInitializer.cc')
-rw-r--r-- | runtime/onert/backend/cpu/ConstantInitializer.cc | 35 |
1 files changed, 29 insertions, 6 deletions
diff --git a/runtime/onert/backend/cpu/ConstantInitializer.cc b/runtime/onert/backend/cpu/ConstantInitializer.cc index 71e313628..deb27f0fe 100644 --- a/runtime/onert/backend/cpu/ConstantInitializer.cc +++ b/runtime/onert/backend/cpu/ConstantInitializer.cc @@ -15,6 +15,7 @@ */ #include "ConstantInitializer.h" +#include "Tensor.h" namespace onert { @@ -30,39 +31,61 @@ ConstantInitializer::ConstantInitializer(const ir::Operands &operands, // DO NOTHING } +void ConstantInitializer::registerDefaultInitializer(const ir::OperandIndex &index, + const ir::Operand &obj) +{ + registerExternalInitializer(index, obj); +} + +void ConstantInitializer::registerExternalInitializer(const ir::OperandIndex &index, + const ir::Operand &obj) +{ + // For only CONSTANTS + // TODO Add to check if tensor has been allocated + if (!obj.isConstant()) + return; + + _init_map[index] = [](const onert::ir::Operand &model_obj, onert::backend::ITensor &itensor) { + auto data = model_obj.shareData(); + assert(data && data->base()); + ExternalTensor &tensor = dynamic_cast<ExternalTensor &>(itensor); + tensor.setData(data); + }; +} + void ConstantInitializer::visit(const ir::operation::Conv2D &node) { const auto &kernel_index = node.getInputs().at(ir::operation::Conv2D::KERNEL); const auto &kernel_obj = _operands.at(kernel_index); - registerCopyInitializer(kernel_index, kernel_obj); + registerExternalInitializer(kernel_index, kernel_obj); const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::BIAS); const auto &bias_obj = _operands.at(bias_index); - registerCopyInitializer(bias_index, bias_obj); + registerExternalInitializer(bias_index, bias_obj); } void ConstantInitializer::visit(const ir::operation::DepthwiseConv2D &node) { const auto &kernel_index = node.getInputs().at(ir::operation::DepthwiseConv2D::KERNEL); const auto &kernel_obj = _operands.at(kernel_index); - registerCopyInitializer(kernel_index, kernel_obj); + registerExternalInitializer(kernel_index, kernel_obj); const auto &bias_index = node.getInputs().at(ir::operation::DepthwiseConv2D::BIAS); const auto &bias_obj = _operands.at(bias_index); - registerCopyInitializer(bias_index, bias_obj); + registerExternalInitializer(bias_index, bias_obj); } void ConstantInitializer::visit(const ir::operation::FullyConnected &node) { const auto &weight_index = node.getInputs().at(ir::operation::FullyConnected::WEIGHT); const auto &weight_obj = _operands.at(weight_index); - registerCopyInitializer(weight_index, weight_obj); + registerExternalInitializer(weight_index, weight_obj); const auto &bias_index = node.getInputs().at(ir::operation::FullyConnected::BIAS); if (!bias_index.undefined()) { const auto &bias_obj = _operands.at(bias_index); - registerCopyInitializer(bias_index, bias_obj); + registerExternalInitializer(bias_index, bias_obj); } } |