summaryrefslogtreecommitdiff
path: root/runtime/onert/backend/cpu/ConstantInitializer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/onert/backend/cpu/ConstantInitializer.cc')
-rw-r--r--runtime/onert/backend/cpu/ConstantInitializer.cc35
1 files changed, 29 insertions, 6 deletions
diff --git a/runtime/onert/backend/cpu/ConstantInitializer.cc b/runtime/onert/backend/cpu/ConstantInitializer.cc
index 71e313628..deb27f0fe 100644
--- a/runtime/onert/backend/cpu/ConstantInitializer.cc
+++ b/runtime/onert/backend/cpu/ConstantInitializer.cc
@@ -15,6 +15,7 @@
*/
#include "ConstantInitializer.h"
+#include "Tensor.h"
namespace onert
{
@@ -30,39 +31,61 @@ ConstantInitializer::ConstantInitializer(const ir::Operands &operands,
// DO NOTHING
}
+void ConstantInitializer::registerDefaultInitializer(const ir::OperandIndex &index,
+ const ir::Operand &obj)
+{
+ registerExternalInitializer(index, obj);
+}
+
+void ConstantInitializer::registerExternalInitializer(const ir::OperandIndex &index,
+ const ir::Operand &obj)
+{
+ // For only CONSTANTS
+ // TODO Add to check if tensor has been allocated
+ if (!obj.isConstant())
+ return;
+
+ _init_map[index] = [](const onert::ir::Operand &model_obj, onert::backend::ITensor &itensor) {
+ auto data = model_obj.shareData();
+ assert(data && data->base());
+ ExternalTensor &tensor = dynamic_cast<ExternalTensor &>(itensor);
+ tensor.setData(data);
+ };
+}
+
void ConstantInitializer::visit(const ir::operation::Conv2D &node)
{
const auto &kernel_index = node.getInputs().at(ir::operation::Conv2D::KERNEL);
const auto &kernel_obj = _operands.at(kernel_index);
- registerCopyInitializer(kernel_index, kernel_obj);
+ registerExternalInitializer(kernel_index, kernel_obj);
const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::BIAS);
const auto &bias_obj = _operands.at(bias_index);
- registerCopyInitializer(bias_index, bias_obj);
+ registerExternalInitializer(bias_index, bias_obj);
}
void ConstantInitializer::visit(const ir::operation::DepthwiseConv2D &node)
{
const auto &kernel_index = node.getInputs().at(ir::operation::DepthwiseConv2D::KERNEL);
const auto &kernel_obj = _operands.at(kernel_index);
- registerCopyInitializer(kernel_index, kernel_obj);
+ registerExternalInitializer(kernel_index, kernel_obj);
const auto &bias_index = node.getInputs().at(ir::operation::DepthwiseConv2D::BIAS);
const auto &bias_obj = _operands.at(bias_index);
- registerCopyInitializer(bias_index, bias_obj);
+ registerExternalInitializer(bias_index, bias_obj);
}
void ConstantInitializer::visit(const ir::operation::FullyConnected &node)
{
const auto &weight_index = node.getInputs().at(ir::operation::FullyConnected::WEIGHT);
const auto &weight_obj = _operands.at(weight_index);
- registerCopyInitializer(weight_index, weight_obj);
+ registerExternalInitializer(weight_index, weight_obj);
const auto &bias_index = node.getInputs().at(ir::operation::FullyConnected::BIAS);
if (!bias_index.undefined())
{
const auto &bias_obj = _operands.at(bias_index);
- registerCopyInitializer(bias_index, bias_obj);
+ registerExternalInitializer(bias_index, bias_obj);
}
}