summaryrefslogtreecommitdiff
path: root/runtime/onert/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/onert/core/src')
-rw-r--r--runtime/onert/core/src/backend/BackendContext.cc2
-rw-r--r--runtime/onert/core/src/backend/basic/StaticTensorManager.cc9
-rw-r--r--runtime/onert/core/src/backend/basic/Tensor.cc2
-rw-r--r--runtime/onert/core/src/backend/basic/TensorBuilder.cc8
-rw-r--r--runtime/onert/core/src/backend/basic/train/TrainableTensor.cc49
-rw-r--r--runtime/onert/core/src/backend/builtin/Backend.h28
-rw-r--r--runtime/onert/core/src/backend/builtin/BackendContext.cc2
-rw-r--r--runtime/onert/core/src/backend/builtin/Config.cc2
-rw-r--r--runtime/onert/core/src/backend/builtin/Config.h2
-rw-r--r--runtime/onert/core/src/backend/builtin/KernelGenerator.cc8
-rw-r--r--runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc4
-rw-r--r--runtime/onert/core/src/backend/builtin/train/BackendContext.cc78
-rw-r--r--runtime/onert/core/src/backend/builtin/train/BackendContext.h76
-rw-r--r--runtime/onert/core/src/backend/builtin/train/KernelGenerator.cc98
-rw-r--r--runtime/onert/core/src/backend/builtin/train/KernelGenerator.h75
-rw-r--r--runtime/onert/core/src/backend/builtin/train/Tensor.h40
-rw-r--r--runtime/onert/core/src/backend/builtin/train/TensorRegistry.h132
-rw-r--r--runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.cc85
-rw-r--r--runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.h60
-rw-r--r--runtime/onert/core/src/compiler/Compiler.cc38
-rw-r--r--runtime/onert/core/src/compiler/CompilerFactory.cc15
-rw-r--r--runtime/onert/core/src/compiler/CompilerHelpers.h52
-rw-r--r--runtime/onert/core/src/compiler/CompilerOptions.cc1
-rw-r--r--runtime/onert/core/src/compiler/ExecutorFactory.cc452
-rw-r--r--runtime/onert/core/src/compiler/ExecutorFactory.h66
-rw-r--r--runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc4
-rw-r--r--runtime/onert/core/src/compiler/HEScheduler.cc22
-rw-r--r--runtime/onert/core/src/compiler/HEScheduler.h8
-rw-r--r--runtime/onert/core/src/compiler/HEScheduler.test.cc6
-rw-r--r--runtime/onert/core/src/compiler/Linear.cc6
-rw-r--r--runtime/onert/core/src/compiler/Linear.h6
-rw-r--r--runtime/onert/core/src/compiler/LoweredGraph.cc22
-rw-r--r--runtime/onert/core/src/compiler/ManualScheduler.cc6
-rw-r--r--runtime/onert/core/src/compiler/MultiModelCompiler.cc46
-rw-r--r--runtime/onert/core/src/compiler/MultiModelCompiler.h6
-rw-r--r--runtime/onert/core/src/compiler/ShapeValidator.cc2
-rw-r--r--runtime/onert/core/src/compiler/StaticShapeInferer.cc36
-rw-r--r--runtime/onert/core/src/compiler/pass/ConstantInsertionPass.cc4
-rw-r--r--runtime/onert/core/src/compiler/pass/ConstantInsertionPass.h2
-rw-r--r--runtime/onert/core/src/compiler/pass/ConstantLoweringPass.cc4
-rw-r--r--runtime/onert/core/src/compiler/pass/ConstantLoweringPass.h2
-rw-r--r--runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc2
-rw-r--r--runtime/onert/core/src/compiler/pass/IPass.h41
-rw-r--r--runtime/onert/core/src/compiler/pass/LoweredOperandPass.h6
-rw-r--r--runtime/onert/core/src/compiler/pass/LoweredOperationPass.h8
-rw-r--r--runtime/onert/core/src/compiler/pass/OperationPass.cc4
-rw-r--r--runtime/onert/core/src/compiler/pass/OperationPass.h4
-rw-r--r--runtime/onert/core/src/compiler/pass/Pass.h4
-rw-r--r--runtime/onert/core/src/compiler/pass/PassRunner.cc2
-rw-r--r--runtime/onert/core/src/compiler/pass/PassRunner.h6
-rw-r--r--runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc9
-rw-r--r--runtime/onert/core/src/compiler/pass/PermutationEliminationPass.h2
-rw-r--r--runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc6
-rw-r--r--runtime/onert/core/src/compiler/pass/PermutationOperationPass.cc2
-rw-r--r--runtime/onert/core/src/compiler/pass/PermutationOperationPass.h2
-rw-r--r--runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.cc6
-rw-r--r--runtime/onert/core/src/compiler/train/LoweredTrainableGraph.cc285
-rw-r--r--runtime/onert/core/src/compiler/train/StaticDerivativeShapeInferer.cc150
-rw-r--r--runtime/onert/core/src/compiler/train/StaticDerivativeShapeInferer.h80
-rw-r--r--runtime/onert/core/src/compiler/train/TensorRegistries.h105
-rw-r--r--runtime/onert/core/src/compiler/train/TrainableOperationConverter.cc86
-rw-r--r--runtime/onert/core/src/compiler/train/TrainableOperationConverter.h57
-rw-r--r--runtime/onert/core/src/compiler/train/TrainingCompiler.cc299
-rw-r--r--runtime/onert/core/src/compiler/train/TrainingCompiler.h83
-rw-r--r--runtime/onert/core/src/compiler/train/UntrainableOperationConverter.cc53
-rw-r--r--runtime/onert/core/src/compiler/train/UntrainableOperationConverter.h52
-rw-r--r--runtime/onert/core/src/compiler/train/pass/LossInsertionPass.cc77
-rw-r--r--runtime/onert/core/src/compiler/train/pass/LossInsertionPass.h55
-rw-r--r--runtime/onert/core/src/compiler/train/pass/Pass.h64
-rw-r--r--runtime/onert/core/src/dumper/dot/DotBuilder.cc4
-rw-r--r--runtime/onert/core/src/dumper/dot/DotDumper.cc15
-rw-r--r--runtime/onert/core/src/dumper/dot/DotDumper.h4
-rw-r--r--runtime/onert/core/src/dumper/dot/OperationNode.cc2
-rw-r--r--runtime/onert/core/src/dumper/dot/OperationNode.h4
-rw-r--r--runtime/onert/core/src/dumper/h5/Dumper.cc34
-rw-r--r--runtime/onert/core/src/dumper/h5/Dumper.h51
-rw-r--r--runtime/onert/core/src/dumper/h5/MinMaxDumper.cc75
-rw-r--r--runtime/onert/core/src/dumper/h5/MinMaxDumper.h70
-rw-r--r--runtime/onert/core/src/dumper/text/GraphDumper.cc28
-rw-r--r--runtime/onert/core/src/dumper/text/GraphDumper.h15
-rw-r--r--runtime/onert/core/src/exec/DataflowExecutor.cc10
-rw-r--r--runtime/onert/core/src/exec/DynamicShapeInferer.cc8
-rw-r--r--runtime/onert/core/src/exec/ExecTime.test.cc2
-rw-r--r--runtime/onert/core/src/exec/Execution.cc42
-rw-r--r--runtime/onert/core/src/exec/ExecutionObservers.cc2
-rw-r--r--runtime/onert/core/src/exec/ExecutionObservers.h2
-rw-r--r--runtime/onert/core/src/exec/ExecutorBase.cc2
-rw-r--r--runtime/onert/core/src/exec/ExecutorBase.h1
-rw-r--r--runtime/onert/core/src/exec/Executors.cc2
-rw-r--r--runtime/onert/core/src/exec/FunctionSequence.cc1
-rw-r--r--runtime/onert/core/src/exec/LinearExecutor.h2
-rw-r--r--runtime/onert/core/src/exec/MinMaxRecorder.cc112
-rw-r--r--runtime/onert/core/src/exec/MinMaxRecorder.h56
-rw-r--r--runtime/onert/core/src/exec/ParallelScheduler.cc2
-rw-r--r--runtime/onert/core/src/exec/train/TrainableExecutor.cc204
-rw-r--r--runtime/onert/core/src/exec/train/TrainableExecutor.h109
-rw-r--r--runtime/onert/core/src/exec/train/TrainableExecutors.cc89
-rw-r--r--runtime/onert/core/src/exec/train/TrainableExecutors.h92
-rw-r--r--runtime/onert/core/src/exec/train/TrainableFnSequence.cc67
-rw-r--r--runtime/onert/core/src/exec/train/optimizer/OptimizerCode.cc42
-rw-r--r--runtime/onert/core/src/exec/train/optimizer/OptimizerHelpers.h47
-rw-r--r--runtime/onert/core/src/exec/train/optimizer/SGD.cc66
-rw-r--r--runtime/onert/core/src/ir/Graph.cc59
-rw-r--r--runtime/onert/core/src/ir/LayoutSet.cc8
-rw-r--r--runtime/onert/core/src/ir/LayoutSet.h1
-rw-r--r--runtime/onert/core/src/ir/OperandIndexSequence.cc9
-rw-r--r--runtime/onert/core/src/ir/OperationCloner.cc2
-rw-r--r--runtime/onert/core/src/ir/OperationCloner.h2
-rw-r--r--runtime/onert/core/src/ir/OperationDumper.cc8
-rw-r--r--runtime/onert/core/src/ir/OperationDumper.h1
-rw-r--r--runtime/onert/core/src/ir/OperationValidator.cc6
-rw-r--r--runtime/onert/core/src/ir/Operations.cc2
-rw-r--r--runtime/onert/core/src/ir/operation/Loss.cc52
-rw-r--r--runtime/onert/core/src/ir/train/TrainableGraph.cc145
-rw-r--r--runtime/onert/core/src/ir/train/operation/Conv2D.cc49
-rw-r--r--runtime/onert/core/src/ir/train/operation/ElementwiseActivation.cc49
-rw-r--r--runtime/onert/core/src/ir/train/operation/FullyConnected.cc49
-rw-r--r--runtime/onert/core/src/ir/train/operation/Loss.cc48
-rw-r--r--runtime/onert/core/src/ir/train/operation/Permute.cc50
-rw-r--r--runtime/onert/core/src/ir/train/operation/Pool2D.cc49
-rw-r--r--runtime/onert/core/src/ir/train/operation/Reshape.cc49
-rw-r--r--runtime/onert/core/src/ir/train/operation/Softmax.cc49
-rw-r--r--runtime/onert/core/src/ir/verifier/Verifier.cc16
-rw-r--r--runtime/onert/core/src/odc/QuantizeManager.cc50
-rw-r--r--runtime/onert/core/src/odc/QuantizeManager.test.cc36
-rw-r--r--runtime/onert/core/src/odc/QuantizerLoader.cc104
-rw-r--r--runtime/onert/core/src/odc/QuantizerLoader.h89
-rw-r--r--runtime/onert/core/src/odc/QuantizerLoader.test.cc63
-rw-r--r--runtime/onert/core/src/util/MDTableEventWriter.cc4
129 files changed, 5164 insertions, 263 deletions
diff --git a/runtime/onert/core/src/backend/BackendContext.cc b/runtime/onert/core/src/backend/BackendContext.cc
index b9aab7994..7b36f106d 100644
--- a/runtime/onert/core/src/backend/BackendContext.cc
+++ b/runtime/onert/core/src/backend/BackendContext.cc
@@ -16,8 +16,6 @@
#include "backend/BackendContext.h"
-#include "ir/Operation.h"
-
namespace onert
{
namespace backend
diff --git a/runtime/onert/core/src/backend/basic/StaticTensorManager.cc b/runtime/onert/core/src/backend/basic/StaticTensorManager.cc
index b03eb607c..71cde4cde 100644
--- a/runtime/onert/core/src/backend/basic/StaticTensorManager.cc
+++ b/runtime/onert/core/src/backend/basic/StaticTensorManager.cc
@@ -35,6 +35,15 @@ StaticTensorManager::StaticTensorManager(const std::shared_ptr<TensorRegistry> &
// DO NOTHING
}
+StaticTensorManager::StaticTensorManager(const std::shared_ptr<TensorRegistry> &reg,
+ const std::string planner_id,
+ DynamicTensorManager *dynamic_tensor_manager)
+ : _nonconst_mgr{new MemoryManager(planner_id)}, _tensors{reg}, _dynamic_tensor_manager{
+ dynamic_tensor_manager}
+{
+ // DO NOTHING
+}
+
void StaticTensorManager::allocateNonconsts(void)
{
_nonconst_mgr->allocate();
diff --git a/runtime/onert/core/src/backend/basic/Tensor.cc b/runtime/onert/core/src/backend/basic/Tensor.cc
index c2bbc5a66..de1cff4f4 100644
--- a/runtime/onert/core/src/backend/basic/Tensor.cc
+++ b/runtime/onert/core/src/backend/basic/Tensor.cc
@@ -51,6 +51,7 @@ bool Tensor::applyShape(const ir::Shape &new_shape)
auto allocTensorMem = [&]() {
auto capacity = total_size();
+ assert(_dynamic_mem_mgr);
auto alloc = _dynamic_mem_mgr->allocate(this, capacity);
setBuffer(alloc);
};
@@ -68,6 +69,7 @@ bool Tensor::applyShape(const ir::Shape &new_shape)
auto new_size = new_shape.num_elements() * ir::sizeOfDataType(data_type());
if (previous_size != new_size)
{
+ assert(_dynamic_mem_mgr);
_dynamic_mem_mgr->deallocate(this);
setShape(new_shape);
diff --git a/runtime/onert/core/src/backend/basic/TensorBuilder.cc b/runtime/onert/core/src/backend/basic/TensorBuilder.cc
index a10cc2bf9..f9d83875d 100644
--- a/runtime/onert/core/src/backend/basic/TensorBuilder.cc
+++ b/runtime/onert/core/src/backend/basic/TensorBuilder.cc
@@ -34,6 +34,14 @@ TensorBuilder::TensorBuilder(const std::shared_ptr<TensorRegistry> &tensor_reg)
/* empty */
}
+TensorBuilder::TensorBuilder(const std::shared_ptr<TensorRegistry> &tensor_reg,
+ const std::string planner_id)
+ : _tensor_reg{tensor_reg}, _dynamic_tensor_mgr{new DynamicTensorManager(_tensor_reg)},
+ _static_tensor_mgr{new StaticTensorManager(_tensor_reg, planner_id, _dynamic_tensor_mgr.get())}
+{
+ /* empty */
+}
+
void TensorBuilder::registerTensorInfo(const ir::OperandIndex &ind, const ir::OperandInfo &info,
ir::Layout layout)
{
diff --git a/runtime/onert/core/src/backend/basic/train/TrainableTensor.cc b/runtime/onert/core/src/backend/basic/train/TrainableTensor.cc
new file mode 100644
index 000000000..d09604224
--- /dev/null
+++ b/runtime/onert/core/src/backend/basic/train/TrainableTensor.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <backend/basic/train/TrainableTensor.h>
+
+namespace onert
+{
+namespace backend
+{
+namespace basic
+{
+namespace train
+{
+
+std::vector<ITensor *> TrainableTensor::optVars()
+{
+ std::vector<ITensor *> ret;
+ for (auto &&e : _opt_vars)
+ {
+ ret.emplace_back(e.get());
+ }
+ return ret;
+}
+
+void TrainableTensor::fillBuffer(const std::shared_ptr<ir::Data> &data)
+{
+ auto *buffer = _tensor.buffer();
+ assert(buffer);
+ assert(total_size() == data->size());
+ std::memcpy(buffer, data->base(), data->size());
+}
+
+} // namespace train
+} // namespace basic
+} // namespace backend
+} // namespace onert
diff --git a/runtime/onert/core/src/backend/builtin/Backend.h b/runtime/onert/core/src/backend/builtin/Backend.h
index 3791f3ffa..c05494a6a 100644
--- a/runtime/onert/core/src/backend/builtin/Backend.h
+++ b/runtime/onert/core/src/backend/builtin/Backend.h
@@ -22,8 +22,16 @@
#include "KernelGenerator.h"
#include "TensorBuilder.h"
#include "Tensor.h"
+#ifdef ONERT_TRAIN
+#include "train/BackendContext.h"
+#include "train/KernelGenerator.h"
+#include "train/TensorRegistry.h"
+#endif // ONERT_TRAIN
#include <backend/Backend.h>
+#ifdef ONERT_TRAIN
+#include <backend/train/ITrainableBackend.h>
+#endif // ONERT_TRAIN
#include <memory>
@@ -35,6 +43,10 @@ namespace builtin
{
class Backend : public ::onert::backend::Backend
+#ifdef ONERT_TRAIN
+ ,
+ public backend::train::ITrainableBackend
+#endif // ONERT_TRAIN
{
public:
Backend() : _config{std::make_shared<Config>()} {}
@@ -70,6 +82,22 @@ public:
return context;
}
+#ifdef ONERT_TRAIN
+ std::unique_ptr<backend::train::TrainableBackendContext>
+ newContext(backend::train::TrainableContextData &&tdata) const override
+ {
+ const auto &tgraph = *tdata.tgraph;
+ auto tr = std::make_shared<train::TensorRegistry>();
+ // TODO Create TensorBuilder if necessary
+ auto tdata_ptr = std::make_unique<backend::train::TrainableContextData>(std::move(tdata));
+ auto context = std::make_unique<train::BackendContext>(this, std::move(tdata_ptr), tr);
+
+ context->kernel_gen =
+ std::make_shared<train::KernelGenerator>(tgraph, tr, context->external_context());
+ return context;
+ }
+#endif // ONERT_TRAIN
+
private:
std::shared_ptr<IConfig> _config;
};
diff --git a/runtime/onert/core/src/backend/builtin/BackendContext.cc b/runtime/onert/core/src/backend/builtin/BackendContext.cc
index c1a2ed537..573617e28 100644
--- a/runtime/onert/core/src/backend/builtin/BackendContext.cc
+++ b/runtime/onert/core/src/backend/builtin/BackendContext.cc
@@ -32,7 +32,7 @@ FunctionMap BackendContext::genKernels()
{
FunctionMap ret;
- for (auto op_ind : _data.op_order)
+ for (auto &&op_ind : _data.op_order)
{
auto fn_seq = kernel_gen->generate(op_ind);
ret.emplace_back(op_ind, std::move(fn_seq));
diff --git a/runtime/onert/core/src/backend/builtin/Config.cc b/runtime/onert/core/src/backend/builtin/Config.cc
index f792c0c36..e5f6d4c21 100644
--- a/runtime/onert/core/src/backend/builtin/Config.cc
+++ b/runtime/onert/core/src/backend/builtin/Config.cc
@@ -27,7 +27,7 @@ std::string Config::ID = "builtin";
bool Config::initialize() { return true; }
-ir::Layout Config::supportLayout(const ir::Operation &, ir::Layout frontend_layout)
+ir::Layout Config::supportLayout(const ir::IOperation &, ir::Layout frontend_layout)
{
return frontend_layout;
}
diff --git a/runtime/onert/core/src/backend/builtin/Config.h b/runtime/onert/core/src/backend/builtin/Config.h
index 5226eba69..196b299d3 100644
--- a/runtime/onert/core/src/backend/builtin/Config.h
+++ b/runtime/onert/core/src/backend/builtin/Config.h
@@ -34,7 +34,7 @@ public:
static std::string ID;
std::string id() override { return ID; }
bool initialize() override;
- ir::Layout supportLayout(const ir::Operation &node, ir::Layout frontend_layout) override;
+ ir::Layout supportLayout(const ir::IOperation &node, ir::Layout frontend_layout) override;
bool supportPermutation() override { return false; }
bool supportDynamicTensor() override
{
diff --git a/runtime/onert/core/src/backend/builtin/KernelGenerator.cc b/runtime/onert/core/src/backend/builtin/KernelGenerator.cc
index 4533703a6..00c200a92 100644
--- a/runtime/onert/core/src/backend/builtin/KernelGenerator.cc
+++ b/runtime/onert/core/src/backend/builtin/KernelGenerator.cc
@@ -71,14 +71,14 @@ void KernelGenerator::visit(const ir::operation::If &node)
const auto else_subg_index = node.param().else_subg_index;
std::vector<backend::IPortableTensor *> input_tensors;
- for (const auto input_index : node.getInputs())
+ for (const auto &input_index : node.getInputs())
{
auto input_tensor = getPortableTensor(input_index);
input_tensors.emplace_back(input_tensor);
}
std::vector<backend::IPortableTensor *> output_tensors;
- for (const auto output_index : node.getOutputs())
+ for (const auto &output_index : node.getOutputs())
{
auto output_tensor = getPortableTensor(output_index);
output_tensors.emplace_back(output_tensor);
@@ -117,14 +117,14 @@ void KernelGenerator::visit(const ir::operation::While &node)
// This op does not support input as a constant, because builtin backend does not have
// TensorBuilder
std::vector<backend::IPortableTensor *> input_tensors;
- for (const auto input_index : node.getInputs())
+ for (const auto &input_index : node.getInputs())
{
auto input_tensor = getPortableTensor(input_index);
input_tensors.emplace_back(input_tensor);
}
std::vector<backend::IPortableTensor *> output_tensors;
- for (const auto output_index : node.getOutputs())
+ for (const auto &output_index : node.getOutputs())
{
auto output_tensor = getPortableTensor(output_index);
output_tensors.emplace_back(output_tensor);
diff --git a/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc b/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc
index c0ca4046c..8b00db468 100644
--- a/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc
+++ b/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc
@@ -96,7 +96,7 @@ void WhileLayer::run()
// Need some temp tensors to hold the body subgraph output
std::vector<std::unique_ptr<Tensor>> temp_outputs_o;
std::vector<IPortableTensor *> temp_outputs;
- for (auto io_tensor : body_exec->getOutputTensors())
+ for (auto &&io_tensor : body_exec->getOutputTensors())
{
auto tensor = std::make_unique<Tensor>(io_tensor->orig_info(), io_tensor->orig_layout(),
_dyn_memory_manager);
@@ -139,7 +139,7 @@ void WhileLayer::run()
// Clean-up the temp tensors
_dyn_memory_manager->deallocate(cond_output_tensor.get());
- for (auto tensor : temp_outputs)
+ for (auto &&tensor : temp_outputs)
{
_dyn_memory_manager->deallocate(tensor);
}
diff --git a/runtime/onert/core/src/backend/builtin/train/BackendContext.cc b/runtime/onert/core/src/backend/builtin/train/BackendContext.cc
new file mode 100644
index 000000000..fa9131f4d
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/train/BackendContext.cc
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "BackendContext.h"
+
+#include "backend/basic/train/TrainableBackendContextHelpers.h"
+#include "exec/FunctionSequence.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace train
+{
+
+backend::ITensorRegistry *BackendContext::genTensors()
+{
+ // For now, there is no need to generate tensors for forwarding.
+ // builtin train backend handles 3 operators: `Permute`, `IF`, `WHILE`.
+ // `Permute`: Tensor generation is not required.
+ // `IF`, `WHILE`: Not supported yet
+ return tensor_registry().get();
+}
+
+backend::train::ITensorRegistry *BackendContext::genTrainingTensors()
+{
+ // For now, there is no need to generate tensors for backwarding.
+ return tensor_registry().get();
+}
+
+backend::train::FunctionMap BackendContext::genKernels()
+{
+ backend::train::FunctionMap ret;
+
+ for (auto &&op_ind : _tdata->op_order)
+ {
+ auto tn_seq = kernel_gen->generate(op_ind);
+ ret.emplace_back(op_ind, std::move(tn_seq));
+ }
+
+ trainable_graph()->operands().iterate(
+ [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
+ if (!external_operands().contains(ind) && operand.isConstant())
+ {
+ throw std::runtime_error(
+ "BackendContext: builtin backend does not support updatable weights yet");
+ }
+ });
+
+ // TODO Enable prepare()
+ // for (auto &&it : ret)
+ // {
+ // auto &fn_seq = it.second;
+ // fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
+ // }
+
+ return ret;
+}
+
+} // namespace train
+} // namespace builtin
+} // namespace backend
+} // namespace onert
diff --git a/runtime/onert/core/src/backend/builtin/train/BackendContext.h b/runtime/onert/core/src/backend/builtin/train/BackendContext.h
new file mode 100644
index 000000000..6f8ce4cae
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/train/BackendContext.h
@@ -0,0 +1,76 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_BUILTIN_TRAIN_BACKEND_CONTEXT_H__
+#define __ONERT_BACKEND_BUILTIN_TRAIN_BACKEND_CONTEXT_H__
+
+#include <backend/train/TrainableBackendContext.h>
+
+#include "KernelGenerator.h"
+#include "../ExternalContext.h"
+#include "../TensorBuilder.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace train
+{
+
+class BackendContext : public backend::train::TrainableBackendContext
+{
+public:
+ BackendContext(const backend::train::ITrainableBackend *backend,
+ std::unique_ptr<backend::train::TrainableContextData> &&data,
+ std::shared_ptr<backend::train::ITensorRegistry> tensor_registry = nullptr,
+ std::shared_ptr<TensorBuilder> tensor_builder = nullptr,
+ std::shared_ptr<KernelGenerator> kernel_gen = nullptr)
+ : backend::train::TrainableBackendContext(backend, std::move(data), tensor_registry),
+ kernel_gen{kernel_gen},
+ _external_context(new ExternalContext), _tensor_builder{tensor_builder}
+ {
+ }
+
+ backend::ITensorRegistry *genTensors() override;
+ backend::train::ITensorRegistry *genTrainingTensors() override;
+
+public:
+ backend::train::FunctionMap genKernels() override;
+
+ std::shared_ptr<ExternalContext> external_context() { return _external_context; }
+
+public:
+ // TODO Make it private
+ std::shared_ptr<KernelGenerator> kernel_gen;
+
+private:
+ // NOTE ruy context has a thread pool, and when multiple ruy contexts are created,
+ // the thread pool is also created in duplicate
+ // TODO Create one ruy context for session
+ std::shared_ptr<ExternalContext> _external_context;
+
+private:
+ std::shared_ptr<TensorBuilder> _tensor_builder;
+};
+
+} // namespace train
+} // namespace builtin
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_BUILTIN_TRAIN_BACKEND_CONTEXT_H__
diff --git a/runtime/onert/core/src/backend/builtin/train/KernelGenerator.cc b/runtime/onert/core/src/backend/builtin/train/KernelGenerator.cc
new file mode 100644
index 000000000..6f2c0a3b9
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/train/KernelGenerator.cc
@@ -0,0 +1,98 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "KernelGenerator.h"
+
+#include "kernel/PermuteLayer.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace train
+{
+
+KernelGenerator::KernelGenerator(const ir::train::TrainableGraph &tgraph,
+ const std::shared_ptr<TensorRegistry> &tensor_reg,
+ const std::shared_ptr<ExternalContext> &external_context)
+ : KernelGeneratorBase{tgraph}, _tensor_reg{tensor_reg}, _external_context(external_context)
+{
+}
+
+std::unique_ptr<exec::train::TrainableFnSequence> KernelGenerator::generate(ir::OperationIndex ind)
+{
+ auto ret = std::make_unique<exec::train::TrainableFnSequence>();
+ const auto &op = _tgraph.operation(ind);
+ op.accept(*this);
+ // _return_fn must have been generated
+ if (_return_fn == nullptr)
+ {
+ throw std::runtime_error(op.name() + " op does not supported trainable kernel yet");
+ }
+
+ ret->_functions.emplace_back(std::move(_return_fn));
+
+ return ret;
+}
+
+void KernelGenerator::visit(const ir::train::operation::Permute &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(0)};
+
+ // Add PermuteLayer
+ std::vector<ITensor *> output_tensors{getTensor(output_index)};
+ std::vector<ITensor *> input_tensors{getTensor(input_index)};
+
+ std::vector<ITensor *> output_deriv_tensors;
+ std::vector<ITensor *> input_deriv_tensors;
+
+ auto input_deriv_tensor = getDerivativeTensor(input_index);
+ auto output_deriv_tensor = getDerivativeTensor(output_index);
+ output_deriv_tensors.emplace_back(output_deriv_tensor);
+ input_deriv_tensors.emplace_back(input_deriv_tensor);
+
+ // NOTE IOTensors of graph outputs for passing data to users must be ignored in training
+ // because the buffers of those IOTensors are unnecessary and nullptr
+ bool ignore_forward_in_training = _whole_graph_outputs.contains(output_index);
+ auto fn = std::make_unique<kernel::PermuteLayer>(input_tensors, output_tensors,
+ input_deriv_tensors, output_deriv_tensors,
+ ignore_forward_in_training, _external_context);
+
+ _return_fn = std::move(fn);
+}
+
+backend::ITensor *KernelGenerator::getTensor(const ir::OperandIndex &index)
+{
+ // Get Tensor from all tensor registries (for Permute op)
+ auto ret = _tensor_registries.getITensor(index);
+ assert(ret != nullptr);
+ return ret;
+}
+
+backend::ITensor *KernelGenerator::getDerivativeTensor(const ir::OperandIndex &index)
+{
+ // Get derivative Tensor from all tensor registries (for Permute op)
+ auto ret = _tensor_registries.getDerivativeITensor(index);
+ return ret;
+}
+
+} // namespace train
+} // namespace builtin
+} // namespace backend
+} // namespace onert
diff --git a/runtime/onert/core/src/backend/builtin/train/KernelGenerator.h b/runtime/onert/core/src/backend/builtin/train/KernelGenerator.h
new file mode 100644
index 000000000..d8781c0d0
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/train/KernelGenerator.h
@@ -0,0 +1,75 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_BUTIN_TRAIN_KERNEL_GENERATOR_H__
+#define __ONERT_BACKEND_BUTIN_TRAIN_KERNEL_GENERATOR_H__
+
+#include "../ExternalContext.h"
+#include "../train/TensorRegistry.h"
+#include "../../../compiler/train/TensorRegistries.h"
+
+#include <backend/train/KernelGeneratorBase.h>
+#include <exec/train/TrainableFnSequence.h>
+#include <ir/train/TrainableGraph.h>
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace train
+{
+
+class KernelGenerator : public backend::train::KernelGeneratorBase
+{
+public:
+ KernelGenerator(const ir::train::TrainableGraph &tgraph,
+ const std::shared_ptr<TensorRegistry> &tensor_reg,
+ const std::shared_ptr<ExternalContext> &external_context);
+
+ std::unique_ptr<exec::train::TrainableFnSequence> generate(ir::OperationIndex ind) override;
+
+ void setTensorRegistries(const compiler::train::TensorRegistries &tensor_registries)
+ {
+ _tensor_registries = tensor_registries;
+ }
+
+ void setWholeGraphOutputs(const ir::OperandIndexSequence &outputs)
+ {
+ _whole_graph_outputs = outputs;
+ }
+
+private:
+ void visit(const ir::train::operation::Permute &) override;
+
+private:
+ backend::ITensor *getTensor(const ir::OperandIndex &index);
+ backend::ITensor *getDerivativeTensor(const ir::OperandIndex &index);
+
+private:
+ std::shared_ptr<TensorRegistry> _tensor_reg;
+ compiler::train::TensorRegistries _tensor_registries;
+ const std::shared_ptr<ExternalContext> _external_context;
+ ir::OperandIndexSequence _whole_graph_outputs;
+};
+
+} // namespace train
+} // namespace builtin
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_BUTIN_TRAIN_KERNEL_GENERATOR_H__
diff --git a/runtime/onert/core/src/backend/builtin/train/Tensor.h b/runtime/onert/core/src/backend/builtin/train/Tensor.h
new file mode 100644
index 000000000..611407bd2
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/train/Tensor.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_BUILTIN_TRAIN_TRAINABLE_TENSOR_H__
+#define __ONERT_BACKEND_BUILTIN_TRAIN_TRAINABLE_TENSOR_H__
+
+#include <backend/basic/train/TrainableTensor.h>
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace train
+{
+
+using TrainableTensor = basic::train::TrainableTensor;
+using DerivativeTensor = basic::Tensor;
+using GradientTensor = basic::Tensor;
+
+} // namespace train
+} // namespace builtin
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_BUILTIN_TRAIN_TRAINABLE_TENSOR_H__
diff --git a/runtime/onert/core/src/backend/builtin/train/TensorRegistry.h b/runtime/onert/core/src/backend/builtin/train/TensorRegistry.h
new file mode 100644
index 000000000..c48e5fe93
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/train/TensorRegistry.h
@@ -0,0 +1,132 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_BUILTIN_TRAIN_TENSOR_REGISTRY_H__
+#define __ONERT_BACKEND_BUILTIN_TRAIN_TENSOR_REGISTRY_H__
+
+#include <backend/train/ITensorRegistry.h>
+
+#include "../IOTensor.h"
+#include "../Tensor.h"
+#include "Tensor.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace train
+{
+
+using BaseTensorRegistry =
+ backend::train::PortableTensorRegistryTemplate<Tensor, TrainableTensor, DerivativeTensor,
+ GradientTensor>;
+
+class TensorRegistry : public backend::train::ITensorRegistry
+{
+public:
+ TensorRegistry() : _base_reg{new BaseTensorRegistry} {}
+
+ ITensor *getITensor(const ir::OperandIndex &index) override
+ {
+ auto base_tensor = _base_reg->getITensor(index);
+ if (base_tensor)
+ return base_tensor;
+ return getNativeIOTensor(index);
+ }
+
+ ITensor *getNativeITensor(const ir::OperandIndex &index) override
+ {
+ auto base_tensor = _base_reg->getNativeITensor(index);
+ if (base_tensor)
+ return base_tensor;
+ return getNativeIOTensor(index);
+ }
+
+ IPortableTensor *getPortableTensor(const ir::OperandIndex &index)
+ {
+ auto base_tensor = _base_reg->getPortableTensor(index);
+ if (base_tensor)
+ return base_tensor;
+ return getNativeIOTensor(index);
+ }
+
+ IOTensor *getNativeIOTensor(const ir::OperandIndex &index)
+ {
+ auto tensor = _native_io_tensors.find(index);
+ if (tensor != _native_io_tensors.end())
+ return tensor->second.get();
+ return nullptr;
+ }
+
+ ITensor *getDerivativeITensor(const ir::OperandIndex &index) override
+ {
+ return _base_reg->getDerivativeTensor(index);
+ }
+
+ ITensor *getGradientITensor(const ir::OperandIndex &index) override
+ {
+ return _base_reg->getGradientTensor(index);
+ }
+
+ DerivativeTensor *getDerivativeTensor(const ir::OperandIndex &index)
+ {
+ return _base_reg->getDerivativeTensor(index);
+ }
+
+ bool setMigrantTensor(const ir::OperandIndex &index, IPortableTensor *tensor) override
+ {
+ assert(tensor);
+ assert(!getITensor(index)); // For the index, tensor is not registered yet
+ _base_reg->setMigrantTensor(index, tensor);
+ return true;
+ }
+
+ void setDerivativeTensor(const ir::OperandIndex &index, std::unique_ptr<DerivativeTensor> tensor)
+ {
+ _base_reg->setDerivativeTensor(index, std::move(tensor));
+ }
+
+ void setGradientTensor(const ir::OperandIndex &index, std::unique_ptr<GradientTensor> tensor)
+ {
+ _base_reg->setGradientTensor(index, std::move(tensor));
+ }
+
+ void setNativeIOTensor(ir::OperandIndex index, std::unique_ptr<IOTensor> &&tensor)
+ {
+ assert(tensor);
+ assert(!getITensor(index)); // For the index, tensor is not registered yet
+ _native_io_tensors[index] = std::move(tensor);
+ }
+
+ const ir::OperandIndexMap<std::unique_ptr<IOTensor>> &native_io_tensors()
+ {
+ return _native_io_tensors;
+ }
+ std::shared_ptr<BaseTensorRegistry> base_reg() { return _base_reg; }
+
+private:
+ std::shared_ptr<BaseTensorRegistry> _base_reg;
+ ir::OperandIndexMap<std::unique_ptr<IOTensor>> _native_io_tensors;
+};
+
+} // namespace train
+} // namespace builtin
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_BUILTIN_TRAIN_TENSOR_REGISTRY_H__
diff --git a/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.cc b/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.cc
new file mode 100644
index 000000000..929092dde
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.cc
@@ -0,0 +1,85 @@
+
+
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PermuteLayer.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace train
+{
+namespace kernel
+{
+
+PermuteLayer::PermuteLayer(const std::vector<ITensor *> &src_tensors,
+ const std::vector<ITensor *> &dst_tensors,
+ const std::vector<ITensor *> &input_deriv_tensors,
+ const std::vector<ITensor *> &output_deriv_tensors,
+ bool ignore_forward_in_training,
+ const std::shared_ptr<ExternalContext> &external_context)
+ : builtin::kernel::PermuteLayer{src_tensors, dst_tensors, external_context},
+ _input_deriv_tensors{input_deriv_tensors}, _output_deriv_tensors{output_deriv_tensors},
+ _ignore_forward_in_training{ignore_forward_in_training}
+{
+ assert(input_deriv_tensors.size() == output_deriv_tensors.size());
+ assert(src_tensors.size() == dst_tensors.size());
+}
+
+void PermuteLayer::optimize()
+{
+ builtin::kernel::PermuteLayer::optimize();
+
+ // TODO Calculate offsets of derivative tensors if necessary
+}
+
+void PermuteLayer::forward(bool training)
+{
+ if (training && _ignore_forward_in_training)
+ return;
+
+ builtin::kernel::PermuteLayer::run();
+}
+
+void PermuteLayer::backward()
+{
+ for (uint32_t i = 0; i < _output_deriv_tensors.size(); ++i)
+ {
+ auto src_deriv = _output_deriv_tensors.at(i);
+ auto dst_deriv = _input_deriv_tensors.at(i);
+
+ // NOTE The derivative tensors corresponding to inputs/outputs of model are nullptr
+ // because permuting those tensors is meaningless
+ if (src_deriv && dst_deriv)
+ {
+ const auto rank = src_deriv->getShape().rank();
+ auto output_offsets = _dst_tensors_offsets.at(i);
+ auto input_offsets = _src_tensors_offsets.at(i);
+
+ exec::IPermuteFunction::permute(src_deriv, dst_deriv, rank, output_offsets, input_offsets);
+ }
+ }
+}
+
+} // namespace kernel
+} // namespace train
+} // namespace builtin
+} // namespace backend
+} // namespace onert
diff --git a/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.h b/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.h
new file mode 100644
index 000000000..de8063a21
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.h
@@ -0,0 +1,60 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_BUILTIN_TRAIN_KERNEL_PERMUTELAYER_H__
+#define __ONERT_BACKEND_BUILTIN_TRAIN_KERNEL_PERMUTELAYER_H__
+
+#include "../../kernel/PermuteLayer.h"
+
+#include "exec/train/ITrainableFunction.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace train
+{
+namespace kernel
+{
+
+class PermuteLayer : public builtin::kernel::PermuteLayer, public exec::train::ITrainableFunction
+{
+public:
+ PermuteLayer(const std::vector<ITensor *> &src_tensors, const std::vector<ITensor *> &dst_tensors,
+ const std::vector<ITensor *> &input_deriv_tensors,
+ const std::vector<ITensor *> &output_deriv_tensors, bool ignore_forward_in_training,
+ const std::shared_ptr<ExternalContext> &external_context);
+
+ void optimize() override;
+
+ void forward(bool training) override;
+ void backward() override;
+
+private:
+ std::vector<ITensor *> _input_deriv_tensors;
+ std::vector<ITensor *> _output_deriv_tensors;
+ bool _ignore_forward_in_training;
+};
+
+} // namespace kernel
+} // namespace train
+} // namespace builtin
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_BUILTIN_TRAIN_KERNEL_PERMUTELAYER_H__
diff --git a/runtime/onert/core/src/compiler/Compiler.cc b/runtime/onert/core/src/compiler/Compiler.cc
index 45124556b..ba621bb4f 100644
--- a/runtime/onert/core/src/compiler/Compiler.cc
+++ b/runtime/onert/core/src/compiler/Compiler.cc
@@ -16,6 +16,7 @@
#include "compiler/Compiler.h"
+#include "CompilerHelpers.h"
#include "ExecutorFactory.h"
#include "ShapeValidator.h"
#include "pass/ConstantOutputPass.h"
@@ -30,6 +31,7 @@
#include "compiler/StaticShapeInferer.h"
#include <misc/string_helpers.h>
+#include <misc/polymorphic_downcast.h>
namespace onert
{
@@ -69,10 +71,25 @@ std::shared_ptr<CompilerArtifact> Compiler::compile(void)
throw std::runtime_error("Profiling mode works only with 'Dataflow' executor");
}
+ if (!_options->minmax_filepath.empty())
+ {
+ if (_options->executor != "Linear")
+ throw std::runtime_error("Recording minmax works only with Linear executor");
+ }
+
+ if (!_model->hasOnly<ir::Graph>())
+ {
+ throw std::runtime_error("Compiler can only compile models for inference.");
+ }
+
_options->forceInternalOptions();
_options->verboseOptions();
- _model->iterate([&](const ir::SubgraphIndex &, ir::Graph &subg) {
+ auto custom_kernel_builder = _model->getKernelBuilder();
+
+ _model->iterate([&](const ir::SubgraphIndex &, ir::IGraph &graph) {
+ auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph);
+
// Mandatory passes
pass::PassRunner{}
.append(std::make_unique<pass::ConstantOutputPass>(subg))
@@ -96,7 +113,9 @@ std::shared_ptr<CompilerArtifact> Compiler::compile(void)
// Lower: Assign backend
std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::LoweredGraph>> lowered_subgs;
{
- _model->iterate([&](const ir::SubgraphIndex &subg_index, ir::Graph &subg) {
+ _model->iterate([&](const ir::SubgraphIndex &subg_index, ir::IGraph &graph) {
+ auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph);
+
// Lower: Assign backend
lowered_subgs[subg_index] = std::make_unique<compiler::LoweredGraph>(subg, *_options);
// Set tracing_ctx for copied graph
@@ -119,7 +138,7 @@ std::shared_ptr<CompilerArtifact> Compiler::compile(void)
// Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called
// recursively
std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers =
- StaticShapeInferer::createStaticShapeInferers(lowered_subgs);
+ createStaticShapeInferers(lowered_subgs);
const auto primary_subg_idx = ir::SubgraphIndex{0};
inferers.at(primary_subg_idx)->infer();
@@ -158,10 +177,15 @@ std::shared_ptr<CompilerArtifact> Compiler::compile(void)
ir::OperationDumper dumper("Executor generation of Subgraph " +
std::to_string(subg_index.value()));
lowered_subg->graph().operations().iterate(
- [&](const ir::OperationIndex &, const ir::Operation &op) { op.accept(dumper); });
-
- auto executor = std::unique_ptr<exec::IExecutor>{ExecutorFactory::get().create(
- std::move(lowered_subg), tracing_ctx.get(), *_options, executors, model_index)};
+ [&](const ir::OperationIndex &, const ir::IOperation &op) { op.accept(dumper); });
+
+ ExecutorFactoryArgs args;
+ args.tracing_ctx = tracing_ctx.get();
+ args.options = _options;
+ args.model_index = model_index;
+ args.custom_kernel_builder = custom_kernel_builder;
+ auto executor = std::unique_ptr<exec::IExecutor>{
+ ExecutorFactory::get().create(std::move(lowered_subg), executors, args)};
executor->setIndexedRanks(indexed_ranks);
executors->emplace(model_index, subg_index, std::move(executor));
}
diff --git a/runtime/onert/core/src/compiler/CompilerFactory.cc b/runtime/onert/core/src/compiler/CompilerFactory.cc
index d8d4bb277..aeb0876c4 100644
--- a/runtime/onert/core/src/compiler/CompilerFactory.cc
+++ b/runtime/onert/core/src/compiler/CompilerFactory.cc
@@ -17,6 +17,9 @@
#include "compiler/CompilerFactory.h"
#include "MultiModelCompiler.h"
+#ifdef ONERT_TRAIN
+#include "train/TrainingCompiler.h"
+#endif // ONERT_TRAIN
#include "compiler/Compiler.h"
@@ -33,8 +36,18 @@ CompilerFactory &CompilerFactory::get()
std::unique_ptr<ICompiler>
CompilerFactory::create(const std::shared_ptr<ir::NNPkg> &nnpkg,
- std::vector<std::unique_ptr<CompilerOptions>> &copts)
+ std::vector<std::unique_ptr<CompilerOptions>> &copts,
+ const compiler::train::TrainingInfo *training_info)
{
+#ifdef ONERT_TRAIN
+ // Returing compiler for training
+ if (training_info)
+ return std::make_unique<train::TrainingCompiler>(nnpkg, copts, *training_info);
+#else // ONERT_TRAIN
+ (void)training_info;
+#endif // ONERT_TRAIN
+
+ // Returing compiler for inference
if (nnpkg->model_count() == 1)
return std::make_unique<Compiler>(nnpkg, copts);
diff --git a/runtime/onert/core/src/compiler/CompilerHelpers.h b/runtime/onert/core/src/compiler/CompilerHelpers.h
new file mode 100644
index 000000000..798334b3b
--- /dev/null
+++ b/runtime/onert/core/src/compiler/CompilerHelpers.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_COMPILER_HELPERS_H__
+#define __ONERT_COMPILER_COMPILER_HELPERS_H__
+
+#include <compiler/ILoweredGraph.h>
+#include <compiler/StaticShapeInferer.h>
+#include <ir/Index.h>
+
+#include <memory>
+#include <unordered_map>
+
+namespace onert
+{
+namespace compiler
+{
+
+/**
+ * @brief Create a shape inferer map for a lowered model
+ * @param[in] lowered_subgs lowered model map
+ * @return Shape inferer map
+ */
+template <typename LoweredGraphType,
+ typename = std::enable_if_t<std::is_base_of<ILoweredGraph, LoweredGraphType>::value>>
+static std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>>
+createStaticShapeInferers(
+ const std::unordered_map<ir::SubgraphIndex, std::unique_ptr<LoweredGraphType>> &lowered_subgs)
+{
+ std::unordered_map<ir::SubgraphIndex, ILoweredGraph *> lsubgs;
+ for (auto &&e : lowered_subgs)
+ lsubgs[e.first] = e.second.get();
+ return StaticShapeInferer::createStaticShapeInferers(lsubgs);
+}
+
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_COMPILER_HELPERS_H__
diff --git a/runtime/onert/core/src/compiler/CompilerOptions.cc b/runtime/onert/core/src/compiler/CompilerOptions.cc
index b5fd392e0..830d9dd00 100644
--- a/runtime/onert/core/src/compiler/CompilerOptions.cc
+++ b/runtime/onert/core/src/compiler/CompilerOptions.cc
@@ -75,6 +75,7 @@ std::unique_ptr<CompilerOptions> CompilerOptions::fromGlobalConfig()
{
auto o = std::make_unique<CompilerOptions>();
o->backend_list = nnfw::misc::split(util::getConfigString(util::config::BACKENDS), ';');
+ o->minmax_filepath = util::getConfigString(util::config::MINMAX_FILEPATH);
o->trace_filepath = util::getConfigString(util::config::TRACE_FILEPATH);
o->graph_dump_level = util::getConfigInt(util::config::GRAPH_DOT_DUMP);
o->executor = util::getConfigString(util::config::EXECUTOR);
diff --git a/runtime/onert/core/src/compiler/ExecutorFactory.cc b/runtime/onert/core/src/compiler/ExecutorFactory.cc
index b09d6b021..6a08524cc 100644
--- a/runtime/onert/core/src/compiler/ExecutorFactory.cc
+++ b/runtime/onert/core/src/compiler/ExecutorFactory.cc
@@ -25,6 +25,9 @@
#include "../exec/ExecTime.h"
#include "../exec/ExecutionObservers.h"
#include "../exec/LinearExecutor.h"
+#ifdef MINMAX_H5DUMPER
+#include "../exec/MinMaxRecorder.h"
+#endif
#include "../exec/ParallelExecutor.h"
#include "../ir/OperationCloner.h"
@@ -36,6 +39,14 @@
#include <functional>
#include <memory>
+#ifdef ONERT_TRAIN
+#include "../backend/builtin/train/BackendContext.h"
+#include "../exec/train/TrainableExecutor.h"
+
+#include <backend/train/TrainableBackendContext.h>
+#include <backend/train/ITrainableBackend.h>
+#endif // ONERT_TRAIN
+
namespace onert
{
namespace
@@ -74,7 +85,7 @@ public:
void run() override
{
- for (auto tensor : _dealloc_list)
+ for (auto &&tensor : _dealloc_list)
{
if (!tensor->is_dynamic())
continue;
@@ -86,7 +97,8 @@ private:
DeallocList _dealloc_list;
};
-void initializeSubgraphIOTensors(compiler::LoweredGraph &lowered_graph,
+// TODO Unify initializeSubgraphIOTensors
+void initializeSubgraphIOTensors(compiler::ILoweredGraph &lowered_graph,
const backend::BackendContexts &backend_contexts,
const ir::OperandIndexSequence &indices)
{
@@ -104,7 +116,38 @@ void initializeSubgraphIOTensors(compiler::LoweredGraph &lowered_graph,
}
assert(builtin_tensor_reg);
- for (auto ind : indices)
+ for (auto &&ind : indices)
+ {
+ const auto &operand = lowered_graph.graph().operands().at(ind);
+ auto tensor = std::make_unique<backend::builtin::IOTensor>(
+ operand.info(),
+ ir::Layout::NHWC /* FIXME find operation for this operand and use frontend_layout */
+ );
+
+ // Add tensor to builtin TensorRegistry.
+ builtin_tensor_reg->setNativeIOTensor(ind, std::move(tensor));
+ }
+}
+
+#ifdef ONERT_TRAIN
+void initializeSubgraphIOTensors(compiler::ILoweredGraph &lowered_graph,
+ const backend::train::TrainableBackendContexts &backend_contexts,
+ const ir::OperandIndexSequence &indices)
+{
+ std::shared_ptr<backend::builtin::train::TensorRegistry> builtin_tensor_reg;
+ for (const auto &e : backend_contexts)
+ {
+ auto backend = e.first;
+ auto &context = e.second;
+ if (backend->config()->id() == backend::builtin::Config::ID)
+ {
+ builtin_tensor_reg = std::dynamic_pointer_cast<backend::builtin::train::TensorRegistry>(
+ context->tensor_registry());
+ }
+ }
+ assert(builtin_tensor_reg);
+
+ for (auto &&ind : indices)
{
const auto &operand = lowered_graph.graph().operands().at(ind);
auto tensor = std::make_unique<backend::builtin::IOTensor>(
@@ -116,8 +159,11 @@ void initializeSubgraphIOTensors(compiler::LoweredGraph &lowered_graph,
builtin_tensor_reg->setNativeIOTensor(ind, std::move(tensor));
}
}
+#endif // ONERT_TRAIN
-backend::BackendContexts createBackendContexts(compiler::LoweredGraph &lgraph, bool linear_executor)
+backend::BackendContexts
+createBackendContexts(compiler::ILoweredGraph &lgraph, bool linear_executor,
+ std::shared_ptr<backend::custom::IKernelBuilder> custom_kernel_builder)
{
backend::BackendContexts contexts;
auto &backend_manager = compiler::BackendManager::get();
@@ -125,7 +171,7 @@ backend::BackendContexts createBackendContexts(compiler::LoweredGraph &lgraph, b
std::unordered_map<const backend::Backend *, backend::ContextData> context_data_map;
// Generate partial graphs for each backend
- for (auto backend : backend_manager.getAll())
+ for (auto &&backend : backend_manager.getAll())
{
auto &data = context_data_map[backend];
auto graph = std::make_unique<ir::Graph>();
@@ -157,7 +203,7 @@ backend::BackendContexts createBackendContexts(compiler::LoweredGraph &lgraph, b
});
// Separate operations into partial graphs
whole_graph.operations().iterate(
- [&](const ir::OperationIndex &op_ind, const ir::Operation &operation) {
+ [&](const ir::OperationIndex &op_ind, const ir::IOperation &operation) {
auto &op_li = lgraph.lower_info().operation;
auto backend = op_li.at(op_ind).backend();
auto &partial_graph = *context_data_map[backend].graph;
@@ -168,7 +214,7 @@ backend::BackendContexts createBackendContexts(compiler::LoweredGraph &lgraph, b
// Add missing operands (externals)
auto io_list = (operation.getInputs() + operation.getOutputs()) | ir::Remove::DUPLICATED |
ir::Remove::UNDEFINED;
- for (auto operand_ind : io_list)
+ for (auto &&operand_ind : io_list)
{
if (partial_graph.operands().exist(operand_ind))
continue;
@@ -217,12 +263,33 @@ backend::BackendContexts createBackendContexts(compiler::LoweredGraph &lgraph, b
std::copy_if(whole_op_order.begin(), whole_op_order.end(), std::back_inserter(data.op_order),
[&](const auto &ind) { return data.graph->operations().exist(ind); });
data.is_linear_executor = linear_executor;
- data.custom_kernel_builder = lgraph.graph().getKernelBuilder();
+ data.custom_kernel_builder = custom_kernel_builder;
contexts.emplace(backend, backend->newContext(std::move(data)));
}
return contexts;
}
+template <typename Context>
+std::deque<std::pair<const backend::Backend *, Context *>> orderBackendContext(
+ const std::unordered_map<const backend::Backend *, std::unique_ptr<Context>> &tbackend_contexts)
+{
+ std::deque<std::pair<const backend::Backend *, Context *>> ordered_contexts;
+
+ for (auto &&pair : tbackend_contexts)
+ {
+ // NOTE builtin backend must be processed lastly.
+ // This is because of Permute layer's specialty which is the only operation that could have
+ // different ITensor objects for the input and the output. And it requires all other backends'
+ // tensors are ready to use.
+ if (pair.first->config()->id() == "builtin")
+ ordered_contexts.emplace_back(pair.first, pair.second.get());
+ else
+ ordered_contexts.emplace_front(pair.first, pair.second.get());
+ }
+
+ return ordered_contexts;
+}
+
} // namespace
} // namespace onert
@@ -240,34 +307,30 @@ ExecutorFactory &ExecutorFactory::get()
ExecutorFactory::ExecutorFactory()
{
_map["Linear"] = createLinearExecutor;
- _map["Dataflow"] =
- std::bind(createDataflowExecutor, std::placeholders::_1, std::placeholders::_2,
- std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, false);
- _map["Parallel"] =
- std::bind(createDataflowExecutor, std::placeholders::_1, std::placeholders::_2,
- std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, true);
+ _map["Dataflow"] = std::bind(createDataflowExecutor, std::placeholders::_1, std::placeholders::_2,
+ std::placeholders::_3, false);
+ _map["Parallel"] = std::bind(createDataflowExecutor, std::placeholders::_1, std::placeholders::_2,
+ std::placeholders::_3, true);
}
exec::IExecutor *ExecutorFactory::create(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
- const util::TracingCtx *tracing_ctx,
- const compiler::CompilerOptions &options,
const std::shared_ptr<exec::IExecutors> &executors,
- const ir::ModelIndex &index)
+ const ExecutorFactoryArgs &args)
{
- return _map.at(options.executor)(std::move(lowered_graph), tracing_ctx, options, executors,
- index);
+ assert(args.options != nullptr);
+ return _map.at(args.options->executor)(std::move(lowered_graph), executors, args);
}
-void ExecutorFactory::prepareMigrantTensors(compiler::LoweredGraph &lowered_graph,
+void ExecutorFactory::prepareMigrantTensors(compiler::ILoweredGraph &lowered_graph,
const backend::BackendContexts &backend_contexts)
{
TensorRegistries tensor_regs{backend_contexts, true};
lowered_graph.graph().operations().iterate(
- [&](const ir::OperationIndex &op_ind, const ir::Operation &op) {
+ [&](const ir::OperationIndex &op_ind, const ir::IOperation &op) {
auto lower_info = lowered_graph.lower_info().operation.getRawPtr(op_ind);
auto &backend_ctx = backend_contexts.at(lower_info->backend());
- for (auto ind :
+ for (auto &&ind :
(op.getInputs() + op.getOutputs()) | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
{
// If an Operation's input/output tensor does not have an own tensor object,
@@ -307,7 +370,6 @@ std::deque<std::pair<const backend::Backend *, backend::BackendContext *>>
ExecutorFactory::orderBackendContext(const backend::BackendContexts &backend_contexts)
{
std::deque<std::pair<const backend::Backend *, backend::BackendContext *>> ordered_contexts;
-
for (auto &&pair : backend_contexts)
{
// NOTE builtin backend must be processed lastly.
@@ -319,19 +381,22 @@ ExecutorFactory::orderBackendContext(const backend::BackendContexts &backend_con
else
ordered_contexts.emplace_front(pair.first, pair.second.get());
}
-
return ordered_contexts;
}
-exec::IExecutor *ExecutorFactory::createLinearExecutor(
- std::unique_ptr<compiler::LoweredGraph> lowered_graph, const util::TracingCtx *tracing_ctx,
- const compiler::CompilerOptions &options, const std::shared_ptr<exec::IExecutors> &executors,
- const ir::ModelIndex &index)
+exec::IExecutor *
+ExecutorFactory::createLinearExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args)
{
+ const auto options = args.options;
+ const auto &model_index = args.model_index;
+ const auto tracing_ctx = args.tracing_ctx;
+ auto custom_kernel_builder = args.custom_kernel_builder;
auto &graph = lowered_graph->graph();
backend::BackendContexts backend_contexts =
- createBackendContexts(*lowered_graph, options.executor == "Linear");
+ createBackendContexts(*lowered_graph, options->executor == "Linear", custom_kernel_builder);
TensorRegistries tensor_regs{backend_contexts, true};
@@ -352,7 +417,7 @@ exec::IExecutor *ExecutorFactory::createLinearExecutor(
prepareMigrantTensors(*lowered_graph, backend_contexts);
// Give some runtime objects to builtin KernelGenerator
- prepareBuiltinBackend(tensor_regs, executors, backend_contexts, index);
+ prepareBuiltinBackend(tensor_regs, executors, backend_contexts, model_index);
ExecutionBuilder builder;
@@ -382,7 +447,7 @@ exec::IExecutor *ExecutorFactory::createLinearExecutor(
uses_map[ind]++;
}
- for (const auto op_ind : order)
+ for (const auto &op_ind : order)
{
const auto &op = graph.operations().at(op_ind);
auto op_inputs = op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED;
@@ -422,7 +487,7 @@ exec::IExecutor *ExecutorFactory::createLinearExecutor(
auto &fn_seq = pair.second;
auto &op = lowered_graph->graph().operations().at(op_ind);
auto lower_info = lowered_graph->lower_info().operation.getRawPtr(op_ind);
- if (options.he_profiling_mode)
+ if (options->he_profiling_mode)
fn_seq->wrap<SyncFunction>(lower_info->backend()->config());
if (!dealloc_list_map[op_ind].empty())
fn_seq->append(std::make_unique<DeallocFunction>(dealloc_list_map[op_ind]));
@@ -439,23 +504,33 @@ exec::IExecutor *ExecutorFactory::createLinearExecutor(
order,
tracing_ctx};
- if (!options.trace_filepath.empty())
+ if (!options->trace_filepath.empty())
{
std::unique_ptr<exec::IExecutionObserver> ctp =
- std::make_unique<exec::TracingObserver>(options.trace_filepath, exec->graph(), tracing_ctx);
+ std::make_unique<exec::TracingObserver>(options->trace_filepath, exec->graph(), tracing_ctx);
exec->addObserver(std::move(ctp));
}
+#ifdef MINMAX_H5DUMPER
+ if (!options->minmax_filepath.empty())
+ exec->addObserver(std::make_unique<exec::MinMaxRecorder>(
+ options->minmax_filepath, exec->graph(), exec->getBackendContexts()));
+#endif
return exec;
}
-exec::IExecutor *ExecutorFactory::createDataflowExecutor(
- std::unique_ptr<compiler::LoweredGraph> lowered_graph, const util::TracingCtx *tracing_ctx,
- const compiler::CompilerOptions &options, const std::shared_ptr<exec::IExecutors> &executors,
- const ir::ModelIndex &index, bool parallel)
+exec::IExecutor *
+ExecutorFactory::createDataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args, bool parallel)
{
+ const auto options = args.options;
+ const auto &model_index = args.model_index;
+ const auto tracing_ctx = args.tracing_ctx;
+ auto custom_kernel_builder = args.custom_kernel_builder;
+
backend::BackendContexts backend_contexts =
- createBackendContexts(*lowered_graph, options.executor == "Linear");
+ createBackendContexts(*lowered_graph, options->executor == "Linear", custom_kernel_builder);
TensorRegistries tensor_regs{backend_contexts, true};
@@ -472,7 +547,7 @@ exec::IExecutor *ExecutorFactory::createDataflowExecutor(
prepareMigrantTensors(*lowered_graph, backend_contexts);
// Give some runtime objects to builtin KernelGenerator
- prepareBuiltinBackend(tensor_regs, executors, backend_contexts, index);
+ prepareBuiltinBackend(tensor_regs, executors, backend_contexts, model_index);
ExecutionBuilder builder;
@@ -489,7 +564,7 @@ exec::IExecutor *ExecutorFactory::createDataflowExecutor(
auto &fn_seq = pair.second;
auto &op = lowered_graph->graph().operations().at(op_ind);
auto lower_info = lowered_graph->lower_info().operation.getRawPtr(op_ind);
- if (options.he_profiling_mode)
+ if (options->he_profiling_mode)
fn_seq->wrap<SyncFunction>(lower_info->backend()->config());
builder.append(op_ind, {op_ind, &op, lower_info, std::move(fn_seq)});
}
@@ -508,7 +583,7 @@ exec::IExecutor *ExecutorFactory::createDataflowExecutor(
auto dataflow_exec =
new exec::DataflowExecutor{std::move(lowered_graph), std::move(backend_contexts), tensor_regs,
std::move(code_map), tracing_ctx};
- if (options.he_profiling_mode)
+ if (options->he_profiling_mode)
{
std::vector<const backend::Backend *> backends;
for (const auto &pair : backend_contexts)
@@ -523,15 +598,304 @@ exec::IExecutor *ExecutorFactory::createDataflowExecutor(
exec = dataflow_exec;
}
- if (!options.trace_filepath.empty())
+ if (!options->trace_filepath.empty())
+ {
+ std::unique_ptr<exec::IExecutionObserver> ctp =
+ std::make_unique<exec::TracingObserver>(options->trace_filepath, exec->graph(), tracing_ctx);
+ exec->addObserver(std::move(ctp));
+ }
+
+ return exec;
+}
+
+#ifdef ONERT_TRAIN
+exec::IExecutor *
+ExecutorFactory::create(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args,
+ const std::shared_ptr<exec::train::optimizer::Optimizer> &optimizer)
+{
+ assert(args.options != nullptr);
+
+ if (args.options->executor != "Linear")
+ throw std::runtime_error("ExecutorFactory: TrainableExecutor supports only 'Linear' now");
+
+ return createTrainableExecutor(std::move(lowered_graph), executors, args, optimizer);
+}
+
+void ExecutorFactory::prepareMigrantTensors(
+ compiler::ILoweredGraph &lowered_graph,
+ const backend::train::TrainableBackendContexts &backend_contexts)
+{
+ train::TensorRegistries tensor_regs{backend_contexts, true};
+
+ lowered_graph.graph().operations().iterate(
+ [&](const ir::OperationIndex &op_ind, const ir::IOperation &op) {
+ auto lower_info = lowered_graph.lower_info().operation.getRawPtr(op_ind);
+ auto &backend_ctx = backend_contexts.at(lower_info->backend());
+ for (auto &&ind :
+ (op.getInputs() + op.getOutputs()) | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
+ {
+ // If an Operation's input/output tensor does not have an own tensor object,
+ // it must be using migrant tensors, so find the tensor from other tensor registries and
+ // register it to the current tensor registry if it is portable
+ if (!backend_ctx->tensor_registry()->getITensor(ind))
+ {
+ auto tensor = tensor_regs.getITensor(ind);
+ assert(tensor); // The tensor must have been registered
+ auto ptensor = dynamic_cast<backend::IPortableTensor *>(tensor);
+ if (ptensor)
+ backend_ctx->tensor_registry()->setMigrantTensor(ind, ptensor);
+ }
+ }
+ });
+}
+
+exec::IExecutor *ExecutorFactory::createTrainableExecutor(
+ std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
+ const std::shared_ptr<exec::IExecutors> &, const ExecutorFactoryArgs &args,
+ const std::shared_ptr<exec::train::optimizer::Optimizer> &optimizer)
+{
+ const auto options = args.options;
+ const auto tracing_ctx = args.tracing_ctx;
+ auto custom_kernel_builder = args.custom_kernel_builder;
+
+ auto &graph = lowered_graph->graph();
+
+ lowered_graph->trainable_graph().operations().iterate([](const onert::ir::OperationIndex &,
+ const onert::ir::IOperation &op) {
+ try
+ {
+ UNUSED_RELEASE(dynamic_cast<const ir::train::ITrainableOperation &>(op));
+ }
+ catch (std::bad_cast &)
+ {
+ throw std::runtime_error("ExecutorFactory: " + op.name() + " is not trainable operation yet");
+ }
+ });
+
+ // TODO Create context only once instead of replacing
+ backend::train::TrainableBackendContexts tbackend_contexts;
+ backend::BackendContexts base_backend_contexts =
+ createBackendContexts(*lowered_graph, true, custom_kernel_builder);
+
+ // Replace BackendContext with TrainbleBackendContext
+ for (auto &&pair : base_backend_contexts)
+ {
+ auto ctx = pair.second.get();
+ const auto &data = ctx->data();
+
+ // Create partial and trainable graphs
+ auto tgraph = std::make_unique<ir::train::TrainableGraph>(*data.graph);
+ data.graph->operations().iterate(
+ [&](const onert::ir::OperationIndex &op_index, const onert::ir::IOperation &) {
+ const auto &orig_tgraph = lowered_graph->trainable_graph();
+ const auto &trainable_op = orig_tgraph.operation(op_index);
+ auto gen_index = tgraph->replaceOperation(op_index, trainable_op.clone());
+ UNUSED_RELEASE(gen_index);
+ assert(gen_index == op_index);
+ });
+ data.graph->operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &) {
+ const auto &orig_tgraph = lowered_graph->trainable_graph();
+ if (orig_tgraph.derivatives().exist(index))
+ {
+ const auto &deriv = orig_tgraph.derivatives().at(index);
+ auto new_deriv = std::make_unique<ir::Operand>(deriv);
+ auto gen_index = tgraph->addDerivative(index, std::move(new_deriv));
+ UNUSED_RELEASE(gen_index);
+ assert(gen_index == index);
+ }
+ });
+
+ // Remove outputs of whole graph from external_operands
+ auto external_operands = data.external_operands;
+ for (const auto &index : lowered_graph->trainable_graph().getOutputs())
+ {
+ if (external_operands.contains(index))
+ external_operands.remove(index);
+ }
+
+ // Set trainable context data
+ backend::train::TrainableContextData tdata;
+ tdata.tgraph = std::move(tgraph);
+ tdata.op_order = std::move(data.op_order);
+ tdata.external_operands = std::move(external_operands);
+ tdata.operand_layouts = std::move(data.operand_layouts);
+ tdata.custom_kernel_builder = std::move(data.custom_kernel_builder);
+ tdata.is_linear_executor = data.is_linear_executor;
+ tdata.optimizer = optimizer;
+
+ // TODO Remove dynamic_cast
+ try
+ {
+ const auto backend = pair.first;
+ const auto tbackend = dynamic_cast<const backend::train::ITrainableBackend *>(backend);
+ tbackend_contexts.emplace(backend, tbackend->newContext(std::move(tdata)));
+ }
+ catch (const std::bad_cast &)
+ {
+ throw std::runtime_error("ExecutorFactory: Invalid backend - TrainableExecutor does not "
+ "support non-trainble backends");
+ }
+ }
+ base_backend_contexts.clear();
+
+ train::TensorRegistries tensor_regs{tbackend_contexts, true};
+
+ initializeSubgraphIOTensors(
+ *lowered_graph, tbackend_contexts,
+ (lowered_graph->graph().getInputs() + lowered_graph->graph().getOutputs()) |
+ ir::Remove::DUPLICATED | ir::Remove::UNDEFINED);
+
+ // linearize
+ auto order = Linear::linearize(*lowered_graph);
+ Linear::dump(*lowered_graph, order);
+
+ for (auto &&pair : tbackend_contexts)
+ {
+ pair.second->genTensors();
+ }
+
+ for (auto &&pair : tbackend_contexts)
+ {
+ auto tctx = pair.second.get();
+ tctx->genTrainingTensors();
+ }
+
+ prepareMigrantTensors(*lowered_graph, tbackend_contexts);
+
+ // Give some runtime objects to builtin KernelGenerator
+ for (auto &&pair : tbackend_contexts)
+ {
+ auto builtin_context =
+ dynamic_cast<backend::builtin::train::BackendContext *>(pair.second.get());
+ if (builtin_context != nullptr)
+ {
+ auto builtin_kernel_gen = builtin_context->kernel_gen;
+ builtin_kernel_gen->setTensorRegistries(tensor_regs);
+ builtin_kernel_gen->setWholeGraphOutputs(lowered_graph->trainable_graph().getOutputs());
+ }
+ }
+
+ // Adjust the order of backends for the upcoming iteration
+ auto ordered_contexts =
+ onert::orderBackendContext<backend::train::TrainableBackendContext>(tbackend_contexts);
+
+ // TODO Remove this simulation
+ // Simulate the execution for deallocation of tensors
+ std::unordered_map<ir::OperationIndex, DeallocList> dealloc_list_map;
+ {
+ ir::OperandIndexMap<uint32_t> uses_map;
+ ir::OperandIndexSequence constants;
+
+ auto model_io =
+ (graph.getInputs() + graph.getOutputs()) | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
+
+ // Prepare scanning
+ graph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) {
+ uses_map[ind] = obj.getUses().size();
+
+ if (obj.isConstant())
+ constants.append(ind);
+ });
+
+ // A trick to consider constants as an execption
+ for (const auto &ind : constants)
+ {
+ uses_map[ind]++;
+ }
+
+ for (const auto op_ind : order)
+ {
+ const auto &op = graph.operations().at(op_ind);
+ auto op_inputs = op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED;
+ auto op_outputs = op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED;
+
+ for (const auto &ind : op_inputs)
+ {
+ const auto &operand = graph.operands().at(ind);
+ assert(uses_map.find(ind) != uses_map.end());
+ assert(uses_map[ind] > 0);
+ uses_map[ind]--;
+ if (uses_map[ind] == 0 && !operand.info().isVariable() && !model_io.contains(ind))
+ {
+ dealloc_list_map[op_ind].emplace_back(tensor_regs.getITensor(ind));
+ }
+ }
+ }
+
+ // Dispose and validate
+ for (const auto &ind : constants)
+ {
+ --uses_map[ind];
+ }
+
+ assert(
+ std::all_of(uses_map.begin(), uses_map.end(),
+ [](std::pair<const ir::OperandIndex, uint32_t> it) { return it.second == 0; }));
+ }
+
+ // Check derivative tensors
+ {
+ // TODO Support multiple subgraphs
+ // Check if the derivative tensors corresponding to inputs of model are nullptr
+ // NOTE The derivative tensors corresponding to inputs of model are for inputs of PermuteLayers
+ // and they are nullptr and because they are meaningless.
+ assert(std::all_of(lowered_graph->trainable_graph().getInputs().begin(),
+ lowered_graph->trainable_graph().getInputs().end(),
+ [&](const auto &input_idx) {
+ return tensor_regs.getDerivativeITensor(input_idx) == nullptr;
+ }));
+
+ // Check if the derivative tensors corresponding to outputs of model exist
+ assert(std::all_of(lowered_graph->trainable_graph().getOutputs().begin(),
+ lowered_graph->trainable_graph().getOutputs().end(),
+ [&](const auto &output_idx) {
+ return tensor_regs.getDerivativeITensor(output_idx) == nullptr;
+ }));
+ }
+
+ train::TrainableCodeMap code_map;
+ // Generate kernels
+ for (auto &&pair : ordered_contexts)
+ {
+ auto codes = pair.second->genKernels();
+ for (auto &&pair : codes)
+ {
+ auto &op_ind = pair.first;
+ auto &tn_seq = pair.second;
+ auto &op = lowered_graph->trainable_graph().operation(op_ind);
+ auto lower_info = lowered_graph->lower_info().operation.getRawPtr(op_ind);
+
+ assert(code_map.find(op_ind) == code_map.end());
+ code_map.insert(
+ {op_ind, train::TrainableCodeAndInfo{op_ind, &op, lower_info, std::move(tn_seq)}});
+ }
+ }
+
+ if (order.size() != code_map.size())
+ {
+ throw std::runtime_error("ExecutorFactory: Some kernels are not generated");
+ }
+
+ auto exec = new exec::train::TrainableExecutor{std::move(lowered_graph),
+ std::move(tbackend_contexts),
+ tensor_regs,
+ std::move(code_map),
+ order,
+ tracing_ctx};
+
+ if (!options->trace_filepath.empty())
{
std::unique_ptr<exec::IExecutionObserver> ctp =
- std::make_unique<exec::TracingObserver>(options.trace_filepath, exec->graph(), tracing_ctx);
+ std::make_unique<exec::TracingObserver>(options->trace_filepath, exec->graph(), tracing_ctx);
exec->addObserver(std::move(ctp));
}
+ // TODO Support MINMAX_H5DUMPER
return exec;
}
+#endif // ONERT_TRAIN
} // namespace compiler
} // namespace onert
diff --git a/runtime/onert/core/src/compiler/ExecutorFactory.h b/runtime/onert/core/src/compiler/ExecutorFactory.h
index f8f989043..cc621bccf 100644
--- a/runtime/onert/core/src/compiler/ExecutorFactory.h
+++ b/runtime/onert/core/src/compiler/ExecutorFactory.h
@@ -20,7 +20,15 @@
#include "TensorRegistries.h"
#include "backend/ITensor.h"
+
+#ifdef ONERT_TRAIN
+#include "backend/train/TrainableBackendContext.h"
+#endif // ONERT_TRAIN
#include "compiler/LoweredGraph.h"
+#ifdef ONERT_TRAIN
+#include "compiler/train/LoweredTrainableGraph.h"
+#include "exec/train/optimizer/Optimizer.h"
+#endif // ONERT_TRAIN
#include "exec/IExecutors.h"
#include <deque>
@@ -31,6 +39,15 @@ namespace onert
namespace compiler
{
+// TODO Change to a better name
+struct ExecutorFactoryArgs
+{
+ const util::TracingCtx *tracing_ctx;
+ const compiler::CompilerOptions *options;
+ ir::ModelIndex model_index;
+ std::shared_ptr<backend::custom::IKernelBuilder> custom_kernel_builder;
+};
+
class ExecutorFactory
{
public:
@@ -38,16 +55,22 @@ public:
public:
exec::IExecutor *create(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
- const util::TracingCtx *tracing_ctx,
- const compiler::CompilerOptions &options,
const std::shared_ptr<exec::IExecutors> &executors,
- const ir::ModelIndex &index);
+ const ExecutorFactoryArgs &args);
+
+#ifdef ONERT_TRAIN
+ // TODO Unify create()
+ exec::IExecutor *create(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args,
+ const std::shared_ptr<exec::train::optimizer::Optimizer> &optimizer);
+#endif // ONERT_TRAIN
private:
ExecutorFactory();
private:
- static void prepareMigrantTensors(compiler::LoweredGraph &lowered_graph,
+ static void prepareMigrantTensors(compiler::ILoweredGraph &lowered_graph,
const backend::BackendContexts &backend_contexts);
static void prepareBuiltinBackend(const TensorRegistries &tensor_regs,
const std::shared_ptr<exec::IExecutors> &executors,
@@ -56,22 +79,31 @@ private:
static std::deque<std::pair<const backend::Backend *, backend::BackendContext *>>
orderBackendContext(const backend::BackendContexts &backend_contexts);
- static exec::IExecutor *createLinearExecutor(
- std::unique_ptr<compiler::LoweredGraph> lowered_graph, const util::TracingCtx *tracing_ctx,
- const compiler::CompilerOptions &options, const std::shared_ptr<exec::IExecutors> &executors,
- const ir::ModelIndex &index);
- static exec::IExecutor *createDataflowExecutor(
- std::unique_ptr<compiler::LoweredGraph> lowered_graph, const util::TracingCtx *tracing_ctx,
- const compiler::CompilerOptions &options, const std::shared_ptr<exec::IExecutors> &executors,
- const ir::ModelIndex &index, bool parallel);
+ static exec::IExecutor *
+ createLinearExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args);
+ static exec::IExecutor *
+ createDataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args, bool parallel);
+#ifdef ONERT_TRAIN
+ // TODO Unify prepareMigrantTensors
+ static void
+ prepareMigrantTensors(compiler::ILoweredGraph &lowered_graph,
+ const backend::train::TrainableBackendContexts &backend_contexts);
+ static exec::IExecutor *
+ createTrainableExecutor(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args,
+ const std::shared_ptr<exec::train::optimizer::Optimizer> &optimizer);
+#endif // ONERT_TRAIN
private:
std::unordered_map<
- std::string,
- std::function<exec::IExecutor *(
- std::unique_ptr<compiler::LoweredGraph>, const util::TracingCtx *tracing_ctx,
- const compiler::CompilerOptions &options, const std::shared_ptr<exec::IExecutors> &executors,
- const ir::ModelIndex &index)>>
+ std::string, std::function<exec::IExecutor *(std::unique_ptr<compiler::LoweredGraph>,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args)>>
_map;
};
diff --git a/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc b/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc
index fdf4e24f0..ce9b09c2d 100644
--- a/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc
+++ b/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc
@@ -776,7 +776,7 @@ Fp32ToFp16Converter::InputToOpSeqs Fp32ToFp16Converter::prepareInputToOpSeqs() c
InputToOpSeqs input_to_op_seqs;
op_seqs.iterate([&](const ir::OpSequenceIndex &op_seq_idx, const ir::OpSequence &op_seq) {
- for (auto input : op_seq.getInputs() | ir::Remove::UNDEFINED)
+ for (auto &&input : op_seq.getInputs() | ir::Remove::UNDEFINED)
{
auto it = input_to_op_seqs.find(input);
if (it == input_to_op_seqs.end())
@@ -862,7 +862,7 @@ void Fp32ToFp16Converter::manipulateContiguousOpSequences(
// |
// [OPERATION] // op_seq_ind_next_to_fp16
//
- for (auto it : opseq_map_to_delete)
+ for (auto &&it : opseq_map_to_delete)
{
// fp16_to_fp32's input/output num is always 1
auto &op_seq_ind_fp16_to_fp32 = it.first;
diff --git a/runtime/onert/core/src/compiler/HEScheduler.cc b/runtime/onert/core/src/compiler/HEScheduler.cc
index 65fd4cd77..f662ef5b9 100644
--- a/runtime/onert/core/src/compiler/HEScheduler.cc
+++ b/runtime/onert/core/src/compiler/HEScheduler.cc
@@ -28,7 +28,7 @@ namespace
using namespace onert;
-uint32_t getOperationsFlattenedIOSize(const ir::Graph &graph, const ir::Operation &node)
+uint32_t getOperationsFlattenedIOSize(const ir::Graph &graph, const ir::IOperation &node)
{
uint32_t size = 0;
for (const auto &ind :
@@ -39,7 +39,7 @@ uint32_t getOperationsFlattenedIOSize(const ir::Graph &graph, const ir::Operatio
return size;
}
-bool isQuant(const ir::Graph &graph, const ir::Operation &node)
+bool isQuant(const ir::Graph &graph, const ir::IOperation &node)
{
for (const auto &input : node.getInputs() | ir::Remove::UNDEFINED)
{
@@ -52,14 +52,14 @@ bool isQuant(const ir::Graph &graph, const ir::Operation &node)
return false;
}
-bool isWorkaroundSkip(const ir::Graph &, const backend::Backend *, const ir::Operation &, bool)
+bool isWorkaroundSkip(const ir::Graph &, const backend::Backend *, const ir::IOperation &, bool)
{
// Now, there is no workaround
return false;
}
// if a node can be merged into op_seq
-bool isMergeable(const ir::Graph &graph, const ir::Operation &node)
+bool isMergeable(const ir::Graph &graph, const ir::IOperation &node)
{
size_t prev_op_cnt = 0;
for (const auto &input : node.getInputs() | ir::Remove::UNDEFINED)
@@ -137,7 +137,7 @@ void HEScheduler::scheduleShufflingBackends()
}
}
-bool HEScheduler::isNodeProfiled(const ir::Operation &node)
+bool HEScheduler::isNodeProfiled(const ir::IOperation &node)
{
const bool quant = isQuant(*_graph, node);
const auto size = getOperationsFlattenedIOSize(*_graph, node);
@@ -207,7 +207,7 @@ std::unique_ptr<compiler::BackendResolver> HEScheduler::schedule(const ir::Graph
{
// Check if profiling info about all backend/node pairs already exists
bool all_nodes_are_profiled = true;
- _graph->operations().iterate([&](const ir::OperationIndex &, const ir::Operation &op) {
+ _graph->operations().iterate([&](const ir::OperationIndex &, const ir::IOperation &op) {
if (all_nodes_are_profiled)
all_nodes_are_profiled = isNodeProfiled(op);
});
@@ -224,7 +224,7 @@ std::unique_ptr<compiler::BackendResolver> HEScheduler::schedule(const ir::Graph
ir::OperationIndexMap<bool> visited;
graph.operations().iterate(
- [&](const ir::OperationIndex &index, const ir::Operation &) { visited[index] = false; });
+ [&](const ir::OperationIndex &index, const ir::IOperation &) { visited[index] = false; });
// for each task select the backend with the smallest earliest finishing time(eft)
for (const auto &rank : _rank_to_op)
{
@@ -258,7 +258,7 @@ int64_t HEScheduler::getPermuteTime(const backend::Backend *src_backend,
return size / 400;
}
-int64_t HEScheduler::tryBackend(const ir::Operation &node, const backend::Backend *backend)
+int64_t HEScheduler::tryBackend(const ir::IOperation &node, const backend::Backend *backend)
{
// if there is no profiling info don't use this backend during scheduling
if (!_is_profiling_mode)
@@ -297,10 +297,10 @@ void HEScheduler::makeRank()
VERBOSE(HEScheduler::makeRank) << "task prioritizing" << std::endl;
_graph->operations().iterate(
- [&](const ir::OperationIndex &index, const ir::Operation &) { DFSMaxRank(index); });
+ [&](const ir::OperationIndex &index, const ir::IOperation &) { DFSMaxRank(index); });
// Check that ranks are calculated for all operations(nodes)
- _graph->operations().iterate([&](const ir::OperationIndex &index, const ir::Operation &) {
+ _graph->operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &) {
UNUSED_RELEASE(index);
assert(_op_to_rank->find(index) != _op_to_rank->end());
});
@@ -564,7 +564,7 @@ HEScheduler::ESTAndExecTime(const backend::Backend *backend, const ir::Operation
return {prev_op_ft, exec_time};
}
-int64_t HEScheduler::predMaxEFT(const backend::Backend *backend, const ir::Operation &node,
+int64_t HEScheduler::predMaxEFT(const backend::Backend *backend, const ir::IOperation &node,
std::multimap<int64_t, int64_t> &transfer_st_exec_time)
{
int64_t max_pred_eft = 0;
diff --git a/runtime/onert/core/src/compiler/HEScheduler.h b/runtime/onert/core/src/compiler/HEScheduler.h
index 18ea388fd..df6c07926 100644
--- a/runtime/onert/core/src/compiler/HEScheduler.h
+++ b/runtime/onert/core/src/compiler/HEScheduler.h
@@ -58,7 +58,7 @@ public:
_is_profiling_mode{options.he_profiling_mode}, _is_linear_exec{options.executor == "Linear"},
_is_parallel_exec{options.executor == "Parallel"}
{
- for (auto entry : backends)
+ for (auto &&entry : backends)
{
if (entry->config()->id() == backend::builtin::Config::ID)
continue;
@@ -88,7 +88,7 @@ public:
std::shared_ptr<ir::OperationIndexMap<int64_t>> getIndexedRanks() { return _op_to_rank; }
private:
- bool isNodeProfiled(const ir::Operation &);
+ bool isNodeProfiled(const ir::IOperation &);
bool schedule(const ir::OperationIndex &, const backend::Backend *parent_backend);
/**
@@ -115,7 +115,7 @@ private:
*
* @return earliest finishing time of parent nodes
*/
- int64_t predMaxEFT(const backend::Backend *backend, const ir::Operation &node,
+ int64_t predMaxEFT(const backend::Backend *backend, const ir::IOperation &node,
std::multimap<int64_t, int64_t> &transfer_st_exec_time);
void makeRank();
@@ -146,7 +146,7 @@ private:
void scheduleShufflingBackends();
- int64_t tryBackend(const ir::Operation &node, const backend::Backend *backend);
+ int64_t tryBackend(const ir::IOperation &node, const backend::Backend *backend);
/**
* @brief Schedule a node and its successor until:
diff --git a/runtime/onert/core/src/compiler/HEScheduler.test.cc b/runtime/onert/core/src/compiler/HEScheduler.test.cc
index 589331b49..1654bfc8b 100644
--- a/runtime/onert/core/src/compiler/HEScheduler.test.cc
+++ b/runtime/onert/core/src/compiler/HEScheduler.test.cc
@@ -43,7 +43,7 @@ struct MockConfigCPU : public IConfig
std::string id() override { return "cpu"; }
bool initialize() override { return true; };
bool supportPermutation() override { return false; }
- Layout supportLayout(const Operation &, Layout) override { return Layout::UNKNOWN; }
+ Layout supportLayout(const IOperation &, Layout) override { return Layout::UNKNOWN; }
bool supportDynamicTensor() override { return false; }
bool supportFP16() override { return false; }
};
@@ -70,7 +70,7 @@ struct MockConfigGPU : public IConfig
std::string id() override { return "gpu"; }
bool initialize() override { return true; };
bool supportPermutation() override { return false; }
- ir::Layout supportLayout(const ir::Operation &, ir::Layout) override
+ ir::Layout supportLayout(const ir::IOperation &, ir::Layout) override
{
return ir::Layout::UNKNOWN;
}
@@ -92,7 +92,7 @@ struct MockConfigNPU : public IConfig
std::string id() override { return "npu"; }
bool initialize() override { return true; };
bool supportPermutation() override { return false; }
- ir::Layout supportLayout(const ir::Operation &, ir::Layout) override
+ ir::Layout supportLayout(const ir::IOperation &, ir::Layout) override
{
return ir::Layout::UNKNOWN;
}
diff --git a/runtime/onert/core/src/compiler/Linear.cc b/runtime/onert/core/src/compiler/Linear.cc
index f85b8d1bd..4dbe229c8 100644
--- a/runtime/onert/core/src/compiler/Linear.cc
+++ b/runtime/onert/core/src/compiler/Linear.cc
@@ -28,16 +28,16 @@ namespace compiler
{
// TODO(easy) Change the LoweredGraph param to Graph
-std::vector<ir::OperationIndex> Linear::linearize(const compiler::LoweredGraph &lowered_graph)
+std::vector<ir::OperationIndex> Linear::linearize(const compiler::ILoweredGraph &lowered_graph)
{
return lowered_graph.graph().topolSortOperations();
}
// TODO(easy) Change the LoweredGraph param to Graph
-void Linear::dump(const compiler::LoweredGraph &lowered_graph,
+void Linear::dump(const compiler::ILoweredGraph &lowered_graph,
const std::vector<ir::OperationIndex> &order)
{
- for (const auto ind : order)
+ for (const auto &ind : order)
{
// TODO Could logging system can handle this? (Inserting prefix for each line)
std::istringstream iss{dumper::text::formatOperation(lowered_graph.graph(), ind)};
diff --git a/runtime/onert/core/src/compiler/Linear.h b/runtime/onert/core/src/compiler/Linear.h
index 9ac9a0139..4f92dc88d 100644
--- a/runtime/onert/core/src/compiler/Linear.h
+++ b/runtime/onert/core/src/compiler/Linear.h
@@ -21,7 +21,7 @@
#include <memory>
#include "ir/Index.h"
-#include "compiler/LoweredGraph.h"
+#include "compiler/ILoweredGraph.h"
namespace onert
{
@@ -31,8 +31,8 @@ namespace compiler
class Linear
{
public:
- static std::vector<ir::OperationIndex> linearize(const compiler::LoweredGraph &lowered_graph);
- static void dump(const compiler::LoweredGraph &lowered_graph,
+ static std::vector<ir::OperationIndex> linearize(const compiler::ILoweredGraph &lowered_graph);
+ static void dump(const compiler::ILoweredGraph &lowered_graph,
const std::vector<ir::OperationIndex> &order);
};
diff --git a/runtime/onert/core/src/compiler/LoweredGraph.cc b/runtime/onert/core/src/compiler/LoweredGraph.cc
index d53d0ed00..46a45e44a 100644
--- a/runtime/onert/core/src/compiler/LoweredGraph.cc
+++ b/runtime/onert/core/src/compiler/LoweredGraph.cc
@@ -49,7 +49,7 @@ void LoweredGraph::lowerGraph(const CompilerOptions &options)
// Build backend contexts
auto &backend_manager = BackendManager::get();
// Create contexts for other backends
- for (auto backend_str : options.backend_list)
+ for (auto &&backend_str : options.backend_list)
{
backend_manager.loadBackend(backend_str);
auto backend = backend_manager.get(backend_str);
@@ -100,9 +100,9 @@ void LoweredGraph::lowerGraph(const CompilerOptions &options)
pass::PassRunner{}.append(std::make_unique<pass::PermutationEliminationPass>(*this)).run();
VERBOSE(LoweredGraph) << "Dump after all the passes" << std::endl;
- for (auto operand : _graph.getInputs())
+ for (auto &&operand : _graph.getInputs())
VERBOSE(LoweredGraph) << "Graph Input : " << operand << std::endl;
- for (auto operand : _graph.getOutputs())
+ for (auto &&operand : _graph.getOutputs())
VERBOSE(LoweredGraph) << "Graph Output : " << operand << std::endl;
dumper::text::dumpLoweredGraph(*this);
@@ -121,8 +121,8 @@ void LoweredGraph::makeLowerInfo(const compiler::BackendResolver &backend_resolv
});
// Set operand lower info using assigned backends to operations
- _graph.operations().iterate([&](const ir::OperationIndex &op_ind, const ir::Operation &) {
- const ir::Operation &op = _graph.operations().at(op_ind);
+ _graph.operations().iterate([&](const ir::OperationIndex &op_ind, const ir::IOperation &) {
+ const ir::IOperation &op = _graph.operations().at(op_ind);
auto backend = backend_resolver.getBackend(op_ind);
if (!backend)
{
@@ -135,12 +135,12 @@ void LoweredGraph::makeLowerInfo(const compiler::BackendResolver &backend_resolv
// TODO Change setting layout of each backend at another place
auto backend_layout = backend->config()->supportLayout(op, frontend_layout);
- for (auto ind : op.getInputs() | ir::Remove::UNDEFINED)
+ for (auto &&ind : op.getInputs() | ir::Remove::UNDEFINED)
{
auto &operand_li = lower_info().operand.at(ind);
operand_li.addUsePermuteFactor(PermuteFactor{backend, backend_layout});
}
- for (auto ind : op.getOutputs() | ir::Remove::UNDEFINED)
+ for (auto &&ind : op.getOutputs() | ir::Remove::UNDEFINED)
{
auto &operand_li = lower_info().operand.at(ind);
operand_li.addDefPermuteFactor(PermuteFactor{backend, backend_layout});
@@ -152,13 +152,13 @@ void LoweredGraph::makeLowerInfo(const compiler::BackendResolver &backend_resolv
// Handle graph inputs and outputs
const auto builtin_backend = BackendManager::get().getBuiltin();
auto factor = PermuteFactor{builtin_backend, _graph.layout()};
- for (auto index : _graph.getInputs() | ir::Remove::UNDEFINED)
+ for (auto &&index : _graph.getInputs() | ir::Remove::UNDEFINED)
{
auto &operand_li = lower_info().operand.at(index);
assert(operand_li.def_factors().empty());
operand_li.addDefPermuteFactor(factor);
}
- for (auto index : _graph.getOutputs() | ir::Remove::UNDEFINED)
+ for (auto &&index : _graph.getOutputs() | ir::Remove::UNDEFINED)
{
auto &operand_li = lower_info().operand.at(index);
operand_li.addUsePermuteFactor(factor);
@@ -204,7 +204,7 @@ void LoweredGraph::dumpLowerInfo()
auto factors_to_string = [](const PermuteFactorSet &factors) {
std::string str;
- for (auto factor : factors)
+ for (auto &&factor : factors)
{
str += factor.backend()->config()->id();
str += "(" + to_string(factor.layout()) + ")";
@@ -216,7 +216,7 @@ void LoweredGraph::dumpLowerInfo()
auto operation_index_set_to_string = [](const ir::OperationIndexSet &operations) {
std::stringstream sstream;
sstream << "{ ";
- for (auto op : operations)
+ for (auto &&op : operations)
sstream << op << " ";
sstream << "}";
return sstream.str();
diff --git a/runtime/onert/core/src/compiler/ManualScheduler.cc b/runtime/onert/core/src/compiler/ManualScheduler.cc
index 621f0c7b7..ccd08893f 100644
--- a/runtime/onert/core/src/compiler/ManualScheduler.cc
+++ b/runtime/onert/core/src/compiler/ManualScheduler.cc
@@ -42,7 +42,7 @@ std::unique_ptr<BackendResolver> ManualScheduler::schedule(const ir::Graph &grap
// This fallback will be used in case that `backend_for_all` is unavailable
auto fallback = [&]() -> const backend::Backend * {
- for (auto backend_id : _options.backend_list)
+ for (auto &&backend_id : _options.backend_list)
{
auto backend = resolveBackend(backend_id);
if (backend)
@@ -58,7 +58,7 @@ std::unique_ptr<BackendResolver> ManualScheduler::schedule(const ir::Graph &grap
VERBOSE(ManualScheduler) << "Default backend for all ops: " << backend_all->config()->id()
<< std::endl;
- graph.operations().iterate([&](const ir::OperationIndex &index, const ir::Operation &) {
+ graph.operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &) {
backend_resolver->setBackend(index, backend_all);
});
@@ -71,7 +71,7 @@ std::unique_ptr<BackendResolver> ManualScheduler::schedule(const ir::Graph &grap
// By default, Custom uses cpu backend
op_type_map[ir::OpCode::Custom] = BackendManager::get().get("cpu");
- graph.operations().iterate([&](const ir::OperationIndex &index, const ir::Operation &operation) {
+ graph.operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &operation) {
auto itr = op_type_map.find(operation.opcode());
if (itr != op_type_map.end())
{
diff --git a/runtime/onert/core/src/compiler/MultiModelCompiler.cc b/runtime/onert/core/src/compiler/MultiModelCompiler.cc
index fea6a7f25..141fdfe09 100644
--- a/runtime/onert/core/src/compiler/MultiModelCompiler.cc
+++ b/runtime/onert/core/src/compiler/MultiModelCompiler.cc
@@ -16,6 +16,7 @@
#include "MultiModelCompiler.h"
+#include "CompilerHelpers.h"
#include "ExecutorFactory.h"
#include "ShapeValidator.h"
#include "pass/ConstantOutputPass.h"
@@ -30,6 +31,7 @@
#include "compiler/StaticShapeInferer.h"
#include <misc/string_helpers.h>
+#include <misc/polymorphic_downcast.h>
namespace onert
{
@@ -53,7 +55,7 @@ std::shared_ptr<CompilerArtifact> MultiModelCompiler::compile(void)
/***************************************************
* Prepare compilation phase
***************************************************/
- for (auto options : _voptions)
+ for (auto &&options : _voptions)
{
if (!options)
throw std::runtime_error{"Empty compile option"};
@@ -63,6 +65,9 @@ std::shared_ptr<CompilerArtifact> MultiModelCompiler::compile(void)
if (options->he_profiling_mode)
throw std::runtime_error("NYI: Profiling mode for multiple model is not supported yet");
+ if (!options->minmax_filepath.empty())
+ throw std::runtime_error("Recording minmax is not supported for multiple models");
+
options->forceInternalOptions();
options->verboseOptions();
}
@@ -74,7 +79,15 @@ std::shared_ptr<CompilerArtifact> MultiModelCompiler::compile(void)
for (uint16_t i = 0; i < model_count; i++)
{
- _nnpkg->model(ir::ModelIndex{i})->iterate([&](const ir::SubgraphIndex &, ir::Graph &subg) {
+ if (!_nnpkg->model(ir::ModelIndex{i})->hasOnly<ir::Graph>())
+ throw std::runtime_error("MultiModelCompiler can only compile models for inference.");
+ }
+
+ for (uint16_t i = 0; i < model_count; i++)
+ {
+ _nnpkg->model(ir::ModelIndex{i})->iterate([&](const ir::SubgraphIndex &, ir::IGraph &graph) {
+ auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph);
+
// Mandatory passes
pass::PassRunner{}
.append(std::make_unique<pass::ConstantOutputPass>(subg))
@@ -100,6 +113,15 @@ std::shared_ptr<CompilerArtifact> MultiModelCompiler::compile(void)
// Model edge context: copy model edge context
auto model_edges = std::make_unique<ir::ModelEdges>(_nnpkg->model_edges());
+ // Custom kernels
+ std::unordered_map<ir::ModelIndex, std::shared_ptr<backend::custom::IKernelBuilder>>
+ custom_kernel_builders;
+ for (uint16_t i = 0; i < model_count; i++)
+ {
+ auto const model_index = ir::ModelIndex{i};
+ custom_kernel_builders[model_index] = _nnpkg->model(model_index)->getKernelBuilder();
+ }
+
// Lower: Assign backend
std::unordered_map<ir::ModelIndex,
std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::LoweredGraph>>>
@@ -110,7 +132,9 @@ std::shared_ptr<CompilerArtifact> MultiModelCompiler::compile(void)
auto const model_index = ir::ModelIndex{i};
auto model = _nnpkg->model(model_index);
- model->iterate([&](const ir::SubgraphIndex &subg_index, ir::Graph &subg) {
+ model->iterate([&](const ir::SubgraphIndex &subg_index, ir::IGraph &graph) {
+ auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph);
+
dot_dumper.dump(subg,
nnfw::misc::str("before_lower_model-", i, "-subg-", subg_index.value()));
// Lower: Assign backend
@@ -146,7 +170,7 @@ std::shared_ptr<CompilerArtifact> MultiModelCompiler::compile(void)
// Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called
// recursively
std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers =
- StaticShapeInferer::createStaticShapeInferers(model_lsubgs);
+ createStaticShapeInferers(model_lsubgs);
const auto primary_subg_idx = ir::SubgraphIndex{0};
inferers.at(primary_subg_idx)->infer();
@@ -194,11 +218,15 @@ std::shared_ptr<CompilerArtifact> MultiModelCompiler::compile(void)
ir::OperationDumper dumper("Executor generation of Subgraph " +
std::to_string(subg_index.value()));
lowered_subg->graph().operations().iterate(
- [&](const ir::OperationIndex &, const ir::Operation &op) { op.accept(dumper); });
-
- auto &options = *_voptions[model_index.value()];
- auto executor = std::unique_ptr<exec::IExecutor>{ExecutorFactory::get().create(
- std::move(lowered_subg), tracing_ctx.get(), options, executors, model_index)};
+ [&](const ir::OperationIndex &, const ir::IOperation &op) { op.accept(dumper); });
+
+ ExecutorFactoryArgs args;
+ args.tracing_ctx = tracing_ctx.get();
+ args.options = _voptions[model_index.value()];
+ args.model_index = model_index;
+ args.custom_kernel_builder = custom_kernel_builders[model_index];
+ auto executor = std::unique_ptr<exec::IExecutor>{
+ ExecutorFactory::get().create(std::move(lowered_subg), executors, args)};
executor->setIndexedRanks(indexed_ranks);
executors->emplace(model_index, subg_index, std::move(executor));
}
diff --git a/runtime/onert/core/src/compiler/MultiModelCompiler.h b/runtime/onert/core/src/compiler/MultiModelCompiler.h
index 89af664f8..b282a5087 100644
--- a/runtime/onert/core/src/compiler/MultiModelCompiler.h
+++ b/runtime/onert/core/src/compiler/MultiModelCompiler.h
@@ -59,12 +59,6 @@ public:
std::shared_ptr<CompilerArtifact> compile(void);
private:
- std::shared_ptr<ir::Graph> &primary_subgraph()
- {
- return _nnpkg->primary_model()->at(ir::SubgraphIndex{0});
- }
-
-private:
std::shared_ptr<ir::NNPkg> _nnpkg;
std::vector<CompilerOptions *> _voptions;
};
diff --git a/runtime/onert/core/src/compiler/ShapeValidator.cc b/runtime/onert/core/src/compiler/ShapeValidator.cc
index 8c6421744..3e940f037 100644
--- a/runtime/onert/core/src/compiler/ShapeValidator.cc
+++ b/runtime/onert/core/src/compiler/ShapeValidator.cc
@@ -52,7 +52,7 @@ void ShapeValidator::checkUnaryOp(const ir::Operation &node)
void ShapeValidator::operator()()
{
_graph.operations().iterate(
- [&](const ir::OperationIndex &, const ir::Operation &node) { node.accept(*this); });
+ [&](const ir::OperationIndex &, const ir::IOperation &node) { node.accept(*this); });
}
void ShapeValidator::visit(const ir::operation::BatchMatMul &node)
diff --git a/runtime/onert/core/src/compiler/StaticShapeInferer.cc b/runtime/onert/core/src/compiler/StaticShapeInferer.cc
index 25747d950..a25b326f1 100644
--- a/runtime/onert/core/src/compiler/StaticShapeInferer.cc
+++ b/runtime/onert/core/src/compiler/StaticShapeInferer.cc
@@ -99,10 +99,10 @@ void StaticShapeInferer::infer()
}
}
-bool StaticShapeInferer::checkDynamicInput(const ir::Operation &op)
+bool StaticShapeInferer::checkDynamicInput(const ir::IOperation &op)
{
const auto &operands = _lowered_subg->graph().operands();
- for (auto input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
+ for (auto &&input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
{
if (operands.at(input_idx).info().isDynamic())
{
@@ -113,10 +113,10 @@ bool StaticShapeInferer::checkDynamicInput(const ir::Operation &op)
return false;
}
-bool StaticShapeInferer::checkDynamicOutput(const ir::Operation &op)
+bool StaticShapeInferer::checkDynamicOutput(const ir::IOperation &op)
{
auto &operands = _lowered_subg->graph().operands();
- for (auto output_idx : op.getOutputs() | ir::Remove::UNDEFINED)
+ for (auto &&output_idx : op.getOutputs() | ir::Remove::UNDEFINED)
{
if (operands.at(output_idx).info().isDynamic())
{
@@ -126,10 +126,10 @@ bool StaticShapeInferer::checkDynamicOutput(const ir::Operation &op)
return false;
}
-void StaticShapeInferer::setDynamicOutput(const ir::Operation &op)
+void StaticShapeInferer::setDynamicOutput(const ir::IOperation &op)
{
auto &operands = _lowered_subg->graph().operands();
- for (auto output_idx : op.getOutputs() | ir::Remove::UNDEFINED)
+ for (auto &&output_idx : op.getOutputs() | ir::Remove::UNDEFINED)
{
operands.at(output_idx).info().setDynamic();
}
@@ -192,7 +192,7 @@ void StaticShapeInferer::dump()
std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>>
StaticShapeInferer::createStaticShapeInferers(
- const std::unordered_map<ir::SubgraphIndex, std::unique_ptr<LoweredGraph>> &lowered_subgs)
+ const std::unordered_map<ir::SubgraphIndex, ILoweredGraph *> &lowered_subgs)
{
// Allocate StaticShapeInferer per each subgraph
std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers;
@@ -200,7 +200,7 @@ StaticShapeInferer::createStaticShapeInferers(
{
const auto &subg_index = pair.first;
auto &lowered_subg = pair.second;
- inferers[subg_index] = std::make_unique<StaticShapeInferer>(lowered_subg.get());
+ inferers[subg_index] = std::make_unique<StaticShapeInferer>(lowered_subg);
}
// Append observers in all StaticShapeInferers
@@ -211,7 +211,7 @@ StaticShapeInferer::createStaticShapeInferers(
// TODO: Change this iteration for all to controlflow iteration
lowered_subg->graph().operations().iterate(
- [&](const ir::OperationIndex &, const ir::Operation &op) {
+ [&](const ir::OperationIndex &, const ir::IOperation &op) {
// A Function to append child inferers. These make it possible for a StaticShapeInferer to
// call StaticShapeInferes of child subgraphs recursively
auto appendChildInferer = [&](const ir::SubgraphIndex &child_subg_idx) {
@@ -251,7 +251,9 @@ StaticShapeInferer::createStaticShapeInferers(
// Append Observers in a StaticShapeInferer
if (op.opcode() == ir::OpCode::If)
{
- const auto &if_op = nnfw::misc::polymorphic_downcast<const ir::operation::If &>(op);
+ // TODO Remove dynamic_cast
+ // An virtual base class cannot be downcasted by static_cast
+ const auto &if_op = dynamic_cast<const ir::operation::If &>(op);
appendChildInferer(if_op.param().then_subg_index);
appendChildInferer(if_op.param().else_subg_index);
@@ -263,7 +265,8 @@ StaticShapeInferer::createStaticShapeInferers(
}
else if (op.opcode() == ir::OpCode::While)
{
- const auto &while_op = nnfw::misc::polymorphic_downcast<const ir::operation::While &>(op);
+ // TODO Remove dynamic_cast
+ const auto &while_op = dynamic_cast<const ir::operation::While &>(op);
appendChildInferer(while_op.param().cond_subg_index);
appendChildInferer(while_op.param().body_subg_index);
@@ -602,6 +605,13 @@ void StaticShapeInferer::visit(const ir::operation::L2Normalization &op)
handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::L2Normalization::Input::INPUT));
}
+void StaticShapeInferer::visit(const ir::operation::Loss &)
+{
+ // TODO Consider SparseCategoricalCrossentropy case
+
+ // TODO Consider output shape in case of reduction option
+}
+
void StaticShapeInferer::visit(const ir::operation::LSTM &op)
{
auto &operands = _lowered_subg->graph().operands();
@@ -1119,7 +1129,7 @@ void StaticShapeInferer::visit(const ir::operation::Split &op)
auto outputs = op.getOutputs();
if (!axis.isConstant())
{
- for (auto output_idx : outputs)
+ for (auto &&output_idx : outputs)
{
ir::Operand &output = operands.at(output_idx);
output.info().setDynamic();
@@ -1137,7 +1147,7 @@ void StaticShapeInferer::visit(const ir::operation::Split &op)
ir::Shape new_shape =
shape_inference::inferSplitShape(input.info().shape(), axis_value, num_splits);
- for (auto output_idx : outputs)
+ for (auto &&output_idx : outputs)
{
ir::Operand &output = operands.at(output_idx);
output.info().shape(new_shape);
diff --git a/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.cc b/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.cc
index 89dd303d4..a6590b13f 100644
--- a/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.cc
+++ b/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.cc
@@ -28,14 +28,14 @@ namespace compiler
namespace pass
{
-void ConstantInsertionPass::callback(const ir::OperationIndex &node_index, ir::Operation &node)
+void ConstantInsertionPass::callback(const ir::OperationIndex &node_index, ir::IOperation &node)
{
const auto op_lower_info = _lowered_graph.lower_info().operation.getRawPtr(node_index);
const auto backend = op_lower_info->backend();
const auto layout = op_lower_info->layout();
const auto factor = PermuteFactor{backend, layout};
- for (const auto input : node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
+ for (const auto &input : node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
{
auto &object = _graph.operands().at(input);
diff --git a/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.h b/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.h
index 4911ace2f..d5b9aa14e 100644
--- a/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.h
+++ b/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.h
@@ -39,7 +39,7 @@ public:
std::string id() final { return "ConstantInsertionPass"; }
public:
- void callback(const ir::OperationIndex &index, ir::Operation &node) final;
+ void callback(const ir::OperationIndex &index, ir::IOperation &node) final;
private:
struct ReplaceKey
diff --git a/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.cc b/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.cc
index 6ed154548..32e32d0ef 100644
--- a/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.cc
+++ b/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.cc
@@ -29,7 +29,7 @@ namespace compiler
namespace pass
{
-void ConstantLoweringPass::callback(const ir::OperationIndex &node_index, ir::Operation &node)
+void ConstantLoweringPass::callback(const ir::OperationIndex &node_index, ir::IOperation &node)
{
const auto op_lower_info = _lowered_graph.lower_info().operation.getRawPtr(node_index);
const auto backend = op_lower_info->backend();
@@ -37,7 +37,7 @@ void ConstantLoweringPass::callback(const ir::OperationIndex &node_index, ir::Op
const auto factor = PermuteFactor{backend, layout};
// Now this runtime does not support the node making output of operation as constant
- for (const auto input : node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
+ for (const auto &input : node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
{
auto &object = _graph.operands().at(input);
if (object.isConstant())
diff --git a/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.h b/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.h
index e17d776d1..d60a1033f 100644
--- a/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.h
+++ b/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.h
@@ -36,7 +36,7 @@ public:
std::string id() final { return "ConstantLoweringPass"; }
public:
- void callback(const ir::OperationIndex &index, ir::Operation &node) final;
+ void callback(const ir::OperationIndex &index, ir::IOperation &node) final;
};
} // namespace pass
diff --git a/runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc b/runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc
index c176f6ffb..1448de473 100644
--- a/runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc
+++ b/runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc
@@ -49,7 +49,7 @@ void ConstantOutputPass::callback(const ir::OperandIndex &ind, ir::Operand &obj)
// Make the operations that uses this operand to use the generated operand
auto orig_uses = obj.getUses();
- for (auto use : orig_uses)
+ for (auto &&use : orig_uses)
{
permute_input_obj.insertUse(use);
obj.removeUse(use);
diff --git a/runtime/onert/core/src/compiler/pass/IPass.h b/runtime/onert/core/src/compiler/pass/IPass.h
new file mode 100644
index 000000000..77f5916fd
--- /dev/null
+++ b/runtime/onert/core/src/compiler/pass/IPass.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_PASS_IPASS_H__
+#define __ONERT_COMPILER_PASS_IPASS_H__
+
+#include <string>
+
+namespace onert
+{
+namespace compiler
+{
+namespace pass
+{
+
+struct IPass
+{
+ virtual ~IPass() = default;
+
+ virtual std::string id() = 0;
+ virtual void run() = 0;
+};
+
+} // namespace pass
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_PASS_IPASS_H__
diff --git a/runtime/onert/core/src/compiler/pass/LoweredOperandPass.h b/runtime/onert/core/src/compiler/pass/LoweredOperandPass.h
index 1f1f32f6d..64831a0ac 100644
--- a/runtime/onert/core/src/compiler/pass/LoweredOperandPass.h
+++ b/runtime/onert/core/src/compiler/pass/LoweredOperandPass.h
@@ -18,7 +18,7 @@
#define __ONERT_IR_PASS_LOWERED_OPERAND_PASS_H__
#include "OperandPass.h"
-#include "compiler/LoweredGraph.h"
+#include "compiler/ILoweredGraph.h"
namespace onert
{
@@ -30,7 +30,7 @@ namespace pass
class LoweredOperandPass : public OperandPass
{
public:
- LoweredOperandPass(compiler::LoweredGraph &lowered_graph)
+ LoweredOperandPass(compiler::ILoweredGraph &lowered_graph)
: OperandPass{lowered_graph.graph()}, _lowered_graph{lowered_graph}
{
// DO NOTHING
@@ -42,7 +42,7 @@ public:
void callback(const ir::OperandIndex &i, ir::Operand &o) override = 0;
protected:
- compiler::LoweredGraph &_lowered_graph;
+ compiler::ILoweredGraph &_lowered_graph;
};
} // namespace pass
diff --git a/runtime/onert/core/src/compiler/pass/LoweredOperationPass.h b/runtime/onert/core/src/compiler/pass/LoweredOperationPass.h
index 76ee3d7ff..27ca77c91 100644
--- a/runtime/onert/core/src/compiler/pass/LoweredOperationPass.h
+++ b/runtime/onert/core/src/compiler/pass/LoweredOperationPass.h
@@ -18,7 +18,7 @@
#define __ONERT_IR_PASS_LOWERED_OPERATION_PASS_H__
#include "OperationPass.h"
-#include "compiler/LoweredGraph.h"
+#include "compiler/ILoweredGraph.h"
namespace onert
{
@@ -30,7 +30,7 @@ namespace pass
class LoweredOperationPass : public OperationPass
{
public:
- LoweredOperationPass(LoweredGraph &lowered_graph)
+ LoweredOperationPass(ILoweredGraph &lowered_graph)
: OperationPass{lowered_graph.graph()}, _lowered_graph{lowered_graph}
{
// DO NOTHING
@@ -39,10 +39,10 @@ public:
virtual ~LoweredOperationPass() = default;
std::string id() override = 0;
- void callback(const ir::OperationIndex &i, ir::Operation &o) override = 0;
+ void callback(const ir::OperationIndex &i, ir::IOperation &o) override = 0;
protected:
- LoweredGraph &_lowered_graph;
+ ILoweredGraph &_lowered_graph;
};
} // namespace pass
diff --git a/runtime/onert/core/src/compiler/pass/OperationPass.cc b/runtime/onert/core/src/compiler/pass/OperationPass.cc
index 357a8798a..bd9bcb4a4 100644
--- a/runtime/onert/core/src/compiler/pass/OperationPass.cc
+++ b/runtime/onert/core/src/compiler/pass/OperationPass.cc
@@ -17,7 +17,7 @@
#include "OperationPass.h"
#include "ir/Index.h"
-#include "ir/Operation.h"
+#include "ir/IOperation.h"
#include "ir/Graph.h"
namespace onert
@@ -30,7 +30,7 @@ namespace pass
void OperationPass::run()
{
_graph.operations().iterate(
- [&](const ir::OperationIndex &index, ir::Operation &node) { callback(index, node); });
+ [&](const ir::OperationIndex &index, ir::IOperation &node) { callback(index, node); });
}
} // namespace pass
diff --git a/runtime/onert/core/src/compiler/pass/OperationPass.h b/runtime/onert/core/src/compiler/pass/OperationPass.h
index ac4d818a2..0a00b11d1 100644
--- a/runtime/onert/core/src/compiler/pass/OperationPass.h
+++ b/runtime/onert/core/src/compiler/pass/OperationPass.h
@@ -29,7 +29,7 @@ namespace onert
{
namespace ir
{
-class Operation;
+struct IOperation;
} // namespace ir
} // namespace onert
@@ -62,7 +62,7 @@ public:
* @param index is the index of a node in graph
* @param node is the node in graph
*/
- virtual void callback(const ir::OperationIndex &index, ir::Operation &node) = 0;
+ virtual void callback(const ir::OperationIndex &index, ir::IOperation &node) = 0;
/**
* @brief Run the pass
diff --git a/runtime/onert/core/src/compiler/pass/Pass.h b/runtime/onert/core/src/compiler/pass/Pass.h
index 3016df490..b34695c97 100644
--- a/runtime/onert/core/src/compiler/pass/Pass.h
+++ b/runtime/onert/core/src/compiler/pass/Pass.h
@@ -17,6 +17,8 @@
#ifndef __ONERT_COMPILER_PASS_PASS_H__
#define __ONERT_COMPILER_PASS_PASS_H__
+#include "IPass.h"
+
#include <string>
namespace onert
@@ -34,7 +36,7 @@ namespace compiler
namespace pass
{
-class Pass
+class Pass : public IPass
{
public:
Pass(ir::Graph &graph) : _graph{graph} {}
diff --git a/runtime/onert/core/src/compiler/pass/PassRunner.cc b/runtime/onert/core/src/compiler/pass/PassRunner.cc
index 2d11be201..cd1b82bb2 100644
--- a/runtime/onert/core/src/compiler/pass/PassRunner.cc
+++ b/runtime/onert/core/src/compiler/pass/PassRunner.cc
@@ -23,7 +23,7 @@ namespace compiler
namespace pass
{
-PassRunner &PassRunner::append(std::unique_ptr<Pass> pass)
+PassRunner &PassRunner::append(std::unique_ptr<IPass> pass)
{
_passes.emplace_back(std::move(pass));
return *this;
diff --git a/runtime/onert/core/src/compiler/pass/PassRunner.h b/runtime/onert/core/src/compiler/pass/PassRunner.h
index a43c83f89..03bfbe220 100644
--- a/runtime/onert/core/src/compiler/pass/PassRunner.h
+++ b/runtime/onert/core/src/compiler/pass/PassRunner.h
@@ -21,7 +21,7 @@
#include <memory>
#include <vector>
-#include "Pass.h"
+#include "IPass.h"
#include "util/logging.h"
namespace onert
@@ -38,12 +38,12 @@ class PassRunner
{
public:
PassRunner() = default;
- PassRunner &append(std::unique_ptr<Pass> pass);
+ PassRunner &append(std::unique_ptr<IPass> pass);
void run();
private:
- std::vector<std::unique_ptr<Pass>> _passes;
+ std::vector<std::unique_ptr<IPass>> _passes;
};
} // namespace pass
diff --git a/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc b/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc
index c27ce3d09..d9452c7f9 100644
--- a/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc
+++ b/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc
@@ -16,6 +16,7 @@
#include "PermutationEliminationPass.h"
+#include "backend/Backend.h"
#include "util/logging.h"
namespace onert
@@ -25,7 +26,7 @@ namespace compiler
namespace pass
{
-void PermutationEliminationPass::callback(const ir::OperationIndex &ind, ir::Operation &node)
+void PermutationEliminationPass::callback(const ir::OperationIndex &ind, ir::IOperation &node)
{
_op_ind = ind;
node.accept(*this);
@@ -73,7 +74,7 @@ void PermutationEliminationPass::visit(const ir::operation::Permute &node)
auto &out_operand_obj = _graph.operands().at(out_operand);
assert(out_operand_obj.getDef() == _op_ind);
out_operand_obj.unsetDef();
- _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::Operation &op) {
+ _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::IOperation &op) {
if (!op.getOutputs().contains(in_operand))
return;
// Update Operation and Operand edges
@@ -87,7 +88,7 @@ void PermutationEliminationPass::visit(const ir::operation::Permute &node)
_graph.operations().remove(_op_ind);
}
- _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::Operation &op) {
+ _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::IOperation &op) {
if (!op.getInputs().contains(in_operand))
return;
op.replaceInputs(in_operand, out_operand);
@@ -106,7 +107,7 @@ void PermutationEliminationPass::visit(const ir::operation::Permute &node)
in_operand_obj.removeUse(_op_ind);
// Make operations(that use the output) use the input
- _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::Operation &op) {
+ _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::IOperation &op) {
if (!op.getInputs().contains(out_operand))
return;
op.replaceInputs(out_operand, in_operand);
diff --git a/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.h b/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.h
index 50c38c53f..18ba99804 100644
--- a/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.h
+++ b/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.h
@@ -49,7 +49,7 @@ public:
std::string id() final { return "PermutationEliminationPass"; }
public:
- void callback(const ir::OperationIndex &i, ir::Operation &n) final;
+ void callback(const ir::OperationIndex &i, ir::IOperation &n) final;
private:
void visit(const ir::operation::Permute &) final;
diff --git a/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc b/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc
index 0da1e54df..39eb803f5 100644
--- a/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc
+++ b/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc
@@ -54,13 +54,13 @@ void PermutationInsertionPass::callback(const ir::OperandIndex &index, ir::Opera
std::unordered_map<PermuteFactor, ir::OperandIndex> factor_to_index;
{
assert(operand_li->def_factors().size() == 1);
- for (auto factor : operand_li->def_factors())
+ for (auto &&factor : operand_li->def_factors())
{
factor_to_index.emplace(factor, index);
}
auto insert_set = operand_li->use_factors() - operand_li->def_factors();
- for (auto factor : insert_set)
+ for (auto &&factor : insert_set)
{
const auto permute_operation_index = insertPermute(index, factor);
permute_indexes.push_back(permute_operation_index);
@@ -75,7 +75,7 @@ void PermutationInsertionPass::callback(const ir::OperandIndex &index, ir::Opera
std::list<ir::OperationIndex> remove_list;
auto uses = object.getUses();
- for (auto use : uses)
+ for (auto &&use : uses)
{
// If permute operation, ignore it
if (std::find(permute_indexes.begin(), permute_indexes.end(), use) != permute_indexes.end())
diff --git a/runtime/onert/core/src/compiler/pass/PermutationOperationPass.cc b/runtime/onert/core/src/compiler/pass/PermutationOperationPass.cc
index f83b1ba31..f014d29d3 100644
--- a/runtime/onert/core/src/compiler/pass/PermutationOperationPass.cc
+++ b/runtime/onert/core/src/compiler/pass/PermutationOperationPass.cc
@@ -30,7 +30,7 @@ namespace pass
using namespace ir;
-void PermutationOperationPass::callback(const OperationIndex &, Operation &node)
+void PermutationOperationPass::callback(const OperationIndex &, IOperation &node)
{
node.accept(*this);
}
diff --git a/runtime/onert/core/src/compiler/pass/PermutationOperationPass.h b/runtime/onert/core/src/compiler/pass/PermutationOperationPass.h
index cea5de288..e253a77ad 100644
--- a/runtime/onert/core/src/compiler/pass/PermutationOperationPass.h
+++ b/runtime/onert/core/src/compiler/pass/PermutationOperationPass.h
@@ -36,7 +36,7 @@ public:
std::string id() final { return "PermutationOperationPass"; }
public:
- void callback(const ir::OperationIndex &i, ir::Operation &n) final;
+ void callback(const ir::OperationIndex &i, ir::IOperation &n) final;
public:
void visit(const ir::operation::BinaryArithmetic &) final;
diff --git a/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.cc b/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.cc
index 35fb575b0..162c4e7ef 100644
--- a/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.cc
+++ b/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.cc
@@ -37,15 +37,15 @@ void UnusedOperandEliminationPass::run()
{
util::Set<ir::OperandIndex> used;
- _graph.operations().iterate([&](const ir::OperationIndex &, const ir::Operation &node) {
- for (auto ind : (node.getInputs() + node.getOutputs()) | ir::Remove::UNDEFINED)
+ _graph.operations().iterate([&](const ir::OperationIndex &, const ir::IOperation &node) {
+ for (auto &&ind : (node.getInputs() + node.getOutputs()) | ir::Remove::UNDEFINED)
{
used.add(ind);
}
});
// Graph's inputs/outputs are always considered as used
- for (auto ind : (_graph.getInputs() + _graph.getOutputs()) | ir::Remove::UNDEFINED)
+ for (auto &&ind : (_graph.getInputs() + _graph.getOutputs()) | ir::Remove::UNDEFINED)
{
used.add(ind);
}
diff --git a/runtime/onert/core/src/compiler/train/LoweredTrainableGraph.cc b/runtime/onert/core/src/compiler/train/LoweredTrainableGraph.cc
new file mode 100644
index 000000000..490c648cd
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/LoweredTrainableGraph.cc
@@ -0,0 +1,285 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "compiler/train/LoweredTrainableGraph.h"
+
+#include "../ManualScheduler.h"
+#include "../pass/ConstantInsertionPass.h"
+#include "../pass/ConstantLoweringPass.h"
+#include "../pass/PassRunner.h"
+#include "../pass/PermutationEliminationPass.h"
+#include "../pass/PermutationInsertionPass.h"
+#include "../pass/PermutationOperationPass.h"
+#include "../../backend/builtin/Config.h"
+#include "../../dumper/text/GraphDumper.h"
+#include "../../ir/verifier/Verifier.h"
+#include "TrainableOperationConverter.h"
+
+#include "backend/Backend.h"
+#include "backend/train/ITrainableBackend.h"
+#include "compiler/BackendResolver.h"
+#include "util/logging.h"
+
+#include <cassert>
+#include <sstream>
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+LoweredTrainableGraph::LoweredTrainableGraph(ir::train::TrainableGraph &graph,
+ const CompilerOptions &options)
+ : _trainable_graph{graph}
+{
+ lowerGraph(options);
+}
+
+void LoweredTrainableGraph::lowerGraph(const CompilerOptions &options)
+{
+ // Build backend contexts
+ auto &backend_manager = BackendManager::get();
+ // Create contexts for other backends
+ for (auto &&backend_str : options.backend_list)
+ {
+ backend_manager.loadBackend(backend_str);
+ auto backend = backend_manager.get(backend_str);
+
+ // TODO As the default value of backend list contains "cpu", "acl_cl" and "acl_neon", and some
+ // are not available on x64 or some other platforms. So this may be a workaround for x64 and
+ // we should change it back(throw if backend is not loaded) later.
+ if (!backend)
+ {
+ VERBOSE(LoweredTrainableGraph) << "Cannot load backend - " << backend_str << std::endl;
+ continue;
+ }
+ }
+ if (backend_manager.num_backends() == 0)
+ throw std::runtime_error{"No available backends loaded."};
+
+ // TODO Move "schedule" phase out of here
+ // TODO Scheduling
+ std::unique_ptr<BackendResolver> backend_resolver;
+ auto all_backends = backend_manager.getAll();
+
+ auto scheduler = ManualScheduler(all_backends, options);
+ backend_resolver = scheduler.schedule(_trainable_graph.graph());
+
+ // Check if backends are trainable
+ _trainable_graph.operations().iterate(
+ [&](const ir::OperationIndex &op_ind, const ir::IOperation &) {
+ const auto backend = backend_resolver->getBackend(op_ind);
+
+ // TODO Remove dynamic_cast
+ if (dynamic_cast<const backend::train::ITrainableBackend *>(backend) == nullptr)
+ {
+ throw std::runtime_error(backend->config()->id() + "backend does not support training");
+ }
+ });
+
+ makeLowerInfo(*backend_resolver);
+ VERBOSE(LoweredTrainableGraph) << "dump before mandatory passes" << std::endl;
+ dumper::text::dumpLoweredGraph(*this);
+
+ // Mandatory passes - kind of legalization(?)
+ compiler::pass::PassRunner{}
+ .append(std::make_unique<compiler::pass::ConstantInsertionPass>(*this))
+ .append(std::make_unique<compiler::pass::ConstantLoweringPass>(*this))
+ .append(std::make_unique<compiler::pass::PermutationOperationPass>(*this))
+ .append(std::make_unique<compiler::pass::PermutationInsertionPass>(*this))
+ .run();
+
+ // TODO Move converting Permute op into PermutationInsertionPass
+ auto op_converter = TrainableOperationConverter{_trainable_graph, nullptr};
+ _trainable_graph.operations().iterate(
+ [&](const onert::ir::OperationIndex &index, const onert::ir::IOperation &op) {
+ if (op.opcode() == ir::OpCode::Permute)
+ {
+ auto trainable_op = op_converter(op);
+ auto gen_index = _trainable_graph.replaceOperation(index, std::move(trainable_op));
+ UNUSED_RELEASE(gen_index);
+ assert(gen_index == index);
+ }
+ });
+
+ dumpLowerInfo();
+
+ // Optimization passes (optional)
+ compiler::pass::PassRunner{}
+ .append(std::make_unique<compiler::pass::PermutationEliminationPass>(*this))
+ .run();
+
+ // TODO Update LowerInfo for training
+
+ VERBOSE(LoweredTrainableGraph) << "Dump after all the passes" << std::endl;
+ for (auto &&operand : _trainable_graph.getInputs())
+ VERBOSE(LoweredTrainableGraph) << "Graph Input : " << operand << std::endl;
+ for (auto &&operand : _trainable_graph.getOutputs())
+ VERBOSE(LoweredTrainableGraph) << "Graph Output : " << operand << std::endl;
+ dumper::text::dumpLoweredGraph(*this);
+
+ // Graph verifications
+ {
+ assert(ir::verifier::InputOutputChecker().verify(_trainable_graph.graph()));
+ assert(ir::verifier::DAGChecker().verify(_trainable_graph.graph()));
+ assert(ir::verifier::EdgeChecker().verify(_trainable_graph.graph()));
+ }
+}
+
+void LoweredTrainableGraph::makeLowerInfo(const compiler::BackendResolver &backend_resolver)
+{
+ _trainable_graph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &) {
+ lower_info().operand.set(index, std::make_unique<OperandLowerInfo>());
+ });
+
+ // Set operand lower info using assigned backends to operations
+ _trainable_graph.operations().iterate(
+ [&](const ir::OperationIndex &op_ind, const ir::IOperation &op) {
+ auto backend = backend_resolver.getBackend(op_ind);
+ if (!backend)
+ {
+ throw std::runtime_error{"Fail to find backend for " + op.name() + " operation"};
+ }
+
+ auto frontend_layout = _trainable_graph.layout();
+
+ // The layout of each backend should be set at another place
+ // TODO Change setting layout of each backend at another place
+ auto backend_layout = backend->config()->supportLayout(op, frontend_layout);
+
+ for (auto &&ind : op.getInputs() | ir::Remove::UNDEFINED)
+ {
+ auto &operand_li = lower_info().operand.at(ind);
+ operand_li.addUsePermuteFactor(PermuteFactor{backend, backend_layout});
+ }
+ for (auto &&ind : op.getOutputs() | ir::Remove::UNDEFINED)
+ {
+ auto &operand_li = lower_info().operand.at(ind);
+ operand_li.addDefPermuteFactor(PermuteFactor{backend, backend_layout});
+ }
+ lower_info().operation.set(
+ op_ind, std::make_unique<compiler::OperationLowerInfo>(backend, backend_layout));
+ });
+
+ // Handle graph inputs and outputs
+ const auto builtin_backend = BackendManager::get().getBuiltin();
+ auto factor = PermuteFactor{builtin_backend, _trainable_graph.layout()};
+ for (auto &&index : _trainable_graph.getInputs() | ir::Remove::UNDEFINED)
+ {
+ auto &operand_li = lower_info().operand.at(index);
+ assert(operand_li.def_factors().empty());
+ operand_li.addDefPermuteFactor(factor);
+ }
+ for (auto &&index : _trainable_graph.getOutputs() | ir::Remove::UNDEFINED)
+ {
+ auto &operand_li = lower_info().operand.at(index);
+ operand_li.addUsePermuteFactor(factor);
+ }
+
+ // Handle variable tensors
+ _trainable_graph.operands().iterate([&](const ir::OperandIndex &index, ir::Operand &operand) {
+ // Some inputs of an operation could be non-constant, but not existed in graph inputs/outputs
+ // and not undefined operand - these are variable tensors. For example,
+ // UnidirectionalSequenceLSTM has such inputs.
+ if (operand.info().isVariable())
+ {
+ // The variable operand with buffer is not supported yet
+ assert(operand.data() == nullptr);
+ assert(operand.getUses().size() == 1 && !operand.getDef().valid());
+ auto operand_li = lower_info().operand.at(index);
+ assert(operand_li.def_factors().empty());
+ operand_li.addDefPermuteFactor(operand_li.use_factors().getOnlyElement());
+ }
+ });
+}
+
+void LoweredTrainableGraph::dumpLowerInfo()
+{
+ if (::onert::util::logging::ctx.enabled() == false)
+ return;
+
+ std::map<uint32_t, std::string> dumps;
+
+ _trainable_graph.operands().iterate([&](const ir::OperandIndex &index, ir::Operand &object) {
+ const auto operand_lower_info = lower_info().operand.getRawPtr(index);
+ assert(operand_lower_info);
+ if (!operand_lower_info->def_factors().empty() || !operand_lower_info->use_factors().empty())
+ {
+ auto shape_to_string = [](const ir::Shape &shape) {
+ std::stringstream sstream;
+ sstream << "{ ";
+ for (auto i = 0; i < shape.rank(); ++i)
+ sstream << (shape.dim(i)) << " ";
+ sstream << "}";
+ return sstream.str();
+ };
+
+ auto factors_to_string = [](const PermuteFactorSet &factors) {
+ std::string str;
+ for (auto &&factor : factors)
+ {
+ str += factor.backend()->config()->id();
+ str += "(" + to_string(factor.layout()) + ")";
+ str += " ";
+ }
+ return "{ " + str + "}";
+ };
+
+ auto operation_index_set_to_string = [](const ir::OperationIndexSet &operations) {
+ std::stringstream sstream;
+ sstream << "{ ";
+ for (auto &&op : operations)
+ sstream << op << " ";
+ sstream << "}";
+ return sstream.str();
+ };
+
+ auto data_to_str = [](const ir::Data *data) {
+ return (data ? (std::to_string(data->size()) + " bytes") : "N/A");
+ };
+
+ std::string shape_str = shape_to_string(object.shape());
+ std::string def_op = operation_index_set_to_string({object.getDef()});
+ std::string use_ops = operation_index_set_to_string(object.getUses());
+ std::string def_factors = factors_to_string(operand_lower_info->def_factors());
+ std::string use_factors = factors_to_string(operand_lower_info->use_factors());
+ std::stringstream sstream;
+ sstream << "Operand " << index << " Info" << std::endl;
+ sstream << " - Shape : " << shape_str << std::endl;
+ sstream << " - Def/Uses : Def " << def_op << " Uses " << use_ops << std::endl;
+ sstream << " - Data : " << data_to_str(object.data()) << std::endl;
+ sstream << " - LowerInfo : Def " << def_factors << " Uses " << use_factors << std::endl;
+ dumps.emplace(index.value(), sstream.str());
+ }
+ });
+
+ for (const auto &e : dumps)
+ {
+ if (!e.second.empty())
+ {
+ std::istringstream iss(e.second);
+ std::string line;
+ while (std::getline(iss, line))
+ VERBOSE(Lower) << line << std::endl;
+ }
+ }
+}
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/train/StaticDerivativeShapeInferer.cc b/runtime/onert/core/src/compiler/train/StaticDerivativeShapeInferer.cc
new file mode 100644
index 000000000..d2153296f
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/StaticDerivativeShapeInferer.cc
@@ -0,0 +1,150 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "StaticDerivativeShapeInferer.h"
+#include "util/ShapeInference.h"
+#include "util/logging.h"
+
+#include <misc/polymorphic_downcast.h>
+
+#include <sstream>
+#include <stdexcept>
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+void StaticDerivativeShapeInferer::infer()
+{
+ // It is not determined to iterate in reverse order.
+ auto sorted_ops = _lowered_subg->graph().topolSortOperations();
+ for (auto it = sorted_ops.rbegin(); it != sorted_ops.rend(); ++it)
+ {
+ const auto op_idx = *it;
+ const auto &op = _lowered_subg->trainable_graph().operation(op_idx);
+ if (checkDynamicInput(op))
+ {
+ std::stringstream msg;
+ msg << "StaticDerivativeShapeInferer does not support dynamic shape yet, ";
+ msg << op.name() << "(op index: " << op_idx << ") has dynamic shape.";
+ throw std::runtime_error(msg.str());
+ }
+
+ checkOutput(op);
+
+ op.accept(*this);
+ }
+}
+
+void StaticDerivativeShapeInferer::dump()
+{
+ // TODO dump
+}
+
+bool StaticDerivativeShapeInferer::checkDynamicInput(const ir::IOperation &op)
+{
+ const auto &operands = _lowered_subg->graph().operands();
+ for (auto input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
+ {
+ if (operands.at(input_idx).info().isDynamic())
+ {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+void StaticDerivativeShapeInferer::checkOutput(const ir::IOperation &op)
+{
+ const auto &derivatives = _lowered_subg->trainable_graph().derivatives();
+ for (auto output_idx : op.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
+ {
+ if (!derivatives.exist(output_idx))
+ {
+ std::stringstream msg;
+ msg << "StaticDerivativeShapeInferer : Invalid output, ";
+ msg << op.name() << "'s derivative output(index: " << output_idx << ") does not exist.";
+ throw std::runtime_error(msg.str());
+ }
+ }
+}
+
+void StaticDerivativeShapeInferer::setShape(const ir::OperandIndex &index, const ir::Shape &shape)
+{
+ auto &tgraph = _lowered_subg->trainable_graph();
+
+ if (tgraph.derivatives().exist(index))
+ tgraph.changeDerivativeShape(index, shape);
+ else
+ {
+ // NOTE This code assumes the types are always the same, but I'm not sure.
+ const auto &type = tgraph.operands().at(index).typeInfo();
+ const auto new_index = tgraph.addDerivative(index, std::make_unique<ir::Operand>(shape, type));
+ assert(new_index == index);
+ UNUSED_RELEASE(new_index);
+ }
+}
+
+void StaticDerivativeShapeInferer::visit(const ir::train::operation::Conv2D &)
+{
+ // NYI
+}
+
+void StaticDerivativeShapeInferer::visit(const ir::train::operation::ElementwiseActivation &)
+{
+ // NYI
+}
+
+void StaticDerivativeShapeInferer::visit(const ir::train::operation::Loss &)
+{
+ // NYI
+}
+
+void StaticDerivativeShapeInferer::visit(const ir::train::operation::Permute &op)
+{
+ const auto &derivatives = _lowered_subg->trainable_graph().derivatives();
+
+ const auto &output_idx = op.getOutputs().at(0);
+ const auto &output = derivatives.at(output_idx);
+
+ // re-sizing input derivative shape
+ const auto &input_idx = op.getInputs().at(0);
+ const auto &new_shape = output.info().shape();
+ setShape(input_idx, new_shape);
+}
+
+void StaticDerivativeShapeInferer::visit(const ir::train::operation::Pool2D &)
+{
+ // NYI
+}
+
+void StaticDerivativeShapeInferer::visit(const ir::train::operation::Reshape &)
+{
+ // NYI
+}
+
+void StaticDerivativeShapeInferer::visit(const ir::train::operation::Softmax &)
+{
+ // NYI
+}
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/train/StaticDerivativeShapeInferer.h b/runtime/onert/core/src/compiler/train/StaticDerivativeShapeInferer.h
new file mode 100644
index 000000000..48b3172d2
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/StaticDerivativeShapeInferer.h
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_STATIC_DERIVATIVE_SHAPE_INFERER_H__
+#define __ONERT_COMPILER_STATIC_DERIVATIVE_SHAPE_INFERER_H__
+
+#include "ir/train/TrainableOperationVisitor.h"
+
+#include "compiler/train/LoweredTrainableGraph.h"
+#include "ir/Index.h"
+
+#include <memory>
+#include <unordered_map>
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+/**
+ * @brief Class to infer shape before running kernels. It does the following:
+ * - re-calculate and set output shape at compile time (before running kernels)
+ * - if calculation cannot be done at compile time, mark the outputs to be dynamic, meaning
+ * shapes of outputs will be calculated during running kernels
+ */
+class StaticDerivativeShapeInferer : public ir::train::TrainableOperationVisitor
+{
+public:
+ StaticDerivativeShapeInferer(compiler::train::LoweredTrainableGraph *lowered_subg)
+ : _lowered_subg{lowered_subg}
+ {
+ }
+
+ /**
+ * @brief Infer shape of operands belonging to ops and set the output shape.
+ * If output shape cannot be known without running op, mark it so that it can be allocated
+ * when running kernel.
+ */
+ void infer(void);
+
+ void dump();
+
+private:
+ bool checkDynamicInput(const ir::IOperation &op);
+ void checkOutput(const ir::IOperation &op);
+ void setShape(const ir::OperandIndex &index, const ir::Shape &shape);
+
+private:
+ void visit(const ir::train::operation::Conv2D &op) override;
+ void visit(const ir::train::operation::ElementwiseActivation &op) override;
+ void visit(const ir::train::operation::Loss &op) override;
+ void visit(const ir::train::operation::Permute &op) override;
+ void visit(const ir::train::operation::Pool2D &op) override;
+ void visit(const ir::train::operation::Reshape &op) override;
+ void visit(const ir::train::operation::Softmax &op) override;
+
+private:
+ compiler::train::LoweredTrainableGraph *_lowered_subg;
+};
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_STATIC_DERIVATIVE_SHAPE_INFERER_H__
diff --git a/runtime/onert/core/src/compiler/train/TensorRegistries.h b/runtime/onert/core/src/compiler/train/TensorRegistries.h
new file mode 100644
index 000000000..48eaf10a1
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/TensorRegistries.h
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__
+#define __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__
+
+#include "../../backend/builtin/Config.h"
+#include "../../backend/builtin/train/TensorRegistry.h"
+
+#include <backend/train/TrainableBackendContext.h>
+
+#include <memory>
+#include <unordered_set>
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+class TensorRegistries
+{
+public:
+ TensorRegistries() = default;
+
+ TensorRegistries(const backend::train::TrainableBackendContexts &backend_contexts,
+ bool include_builtin)
+ {
+ for (const auto &e : backend_contexts)
+ {
+ auto tensor_reg = e.second->tensor_registry();
+ if (e.first->config()->id() == backend::builtin::Config::ID)
+ {
+ _builtin_tensor_reg =
+ std::dynamic_pointer_cast<backend::builtin::train::TensorRegistry>(tensor_reg);
+ if (include_builtin)
+ _tensor_regs.insert(tensor_reg);
+ }
+ else
+ {
+ _tensor_regs.insert(tensor_reg);
+ }
+ }
+ }
+
+ std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>>::const_iterator begin() const
+ {
+ return _tensor_regs.cbegin();
+ }
+ std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>>::const_iterator end() const
+ {
+ return _tensor_regs.cend();
+ }
+
+ std::shared_ptr<backend::builtin::train::TensorRegistry> getBuiltinTensorRegistry() const
+ {
+ return _builtin_tensor_reg;
+ }
+
+ backend::ITensor *getITensor(ir::OperandIndex index) const
+ {
+ for (auto &&tensor_reg : _tensor_regs)
+ {
+ auto tensor = tensor_reg->getITensor(index);
+ if (tensor)
+ return tensor;
+ }
+ return nullptr;
+ }
+
+ backend::ITensor *getDerivativeITensor(ir::OperandIndex index) const
+ {
+ for (auto &&tensor_reg : _tensor_regs)
+ {
+ auto tensor = tensor_reg->getDerivativeITensor(index);
+ if (tensor)
+ return tensor;
+ }
+ return nullptr;
+ }
+
+private:
+ std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>> _tensor_regs;
+ std::shared_ptr<backend::builtin::train::TensorRegistry> _builtin_tensor_reg;
+};
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__
diff --git a/runtime/onert/core/src/compiler/train/TrainableOperationConverter.cc b/runtime/onert/core/src/compiler/train/TrainableOperationConverter.cc
new file mode 100644
index 000000000..d20ae9fd3
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/TrainableOperationConverter.cc
@@ -0,0 +1,86 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TrainableOperationConverter.h"
+
+#include "ir/train/Operations.Include.h"
+#include "util/Utils.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+TrainableOperationConverter::TrainableOperationConverter(
+ ir::train::TrainableGraph &tgraph, const compiler::train::TrainingInfo *training_info)
+ : UntrainableOperationConverter{tgraph}, _training_info{training_info}
+{
+ // Avoid unused-private-field error
+ UNUSED_RELEASE(_training_info);
+}
+
+void TrainableOperationConverter::visit(const ir::operation::Conv2D &node)
+{
+ _return_op = std::make_unique<ir::train::operation::Conv2D>(node);
+}
+
+void TrainableOperationConverter::visit(const ir::operation::ElementwiseActivation &node)
+{
+ if (node.param().op_type == ir::operation::ElementwiseActivation::Type::RELU)
+ {
+ _return_op = std::make_unique<ir::train::operation::ElementwiseActivation>(node);
+ }
+ else
+ {
+ UntrainableOperationConverter::visit(node);
+ }
+}
+
+void TrainableOperationConverter::visit(const ir::operation::FullyConnected &node)
+{
+ _return_op = std::make_unique<ir::train::operation::FullyConnected>(node);
+}
+
+void TrainableOperationConverter::visit(const ir::operation::Loss &node)
+{
+ _return_op = std::make_unique<ir::train::operation::Loss>(node);
+}
+
+void TrainableOperationConverter::visit(const ir::operation::Permute &node)
+{
+ _return_op = std::make_unique<ir::train::operation::Permute>(node);
+}
+
+void TrainableOperationConverter::visit(const ir::operation::Pool2D &node)
+{
+ _return_op = std::make_unique<ir::train::operation::Pool2D>(node);
+}
+
+void TrainableOperationConverter::visit(const ir::operation::Reshape &node)
+{
+ _return_op = std::make_unique<ir::train::operation::Reshape>(node);
+}
+
+void TrainableOperationConverter::visit(const ir::operation::Softmax &node)
+{
+ _return_op = std::make_unique<ir::train::operation::Softmax>(node);
+}
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/train/TrainableOperationConverter.h b/runtime/onert/core/src/compiler/train/TrainableOperationConverter.h
new file mode 100644
index 000000000..5f6fc10c3
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/TrainableOperationConverter.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_TRAIN_TRAINABLE_OPERATION_CONVERTER_H__
+#define __ONERT_COMPILER_TRAIN_TRAINABLE_OPERATION_CONVERTER_H__
+
+#include "UntrainableOperationConverter.h"
+
+#include "compiler/train/TrainingInfo.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+class TrainableOperationConverter : public UntrainableOperationConverter
+{
+public:
+ TrainableOperationConverter(ir::train::TrainableGraph &trainable_graph,
+ const compiler::train::TrainingInfo *training_info);
+
+ using UntrainableOperationConverter::operator();
+
+private:
+ void visit(const ir::operation::Conv2D &) override;
+ void visit(const ir::operation::ElementwiseActivation &) override;
+ void visit(const ir::operation::FullyConnected &) override;
+ void visit(const ir::operation::Loss &node) override;
+ void visit(const ir::operation::Permute &node) override;
+ void visit(const ir::operation::Pool2D &node) override;
+ void visit(const ir::operation::Reshape &) override;
+ void visit(const ir::operation::Softmax &) override;
+
+private:
+ const compiler::train::TrainingInfo *_training_info;
+};
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_TRAIN_TRAINABLE_OPERATION_CONVERTER_H__
diff --git a/runtime/onert/core/src/compiler/train/TrainingCompiler.cc b/runtime/onert/core/src/compiler/train/TrainingCompiler.cc
new file mode 100644
index 000000000..711af1651
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/TrainingCompiler.cc
@@ -0,0 +1,299 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TrainingCompiler.h"
+
+#include "StaticDerivativeShapeInferer.h"
+#include "TrainableOperationConverter.h"
+#include "pass/LossInsertionPass.h"
+#include "../CompilerHelpers.h"
+#include "../ExecutorFactory.h"
+#include "../pass/ConstantOutputPass.h"
+#include "../pass/OddOutputPass.h"
+#include "../pass/PassRunner.h"
+#include "../pass/UnusedOperandEliminationPass.h"
+#include "../ShapeValidator.h"
+#include "../../dumper/dot/DotDumper.h"
+#include "../../exec/train/TrainableExecutors.h"
+#include "../../ir/OperationDumper.h"
+#include "../../ir/verifier/Verifier.h"
+
+#include <compiler/StaticShapeInferer.h>
+#include <compiler/train/LoweredTrainableGraph.h>
+#include <ir/train/TrainableGraph.h>
+#include <exec/train/optimizer/SGD.h>
+
+#include <misc/polymorphic_downcast.h>
+#include <misc/string_helpers.h>
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+TrainingCompiler::TrainingCompiler(const std::shared_ptr<ir::NNPkg> &nnpkg,
+ std::vector<std::unique_ptr<CompilerOptions>> &copts,
+ const TrainingInfo &training_info)
+ : _model{nnpkg->primary_model()}, _options{copts[0].get()}, _training_info{training_info}
+{
+ if (nnpkg->model_count() > 1)
+ throw std::runtime_error("TrainingCompiler does not support multiple models yet");
+
+ if (nnpkg->primary_model()->subgraphs_count() > 1)
+ throw std::runtime_error("TrainingCompiler does not support multiple subgraphs yet");
+}
+
+std::shared_ptr<CompilerArtifact> TrainingCompiler::compile(void)
+{
+ /***************************************************
+ * Prepare compilation phase
+ ***************************************************/
+ if (!_options)
+ throw std::runtime_error{"Empty compile option"};
+
+ // Mode check
+ // TODO handle option for each model
+ if (_options->he_profiling_mode)
+ {
+ if (!_options->he_scheduler)
+ throw std::runtime_error("Heterogeneous scheduler must be enabled during profiling.");
+
+ if (_options->executor != "Dataflow")
+ throw std::runtime_error("Profiling mode works only with 'Dataflow' executor");
+ }
+
+ if (!_options->minmax_filepath.empty())
+ {
+ if (_options->executor != "Linear")
+ throw std::runtime_error("Recording minmax works only with Linear executor");
+ }
+
+ _options->forceInternalOptions();
+ _options->verboseOptions();
+
+ auto custom_kernel_builder = _model->getKernelBuilder();
+
+ _model->iterate([&](const ir::SubgraphIndex &, ir::IGraph &graph) {
+ auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph);
+ // Mandatory passes
+ compiler::pass::PassRunner{}
+ .append(std::make_unique<compiler::pass::ConstantOutputPass>(subg))
+ .append(std::make_unique<compiler::pass::OddOutputPass>(subg))
+ .run();
+
+ // Optimizations
+ compiler::pass::PassRunner{}
+ .append(std::make_unique<compiler::pass::UnusedOperandEliminationPass>(subg))
+ .run();
+ });
+
+ std::unordered_map<ir::SubgraphIndex, std::shared_ptr<ir::train::TrainableGraph>>
+ trainable_subgraphs;
+
+ if (_model->hasOnly<ir::Graph>())
+ {
+ // Create trainable subgraphs by copy and converting inference model
+ _model->iterate([&](const ir::SubgraphIndex &subg_index, const ir::IGraph &graph) {
+ const auto &subg = nnfw::misc::polymorphic_downcast<const ir::Graph &>(graph);
+ // Create TrainableGraph by copying Graph
+ auto trainable_subg = std::make_shared<ir::train::TrainableGraph>(subg);
+
+ // Convert operations to trainable operations
+ auto converter = TrainableOperationConverter{*trainable_subg, &_training_info};
+ subg.operations().iterate(
+ [&](const onert::ir::OperationIndex &op_index, const onert::ir::IOperation &op) {
+ auto trainable_op = converter(op);
+ auto gen_index = trainable_subg->replaceOperation(op_index, std::move(trainable_op));
+ UNUSED_RELEASE(gen_index);
+ assert(gen_index == op_index);
+ });
+
+ trainable_subgraphs[subg_index] = std::move(trainable_subg);
+ });
+ }
+ else
+ {
+ // TODO Support models that have TrainableGraphs
+ throw std::runtime_error("TrainingCompiler: Invalid model");
+ }
+
+ // operation
+ _model.reset();
+
+ // Apply pass for trainable subgraphs
+ for (auto &&pair : trainable_subgraphs)
+ {
+ auto trainable_subg = pair.second;
+ auto subg_index = pair.first;
+
+ compiler::pass::PassRunner{}
+ .append(std::make_unique<train::pass::LossInsertionPass>(*trainable_subg, &_training_info,
+ subg_index))
+ .run();
+ }
+
+ // Change input shape according to batch_size
+ for (auto &&pair : trainable_subgraphs)
+ {
+ auto trainable_subg = pair.second;
+
+ for (const auto &ind : trainable_subg->getInputs())
+ {
+ auto &input = trainable_subg->operands().at(ind);
+ auto new_shape = input.info().shape();
+ // TODO Consider batch size index
+ if (new_shape.dim(0) != 1)
+ throw std::runtime_error("the first dim is not 1. It is not supported yet.");
+ new_shape.dim(0) = _training_info.batchSize();
+ input.info().shape(new_shape);
+ }
+ }
+
+ /***************************************************
+ * Backend independent analysis & optimization phase
+ ***************************************************/
+ // TODO Handle dump level for each model
+ auto dump_level = static_cast<dumper::dot::DotDumper::Level>(_options->graph_dump_level);
+ onert::dumper::dot::DotDumper dot_dumper(dump_level);
+
+ // Tracing context
+ auto tracing_ctx = std::make_unique<util::TracingCtx>();
+
+ // Lower: Assign backend
+ std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::train::LoweredTrainableGraph>>
+ lowered_subgs;
+ {
+ for (auto &&pair : trainable_subgraphs)
+ {
+ auto &subg_index = pair.first;
+ auto trainable_subg = pair.second;
+
+ // Lower: Assign backend
+ lowered_subgs[subg_index] =
+ std::make_unique<compiler::train::LoweredTrainableGraph>(*trainable_subg, *_options);
+ // Set tracing_ctx for copied graph
+ if (tracing_ctx != nullptr)
+ tracing_ctx->setSubgraphIndex(&(lowered_subgs[subg_index]->graph()), subg_index.value());
+ }
+ }
+
+ for (const auto &pair : lowered_subgs)
+ {
+ const auto &subg_index = pair.first;
+ const auto &lowered_subg = pair.second;
+ dot_dumper.dump(*lowered_subg, nnfw::misc::str("after_lower_subg-", subg_index.value()));
+ }
+
+ // Set derivatives as default tensor info
+ for (const auto &pair : lowered_subgs)
+ {
+ auto lowered_subg = pair.second.get();
+ auto &tgraph = lowered_subg->trainable_graph();
+ tgraph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &obj) {
+ if (!obj.isConstant())
+ {
+ auto deriv = std::make_unique<ir::Operand>(obj);
+ const auto gen_index = tgraph.addDerivative(index, std::move(deriv));
+ assert(gen_index == index);
+ UNUSED_RELEASE(gen_index);
+ }
+ });
+ }
+
+ // Shape inference.
+ {
+ // Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called
+ // recursively
+ std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers =
+ createStaticShapeInferers(lowered_subgs);
+
+ const auto primary_subg_idx = ir::SubgraphIndex{0};
+ inferers.at(primary_subg_idx)->infer();
+
+ for (const auto &pair_inferer : inferers)
+ {
+ const auto inferer = pair_inferer.second.get();
+ inferer->dump();
+ }
+
+ // NOTE StaticDerivativeShapeInferer is allocated for each subgraph,
+ // so it does not support models that have controlflow operations yet.
+ for (auto &&pair : lowered_subgs)
+ {
+ auto &lowered_subg = pair.second;
+ auto inferer = std::make_unique<StaticDerivativeShapeInferer>(lowered_subg.get());
+ inferer->infer();
+ inferer->dump();
+ }
+ }
+
+ // Shape validation
+ for (const auto &pair : lowered_subgs)
+ {
+ auto &lowered_subg = pair.second;
+ compiler::ShapeValidator{lowered_subg->graph()}();
+ }
+
+ // TODO Validate shapes of derivative tensors
+
+ // Create optimizer
+ // TODO Set properties of optimizer
+ std::shared_ptr<exec::train::optimizer::Optimizer> optimizer;
+ const auto &optim_info = _training_info.optimizerInfo();
+ if (optim_info.optim_code == exec::train::optimizer::OptimizerCode::SGD)
+ optimizer = std::make_shared<exec::train::optimizer::SGD>(optim_info.learning_rate);
+ else
+ throw std::runtime_error("Invalid optimizer type, " +
+ exec::train::optimizer::toString(optim_info.optim_code));
+
+ /*************************************************************
+ * Backend independent analysis & optimization phase finished
+ *************************************************************/
+ auto executors = std::make_shared<exec::train::TrainableExecutors>();
+ for (auto &&pair : lowered_subgs)
+ {
+ auto const model_index = ir::ModelIndex{0};
+ auto const subg_index = pair.first;
+ auto &lowered_subg = pair.second;
+ auto const indexed_ranks = lowered_subg->indexed_ranks();
+
+ ir::OperationDumper dumper("Executor generation of Subgraph " +
+ std::to_string(subg_index.value()));
+ lowered_subg->graph().operations().iterate(
+ [&](const ir::OperationIndex &, const ir::IOperation &op) { op.accept(dumper); });
+
+ ExecutorFactoryArgs args;
+ args.tracing_ctx = tracing_ctx.get();
+ args.options = _options;
+ args.model_index = model_index;
+ args.custom_kernel_builder = custom_kernel_builder;
+ auto executor = std::unique_ptr<exec::IExecutor>{
+ ExecutorFactory::get().create(std::move(lowered_subg), executors, args, optimizer)};
+ executor->setIndexedRanks(indexed_ranks);
+ executors->emplace(model_index, subg_index, std::move(executor));
+ }
+
+ /********************************
+ * Code generation phase finished
+ ********************************/
+ return std::make_shared<CompilerArtifact>(executors, std::move(tracing_ctx));
+}
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/train/TrainingCompiler.h b/runtime/onert/core/src/compiler/train/TrainingCompiler.h
new file mode 100644
index 000000000..b93437217
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/TrainingCompiler.h
@@ -0,0 +1,83 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * @file TrainingCompiler.h
+ * @brief This file contains TrainingCompiler class to define and run compilation phase
+ */
+
+#ifndef __ONERT_COMPILER_TRAIN_TRAINING_COMPILER_H_
+#define __ONERT_COMPILER_TRAIN_TRAINING_COMPILER_H_
+
+#include "compiler/CompilerOptions.h"
+#include "compiler/ICompiler.h"
+#include "compiler/train/TrainingInfo.h"
+#include "ir/NNPkg.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+/**
+ * @brief Class to compile NN package
+ */
+class TrainingCompiler : public ICompiler
+{
+public:
+ /**
+ * @brief Construct a new TrainingCompiler object for single model
+ * @param[in] model model to compile
+ * @param[in] inference_compiler Compiler for inference
+ * @param[in] coptions Compiler Options
+ * @param[in] training_info Training information
+ */
+ explicit TrainingCompiler(const std::shared_ptr<ir::NNPkg> &nnpkg,
+ std::vector<std::unique_ptr<CompilerOptions>> &copts,
+ const TrainingInfo &training_info);
+
+ /**
+ * @brief Default Construct
+ *
+ */
+ TrainingCompiler(void) = delete;
+
+ /**
+ * @brief Destroy the TrainingCompiler object
+ */
+ ~TrainingCompiler() = default;
+
+public:
+ /**
+ * @brief Do compilation with the options
+ *
+ * @return std::shared_ptr<CompilerArtifact> Executors as a result of compilation
+ */
+ std::shared_ptr<CompilerArtifact> compile(void);
+
+private:
+ std::shared_ptr<ir::Model> _model;
+ CompilerOptions *_options;
+ const TrainingInfo _training_info;
+};
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_TRAIN_TRAINING_COMPILER_H_
diff --git a/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.cc b/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.cc
new file mode 100644
index 000000000..6a5a052b6
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.cc
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "UntrainableOperationConverter.h"
+
+#include "ir/train/operation/UntrainableOperation.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+UntrainableOperationConverter::UntrainableOperationConverter(ir::train::TrainableGraph &tgraph)
+ : _tgraph{tgraph}, _return_op{nullptr}
+{
+}
+
+std::unique_ptr<ir::train::ITrainableOperation> UntrainableOperationConverter::
+operator()(const ir::IOperation &op)
+{
+ op.accept(*this);
+
+ return std::move(_return_op);
+}
+
+#define OP(InternalName) \
+ void UntrainableOperationConverter::visit(const ir::operation::InternalName &node) \
+ { \
+ _return_op = \
+ std::make_unique<ir::train::operation::UntrainableOperation<ir::operation::InternalName>>( \
+ node); \
+ }
+#include "ir/Operations.lst"
+#undef OP
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.h b/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.h
new file mode 100644
index 000000000..e960b3831
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_TRAIN_UNTRAINABLE_OPERATION_CONVERTER_H__
+#define __ONERT_COMPILER_TRAIN_UNTRAINABLE_OPERATION_CONVERTER_H__
+
+#include "ir/Operations.Include.h"
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableGraph.h"
+
+#include <memory>
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+class UntrainableOperationConverter : public ir::OperationVisitor
+{
+public:
+ UntrainableOperationConverter(ir::train::TrainableGraph &tgraph);
+ std::unique_ptr<ir::train::ITrainableOperation> operator()(const ir::IOperation &op);
+
+#define OP(InternalName) void visit(const ir::operation::InternalName &node);
+#include "ir/Operations.lst"
+#undef OP
+
+protected:
+ ir::train::TrainableGraph &_tgraph;
+ std::unique_ptr<ir::train::ITrainableOperation> _return_op;
+};
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_TRAIN_UNTRAINABLE_OPERATION_CONVERTER_H__
diff --git a/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.cc b/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.cc
new file mode 100644
index 000000000..3e01a9739
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.cc
@@ -0,0 +1,77 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "LossInsertionPass.h"
+
+#include "ir/train/TrainableGraph.h"
+#include "ir/train/operation/Loss.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+namespace pass
+{
+
+void LossInsertionPass::run()
+{
+ const auto &loss_info = _training_info->lossInfo();
+
+ ir::operation::Loss::Param param;
+ param.op_type = loss_info.type;
+
+ if (_trainable_graph.getOutputs().size() != 1)
+ {
+ throw std::runtime_error("LossInsertionPass: Not supported multiple outputs");
+ }
+
+ // TODO Consider SparseCategoricalCrossentropy y_true shape
+ // SparseCategoricalCrossentropy loss has a different y_true shape than y_pred.
+
+ // TODO Implement Loop [0, getOutputs().size())
+ // index: a loop index
+ const auto index = 0;
+ const auto &y_pred_index = _trainable_graph.getOutputs().at(index);
+ const auto &y_pred = _trainable_graph.operands().at(y_pred_index);
+ const auto &shape = y_pred.shape();
+ const auto &type_info = y_pred.typeInfo();
+ auto y_true_index = _trainable_graph.addOperand(shape, type_info);
+ ir::OperandIndexSequence inputs{y_pred_index, y_true_index};
+
+ // TODO Consider Reduction
+ // Some types of Reduction have the same shape y_true and output.
+
+ const ir::TypeInfo float_op(ir::DataType::FLOAT32);
+ auto output_index = _trainable_graph.addOperand(ir::Shape{1}, float_op);
+ ir::OperandIndexSequence outputs{output_index};
+
+ auto loss_op = std::make_unique<ir::operation::Loss>(inputs, outputs, param);
+ auto trainable_loss_op = std::make_unique<ir::train::operation::Loss>(*loss_op);
+
+ _trainable_graph.addOperation(std::move(trainable_loss_op));
+
+ _trainable_graph.addInput(y_true_index);
+
+ // TODO Add loss as many as output size
+ _trainable_graph.addLoss(output_index, ir::IOIndex{index});
+}
+
+} // namespace pass
+} // namespace train
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.h b/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.h
new file mode 100644
index 000000000..ed4d60c96
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_TRAIN_PASS_LOSS_INSERTION_PASS_H__
+#define __ONERT_COMPILER_TRAIN_PASS_LOSS_INSERTION_PASS_H__
+
+#include "Pass.h"
+
+#include "compiler/train/TrainingInfo.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+namespace pass
+{
+
+class LossInsertionPass : public Pass
+{
+public:
+ LossInsertionPass(ir::train::TrainableGraph &trainable_graph, const TrainingInfo *training_info,
+ const ir::SubgraphIndex &subg_index)
+ : Pass{trainable_graph, training_info}, _subg_index{subg_index}
+ {
+ }
+
+public:
+ std::string id() final { return "LossInsertionPass"; }
+ void run() final;
+
+private:
+ ir::SubgraphIndex _subg_index;
+};
+
+} // namespace pass
+} // namespace train
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_TRAIN_PASS_LOSS_INSERTION_PASS_H__
diff --git a/runtime/onert/core/src/compiler/train/pass/Pass.h b/runtime/onert/core/src/compiler/train/pass/Pass.h
new file mode 100644
index 000000000..d64c06cf4
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/pass/Pass.h
@@ -0,0 +1,64 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_TRAIN_PASS_PASS_H__
+#define __ONERT_COMPILER_TRAIN_PASS_PASS_H__
+
+#include "../../pass/IPass.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+class TrainableGraph;
+} // namespace train
+} // namespace ir
+} // namespace onert
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+class TrainingInfo;
+
+namespace pass
+{
+
+class Pass : public compiler::pass::IPass
+{
+public:
+ Pass(ir::train::TrainableGraph &trainable_graph, const TrainingInfo *training_info)
+ : _trainable_graph{trainable_graph}, _training_info{training_info}
+ {
+ }
+ virtual ~Pass() = default;
+
+protected:
+ ir::train::TrainableGraph &_trainable_graph;
+ const TrainingInfo *_training_info;
+};
+
+} // namespace pass
+} // namespace train
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_TRAIN_PASS_PASS_H__
diff --git a/runtime/onert/core/src/dumper/dot/DotBuilder.cc b/runtime/onert/core/src/dumper/dot/DotBuilder.cc
index d4e4d5484..9257434fa 100644
--- a/runtime/onert/core/src/dumper/dot/DotBuilder.cc
+++ b/runtime/onert/core/src/dumper/dot/DotBuilder.cc
@@ -29,7 +29,7 @@ DotBuilder::DotBuilder() {}
void DotBuilder::update(const Node &node_info)
{
add(node_info);
- for (auto edge : node_info.out_edges())
+ for (auto &&edge : node_info.out_edges())
{
addEdge(node_info, *edge);
}
@@ -47,7 +47,7 @@ void DotBuilder::add(const Node &node)
_dot << node.id();
std::stringstream ss;
_dot << "[";
- for (auto attr : node.attributes())
+ for (auto &&attr : node.attributes())
{
_dot << attr.first << "=\"" << attr.second << "\" ";
}
diff --git a/runtime/onert/core/src/dumper/dot/DotDumper.cc b/runtime/onert/core/src/dumper/dot/DotDumper.cc
index 0bb2fa11f..ab77a6c62 100644
--- a/runtime/onert/core/src/dumper/dot/DotDumper.cc
+++ b/runtime/onert/core/src/dumper/dot/DotDumper.cc
@@ -98,10 +98,10 @@ generate_dot_operations(const ir::Graph &graph,
{
ir::OperationIndexMap<std::unique_ptr<Operation>> dot_operations;
const auto &operations = graph.operations();
- operations.iterate([&](const ir::OperationIndex &index, const ir::Operation &op) {
+ operations.iterate([&](const ir::OperationIndex &index, const ir::IOperation &op) {
auto node = std::make_unique<Operation>(index, op);
- for (auto input : op.getInputs())
+ for (auto &&input : op.getInputs())
{
using onert::dumper::dot::Operand;
@@ -113,7 +113,7 @@ generate_dot_operations(const ir::Graph &graph,
input_node->addOutEdge(node.get());
}
- for (auto output : op.getOutputs() | ir::Remove::UNDEFINED)
+ for (auto &&output : op.getOutputs() | ir::Remove::UNDEFINED)
{
using onert::dumper::dot::Operand;
auto &output_node = dot_operands.at(output);
@@ -126,7 +126,7 @@ generate_dot_operations(const ir::Graph &graph,
return dot_operations;
}
-void update_lower_info(const compiler::LoweredGraph &lowered_graph,
+void update_lower_info(const compiler::ILoweredGraph &lowered_graph,
ir::OperandIndexMap<std::unique_ptr<Operand>> *dot_operands)
{
const auto &operands = lowered_graph.graph().operands();
@@ -153,11 +153,11 @@ void update_lower_info(const compiler::LoweredGraph &lowered_graph,
});
}
-void update_lower_info(const compiler::LoweredGraph &lowered_graph,
+void update_lower_info(const compiler::ILoweredGraph &lowered_graph,
ir::OperationIndexMap<std::unique_ptr<Operation>> *dot_operations)
{
const auto &operations = lowered_graph.graph().operations();
- operations.iterate([&](const ir::OperationIndex &index, const ir::Operation &) {
+ operations.iterate([&](const ir::OperationIndex &index, const ir::IOperation &) {
const auto lower_info = lowered_graph.lower_info().operation.getRawPtr(index);
if (lower_info)
{
@@ -213,7 +213,8 @@ void DotDumper::dump(const ir::Graph &graph, const std::string &tag)
dump_to_file(dot_operands, dot_operations, tag);
}
-void DotDumper::dump(const compiler::LoweredGraph &lowered_graph, const std::string &tag)
+// TODO Support derivative tensors
+void DotDumper::dump(const compiler::ILoweredGraph &lowered_graph, const std::string &tag)
{
if (_level == Level::OFF)
{
diff --git a/runtime/onert/core/src/dumper/dot/DotDumper.h b/runtime/onert/core/src/dumper/dot/DotDumper.h
index 6249010d3..fca5f356c 100644
--- a/runtime/onert/core/src/dumper/dot/DotDumper.h
+++ b/runtime/onert/core/src/dumper/dot/DotDumper.h
@@ -15,7 +15,7 @@
*/
#include "ir/Graph.h"
-#include "compiler/LoweredGraph.h"
+#include "compiler/ILoweredGraph.h"
#ifndef __ONERT_DUMPER_DOT_DOT_DUMPER_H__
#define __ONERT_DUMPER_DOT_DOT_DUMPER_H__
@@ -57,7 +57,7 @@ public:
* @param[in] tag The name of dot file that would be created
* @return N/A
*/
- void dump(const compiler::LoweredGraph &lowered_graph, const std::string &tag);
+ void dump(const compiler::ILoweredGraph &lowered_graph, const std::string &tag);
private:
Level _level;
diff --git a/runtime/onert/core/src/dumper/dot/OperationNode.cc b/runtime/onert/core/src/dumper/dot/OperationNode.cc
index 87c5ba148..2ef08c9c6 100644
--- a/runtime/onert/core/src/dumper/dot/OperationNode.cc
+++ b/runtime/onert/core/src/dumper/dot/OperationNode.cc
@@ -31,7 +31,7 @@ namespace dot
const std::string Operation::OPERATION_SHAPE = "rect";
const std::string Operation::BG_COLOR_SCHEME = "pastel18";
-Operation::Operation(const ir::OperationIndex &index, const ir::Operation &node)
+Operation::Operation(const ir::OperationIndex &index, const ir::IOperation &node)
: Node{"operation" + std::to_string(index.value())}
{
setAttribute("label", std::to_string(index.value()) + " : " + node.name());
diff --git a/runtime/onert/core/src/dumper/dot/OperationNode.h b/runtime/onert/core/src/dumper/dot/OperationNode.h
index 74a37d3fb..d9292ad0c 100644
--- a/runtime/onert/core/src/dumper/dot/OperationNode.h
+++ b/runtime/onert/core/src/dumper/dot/OperationNode.h
@@ -25,7 +25,7 @@
#define __ONERT_DUMPER_DOT_DOT_NODE_INFO_H__
#include "Node.h"
-#include "ir/Operation.h"
+#include "ir/IOperation.h"
#include "ir/Index.h"
namespace onert
@@ -52,7 +52,7 @@ public:
* @param[in] index operation index
* @param[in] node operation object
*/
- Operation(const ir::OperationIndex &index, const ir::Operation &node);
+ Operation(const ir::OperationIndex &index, const ir::IOperation &node);
};
} // namespace dot
diff --git a/runtime/onert/core/src/dumper/h5/Dumper.cc b/runtime/onert/core/src/dumper/h5/Dumper.cc
new file mode 100644
index 000000000..5e12c2dbb
--- /dev/null
+++ b/runtime/onert/core/src/dumper/h5/Dumper.cc
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Dumper.h"
+
+#include <iostream>
+#include <sstream>
+#include <stdexcept>
+
+namespace onert
+{
+namespace dumper
+{
+namespace h5
+{
+
+Dumper::Dumper(const std::string &filepath) : _file{filepath, H5F_ACC_CREAT | H5F_ACC_RDWR} {}
+
+} // namespace h5
+} // namespace dumper
+} // namespace onert
diff --git a/runtime/onert/core/src/dumper/h5/Dumper.h b/runtime/onert/core/src/dumper/h5/Dumper.h
new file mode 100644
index 000000000..53d0e0332
--- /dev/null
+++ b/runtime/onert/core/src/dumper/h5/Dumper.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_DUMPER_H5_DUMPER_H__
+#define __ONERT_DUMPER_H5_DUMPER_H__
+
+#include "exec/MinMaxMap.h"
+
+#include <H5Cpp.h>
+#include <string>
+
+namespace onert
+{
+namespace dumper
+{
+namespace h5
+{
+
+class Dumper
+{
+public:
+ /**
+ * @brief Construct dumper
+ *
+ * @param[in] path filepath to dump
+ * @throw H5::FileIException on error during file open/create
+ */
+ Dumper(const std::string &filepath);
+
+protected:
+ H5::H5File _file;
+};
+
+} // namespace h5
+} // namespace dumper
+} // namespace onert
+
+#endif // __ONERT_DUMPER_H5_DUMPER_H__
diff --git a/runtime/onert/core/src/dumper/h5/MinMaxDumper.cc b/runtime/onert/core/src/dumper/h5/MinMaxDumper.cc
new file mode 100644
index 000000000..8a9de9f95
--- /dev/null
+++ b/runtime/onert/core/src/dumper/h5/MinMaxDumper.cc
@@ -0,0 +1,75 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MinMaxDumper.h"
+
+#include <iostream>
+#include <sstream>
+#include <stdexcept>
+
+namespace onert
+{
+namespace dumper
+{
+namespace h5
+{
+
+static const char *h5_value_grpname = "value";
+
+/*
+ * ensure grp_name exists in parent
+ */
+H5::Group ensureGroup(H5::Group parent, const char *child)
+{
+ H5::Exception::dontPrint();
+ try
+ {
+ return parent.openGroup(child);
+ }
+ catch (H5::Exception &e)
+ {
+ return parent.createGroup(child);
+ }
+}
+
+MinMaxDumper::MinMaxDumper(const std::string &filepath) : Dumper(filepath)
+{
+ auto root_grp = _file.openGroup("/");
+ ensureGroup(root_grp, h5_value_grpname);
+}
+
+void MinMaxDumper::dump(const exec::SMMinMaxMap &mmmap) const
+{
+ auto val_grp = _file.openGroup(h5_value_grpname);
+ auto num_run = val_grp.getNumObjs();
+ auto num_grp = val_grp.createGroup(std::to_string(num_run));
+ auto model_grp = ensureGroup(num_grp, "0");
+ hsize_t dims[] = {2};
+ H5::DataSpace dspace(1, dims); // rank=1, dim(0)=2, {min, max}
+ for (auto &&e : mmmap)
+ {
+ // key = {subg_idx, op_idx} = e.first
+ const auto subg_idx = e.first.first.value();
+ const auto op_idx = e.first.second.value();
+ auto subg_grp = ensureGroup(model_grp, std::to_string(subg_idx).c_str());
+ auto op_dset = subg_grp.createDataSet(std::to_string(op_idx), H5::PredType::IEEE_F32BE, dspace);
+ op_dset.write(e.second.data, H5::PredType::NATIVE_FLOAT);
+ }
+}
+
+} // namespace h5
+} // namespace dumper
+} // namespace onert
diff --git a/runtime/onert/core/src/dumper/h5/MinMaxDumper.h b/runtime/onert/core/src/dumper/h5/MinMaxDumper.h
new file mode 100644
index 000000000..1f1b27c6e
--- /dev/null
+++ b/runtime/onert/core/src/dumper/h5/MinMaxDumper.h
@@ -0,0 +1,70 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_DUMPER_H5_MINMAX_DUMPER_H__
+#define __ONERT_DUMPER_H5_MINMAX_DUMPER_H__
+
+#include "exec/MinMaxMap.h"
+#include "Dumper.h"
+
+#include <H5Cpp.h>
+#include <string>
+
+namespace onert
+{
+namespace dumper
+{
+namespace h5
+{
+
+// The hierachy of single model minmax h5 file
+//
+// GROUP /
+// GROUP value
+// └── GROUP run_idx
+// └── GROUP model_idx
+// └── GROUP subg_idx
+// └── DATASET op_idx
+// DATATYPE Float32
+// DATASPACE (2)
+// DATA { min, max }
+// GROUP name (optional, for debug)
+// └── GROUP model_idx
+// └── GROUP subg_idx
+// └── ATTRIBUTE op_idx
+// DATATYPE String
+// DATA { "model/your/op/name"}
+//
+class MinMaxDumper : private Dumper
+{
+public:
+ MinMaxDumper(const std::string &filepath);
+ /**
+ * @brief Dump minmax map
+ *
+ * @param[in] map single model minmax map
+ */
+ void dump(const exec::SMMinMaxMap &map) const;
+
+private:
+ H5::Group _val_grp;
+};
+
+} // namespace h5
+} // namespace dumper
+} // namespace onert
+
+#endif // __ONERT_DUMPER_H5_MINMAX_DUMPER_H__
diff --git a/runtime/onert/core/src/dumper/text/GraphDumper.cc b/runtime/onert/core/src/dumper/text/GraphDumper.cc
index 80cfbbc34..6bd7904aa 100644
--- a/runtime/onert/core/src/dumper/text/GraphDumper.cc
+++ b/runtime/onert/core/src/dumper/text/GraphDumper.cc
@@ -18,6 +18,9 @@
#include "ir/Graph.h"
#include "compiler/LoweredGraph.h"
+#ifdef ONERT_TRAIN
+#include "compiler/train/LoweredTrainableGraph.h"
+#endif // ONERT_TRAIN
#include "util/logging.h"
#include "misc/string_helpers.h"
@@ -34,7 +37,7 @@ namespace
std::string formatOperandIndexSequence(const ir::OperandIndexSequence &seq)
{
std::vector<std::string> strs;
- for (auto ind : seq)
+ for (auto &&ind : seq)
strs.push_back(dumper::text::formatOperandBrief(ind));
return nnfw::misc::join(strs.begin(), strs.end(), ", ");
}
@@ -56,10 +59,9 @@ std::string formatOperand(const ir::Graph &, ir::OperandIndex ind)
return ss.str();
}
-std::string formatOperation(const ir::Graph &graph, ir::OperationIndex ind)
+std::string formatOperation(const ir::IOperation &op, ir::OperationIndex ind)
{
std::stringstream ss;
- const auto &op = graph.operations().at(ind);
ss << formatOperandIndexSequence(op.getOutputs());
ss << " = ";
@@ -69,13 +71,21 @@ std::string formatOperation(const ir::Graph &graph, ir::OperationIndex ind)
return ss.str();
}
+std::string formatOperation(const ir::Graph &graph, ir::OperationIndex ind)
+{
+ std::stringstream ss;
+ const auto &op = graph.operations().at(ind);
+ return formatOperation(op, ind);
+}
+
void dumpGraph(const ir::Graph &graph)
{
VERBOSE(GraphDumper) << "{\n";
auto ops_topol = graph.topolSortOperations();
- for (auto op_ind : ops_topol)
+ for (auto &&op_ind : ops_topol)
{
- VERBOSE(GraphDumper) << " " << formatOperation(graph, op_ind) << "\n";
+ const auto &op = graph.operations().at(op_ind);
+ VERBOSE(GraphDumper) << " " << formatOperation(op, op_ind) << "\n";
}
VERBOSE(GraphDumper) << "}\n";
VERBOSE(GraphDumper) << std::endl;
@@ -87,6 +97,14 @@ void dumpLoweredGraph(const compiler::LoweredGraph &lgraph)
dumpGraph(lgraph.graph());
}
+#ifdef ONERT_TRAIN
+void dumpLoweredGraph(const compiler::train::LoweredTrainableGraph &lgraph)
+{
+ // TODO Graph dump with backend info
+ dumpGraph(lgraph.graph());
+}
+#endif // ONERT_TRAIN
+
} // namespace text
} // namespace dumper
} // namespace onert
diff --git a/runtime/onert/core/src/dumper/text/GraphDumper.h b/runtime/onert/core/src/dumper/text/GraphDumper.h
index 0501ff050..ab0061465 100644
--- a/runtime/onert/core/src/dumper/text/GraphDumper.h
+++ b/runtime/onert/core/src/dumper/text/GraphDumper.h
@@ -24,7 +24,8 @@ namespace onert
namespace ir
{
class Graph;
-}
+struct IOperation;
+} // namespace ir
} // namespace onert
namespace onert
@@ -32,7 +33,14 @@ namespace onert
namespace compiler
{
class LoweredGraph;
-}
+
+#ifdef ONERT_TRAIN
+namespace train
+{
+class LoweredTrainableGraph;
+} // namespace train
+#endif // ONERT_TRAIN
+} // namespace compiler
} // namespace onert
namespace onert
@@ -47,6 +55,9 @@ std::string formatOperand(const ir::Graph &, ir::OperandIndex ind);
std::string formatOperation(const ir::Graph &graph, ir::OperationIndex ind);
void dumpGraph(const ir::Graph &graph);
void dumpLoweredGraph(const compiler::LoweredGraph &lgraph);
+#ifdef ONERT_TRAIN
+void dumpLoweredGraph(const compiler::train::LoweredTrainableGraph &lgraph);
+#endif // ONERT_TRAIN
} // namespace text
} // namespace dumper
diff --git a/runtime/onert/core/src/exec/DataflowExecutor.cc b/runtime/onert/core/src/exec/DataflowExecutor.cc
index 8dac1219e..e0b00077f 100644
--- a/runtime/onert/core/src/exec/DataflowExecutor.cc
+++ b/runtime/onert/core/src/exec/DataflowExecutor.cc
@@ -60,7 +60,7 @@ void DataflowExecutor::emplaceToReadyJobs(const uint32_t &id)
void DataflowExecutor::notify(uint32_t finished_job_id)
{
- for (auto id : _output_info[finished_job_id])
+ for (auto &&id : _output_info[finished_job_id])
{
assert(_input_info[id] > 0);
auto count = --_input_info[id];
@@ -90,7 +90,7 @@ DataflowExecutor::DataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lower
uint32_t next_job_index = 0;
std::unordered_map<ir::OperationIndex, uint32_t> op_to_job;
const auto &operations = _lowered_graph->graph().operations();
- operations.iterate([&](const ir::OperationIndex &op_ind, const ir::Operation &) {
+ operations.iterate([&](const ir::OperationIndex &op_ind, const ir::IOperation &) {
VERBOSE(DataflowExecutor) << "Create a job " << next_job_index << " with Operation " << op_ind
<< std::endl;
_finished_jobs.emplace_back(
@@ -102,12 +102,12 @@ DataflowExecutor::DataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lower
_output_info.resize(next_job_index);
_initial_input_info.resize(next_job_index, 0);
- operations.iterate([&](const ir::OperationIndex &op_ind, const ir::Operation &op) {
+ operations.iterate([&](const ir::OperationIndex &op_ind, const ir::IOperation &op) {
auto job_index = op_to_job[op_ind];
- for (auto output : op.getOutputs())
+ for (auto &&output : op.getOutputs())
{
// Update output and input info
- operations.iterate([&](const ir::OperationIndex &op_cur_ind, const ir::Operation &op_cur) {
+ operations.iterate([&](const ir::OperationIndex &op_cur_ind, const ir::IOperation &op_cur) {
if (op_cur.getInputs().contains(output))
{
auto dep_index = op_to_job[op_cur_ind];
diff --git a/runtime/onert/core/src/exec/DynamicShapeInferer.cc b/runtime/onert/core/src/exec/DynamicShapeInferer.cc
index fb8058d23..78b21cf49 100644
--- a/runtime/onert/core/src/exec/DynamicShapeInferer.cc
+++ b/runtime/onert/core/src/exec/DynamicShapeInferer.cc
@@ -253,7 +253,7 @@ void DynamicShapeInferer::visit(const ir::operation::Concat &op)
So, only when all inputs are static, we can skip dynamic shape inference.
*/
bool all_static = true;
- for (auto input_ind : op.getInputs())
+ for (auto &&input_ind : op.getInputs())
{
auto input = _tensor_registry->getITensor(input_ind);
if (input->is_dynamic())
@@ -290,7 +290,7 @@ void DynamicShapeInferer::visit(const ir::operation::Concat &op)
auto first_input_ind = op.getInputs().at(0);
auto first_input = _tensor_registry->getITensor(first_input_ind);
- for (auto input_ind : op.getInputs())
+ for (auto &&input_ind : op.getInputs())
{
auto input = _tensor_registry->getITensor(input_ind);
if (input != first_input && !isConcatible(first_input, input, op.param().axis))
@@ -300,7 +300,7 @@ void DynamicShapeInferer::visit(const ir::operation::Concat &op)
// getting output shape
onert::shape_inference::Shapes in_shapes;
- for (auto input_ind : op.getInputs())
+ for (auto &&input_ind : op.getInputs())
{
auto input = _tensor_registry->getITensor(input_ind);
ir::Shape shape = input->getShape();
@@ -1042,7 +1042,7 @@ void DynamicShapeInferer::visit(const ir::operation::Split &op)
// Return if all tensors are not dynamic
bool has_dynamic = false;
- for (const auto output_idx : op.getOutputs())
+ for (const auto &output_idx : op.getOutputs())
{
auto output = _tensor_registry->getITensor(output_idx);
has_dynamic |= output->is_dynamic();
diff --git a/runtime/onert/core/src/exec/ExecTime.test.cc b/runtime/onert/core/src/exec/ExecTime.test.cc
index 1f7152e7b..939184e4e 100644
--- a/runtime/onert/core/src/exec/ExecTime.test.cc
+++ b/runtime/onert/core/src/exec/ExecTime.test.cc
@@ -34,7 +34,7 @@ struct MockConfig : public IConfig
std::string id() override { return "b1"; }
bool initialize() override { return true; };
bool supportPermutation() override { return false; }
- ir::Layout supportLayout(const ir::Operation &, ir::Layout) override
+ ir::Layout supportLayout(const ir::IOperation &, ir::Layout) override
{
return ir::Layout::UNKNOWN;
}
diff --git a/runtime/onert/core/src/exec/Execution.cc b/runtime/onert/core/src/exec/Execution.cc
index 7d5b406ef..1384c9fdc 100644
--- a/runtime/onert/core/src/exec/Execution.cc
+++ b/runtime/onert/core/src/exec/Execution.cc
@@ -16,6 +16,8 @@
#include "exec/Execution.h"
+#include "train/TrainableExecutors.h"
+
#include "util/logging.h"
namespace onert
@@ -151,6 +153,35 @@ void Execution::waitFinish()
bool Execution::isFinished(void) const { return finished; }
+#ifdef ONERT_TRAIN
+void Execution::train(uint32_t training_step)
+{
+ auto execs = dynamic_cast<exec::train::TrainableExecutors *>(_executors.get());
+ if (!execs)
+ {
+ throw std::runtime_error{"Supported only TrainableExecutors"};
+ }
+
+ VERBOSE(Execution) << "Start training" << std::endl;
+
+ execs->train(_io_desc, training_step);
+ finished = true;
+
+ VERBOSE(Execution) << "training finished" << std::endl;
+}
+
+float Execution::getLoss(const ir::IOIndex &ind)
+{
+ auto execs = dynamic_cast<exec::train::TrainableExecutors *>(_executors.get());
+ if (!execs)
+ {
+ throw std::runtime_error{"Supported only TrainableExecutors"};
+ }
+
+ return execs->getLoss(ind);
+}
+#endif // ONERT_TRAIN
+
ir::Shape Execution::getInputShape(ir::IOIndex ind) const
{
auto itr = _io_desc.dynamic_input_shapes.find(ind);
@@ -180,5 +211,16 @@ ir::Shape Execution::getOutputShape(ir::IOIndex ind) const
return output_desc->info.shape();
}
+size_t Execution::getInputTotalSize(ir::IOIndex ind) const
+{
+ // TODO Support dynamic shape
+ return _executors->inputInfo(ind).total_size();
+}
+
+size_t Execution::getOutputTotalSize(ir::IOIndex ind) const
+{
+ return _executors->outputInfo(ind).total_size();
+}
+
} // namespace exec
} // namespace onert
diff --git a/runtime/onert/core/src/exec/ExecutionObservers.cc b/runtime/onert/core/src/exec/ExecutionObservers.cc
index 9abde7ba4..5245518a0 100644
--- a/runtime/onert/core/src/exec/ExecutionObservers.cc
+++ b/runtime/onert/core/src/exec/ExecutionObservers.cc
@@ -28,7 +28,7 @@
namespace
{
-void setUserData(const onert::ir::Graph &g, const onert::ir::Operation *op,
+void setUserData(const onert::ir::Graph &g, const onert::ir::IOperation *op,
decltype(EventCollector::Event::userData) &data)
{
// From a tensor of shape [a, b, c], this will return a string "shape(a b c)".
diff --git a/runtime/onert/core/src/exec/ExecutionObservers.h b/runtime/onert/core/src/exec/ExecutionObservers.h
index 91fbac323..7e93ecf7c 100644
--- a/runtime/onert/core/src/exec/ExecutionObservers.h
+++ b/runtime/onert/core/src/exec/ExecutionObservers.h
@@ -24,7 +24,7 @@
#include "exec/IExecutor.h"
#include "ir/Index.h"
-#include "ir/Operation.h"
+#include "ir/IOperation.h"
#include "util/ITimer.h"
#include "util/TracingCtx.h"
diff --git a/runtime/onert/core/src/exec/ExecutorBase.cc b/runtime/onert/core/src/exec/ExecutorBase.cc
index 515cf8e48..ad0073477 100644
--- a/runtime/onert/core/src/exec/ExecutorBase.cc
+++ b/runtime/onert/core/src/exec/ExecutorBase.cc
@@ -35,7 +35,7 @@ ExecutorBase::ExecutorBase(std::unique_ptr<compiler::LoweredGraph> &&lowered_gra
{
auto build_tensor_list = [&](const auto &ind_seq, auto &tensors) {
assert(tensors.empty());
- for (auto ind : ind_seq)
+ for (auto &&ind : ind_seq)
{
backend::ITensor *tensor = tensor_regs.getITensor(ind);
assert(tensor != nullptr);
diff --git a/runtime/onert/core/src/exec/ExecutorBase.h b/runtime/onert/core/src/exec/ExecutorBase.h
index 7aee3d9ee..4f97de922 100644
--- a/runtime/onert/core/src/exec/ExecutorBase.h
+++ b/runtime/onert/core/src/exec/ExecutorBase.h
@@ -77,6 +77,7 @@ public:
{
return _output_tensors;
}
+ backend::BackendContexts &getBackendContexts() { return _backend_contexts; }
protected:
/**
diff --git a/runtime/onert/core/src/exec/Executors.cc b/runtime/onert/core/src/exec/Executors.cc
index 7edd5aaa2..8a1be3df4 100644
--- a/runtime/onert/core/src/exec/Executors.cc
+++ b/runtime/onert/core/src/exec/Executors.cc
@@ -147,7 +147,7 @@ void Executors::checkSupportedMultimodel() const
// Assumption: edges
// m1 < m2, s1 == 0 and s2 == 0 if edge 'm1:s1:o1 -> m2:s2:o2'
- for (auto edge : _model_edges->edges)
+ for (auto &&edge : _model_edges->edges)
{
auto const model_from = std::get<ir::ModelIndex>(edge.from);
auto const model_to = std::get<ir::ModelIndex>(edge.to);
diff --git a/runtime/onert/core/src/exec/FunctionSequence.cc b/runtime/onert/core/src/exec/FunctionSequence.cc
index f87c271f7..578123a54 100644
--- a/runtime/onert/core/src/exec/FunctionSequence.cc
+++ b/runtime/onert/core/src/exec/FunctionSequence.cc
@@ -16,7 +16,6 @@
#include "exec/FunctionSequence.h"
-#include "ir/Operation.h"
#include "backend/ITensorRegistry.h"
#include "util/logging.h"
diff --git a/runtime/onert/core/src/exec/LinearExecutor.h b/runtime/onert/core/src/exec/LinearExecutor.h
index a833466da..cc073411a 100644
--- a/runtime/onert/core/src/exec/LinearExecutor.h
+++ b/runtime/onert/core/src/exec/LinearExecutor.h
@@ -52,7 +52,7 @@ public:
const std::vector<ir::OperationIndex> &order, const util::TracingCtx *tracing_ctx)
: ExecutorBase{std::move(lowered_graph), std::move(backend_contexts), tensor_regs, tracing_ctx}
{
- for (auto index : order)
+ for (auto &&index : order)
{
_code.emplace_back(std::move(code_map.at(index)));
}
diff --git a/runtime/onert/core/src/exec/MinMaxRecorder.cc b/runtime/onert/core/src/exec/MinMaxRecorder.cc
new file mode 100644
index 000000000..88fc104d1
--- /dev/null
+++ b/runtime/onert/core/src/exec/MinMaxRecorder.cc
@@ -0,0 +1,112 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MinMaxRecorder.h"
+
+#include "backend/ITensor.h"
+
+#include <cassert>
+#include <cmath>
+
+namespace onert
+{
+namespace exec
+{
+
+MinMaxRecorder::MinMaxRecorder(const std::string &minmax_filepath, const ir::Graph &graph,
+ const backend::BackendContexts &backend_contexts)
+ : _graph{graph}, _backend_contexts{backend_contexts}, _h5dumper(minmax_filepath)
+{
+}
+
+void MinMaxRecorder::handleJobEnd(IExecutor *, ir::SubgraphIndex subg_idx,
+ ir::OperationIndex op_idx, const backend::Backend *backend)
+{
+ const auto &tensor_reg = _backend_contexts.at(backend)->tensor_registry;
+ const auto &op = _graph.operations().at(op_idx);
+ const auto &outputs = op.getOutputs();
+ // TODO: Support multiple output
+ if (outputs.size() != 1)
+ throw std::runtime_error("Only 1 output operator is supported for recording minmax.");
+
+ auto tensor = tensor_reg->getITensor(outputs.at(0));
+
+ // Logic copied from MinMaxObserver.cpp.
+
+ // Filter Ops
+ if (tensor->is_constant())
+ return;
+
+ if (tensor->data_type() != ir::DataType::FLOAT32)
+ return;
+
+ switch (op.opcode())
+ {
+ // Operators with multiple outputs
+ case ir::OpCode::If:
+ case ir::OpCode::Split:
+ case ir::OpCode::SplitV:
+ case ir::OpCode::TopKV2:
+ case ir::OpCode::Unpack:
+ case ir::OpCode::While:
+ return;
+ // NOTE: Sin, Cos, Tanh's output is in [-1, 1]
+ // We may not need to dump those operators.
+ default:; // Do Nothing
+ }
+
+ // Otherwise, dump!
+ assert(tensor->data_type() == ir::DataType::FLOAT32);
+ const auto data = reinterpret_cast<float *>(tensor->buffer());
+ const auto num_elements = tensor->total_size() / sizeof(float);
+
+ float max = std::numeric_limits<float>::lowest();
+ float min = std::numeric_limits<float>::max();
+
+ bool all_nan = true;
+ for (size_t i = 0; i < num_elements; ++i)
+ {
+ const float number = data[i];
+ if (std::isnan(number))
+ continue;
+
+ if (number == std::numeric_limits<float>::lowest())
+ continue;
+
+ all_nan = false;
+
+ if (number > max)
+ max = number;
+
+ if (number < min)
+ min = number;
+ }
+
+ if (all_nan)
+ throw std::runtime_error("All values are NaN(Not a Number)");
+
+ _minmax_map.append({subg_idx, op_idx}, min, max);
+}
+
+void MinMaxRecorder::handleSubgraphEnd(ir::SubgraphIndex)
+{
+ // It would be better to dump at the end of model execution, not subgraph
+ // But it requires more changes than subgraph.
+ _h5dumper.dump(_minmax_map);
+}
+
+} // namespace exec
+} // namespace onert
diff --git a/runtime/onert/core/src/exec/MinMaxRecorder.h b/runtime/onert/core/src/exec/MinMaxRecorder.h
new file mode 100644
index 000000000..7a0817f5f
--- /dev/null
+++ b/runtime/onert/core/src/exec/MinMaxRecorder.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_EXEC_MINMAX_RECORDER__
+#define __ONERT_EXEC_MINMAX_RECORDER__
+
+#include "ExecutionObservers.h"
+#include "ir/Index.h"
+#include "exec/MinMaxMap.h"
+#include "../dumper/h5/MinMaxDumper.h"
+
+#include <memory>
+
+namespace onert
+{
+namespace exec
+{
+
+class MinMaxRecorder : public IExecutionObserver
+{
+public:
+ MinMaxRecorder(const std::string &minmax_filepath, const ir::Graph &graph,
+ const backend::BackendContexts &backend_contexts);
+ void handleJobBegin(IExecutor *, ir::SubgraphIndex, ir::OperationIndex,
+ const backend::Backend *) override
+ {
+ return;
+ }
+ void handleJobEnd(IExecutor *, ir::SubgraphIndex, ir::OperationIndex,
+ const backend::Backend *) override;
+ void handleSubgraphEnd(ir::SubgraphIndex) override;
+
+private:
+ const ir::Graph &_graph;
+ const backend::BackendContexts &_backend_contexts;
+ dumper::h5::MinMaxDumper _h5dumper;
+ SMMinMaxMap _minmax_map;
+};
+
+} // namespace exec
+} // namespace onert
+
+#endif // __ONERT_EXEC_MINMAX_RECORDER__
diff --git a/runtime/onert/core/src/exec/ParallelScheduler.cc b/runtime/onert/core/src/exec/ParallelScheduler.cc
index 456663f91..538945631 100644
--- a/runtime/onert/core/src/exec/ParallelScheduler.cc
+++ b/runtime/onert/core/src/exec/ParallelScheduler.cc
@@ -30,7 +30,7 @@ ParallelScheduler::ParallelScheduler(const BackendSet &backends)
{
assert(!backends.empty());
- for (auto backend : backends)
+ for (auto &&backend : backends)
{
_thread_pools[backend] = std::make_unique<ThreadPool>();
}
diff --git a/runtime/onert/core/src/exec/train/TrainableExecutor.cc b/runtime/onert/core/src/exec/train/TrainableExecutor.cc
new file mode 100644
index 000000000..9c7e70c29
--- /dev/null
+++ b/runtime/onert/core/src/exec/train/TrainableExecutor.cc
@@ -0,0 +1,204 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TrainableExecutor.h"
+#ifdef RUY_PROFILER
+#include "ruy/profiler/instrumentation.h"
+#endif
+
+#include <misc/polymorphic_downcast.h>
+
+namespace onert
+{
+namespace exec
+{
+namespace train
+{
+
+TrainableExecutor::TrainableExecutor(
+ std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
+ backend::train::TrainableBackendContexts &&backend_contexts,
+ const compiler::train::TensorRegistries &tensor_regs,
+ compiler::train::TrainableCodeMap &&code_map, const std::vector<ir::OperationIndex> &order,
+ const util::TracingCtx *tracing_ctx)
+ : _lowered_graph{std::move(lowered_graph)}, _backend_contexts{std::move(backend_contexts)},
+ _trainable_graph{_lowered_graph->trainable_graph()}, _tensor_regs{std::move(tensor_regs)},
+ _mutex(), _tracing_ctx(tracing_ctx)
+{
+ auto build_tensor_list = [&](const auto &ind_seq, auto &tensors) {
+ assert(tensors.empty());
+ for (auto &&ind : ind_seq)
+ {
+ backend::ITensor *tensor = tensor_regs.getITensor(ind);
+ assert(tensor != nullptr);
+ auto io_tensor = nnfw::misc::polymorphic_downcast<backend::builtin::IOTensor *>(tensor);
+ tensors.push_back(io_tensor);
+ }
+ };
+ build_tensor_list(_trainable_graph.getInputs(), _input_tensors);
+ build_tensor_list(_trainable_graph.getOutputs(), _output_tensors);
+
+ for (auto &&index : order)
+ {
+ auto &trainable_code = code_map.at(index);
+ _code.emplace_back(std::move(trainable_code));
+ }
+}
+
+void TrainableExecutor::execute(const std::vector<backend::IPortableTensor *> &,
+ const std::vector<backend::IPortableTensor *> &)
+{
+ throw std::runtime_error("TrainableExecutor does not support multiple subgraphs yet");
+}
+
+void TrainableExecutor::forward(const IODescription &desc, bool training)
+{
+ // For thread-safe, use mutex
+ // TODO: if all used backends on this executor are thread-safe,
+ // do not need to use mutex (otherwise, use mutex)
+ std::lock_guard<std::mutex> lock(_mutex);
+
+ // TODO Update IO tensors if desc has dynamic input
+ // Set input(s)
+ assert(_input_tensors.size() == desc.inputs.size());
+ for (uint32_t i = 0; i < _input_tensors.size(); ++i)
+ {
+ auto tensor = _input_tensors[i];
+
+ // TODO Check if (desc.inputs[i] == nullptr)
+ // TODO Better design for ITensor? (we need const_cast as ITensor is writable)
+ tensor->setUserTensor(static_cast<uint8_t *>(const_cast<void *>(desc.inputs[i]->buffer)),
+ desc.inputs[i]->size);
+ }
+
+ if (!training)
+ {
+ // Set output(s)
+ assert(_output_tensors.size() == desc.outputs.size());
+ for (uint32_t i = 0; i < _output_tensors.size(); ++i)
+ {
+ auto tensor = _output_tensors[i];
+
+ if (desc.outputs[i] == nullptr)
+ throw std::runtime_error{"Output " + std::to_string(i) + "'s buffer is not set."};
+ tensor->setUserTensor(static_cast<uint8_t *>(desc.outputs[i]->buffer), desc.outputs[i]->size);
+ }
+ }
+
+ forwardImpl(training);
+
+ // TODO Update output(s) desc if desc has dynamic input
+}
+
+void TrainableExecutor::forwardImpl(bool training)
+{
+ if (_tracing_ctx)
+ {
+ auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_trainable_graph.graph());
+
+ _subject.notifySubgraphBegin(profiling_subg_index);
+ for (auto &&code : _code)
+ {
+ const auto backend = code.lower_info->backend();
+// TODO : Move ruy profiler into ExecutionObserver
+#ifdef RUY_PROFILER
+ ruy::profiler::ScopeLabel label(code.op->name());
+#endif
+ _subject.notifyJobBegin(this, profiling_subg_index, code.op_ind, backend);
+
+ auto &tn_seq = code.tn_seq;
+ tn_seq->forward(training);
+
+ _subject.notifyJobEnd(this, profiling_subg_index, code.op_ind, backend);
+ }
+ _subject.notifySubgraphEnd(profiling_subg_index);
+ }
+ else
+ {
+ for (auto &&code : _code)
+ {
+// TODO : Move ruy profiler into ExecutionObserver
+#ifdef RUY_PROFILER
+ ruy::profiler::ScopeLabel label(code.op->name());
+#endif
+ auto &tn_seq = code.tn_seq;
+ tn_seq->forward(training);
+ }
+ }
+}
+
+void TrainableExecutor::backward(const IODescription &, uint32_t training_step)
+{
+ // For thread-safe, use mutex
+ // TODO: if all used backends on this executor are thread-safe,
+ // do not need to use mutex (otherwise, use mutex)
+ std::lock_guard<std::mutex> lock(_mutex);
+
+ backwardImpl(training_step);
+}
+
+void TrainableExecutor::backwardImpl(uint32_t training_step)
+{
+ if (_tracing_ctx)
+ {
+ auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_trainable_graph.graph());
+
+ _subject.notifySubgraphBegin(profiling_subg_index);
+ for (auto it = _code.rbegin(); it != _code.rend(); ++it)
+ {
+ const auto &code = *it;
+ const auto backend = code.lower_info->backend();
+// TODO : Move ruy profiler into ExecutionObserver
+#ifdef RUY_PROFILER
+ ruy::profiler::ScopeLabel label(code.op->name());
+#endif
+ _subject.notifyJobBegin(this, profiling_subg_index, code.op_ind, backend);
+
+ auto &tn_seq = code.tn_seq;
+ tn_seq->backward(training_step);
+
+ _subject.notifyJobEnd(this, profiling_subg_index, code.op_ind, backend);
+ }
+ _subject.notifySubgraphEnd(profiling_subg_index);
+ }
+ else
+ {
+ for (auto it = _code.rbegin(); it != _code.rend(); ++it)
+ {
+ const auto &code = *it;
+// TODO : Move ruy profiler into ExecutionObserver
+#ifdef RUY_PROFILER
+ ruy::profiler::ScopeLabel label(code.op->name());
+#endif
+ auto &tn_seq = code.tn_seq;
+ tn_seq->backward(training_step);
+ }
+ }
+}
+
+float TrainableExecutor::getLoss(const ir::IOIndex &pred_io_ind) const
+{
+ const auto &loss_ind = _trainable_graph.getLossIndex(pred_io_ind);
+ if (loss_ind.undefined())
+ throw std::runtime_error{"Loss " + std::to_string(loss_ind.value()) + " is not defined."};
+ backend::ITensor *tensor = _tensor_regs.getITensor(loss_ind);
+ auto loss_buf = reinterpret_cast<float *>(tensor->buffer());
+ return *loss_buf;
+}
+
+} // namespace train
+} // namespace exec
+} // namespace onert
diff --git a/runtime/onert/core/src/exec/train/TrainableExecutor.h b/runtime/onert/core/src/exec/train/TrainableExecutor.h
new file mode 100644
index 000000000..6b645305f
--- /dev/null
+++ b/runtime/onert/core/src/exec/train/TrainableExecutor.h
@@ -0,0 +1,109 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTOR_H_
+#define __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTOR_H_
+
+#include "exec/IExecutor.h"
+
+#include "../ExecutionObservee.h"
+#include "../../compiler/train/TensorRegistries.h"
+
+#include "backend/train/TrainableBackendContext.h"
+#include "compiler/train/TrainableCodeMap.h"
+#include "compiler/train/LoweredTrainableGraph.h"
+#include "ir/Index.h"
+#include "util/TracingCtx.h"
+
+namespace onert
+{
+namespace exec
+{
+namespace train
+{
+
+class TrainableExecutor : public IExecutor
+{
+public:
+ /**
+ * @brief Construct a new TrainableExecutor object
+ * @param lowered_graph LoweredTrainableGraph object
+ * @param tensor_builders Tensor builders that are currently used
+ * @param code_map @c ir::Operation and its code map
+ */
+ TrainableExecutor(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
+ backend::train::TrainableBackendContexts &&backend_contexts,
+ const compiler::train::TensorRegistries &tensor_regs,
+ compiler::train::TrainableCodeMap &&code_map,
+ const std::vector<ir::OperationIndex> &order,
+ const util::TracingCtx *tracing_ctx);
+
+public:
+ const ir::Graph &graph() const final { return _trainable_graph.graph(); }
+
+ void execute(const IODescription &desc) override { forward(desc, false); };
+
+ void execute(const std::vector<backend::IPortableTensor *> &inputs,
+ const std::vector<backend::IPortableTensor *> &outputs) override;
+
+ void forward(const IODescription &desc, bool training);
+ void backward(const IODescription &desc, uint32_t training_step);
+
+ // Used only in Dataflow and Parallel Executors
+ void setIndexedRanks(std::shared_ptr<ir::OperationIndexMap<int64_t>> ranks) final
+ {
+ _indexed_ranks = std::move(ranks);
+ };
+
+ void addObserver(std::unique_ptr<IExecutionObserver> ref) { _subject.add(std::move(ref)); };
+
+ const std::vector<backend::builtin::IOTensor *> &getInputTensors() const override
+ {
+ return _input_tensors;
+ }
+
+ const std::vector<backend::builtin::IOTensor *> &getOutputTensors() const override
+ {
+ return _output_tensors;
+ }
+
+ float getLoss(const ir::IOIndex &pred_io_ind) const;
+
+ backend::train::TrainableBackendContexts &getBackendContexts() { return _backend_contexts; }
+
+private:
+ void forwardImpl(bool training);
+ void backwardImpl(uint32_t training_step);
+
+private:
+ std::vector<compiler::train::TrainableCodeAndInfo> _code;
+ ExecutionObservee _subject;
+ std::shared_ptr<ir::OperationIndexMap<int64_t>> _indexed_ranks;
+ std::unique_ptr<compiler::train::LoweredTrainableGraph> _lowered_graph;
+ backend::train::TrainableBackendContexts _backend_contexts;
+ const ir::train::TrainableGraph &_trainable_graph;
+ compiler::train::TensorRegistries _tensor_regs;
+ std::vector<backend::builtin::IOTensor *> _input_tensors;
+ std::vector<backend::builtin::IOTensor *> _output_tensors;
+ std::mutex _mutex;
+ const util::TracingCtx *_tracing_ctx;
+};
+
+} // namespace train
+} // namespace exec
+} // namespace onert
+
+#endif // __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTOR_H_
diff --git a/runtime/onert/core/src/exec/train/TrainableExecutors.cc b/runtime/onert/core/src/exec/train/TrainableExecutors.cc
new file mode 100644
index 000000000..ba39bf0f0
--- /dev/null
+++ b/runtime/onert/core/src/exec/train/TrainableExecutors.cc
@@ -0,0 +1,89 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TrainableExecutors.h"
+
+#include "../../backend/builtin/IOTensor.h"
+
+#include <misc/polymorphic_downcast.h>
+
+namespace onert
+{
+namespace exec
+{
+namespace train
+{
+
+void TrainableExecutors::emplace(const ir::ModelIndex &, const ir::SubgraphIndex &subg_index,
+ std::unique_ptr<IExecutor> exec)
+{
+ std::unique_ptr<TrainableExecutor> t_exec{
+ nnfw::misc::polymorphic_downcast<TrainableExecutor *>(exec.release())};
+ _executors.emplace(subg_index, std::move(t_exec));
+}
+
+TrainableExecutor *TrainableExecutors::at(const ir::ModelIndex &,
+ const ir::SubgraphIndex &subg_index) const
+{
+ return _executors.at(subg_index).get();
+}
+
+uint32_t TrainableExecutors::inputSize() const { return entryExecutor()->getInputTensors().size(); }
+
+uint32_t TrainableExecutors::outputSize() const
+{
+ return entryExecutor()->getOutputTensors().size();
+}
+
+const ir::OperandInfo &TrainableExecutors::inputInfo(const ir::IOIndex &index) const
+{
+ return entryExecutor()->getInputTensors().at(index.value())->orig_info();
+}
+
+const ir::OperandInfo &TrainableExecutors::outputInfo(const ir::IOIndex &index) const
+{
+ return entryExecutor()->getOutputTensors().at(index.value())->orig_info();
+}
+
+void TrainableExecutors::execute(const IODescription &desc)
+{
+ if (_executors.size() > 1)
+ throw std::runtime_error("TrainableExecutors does not support multiple executors yet");
+ entryExecutor()->forward(desc, false);
+
+ // TODO Support multple executors
+}
+
+void TrainableExecutors::train(const IODescription &desc, uint32_t training_step)
+{
+ if (_executors.size() > 1)
+ throw std::runtime_error("TrainableExecutors does not support multiple executors yet");
+ entryExecutor()->forward(desc, true);
+ entryExecutor()->backward(desc, training_step);
+
+ // TODO Support multple executors
+}
+
+float TrainableExecutors::getLoss(const ir::IOIndex &index) const
+{
+ if (_executors.size() > 1)
+ throw std::runtime_error("TrainableExecutors does not support multiple executors yet");
+ return entryExecutor()->getLoss(index);
+}
+
+} // namespace train
+} // namespace exec
+} // namespace onert
diff --git a/runtime/onert/core/src/exec/train/TrainableExecutors.h b/runtime/onert/core/src/exec/train/TrainableExecutors.h
new file mode 100644
index 000000000..db6d198b1
--- /dev/null
+++ b/runtime/onert/core/src/exec/train/TrainableExecutors.h
@@ -0,0 +1,92 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTORS_H__
+#define __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTORS_H__
+
+#include "TrainableExecutor.h"
+#include "exec/IExecutors.h"
+#include "ir/NNPkg.h"
+
+namespace onert
+{
+namespace exec
+{
+namespace train
+{
+
+/**
+ * @brief Class to gather executor set for trainable model NN package
+ */
+class TrainableExecutors : public IExecutors
+{
+public:
+ /**
+ * @brief Construct a new TrainableExecutors object
+ */
+ TrainableExecutors(void) = default;
+ TrainableExecutors(const TrainableExecutors &) = delete;
+ TrainableExecutors(TrainableExecutors &&) = default;
+
+ /**
+ * @brief Destroy the TrainableExecutors object
+ */
+ ~TrainableExecutors() = default;
+
+public:
+ TrainableExecutors &operator=(const TrainableExecutors &) = delete;
+ TrainableExecutors &operator=(TrainableExecutors &&) = default;
+
+public:
+ void emplace(const ir::ModelIndex &model_index, const ir::SubgraphIndex &subg_index,
+ std::unique_ptr<IExecutor> exec) override;
+
+ TrainableExecutor *at(const ir::ModelIndex &model_index,
+ const ir::SubgraphIndex &subg_index) const override;
+
+ TrainableExecutor *entryExecutor() const { return at(ir::ModelIndex{0}, ir::SubgraphIndex{0}); }
+
+ uint32_t inputSize() const override;
+
+ uint32_t outputSize() const override;
+
+ const ir::OperandInfo &inputInfo(const ir::IOIndex &index) const override;
+
+ const ir::OperandInfo &outputInfo(const ir::IOIndex &index) const override;
+
+ void execute(const IODescription &desc) override;
+
+ /**
+ * @brief Train
+ *
+ * @param desc IO information
+ * @param training_step The number of iterations of an training process.
+ * In other words, the number of gradient update.
+ */
+ void train(const IODescription &desc, uint32_t training_step);
+
+ float getLoss(const ir::IOIndex &index) const;
+
+private:
+ // TODO Append model index to ModelIndex
+ std::unordered_map<ir::SubgraphIndex, std::unique_ptr<TrainableExecutor>> _executors;
+};
+
+} // namespace train
+} // namespace exec
+} // namespace onert
+
+#endif // __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTORS_H__
diff --git a/runtime/onert/core/src/exec/train/TrainableFnSequence.cc b/runtime/onert/core/src/exec/train/TrainableFnSequence.cc
new file mode 100644
index 000000000..084b3d708
--- /dev/null
+++ b/runtime/onert/core/src/exec/train/TrainableFnSequence.cc
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "exec/train/TrainableFnSequence.h"
+
+namespace onert
+{
+namespace exec
+{
+namespace train
+{
+
+void TrainableFnSequence::forward(bool training)
+{
+ for (const auto &function : _functions)
+ {
+ function->forward(training);
+ }
+}
+
+void TrainableFnSequence::backward(uint32_t training_step)
+{
+ for (auto it = _functions.rbegin(); it != _functions.rend(); ++it)
+ {
+ (*it)->backward();
+ }
+
+ for (const auto &applier : _appliers)
+ {
+ applier->applyGradient(training_step);
+ }
+}
+
+void TrainableFnSequence::append(std::unique_ptr<ITrainableFunction> &&function)
+{
+ _functions.push_back(std::move(function));
+}
+
+void TrainableFnSequence::append(std::unique_ptr<IGradientApplier> &&applier)
+{
+ _appliers.push_back(std::move(applier));
+}
+
+void TrainableFnSequence::iterate(const std::function<void(ITrainableFunction &)> &fn)
+{
+ for (const auto &func : _functions)
+ {
+ fn(*func);
+ }
+}
+
+} // namespace train
+} // namespace exec
+} // namespace onert
diff --git a/runtime/onert/core/src/exec/train/optimizer/OptimizerCode.cc b/runtime/onert/core/src/exec/train/optimizer/OptimizerCode.cc
new file mode 100644
index 000000000..72b581bf6
--- /dev/null
+++ b/runtime/onert/core/src/exec/train/optimizer/OptimizerCode.cc
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "exec/train/optimizer/OptimizerCode.h"
+
+#include <unordered_map>
+
+namespace onert
+{
+namespace exec
+{
+namespace train
+{
+namespace optimizer
+{
+
+std::string toString(OptimizerCode code)
+{
+ static const std::unordered_map<OptimizerCode, const char *> map{
+ {OptimizerCode::Invalid, "Invalid"},
+ {OptimizerCode::SGD, "SGD"},
+ {OptimizerCode::Adam, "Adam"}};
+ return map.at(code);
+}
+
+} // namespace optimizer
+} // namespace train
+} // namespace exec
+} // namespace onert
diff --git a/runtime/onert/core/src/exec/train/optimizer/OptimizerHelpers.h b/runtime/onert/core/src/exec/train/optimizer/OptimizerHelpers.h
new file mode 100644
index 000000000..66a08b50f
--- /dev/null
+++ b/runtime/onert/core/src/exec/train/optimizer/OptimizerHelpers.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_EXEC_TRAIN_OPTIMIZER_OPTIMIZER_HELPERS_H__
+#define __ONERT_EXEC_TRAIN_OPTIMIZER_OPTIMIZER_HELPERS_H__
+
+#include "backend/IPortableTensor.h"
+
+namespace onert
+{
+namespace exec
+{
+namespace train
+{
+namespace optimizer
+{
+
+template <typename T, typename L>
+void elementwise(const ir::Shape &shape, const backend::ITensor &src, backend::ITensor &dst,
+ const L &f)
+{
+ ShapeLoop(shape, [&](const ir::Coordinates &coords) {
+ const T src_val = *reinterpret_cast<const T *>(src.buffer() + src.calcOffset(coords));
+ T *dst_data = reinterpret_cast<T *>(dst.buffer() + dst.calcOffset(coords));
+ *dst_data = f(src_val, *dst_data);
+ });
+}
+
+} // namespace optimizer
+} // namespace train
+} // namespace exec
+} // namespace onert
+
+#endif // __ONERT_EXEC_TRAIN_OPTIMIZER_OPTIMIZER_HELPERS_H__
diff --git a/runtime/onert/core/src/exec/train/optimizer/SGD.cc b/runtime/onert/core/src/exec/train/optimizer/SGD.cc
new file mode 100644
index 000000000..abfbc1b4b
--- /dev/null
+++ b/runtime/onert/core/src/exec/train/optimizer/SGD.cc
@@ -0,0 +1,66 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <exec/train/optimizer/SGD.h>
+
+#include "OptimizerHelpers.h"
+
+namespace onert
+{
+namespace exec
+{
+namespace train
+{
+namespace optimizer
+{
+
+double SGD::getLearningRate(uint32_t) const
+{
+ // TODO Use iteration, momentum, and nesterov
+ return _learning_rate;
+}
+
+void SGD::applyGradient(const UpdateFactors &factors) const
+{
+ const auto lr = getLearningRate(std::get<size_t>(factors));
+ const auto &grad_tensor = std::get<const backend::IPortableTensor &>(factors);
+ auto &trainable_tensor = std::get<backend::train::ITrainableTensor &>(factors);
+ assert(trainable_tensor.data_type() == grad_tensor.data_type());
+
+ const auto shape = trainable_tensor.getShape();
+ const auto &grad_shape = grad_tensor.get_info().shape();
+
+ // TODO Support for different shapes
+ if (shape != grad_shape)
+ {
+ throw std::runtime_error("SGD: Invalid gradient tensor");
+ }
+
+ switch (grad_tensor.data_type())
+ {
+ case ir::DataType::FLOAT32:
+ elementwise<float>(shape, grad_tensor, trainable_tensor,
+ [&](float src, float dst) -> float { return dst - src * lr; });
+ break;
+ default:
+ throw std::runtime_error("SGD: Not supported data type");
+ }
+}
+
+} // namespace optimizer
+} // namespace train
+} // namespace exec
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/Graph.cc b/runtime/onert/core/src/ir/Graph.cc
index 28cf4137d..ef0f988fa 100644
--- a/runtime/onert/core/src/ir/Graph.cc
+++ b/runtime/onert/core/src/ir/Graph.cc
@@ -42,33 +42,33 @@ OperandIndex Graph::addOperand(OperandIndex index, std::unique_ptr<Operand> &&op
return _operands.push(std::move(operand), index);
}
-bool Graph::checkOperandsForOperation(const Operation &operation)
+bool Graph::checkOperandsForOperation(const IOperation &operation)
{
auto inputs = operation.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
auto outputs = operation.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
- for (auto input : inputs)
+ for (auto &&input : inputs)
if (!operands().exist(input))
return false;
- for (auto input : outputs)
+ for (auto &&input : outputs)
if (!operands().exist(input))
return false;
return true;
}
-void Graph::linkOperandToOperation(OperationIndex index, const Operation &operation)
+void Graph::linkOperandToOperation(OperationIndex index, const IOperation &operation)
{
auto inputs = operation.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
auto outputs = operation.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
- for (auto input : inputs)
+ for (auto &&input : inputs)
operands().at(input).insertUse(index);
- for (auto output : outputs)
+ for (auto &&output : outputs)
operands().at(output).setDef(index);
}
-OperationIndex Graph::addOperation(std::unique_ptr<Operation> &&operation)
+OperationIndex Graph::addOperation(std::unique_ptr<IOperation> &&operation)
{
- const Operation &op_ref = *operation;
+ const IOperation &op_ref = *operation;
if (!checkOperandsForOperation(op_ref))
return OperationIndex{};
auto ind = _operations.push(std::move(operation));
@@ -77,9 +77,9 @@ OperationIndex Graph::addOperation(std::unique_ptr<Operation> &&operation)
return ind;
}
-OperationIndex Graph::addOperation(OperationIndex index, std::unique_ptr<Operation> &&operation)
+OperationIndex Graph::addOperation(OperationIndex index, std::unique_ptr<IOperation> &&operation)
{
- const Operation &op_ref = *operation;
+ const IOperation &op_ref = *operation;
if (!checkOperandsForOperation(op_ref))
return OperationIndex{};
auto ind_gen = _operations.push(std::move(operation), index);
@@ -91,12 +91,35 @@ OperationIndex Graph::addOperation(OperationIndex index, std::unique_ptr<Operati
return index;
}
+OperationIndex Graph::replaceOperation(OperationIndex index,
+ std::unique_ptr<IOperation> &&operation)
+{
+ const IOperation &op_ref = *operation;
+ if (!checkOperandsForOperation(op_ref) || !_operations.exist(index))
+ return OperationIndex{};
+
+ // Check the new operation has the same inputs/outputs as the existing operation
+ const auto &old_op = _operations.at(index);
+ if (!(old_op.getInputs() == op_ref.getInputs() && old_op.getOutputs() == op_ref.getOutputs()))
+ {
+ return OperationIndex{};
+ }
+
+ return _operations.set(index, std::move(operation));
+}
+
void Graph::setOperandValue(const OperandIndex &ind, std::shared_ptr<Data> data)
{
assert(_operands.exist(ind));
_operands.at(ind).data(std::move(data));
}
+void Graph::changeShape(const OperandIndex &ind, const ir::Shape &new_shape)
+{
+ assert(_operands.exist(ind));
+ _operands.at(ind).info().shape(new_shape);
+}
+
void Graph::addInput(const OperandIndex &ind, const std::string &name)
{
if (!name.empty())
@@ -123,7 +146,7 @@ IOIndex Graph::getOutputIndex(const std::string &name) const
return (itr == _name_to_output.end()) ? IOIndex{} : itr->second;
}
-void Graph::verify(void)
+void Graph::verify(void) const
{
// Call graph verifications for the MODEL phase
{
@@ -144,14 +167,14 @@ void Graph::verify(void)
void Graph::initializeUseDef()
{
- operations().iterate([&](const OperationIndex &index, const Operation &node) -> void {
+ operations().iterate([&](const OperationIndex &index, const IOperation &node) -> void {
auto outputs = node.getOutputs();
- for (auto output : outputs | ir::Remove::UNDEFINED)
+ for (auto &&output : outputs | ir::Remove::UNDEFINED)
{
operands().at(output).setDef(index);
}
- for (auto input : node.getInputs() | ir::Remove::UNDEFINED)
+ for (auto &&input : node.getInputs() | ir::Remove::UNDEFINED)
{
operands().at(input).insertUse(index);
}
@@ -163,15 +186,15 @@ std::vector<ir::OperationIndex> Graph::topolSortOperations() const
std::vector<ir::OperationIndex> ret;
util::Set<ir::OperationIndex> unvisited;
operations().iterate(
- [&](const ir::OperationIndex &index, const ir::Operation &) { unvisited.add(index); });
+ [&](const ir::OperationIndex &index, const ir::IOperation &) { unvisited.add(index); });
- std::function<void(const ir::OperationIndex &, const ir::Operation &)> dfs =
- [&](const ir::OperationIndex &index, const ir::Operation &op) -> void {
+ std::function<void(const ir::OperationIndex &, const ir::IOperation &)> dfs =
+ [&](const ir::OperationIndex &index, const ir::IOperation &op) -> void {
if (!unvisited.contains(index))
return;
unvisited.remove(index);
- for (const auto output : op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
+ for (const auto &output : op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
{
const auto &operand = operands().at(output);
for (const auto &use : operand.getUses())
diff --git a/runtime/onert/core/src/ir/LayoutSet.cc b/runtime/onert/core/src/ir/LayoutSet.cc
index bd3f438ad..732460aa2 100644
--- a/runtime/onert/core/src/ir/LayoutSet.cc
+++ b/runtime/onert/core/src/ir/LayoutSet.cc
@@ -23,7 +23,7 @@ namespace ir
LayoutSet::LayoutSet(std::initializer_list<Layout> layouts)
{
- for (auto layout : layouts)
+ for (auto &&layout : layouts)
{
_set.insert(layout);
}
@@ -32,7 +32,7 @@ LayoutSet::LayoutSet(std::initializer_list<Layout> layouts)
LayoutSet LayoutSet::operator|(const LayoutSet &other) const
{
auto ret = *this;
- for (auto layout : other)
+ for (auto &&layout : other)
{
ret.add(layout);
}
@@ -42,7 +42,7 @@ LayoutSet LayoutSet::operator|(const LayoutSet &other) const
LayoutSet LayoutSet::operator&(const LayoutSet &other) const
{
LayoutSet ret;
- for (auto layout : other)
+ for (auto &&layout : other)
{
if (contains(layout))
{
@@ -55,7 +55,7 @@ LayoutSet LayoutSet::operator&(const LayoutSet &other) const
LayoutSet LayoutSet::operator-(const LayoutSet &other) const
{
auto ret = *this;
- for (auto layout : other)
+ for (auto &&layout : other)
{
ret.remove(layout);
}
diff --git a/runtime/onert/core/src/ir/LayoutSet.h b/runtime/onert/core/src/ir/LayoutSet.h
index 6ce4e38c6..be077f2f0 100644
--- a/runtime/onert/core/src/ir/LayoutSet.h
+++ b/runtime/onert/core/src/ir/LayoutSet.h
@@ -17,6 +17,7 @@
#ifndef __ONERT_IR_LAYOUT_SET_H__
#define __ONERT_IR_LAYOUT_SET_H__
+#include <cstdint>
#include <initializer_list>
#include <unordered_set>
diff --git a/runtime/onert/core/src/ir/OperandIndexSequence.cc b/runtime/onert/core/src/ir/OperandIndexSequence.cc
index b092f5cee..a15b6d0d6 100644
--- a/runtime/onert/core/src/ir/OperandIndexSequence.cc
+++ b/runtime/onert/core/src/ir/OperandIndexSequence.cc
@@ -31,7 +31,7 @@ OperandIndexSequence::OperandIndexSequence(std::initializer_list<OperandIndex> l
OperandIndexSequence::OperandIndexSequence(std::initializer_list<int32_t> list)
{
- for (auto val : list)
+ for (auto &&val : list)
{
_vec.emplace_back(static_cast<uint32_t>(val));
}
@@ -39,7 +39,7 @@ OperandIndexSequence::OperandIndexSequence(std::initializer_list<int32_t> list)
OperandIndexSequence::OperandIndexSequence(std::initializer_list<uint32_t> list)
{
- for (auto val : list)
+ for (auto &&val : list)
{
_vec.emplace_back(val);
}
@@ -55,6 +55,11 @@ void OperandIndexSequence::replace(const OperandIndex &from, const OperandIndex
std::replace(_vec.begin(), _vec.end(), from, to);
}
+bool OperandIndexSequence::operator==(const OperandIndexSequence &other) const
+{
+ return _vec == other._vec;
+}
+
OperandIndexSequence OperandIndexSequence::operator+(const OperandIndexSequence &other) const
{
OperandIndexSequence ret = *this;
diff --git a/runtime/onert/core/src/ir/OperationCloner.cc b/runtime/onert/core/src/ir/OperationCloner.cc
index c06315814..64e1cc807 100644
--- a/runtime/onert/core/src/ir/OperationCloner.cc
+++ b/runtime/onert/core/src/ir/OperationCloner.cc
@@ -57,7 +57,7 @@ std::unique_ptr<Operation> OperationCloner::releaseClone()
} // namespace
-std::unique_ptr<Operation> clone(const Operation &operation)
+std::unique_ptr<Operation> clone(const IOperation &operation)
{
OperationCloner cloner;
operation.accept(cloner);
diff --git a/runtime/onert/core/src/ir/OperationCloner.h b/runtime/onert/core/src/ir/OperationCloner.h
index 6424549e9..49297a05c 100644
--- a/runtime/onert/core/src/ir/OperationCloner.h
+++ b/runtime/onert/core/src/ir/OperationCloner.h
@@ -26,7 +26,7 @@ namespace onert
namespace ir
{
-std::unique_ptr<Operation> clone(const Operation &operation);
+std::unique_ptr<Operation> clone(const IOperation &operation);
} // namespace ir
} // namespace onert
diff --git a/runtime/onert/core/src/ir/OperationDumper.cc b/runtime/onert/core/src/ir/OperationDumper.cc
index 0b596ff13..5e6d700f3 100644
--- a/runtime/onert/core/src/ir/OperationDumper.cc
+++ b/runtime/onert/core/src/ir/OperationDumper.cc
@@ -202,6 +202,14 @@ void OperationDumper::visit(const L2Normalization &node) { dumpOpGeneric(node);
void OperationDumper::visit(const LocalResponseNormalization &node) { dumpOpGeneric(node); }
+void OperationDumper::visit(const Loss &node)
+{
+ VERBOSE(LIR) << "* " << node.name() << std::endl;
+ VERBOSE(LIR) << " - Inputs : Prediction(" << node.getInputs().at(Loss::Input::Y_PRED) << ") True("
+ << node.getInputs().at(Loss::Input::Y_TRUE) << ")" << std::endl;
+ VERBOSE(LIR) << " - Outputs : Output(" << node.getOutputs().at(0) << ")" << std::endl;
+}
+
void OperationDumper::visit(const LSTM &node)
{
VERBOSE(LIR) << "* " << node.name() << std::endl;
diff --git a/runtime/onert/core/src/ir/OperationDumper.h b/runtime/onert/core/src/ir/OperationDumper.h
index fe18307b9..99bf869d5 100644
--- a/runtime/onert/core/src/ir/OperationDumper.h
+++ b/runtime/onert/core/src/ir/OperationDumper.h
@@ -55,6 +55,7 @@ public:
void visit(const operation::InstanceNorm &) override;
void visit(const operation::L2Normalization &) override;
void visit(const operation::LocalResponseNormalization &) override;
+ void visit(const operation::Loss &node) override;
void visit(const operation::LSTM &) override;
void visit(const operation::Pack &) override;
void visit(const operation::Pad &) override;
diff --git a/runtime/onert/core/src/ir/OperationValidator.cc b/runtime/onert/core/src/ir/OperationValidator.cc
index 094dbc0d5..cf7323d77 100644
--- a/runtime/onert/core/src/ir/OperationValidator.cc
+++ b/runtime/onert/core/src/ir/OperationValidator.cc
@@ -38,7 +38,7 @@ OperationValidator::OperationValidator(const Graph &graph)
void OperationValidator::operator()()
{
- _operations.iterate([&](const OperationIndex &, const Operation &node) { node.accept(*this); });
+ _operations.iterate([&](const OperationIndex &, const IOperation &node) { node.accept(*this); });
}
DataType OperationValidator::operandType(const OperandIndex &idx)
@@ -75,7 +75,7 @@ bool OperationValidator::isValidType(const OperandIndex &idx, const DataType &ty
bool OperationValidator::isValidType(const OperandIndex &idx,
std::initializer_list<DataType> valid_types)
{
- for (auto type_to_check : valid_types)
+ for (auto &&type_to_check : valid_types)
{
if (isValidType(idx, type_to_check))
{
@@ -163,7 +163,7 @@ void OperationValidator::visit(const operation::Concat &node)
{
const auto output_index{node.getOutputs().at(0)};
- for (auto input_index : node.getInputs())
+ for (auto &&input_index : node.getInputs())
{
OP_REQUIRES(isSameType(input_index, output_index));
diff --git a/runtime/onert/core/src/ir/Operations.cc b/runtime/onert/core/src/ir/Operations.cc
index e7e0c88cf..1b4691f58 100644
--- a/runtime/onert/core/src/ir/Operations.cc
+++ b/runtime/onert/core/src/ir/Operations.cc
@@ -26,7 +26,7 @@ namespace ir
Operations::Operations(const Operations &obj)
{
obj.iterate(
- [&](const OperationIndex &index, const Operation &op) { _objects.emplace(index, clone(op)); });
+ [&](const OperationIndex &index, const IOperation &op) { _objects.emplace(index, clone(op)); });
_next_index = obj._next_index;
}
diff --git a/runtime/onert/core/src/ir/operation/Loss.cc b/runtime/onert/core/src/ir/operation/Loss.cc
new file mode 100644
index 000000000..fa3520b2c
--- /dev/null
+++ b/runtime/onert/core/src/ir/operation/Loss.cc
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/operation/Loss.h"
+#include "ir/OperationVisitor.h"
+
+#include <unordered_map>
+
+namespace onert
+{
+namespace ir
+{
+namespace operation
+{
+
+void Loss::accept(OperationVisitor &v) const { v.visit(*this); }
+
+Loss::Loss(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
+ const Param &param)
+ : Operation{OperandConstraint::createAtLeast(2u), inputs, outputs}, _param{param}
+{
+ if (param.op_type == Type::CATEGORICAL_CROSSENTROPY)
+ {
+ assert(inputs.size() == 2 && "CategoricalCrossentropy Loss has 2 inputs");
+ }
+}
+
+std::string Loss::name() const
+{
+ using LossType = onert::ir::operation::Loss::Type;
+ static const std::unordered_map<Type, std::string> name_map{
+ {LossType::MEAN_SQUARED_ERROR, "MeanSquaredError Loss"},
+ {LossType::CATEGORICAL_CROSSENTROPY, "CategoricalCrossentropy Loss"}};
+ return name_map.at(_param.op_type);
+}
+
+} // namespace operation
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/TrainableGraph.cc b/runtime/onert/core/src/ir/train/TrainableGraph.cc
new file mode 100644
index 000000000..781f04956
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/TrainableGraph.cc
@@ -0,0 +1,145 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/TrainableGraph.h"
+#include "util/Utils.h"
+
+#include <algorithm>
+#include <misc/polymorphic_downcast.h>
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+
+TrainableGraph::TrainableGraph() : _graph{} {}
+
+TrainableGraph::TrainableGraph(const TrainableGraph &tgraph)
+ : _graph{tgraph._graph}, _derivatives{tgraph._derivatives}, _losses{tgraph._losses}
+{
+ tgraph.operations().iterate(
+ [&](const onert::ir::OperationIndex &index, const onert::ir::IOperation &op) {
+ replaceOperation(index, dynamic_cast<const ITrainableOperation &>(op).clone());
+ });
+}
+
+TrainableGraph::TrainableGraph(const Graph &graph) : _graph{graph} {}
+
+OperandIndex TrainableGraph::addOperand(const Shape &shape, const TypeInfo &type)
+{
+ return _graph.addOperand(shape, type);
+}
+
+OperandIndex TrainableGraph::addOperand(OperandIndex index, std::unique_ptr<Operand> &&operand)
+{
+ return _graph.addOperand(index, std::move(operand));
+}
+
+OperationIndex TrainableGraph::addOperation(std::unique_ptr<ITrainableOperation> &&operation)
+{
+ return _graph.addOperation(std::move(operation));
+}
+
+OperationIndex TrainableGraph::replaceOperation(OperationIndex index,
+ std::unique_ptr<ITrainableOperation> &&operation)
+{
+ return _graph.replaceOperation(index, std::move(operation));
+}
+
+OperandIndex TrainableGraph::addDerivative(OperandIndex index,
+ std::unique_ptr<Operand> &&derivative)
+{
+ return _derivatives.push(std::move(derivative), index);
+}
+
+IOIndex TrainableGraph::getInputIndex(const std::string &name) const
+{
+ return _graph.getInputIndex(name);
+}
+
+IOIndex TrainableGraph::getOutputIndex(const std::string &name) const
+{
+ return _graph.getOutputIndex(name);
+}
+
+void TrainableGraph::changeShape(const OperandIndex &index, const ir::Shape &new_shape)
+{
+ _graph.changeShape(index, new_shape);
+}
+
+void TrainableGraph::changeDerivativeShape(const OperandIndex &index, const ir::Shape &new_shape)
+{
+ assert(_derivatives.exist(index));
+ _derivatives.at(index).info().shape(new_shape);
+}
+
+void TrainableGraph::addInput(const OperandIndex &ind, const std::string &name)
+{
+ _graph.addInput(ind, name);
+}
+
+void TrainableGraph::addOutput(const OperandIndex &ind, const std::string &name)
+{
+ _graph.addOutput(ind, name);
+}
+
+void TrainableGraph::verify(void) const
+{
+ _graph.verify();
+
+ operations().iterate([](const onert::ir::OperationIndex &, const onert::ir::IOperation &op) {
+ try
+ {
+ UNUSED_RELEASE(dynamic_cast<const onert::ir::train::ITrainableOperation &>(op));
+ }
+ catch (const std::bad_cast &)
+ {
+ std::runtime_error("TrainableGraph: " + op.name() + " is not a trainable operation");
+ }
+ });
+}
+
+void TrainableGraph::removeOperand(const OperandIndex &ind) { _graph.removeOperand(ind); }
+
+void TrainableGraph::setLayout(Layout layout) { _graph.setLayout(layout); }
+
+const ITrainableOperation &TrainableGraph::operation(OperationIndex index) const
+{
+ // NOTE Virtual inherited objects cannot be static_casted.
+ return dynamic_cast<const ITrainableOperation &>(_graph.operations().at(index));
+}
+
+std::vector<ir::OperationIndex> TrainableGraph::topolSortOperations() const
+{
+ return _graph.topolSortOperations();
+}
+
+void TrainableGraph::addLoss(const OperandIndex &loss_ind, const IOIndex &pred_ioind)
+{
+ _losses.emplace(pred_ioind, loss_ind);
+}
+
+OperandIndex TrainableGraph::getLossIndex(const IOIndex &pred_ioind) const
+{
+ auto itr = _losses.find(pred_ioind);
+ return (itr == _losses.end()) ? OperandIndex{} : itr->second;
+}
+
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/Conv2D.cc b/runtime/onert/core/src/ir/train/operation/Conv2D.cc
new file mode 100644
index 000000000..923861ae3
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/Conv2D.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/Conv2D.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> Conv2D::clone() const
+{
+ return std::make_unique<Conv2D>(*this);
+}
+
+void Conv2D::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void Conv2D::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+Conv2D::Conv2D(const OperationType &operation)
+ : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/ElementwiseActivation.cc b/runtime/onert/core/src/ir/train/operation/ElementwiseActivation.cc
new file mode 100644
index 000000000..1dae3f674
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/ElementwiseActivation.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/ElementwiseActivation.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> ElementwiseActivation::clone() const
+{
+ return std::make_unique<ElementwiseActivation>(*this);
+}
+
+void ElementwiseActivation::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void ElementwiseActivation::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+ElementwiseActivation::ElementwiseActivation(const OperationType &operation)
+ : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/FullyConnected.cc b/runtime/onert/core/src/ir/train/operation/FullyConnected.cc
new file mode 100644
index 000000000..a26f7c489
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/FullyConnected.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/FullyConnected.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> FullyConnected::clone() const
+{
+ return std::make_unique<FullyConnected>(*this);
+}
+
+void FullyConnected::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void FullyConnected::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+FullyConnected::FullyConnected(const OperationType &operation)
+ : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/Loss.cc b/runtime/onert/core/src/ir/train/operation/Loss.cc
new file mode 100644
index 000000000..abd79929b
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/Loss.cc
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/Loss.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+#include <misc/polymorphic_downcast.h>
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> Loss::clone() const { return std::make_unique<Loss>(*this); }
+
+void Loss::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void Loss::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+Loss::Loss(const OperationType &operation)
+ : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/Permute.cc b/runtime/onert/core/src/ir/train/operation/Permute.cc
new file mode 100644
index 000000000..adc23aa49
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/Permute.cc
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/Permute.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> Permute::clone() const
+{
+ return std::make_unique<Permute>(*this);
+}
+
+void Permute::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void Permute::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+Permute::Permute(const OperationType &operation)
+ : OperationType{operation.getInputs().at(0), operation.getOutputs().at(0),
+ operation.getPermuteType()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/Pool2D.cc b/runtime/onert/core/src/ir/train/operation/Pool2D.cc
new file mode 100644
index 000000000..021574f19
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/Pool2D.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/Pool2D.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> Pool2D::clone() const
+{
+ return std::make_unique<Pool2D>(*this);
+}
+
+void Pool2D::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void Pool2D::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+Pool2D::Pool2D(const OperationType &operation)
+ : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/Reshape.cc b/runtime/onert/core/src/ir/train/operation/Reshape.cc
new file mode 100644
index 000000000..c76158607
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/Reshape.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/Reshape.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> Reshape::clone() const
+{
+ return std::make_unique<Reshape>(*this);
+}
+
+void Reshape::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void Reshape::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+Reshape::Reshape(const OperationType &operation)
+ : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/Softmax.cc b/runtime/onert/core/src/ir/train/operation/Softmax.cc
new file mode 100644
index 000000000..dbd403879
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/Softmax.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/Softmax.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> Softmax::clone() const
+{
+ return std::make_unique<Softmax>(*this);
+}
+
+void Softmax::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void Softmax::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+Softmax::Softmax(const OperationType &operation)
+ : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/verifier/Verifier.cc b/runtime/onert/core/src/ir/verifier/Verifier.cc
index 25a82d5a2..6260d29ff 100644
--- a/runtime/onert/core/src/ir/verifier/Verifier.cc
+++ b/runtime/onert/core/src/ir/verifier/Verifier.cc
@@ -39,11 +39,11 @@ bool DAGChecker::verify(const Graph &graph) const noexcept
OperationIndexMap<bool> visited;
operations.iterate(
- [&](const OperationIndex &index, const Operation &) { visited[index] = false; });
+ [&](const OperationIndex &index, const IOperation &) { visited[index] = false; });
OperationIndexMap<bool> on_stack = visited; // Copy from visited
- std::function<void(const OperationIndex &index, const Operation &)> dfs_recursive =
- [&](const OperationIndex &index, const Operation &node) -> void {
+ std::function<void(const OperationIndex &index, const IOperation &)> dfs_recursive =
+ [&](const OperationIndex &index, const IOperation &node) -> void {
if (on_stack[index])
cyclic = true;
if (visited[index])
@@ -51,7 +51,7 @@ bool DAGChecker::verify(const Graph &graph) const noexcept
visited[index] = true;
on_stack[index] = true;
- for (auto output : node.getOutputs() | Remove::DUPLICATED | Remove::UNDEFINED)
+ for (auto &&output : node.getOutputs() | Remove::DUPLICATED | Remove::UNDEFINED)
{
const auto &operand = graph.operands().at(output);
for (const auto &use : operand.getUses())
@@ -76,8 +76,8 @@ bool EdgeChecker::verify(const Graph &graph) const noexcept
{
auto &operations = graph.operations();
uint32_t errors = 0;
- operations.iterate([&](const OperationIndex &index, const Operation &node) {
- for (auto operand_index : node.getInputs() | ir::Remove::UNDEFINED)
+ operations.iterate([&](const OperationIndex &index, const IOperation &node) {
+ for (auto &&operand_index : node.getInputs() | ir::Remove::UNDEFINED)
{
try
{
@@ -98,7 +98,7 @@ bool EdgeChecker::verify(const Graph &graph) const noexcept
errors += 1;
}
}
- for (auto operand_index : node.getOutputs() | ir::Remove::UNDEFINED)
+ for (auto &&operand_index : node.getOutputs() | ir::Remove::UNDEFINED)
{
try
{
@@ -127,7 +127,7 @@ bool EdgeChecker::verify(const Graph &graph) const noexcept
bool InputOutputChecker::verify(const Graph &graph) const noexcept
{
- for (auto operand_ind :
+ for (auto &&operand_ind :
(graph.getInputs() + graph.getOutputs()) | Remove::DUPLICATED | Remove::UNDEFINED)
{
if (!graph.operands().exist(operand_ind))
diff --git a/runtime/onert/core/src/odc/QuantizeManager.cc b/runtime/onert/core/src/odc/QuantizeManager.cc
new file mode 100644
index 000000000..71572a7e0
--- /dev/null
+++ b/runtime/onert/core/src/odc/QuantizeManager.cc
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizerLoader.h"
+#include "odc/QuantizeManager.h"
+
+#include <iostream>
+#include <mutex>
+
+namespace onert
+{
+namespace odc
+{
+
+bool QuantizeManager::quantize()
+{
+ // Compile function is thread-unsafe
+ static std::mutex lock;
+ std::lock_guard<std::mutex> guard(lock);
+
+ if (_export_model_path.empty())
+ throw std::runtime_error("Export model path is not set");
+
+ auto &quantize_loader = QuantizerLoader::instance();
+ if (quantize_loader.loadLibrary() != 0)
+ return false;
+
+ auto quantizer = quantize_loader.get();
+ auto result = quantizer->quantize(_model_path.c_str(), _export_model_path.c_str(), _is_q16);
+
+ // TODO Unload quantize library to reduce memory usage
+
+ return (result == 0);
+}
+
+} // namespace odc
+} // namespace onert
diff --git a/runtime/onert/core/src/odc/QuantizeManager.test.cc b/runtime/onert/core/src/odc/QuantizeManager.test.cc
new file mode 100644
index 000000000..4e155a6ef
--- /dev/null
+++ b/runtime/onert/core/src/odc/QuantizeManager.test.cc
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "odc/QuantizeManager.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert::odc;
+
+// Test export model path is not set
+TEST(odc_QuantizeManager, neg_export_model_path)
+{
+ QuantizeManager manager("model_path");
+ ASSERT_THROW(manager.quantize(), std::runtime_error);
+}
+
+// Test invalid model path
+TEST(odc_QuantizeManager, neg_invalid_model_path)
+{
+ QuantizeManager manager("invalid_model_path.circle");
+ manager.exportModelPath("export_model_path.circle");
+ ASSERT_EQ(manager.quantize(), false);
+}
diff --git a/runtime/onert/core/src/odc/QuantizerLoader.cc b/runtime/onert/core/src/odc/QuantizerLoader.cc
new file mode 100644
index 000000000..8a972e97e
--- /dev/null
+++ b/runtime/onert/core/src/odc/QuantizerLoader.cc
@@ -0,0 +1,104 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizerLoader.h"
+
+#include <dlfcn.h>
+#include <iostream>
+#include <string>
+
+static const char *SHARED_LIB_EXT =
+#if defined(__APPLE__) && defined(__MACH__)
+ ".dylib";
+#else
+ ".so";
+#endif
+
+namespace onert
+{
+namespace odc
+{
+
+QuantizerLoader &QuantizerLoader::instance()
+{
+ static QuantizerLoader singleton;
+ return singleton;
+}
+
+int32_t QuantizerLoader::loadLibrary()
+{
+ if (get() != nullptr)
+ return 0;
+
+ const std::string quantize_so = std::string("libonert_odc") + SHARED_LIB_EXT;
+ void *handle = dlopen(quantize_so.c_str(), RTLD_LAZY | RTLD_LOCAL);
+ auto dlerror_msg = dlerror();
+
+ if (handle == nullptr)
+ {
+ std::cerr << "Failed to load " << quantize_so << std::endl;
+ std::cerr << dlerror_msg << std::endl;
+ return 1;
+ }
+
+ {
+ const char *factory_name = "create_quantizer";
+ auto factory = (factory_t)dlsym(handle, factory_name);
+ dlerror_msg = dlerror();
+
+ if (factory == nullptr)
+ {
+ std::cerr << "QuantizerLoader: unable to find function " << factory_name << dlerror_msg
+ << std::endl;
+ dlclose(handle);
+ return 1;
+ }
+
+ auto destroyer = (quantizer_destory_t)dlsym(handle, "destroy_quantizer");
+ _quantizer = std::unique_ptr<IQuantizer, quantizer_destory_t>(factory(), destroyer);
+
+ if (_quantizer == nullptr)
+ {
+ std::cerr << "QuantizerLoader: unable to create quantizer" << std::endl;
+ dlclose(handle);
+ return 1;
+ }
+ }
+
+ // Save quantize library handle (avoid warning by handle lost without dlclose())
+ // clang-format off
+ _dlhandle = std::unique_ptr<void, dlhandle_destroy_t>{handle, [filename = quantize_so](void *h) {
+ if (dlclose(h) != 0)
+ std::cerr << "Failed to unload backend " << filename << std::endl;
+ }};
+ // clang-format on
+
+ return 0;
+}
+
+int32_t QuantizerLoader::unloadLibrary()
+{
+ if (get() == nullptr)
+ return 0;
+
+ _quantizer.reset(nullptr);
+ _dlhandle.reset(nullptr);
+
+ return 0;
+}
+
+} // namespace odc
+} // namespace onert
diff --git a/runtime/onert/core/src/odc/QuantizerLoader.h b/runtime/onert/core/src/odc/QuantizerLoader.h
new file mode 100644
index 000000000..36a9f2996
--- /dev/null
+++ b/runtime/onert/core/src/odc/QuantizerLoader.h
@@ -0,0 +1,89 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_ODC_QUANTIZER_LOADER_H__
+#define __ONERT_ODC_QUANTIZER_LOADER_H__
+
+#include "odc/IQuantizer.h"
+
+#include <functional>
+#include <memory>
+
+namespace onert
+{
+namespace odc
+{
+
+/**
+ * @brief Class to manage loading and unloading of dynamic library containing
+ * implementation of IQuantizer interface
+ */
+class QuantizerLoader
+{
+public:
+ /**
+ * @brief Typedef for function pointer to destroy loaded library handle
+ */
+ using dlhandle_destroy_t = std::function<void(void *)>;
+ /**
+ * @brief Typedef for function pointer to create instance of IQuantizer
+ */
+ using factory_t = IQuantizer *(*)();
+ /**
+ * @brief Typedef for function pointer to destroy instance of IQuantizer
+ */
+ using quantizer_destory_t = void (*)(IQuantizer *);
+
+ /**
+ * @brief Get singleton instance of QuantizerLoader
+ * @return Reference to singleton instance of QuantizerLoader
+ */
+ static QuantizerLoader &instance();
+
+private:
+ // Cannot create instance of QuantizerLoader outside of this class
+ QuantizerLoader() = default;
+ QuantizerLoader(QuantizerLoader const &) = delete;
+ QuantizerLoader &operator=(QuantizerLoader const &) = delete;
+ ~QuantizerLoader() = default;
+
+public:
+ /**
+ * @brief Load dynamic library containing implementation of IQuantizer
+ * @return 0 if success, otherwise errno value
+ */
+ int32_t loadLibrary();
+ /**
+ * @brief Unload dynamic library containing implementation of IQuantizer
+ * @return 0 if success, otherwise errno value
+ */
+ int32_t unloadLibrary();
+ /**
+ * @brief Get instance of IQuantizer created through factory method
+ * @return Pointer to instance of IQuantizer
+ */
+ IQuantizer *get() const { return _quantizer.get(); }
+
+private:
+ // Note: Keep handle to avoid svace warning of "handle lost without dlclose()"
+ std::unique_ptr<void, dlhandle_destroy_t> _dlhandle;
+ std::unique_ptr<IQuantizer, quantizer_destory_t> _quantizer{nullptr, nullptr};
+};
+
+} // namespace odc
+} // namespace onert
+
+#endif // __ONERT_ODC_QUANTIZER_LOADER_H__
diff --git a/runtime/onert/core/src/odc/QuantizerLoader.test.cc b/runtime/onert/core/src/odc/QuantizerLoader.test.cc
new file mode 100644
index 000000000..112e65b27
--- /dev/null
+++ b/runtime/onert/core/src/odc/QuantizerLoader.test.cc
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizerLoader.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert::odc;
+
+// Test QuantizerLoader singleton
+TEST(odc_QuantizerLoader, singleton)
+{
+ QuantizerLoader &loader1 = QuantizerLoader::instance();
+ QuantizerLoader &loader2 = QuantizerLoader::instance();
+ ASSERT_EQ(&loader1, &loader2);
+}
+
+// Test load quantizer library
+TEST(odc_QuantizerLoader, load)
+{
+ QuantizerLoader &loader = QuantizerLoader::instance();
+ // Unload because it may be loaded on previous tests
+ ASSERT_EQ(loader.unloadLibrary(), 0);
+
+ if (loader.loadLibrary() == 0)
+ {
+ // Load twice to check if it is thread-safe
+ ASSERT_EQ(loader.loadLibrary(), 0);
+ }
+}
+
+// Get quantizer function without loading quantizer library
+TEST(odc_QuantizerLoader, neg_get)
+{
+ QuantizerLoader &loader = QuantizerLoader::instance();
+ // Unload because it may be loaded on previous tests
+ ASSERT_EQ(loader.unloadLibrary(), 0);
+ ASSERT_EQ(loader.get(), nullptr);
+}
+
+// Check quantizer function pointer when QuantizerLoader is unloaded
+TEST(odc_QuantizerLoader, neg_unload)
+{
+ QuantizerLoader &loader = QuantizerLoader::instance();
+ if (loader.loadLibrary() == 0)
+ ASSERT_NE(loader.get(), nullptr);
+
+ ASSERT_EQ(loader.unloadLibrary(), 0);
+ ASSERT_EQ(loader.get(), nullptr);
+}
diff --git a/runtime/onert/core/src/util/MDTableEventWriter.cc b/runtime/onert/core/src/util/MDTableEventWriter.cc
index 13dab5b77..e7d90eec4 100644
--- a/runtime/onert/core/src/util/MDTableEventWriter.cc
+++ b/runtime/onert/core/src/util/MDTableEventWriter.cc
@@ -124,7 +124,7 @@ struct Graph : public MDContent
void setOperations(const std::map<std::string, Operation> &name_to_op)
{
uint64_t graph_latency = end_ts - begin_ts;
- for (auto it : name_to_op)
+ for (auto &&it : name_to_op)
{
auto op = it.second;
op.graph_latency = graph_latency;
@@ -172,7 +172,7 @@ struct Graph : public MDContent
writeMDTableRow(os, op_headers_line);
// Operation's contents
- for (auto op : ops)
+ for (auto &&op : ops)
{
op.write(os);
}