summaryrefslogtreecommitdiff
path: root/runtime/onert/core/src
diff options
context:
space:
mode:
authorChunseok Lee <chunseok.lee@samsung.com>2022-09-07 19:04:21 +0900
committerChunseok Lee <chunseok.lee@samsung.com>2022-09-07 19:04:21 +0900
commitc690d52bdd137ed6a17353aa7af35e8141ece77b (patch)
treedbb7dd99133132dfbffcb8c9e9af4f1ffc2f4808 /runtime/onert/core/src
parent3ad689f0803519e343c36d5700646e86059df961 (diff)
downloadnnfw-c690d52bdd137ed6a17353aa7af35e8141ece77b.tar.gz
nnfw-c690d52bdd137ed6a17353aa7af35e8141ece77b.tar.bz2
nnfw-c690d52bdd137ed6a17353aa7af35e8141ece77b.zip
Diffstat (limited to 'runtime/onert/core/src')
-rw-r--r--runtime/onert/core/src/backend/builtin/ExternalContext.h2
-rw-r--r--runtime/onert/core/src/backend/builtin/KernelGenerator.cc32
-rw-r--r--runtime/onert/core/src/backend/builtin/KernelGenerator.h17
-rw-r--r--runtime/onert/core/src/backend/builtin/kernel/IfLayer.cc16
-rw-r--r--runtime/onert/core/src/backend/builtin/kernel/IfLayer.h7
-rw-r--r--runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.cc4
-rw-r--r--runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.h6
-rw-r--r--runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc19
-rw-r--r--runtime/onert/core/src/backend/builtin/kernel/WhileLayer.h6
-rw-r--r--runtime/onert/core/src/compiler/BackendManager.cc15
-rw-r--r--runtime/onert/core/src/compiler/Compiler.cc505
-rw-r--r--runtime/onert/core/src/compiler/ExecutorFactory.cc85
-rw-r--r--runtime/onert/core/src/compiler/ExecutorFactory.h26
-rw-r--r--runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc10
-rw-r--r--runtime/onert/core/src/compiler/HEScheduler.cc11
-rw-r--r--runtime/onert/core/src/compiler/HEScheduler.h18
-rw-r--r--runtime/onert/core/src/compiler/HEScheduler.test.cc572
-rw-r--r--runtime/onert/core/src/compiler/Linear.cc10
-rw-r--r--runtime/onert/core/src/compiler/LoweredGraph.cc44
-rw-r--r--runtime/onert/core/src/compiler/ShapeValidator.cc667
-rw-r--r--runtime/onert/core/src/compiler/ShapeValidator.h8
-rw-r--r--runtime/onert/core/src/compiler/StaticShapeInferer.cc648
-rw-r--r--runtime/onert/core/src/compiler/TensorRegistries.h13
-rw-r--r--runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc1
-rw-r--r--runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc18
-rw-r--r--runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.test.cc47
-rw-r--r--runtime/onert/core/src/dumper/dot/DotDumper.cc222
-rw-r--r--runtime/onert/core/src/dumper/dot/DotDumper.h25
-rw-r--r--runtime/onert/core/src/exec/DataflowExecutor.h17
-rw-r--r--runtime/onert/core/src/exec/ExecTime.cc6
-rw-r--r--runtime/onert/core/src/exec/ExecTime.test.cc106
-rw-r--r--runtime/onert/core/src/exec/Execution.cc24
-rw-r--r--runtime/onert/core/src/exec/Execution.test.cc302
-rw-r--r--runtime/onert/core/src/exec/ExecutionObservee.h5
-rw-r--r--runtime/onert/core/src/exec/ExecutionObservers.cc14
-rw-r--r--runtime/onert/core/src/exec/ExecutionObservers.h13
-rw-r--r--runtime/onert/core/src/exec/ExecutorBase.cc5
-rw-r--r--runtime/onert/core/src/exec/ExecutorBase.h15
-rw-r--r--runtime/onert/core/src/exec/Executors.cc183
-rw-r--r--runtime/onert/core/src/exec/FunctionSequence.cc4
-rw-r--r--runtime/onert/core/src/exec/JSONExecTime.cc4
-rw-r--r--runtime/onert/core/src/exec/LinearExecutor.h5
-rw-r--r--runtime/onert/core/src/exec/ParallelExecutor.h14
-rw-r--r--runtime/onert/core/src/exec/feature/MockTensor.h66
-rw-r--r--runtime/onert/core/src/exec/feature/nchw/Reader.test.cc85
-rw-r--r--runtime/onert/core/src/exec/feature/nchw/View.test.cc85
-rw-r--r--runtime/onert/core/src/exec/feature/nhwc/Reader.test.cc86
-rw-r--r--runtime/onert/core/src/exec/feature/nhwc/View.h2
-rw-r--r--runtime/onert/core/src/exec/feature/nhwc/View.test.cc86
-rw-r--r--runtime/onert/core/src/interp/InterpExecutor.cc7
-rw-r--r--runtime/onert/core/src/interp/InterpExecutor.h7
-rw-r--r--runtime/onert/core/src/interp/InterpExecutor.test.cc355
-rw-r--r--runtime/onert/core/src/interp/operations/BinaryArithmeticOps.cc10
-rw-r--r--runtime/onert/core/src/interp/operations/Concat.cc8
-rw-r--r--runtime/onert/core/src/interp/operations/Conv2D.cc10
-rw-r--r--runtime/onert/core/src/interp/operations/DepthwiseConv2D.cc10
-rw-r--r--runtime/onert/core/src/interp/operations/ElementwiseActivations.cc9
-rw-r--r--runtime/onert/core/src/interp/operations/FullyConnected.cc8
-rw-r--r--runtime/onert/core/src/interp/operations/Gather.cc8
-rw-r--r--runtime/onert/core/src/interp/operations/InstanceNorm.cc8
-rw-r--r--runtime/onert/core/src/interp/operations/Pad.cc6
-rw-r--r--runtime/onert/core/src/interp/operations/Pool2D.cc12
-rw-r--r--runtime/onert/core/src/interp/operations/Reshape.cc2
-rw-r--r--runtime/onert/core/src/interp/operations/Softmax.cc8
-rw-r--r--runtime/onert/core/src/interp/operations/TransposeConv.cc8
-rw-r--r--runtime/onert/core/src/ir/Graph.cc14
-rw-r--r--runtime/onert/core/src/ir/Graph.test.cc147
-rw-r--r--runtime/onert/core/src/ir/LayoutSet.test.cc67
-rw-r--r--runtime/onert/core/src/ir/MockNode.h47
-rw-r--r--runtime/onert/core/src/ir/Operand.test.cc86
-rw-r--r--runtime/onert/core/src/ir/OperandIndexSequence.test.cc52
-rw-r--r--runtime/onert/core/src/ir/Operands.test.cc45
-rw-r--r--runtime/onert/core/src/ir/Operation.test.cc98
-rw-r--r--runtime/onert/core/src/ir/Operations.test.cc42
-rw-r--r--runtime/onert/core/src/ir/Shape.test.cc58
-rw-r--r--runtime/onert/core/src/ir/verifier/Verifier.test.cc93
-rw-r--r--runtime/onert/core/src/util/ChromeTracingEventWriter.cc6
-rw-r--r--runtime/onert/core/src/util/ConfigSource.cc25
-rw-r--r--runtime/onert/core/src/util/EventCollector.cc2
-rw-r--r--runtime/onert/core/src/util/EventCollector.h7
-rw-r--r--runtime/onert/core/src/util/EventRecorder.cc2
-rw-r--r--runtime/onert/core/src/util/EventWriter.cc2
-rw-r--r--runtime/onert/core/src/util/GeneralConfigSource.cc45
-rw-r--r--runtime/onert/core/src/util/Index.test.cc (renamed from runtime/onert/core/src/util/EnvConfigSource.cc)34
-rw-r--r--runtime/onert/core/src/util/MDTableEventWriter.cc10
-rw-r--r--runtime/onert/core/src/util/ObjectManager.test.cc211
-rw-r--r--runtime/onert/core/src/util/SNPEEventWriter.cc5
-rw-r--r--runtime/onert/core/src/util/ShapeInference.test.cc544
88 files changed, 4962 insertions, 1337 deletions
diff --git a/runtime/onert/core/src/backend/builtin/ExternalContext.h b/runtime/onert/core/src/backend/builtin/ExternalContext.h
index e67be988d..390dbb579 100644
--- a/runtime/onert/core/src/backend/builtin/ExternalContext.h
+++ b/runtime/onert/core/src/backend/builtin/ExternalContext.h
@@ -24,6 +24,8 @@
#include <ruy/ctx.h>
#include <ruy/tune.h>
+#include <memory>
+
namespace onert
{
namespace backend
diff --git a/runtime/onert/core/src/backend/builtin/KernelGenerator.cc b/runtime/onert/core/src/backend/builtin/KernelGenerator.cc
index 3d6358d9d..fa2fc0b94 100644
--- a/runtime/onert/core/src/backend/builtin/KernelGenerator.cc
+++ b/runtime/onert/core/src/backend/builtin/KernelGenerator.cc
@@ -16,12 +16,10 @@
#include "KernelGenerator.h"
-#include <backend/BackendContext.h>
-#include <util/Utils.h>
#include "kernel/IfLayer.h"
-#include "kernel/WhileLayer.h"
#include "kernel/PermuteLayer.h"
-#include "exec/ExecutorBase.h"
+#include "kernel/WhileLayer.h"
+
#include "exec/FunctionSequence.h"
namespace onert
@@ -35,12 +33,12 @@ KernelGenerator::KernelGenerator(const ir::Graph &graph, DynamicTensorManager *d
const std::shared_ptr<TensorRegistry> &tensor_reg,
const std::shared_ptr<ExternalContext> &external_context)
: basic::KernelGeneratorBase{graph}, _dyn_tensor_manager{dyn_tensor_manager},
- _tensor_reg{tensor_reg}, _tensor_registries{}, _executor_map{nullptr}, _external_context{
- external_context}
+ _tensor_reg{tensor_reg}, _tensor_registries{}, _executors{nullptr}, _external_context{
+ external_context}
{
UNUSED_RELEASE(_graph);
UNUSED_RELEASE(_tensor_registries);
- UNUSED_RELEASE(_executor_map);
+ UNUSED_RELEASE(_executors);
}
std::unique_ptr<exec::FunctionSequence> KernelGenerator::generate(ir::OperationIndex ind)
@@ -48,20 +46,16 @@ std::unique_ptr<exec::FunctionSequence> KernelGenerator::generate(ir::OperationI
assert(_dyn_tensor_manager);
assert(_tensor_reg);
- auto dyn_shape_inferer =
- std::make_unique<exec::DynamicShapeInferer>(_graph.operands(), _tensor_reg);
-
auto ret = std::make_unique<exec::FunctionSequence>();
// Prepare to handle dynamic tensors later
auto dyn_ctx = std::make_shared<exec::FunctionSequence::DynamicTensorCtx>();
{
- dyn_ctx->op_ind = ind;
- dyn_ctx->operations = &_graph.operations();
- dyn_ctx->dynamic_shape_inferer = std::move(dyn_shape_inferer);
-
- ret->dynamic_tensor_ctx(dyn_ctx);
+ dyn_ctx->op = &_graph.operations().at(ind);
+ dyn_ctx->dynamic_shape_inferer =
+ std::make_unique<exec::DynamicShapeInferer>(_graph.operands(), _tensor_reg);
}
+ ret->dynamic_tensor_ctx(dyn_ctx);
auto &op = _graph.operations().at(ind);
op.accept(*this);
@@ -90,12 +84,12 @@ void KernelGenerator::visit(const ir::operation::If &node)
output_tensors.emplace_back(output_tensor);
}
- // IfLayer just set ExecutorMap instead of then and else executor to avoid complexity of
+ // IfLayer just set Executors instead of then and else executor to avoid complexity of
// creating executor recusively
const auto cond_tensor = input_tensors.front();
input_tensors.erase(input_tensors.begin());
auto fn = std::make_unique<::onert::backend::builtin::kernel::IfLayer>(
- cond_tensor, input_tensors, output_tensors, then_subg_index, else_subg_index, _executor_map,
+ cond_tensor, input_tensors, output_tensors, then_subg_index, else_subg_index, _executors,
_external_context);
_return_fn = std::move(fn);
@@ -136,10 +130,10 @@ void KernelGenerator::visit(const ir::operation::While &node)
output_tensors.emplace_back(output_tensor);
}
- // WhileLayer just set ExecutorMap instead of cond and body executor to avoid complexity of
+ // WhileLayer just set Executors instead of cond and body executor to avoid complexity of
// creating executor recusively
auto fn = std::make_unique<::onert::backend::builtin::kernel::WhileLayer>(
- input_tensors, output_tensors, cond_subg_index, body_subg_index, _executor_map,
+ input_tensors, output_tensors, cond_subg_index, body_subg_index, _executors,
_dyn_tensor_manager->dynamic_mem_mgr().get(), _external_context);
_return_fn = std::move(fn);
diff --git a/runtime/onert/core/src/backend/builtin/KernelGenerator.h b/runtime/onert/core/src/backend/builtin/KernelGenerator.h
index 00ad962b9..d5931ca26 100644
--- a/runtime/onert/core/src/backend/builtin/KernelGenerator.h
+++ b/runtime/onert/core/src/backend/builtin/KernelGenerator.h
@@ -17,13 +17,14 @@
#ifndef __ONERT_BACKEND_BUILTIN_KERNEL_GENERATOR_H__
#define __ONERT_BACKEND_BUILTIN_KERNEL_GENERATOR_H__
-#include "exec/IExecutor.h"
+#include "DynamicTensorManager.h"
#include "ExternalContext.h"
-#include "ir/Graph.h"
-#include "TensorBuilder.h"
-#include "compiler/TensorRegistries.h"
-#include "backend/basic/KernelGeneratorBase.h"
#include "TensorRegistry.h"
+#include "../../compiler/TensorRegistries.h"
+
+#include "backend/basic/KernelGeneratorBase.h"
+#include "exec/Executors.h"
+#include "ir/Graph.h"
namespace onert
{
@@ -43,10 +44,10 @@ public:
{
_tensor_registries = tensor_registries;
}
- void setExecutorMap(const std::shared_ptr<exec::ExecutorMap> &executor_map)
+ void setExecutors(const std::shared_ptr<exec::Executors> &executors)
{
// FIXME Using shared_ptr's raw pointer!
- _executor_map = executor_map.get();
+ _executors = executors.get();
}
std::unique_ptr<exec::FunctionSequence> generate(ir::OperationIndex ind) override;
@@ -64,7 +65,7 @@ private:
DynamicTensorManager *_dyn_tensor_manager;
std::shared_ptr<TensorRegistry> _tensor_reg;
compiler::TensorRegistries _tensor_registries;
- exec::ExecutorMap *_executor_map;
+ exec::Executors *_executors;
const std::shared_ptr<ExternalContext> _external_context;
};
diff --git a/runtime/onert/core/src/backend/builtin/kernel/IfLayer.cc b/runtime/onert/core/src/backend/builtin/kernel/IfLayer.cc
index fdd9d9d14..cdb41960a 100644
--- a/runtime/onert/core/src/backend/builtin/kernel/IfLayer.cc
+++ b/runtime/onert/core/src/backend/builtin/kernel/IfLayer.cc
@@ -16,10 +16,6 @@
#include "IfLayer.h"
-#include <backend/ITensor.h>
-#include "exec/ExecutorBase.h"
-#include "PermuteLayer.h"
-
namespace onert
{
namespace backend
@@ -33,13 +29,13 @@ IfLayer::IfLayer(backend::IPortableTensor *cond_tensor,
const std::vector<backend::IPortableTensor *> input_tensors,
const std::vector<backend::IPortableTensor *> output_tensors,
const ir::SubgraphIndex &then_subg_index, const ir::SubgraphIndex &else_subg_index,
- exec::ExecutorMap *executor_map,
+ exec::Executors *executors,
const std::shared_ptr<ExternalContext> &external_context)
: _cond_tensor{cond_tensor}, _input_tensors{input_tensors}, _output_tensors{output_tensors},
- _then_subg_index{then_subg_index}, _else_subg_index{else_subg_index},
- _executor_map{executor_map}, _external_context{external_context}
+ _then_subg_index{then_subg_index}, _else_subg_index{else_subg_index}, _executors{executors},
+ _external_context{external_context}
{
- // At this point, executor_map may not have executors of then subg and else subg
+ // At this point, executors may not have executors of then subg and else subg
}
void IfLayer::run()
@@ -65,12 +61,12 @@ void IfLayer::run()
if (cond_result)
{
VERBOSE(If) << "Call to $" << _then_subg_index << " (then)" << std::endl;
- subg_exec = _executor_map->at(_then_subg_index).get();
+ subg_exec = _executors->at(_then_subg_index).get();
}
else
{
VERBOSE(If) << "Call to $" << _else_subg_index << " (else)" << std::endl;
- subg_exec = _executor_map->at(_else_subg_index).get();
+ subg_exec = _executors->at(_else_subg_index).get();
}
subg_exec->execute(_input_tensors, _output_tensors);
diff --git a/runtime/onert/core/src/backend/builtin/kernel/IfLayer.h b/runtime/onert/core/src/backend/builtin/kernel/IfLayer.h
index f12ef3605..fa5537a67 100644
--- a/runtime/onert/core/src/backend/builtin/kernel/IfLayer.h
+++ b/runtime/onert/core/src/backend/builtin/kernel/IfLayer.h
@@ -18,7 +18,7 @@
#define __ONERT_BACKEND_BUILTIN_KERNEL_IF_LAYER_H__
#include <backend/IPortableTensor.h>
-#include <exec/IExecutor.h>
+#include <exec/Executors.h>
#include "../ExternalContext.h"
namespace onert
@@ -37,8 +37,7 @@ public:
const std::vector<backend::IPortableTensor *> input_tensors,
const std::vector<backend::IPortableTensor *> output_tensors,
const ir::SubgraphIndex &then_subg_index, const ir::SubgraphIndex &else_subg_index,
- exec::ExecutorMap *executor_map,
- const std::shared_ptr<ExternalContext> &external_context);
+ exec::Executors *executors, const std::shared_ptr<ExternalContext> &external_context);
public:
void run() override;
@@ -49,7 +48,7 @@ private:
const std::vector<backend::IPortableTensor *> _output_tensors;
const ir::SubgraphIndex _then_subg_index;
const ir::SubgraphIndex _else_subg_index;
- exec::ExecutorMap *_executor_map;
+ exec::Executors *_executors;
const std::shared_ptr<ExternalContext> _external_context;
};
diff --git a/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.cc b/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.cc
index 20cd87ad1..ddaecdf57 100644
--- a/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.cc
+++ b/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.cc
@@ -16,9 +16,9 @@
#include "PermuteLayer.h"
-#include "exec/ShapeConverter.h"
+#include "../../../exec/ShapeConverter.h"
-#include "ruy/context.h" // from @ruy
+#include <ruy/context.h> // from @ruy
namespace onert
{
diff --git a/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.h b/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.h
index ac5470e85..227e32434 100644
--- a/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.h
+++ b/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.h
@@ -17,10 +17,10 @@
#ifndef __ONERT_BACKEND_BUILTIN_KERNEL_PERMUTELAYER_H__
#define __ONERT_BACKEND_BUILTIN_KERNEL_PERMUTELAYER_H__
-#include "exec/IPermuteFunction.h"
-#include "exec/IExecutor.h"
#include "../ExternalContext.h"
-#include "ruy/thread_pool.h" // from @ruy
+#include "../../../exec/IPermuteFunction.h"
+
+#include <ruy/thread_pool.h> // from @ruy
namespace onert
{
diff --git a/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc b/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc
index 81b4a6378..8e006c5ea 100644
--- a/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc
+++ b/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc
@@ -16,11 +16,12 @@
#include "WhileLayer.h"
-#include <algorithm>
-#include <backend/ITensor.h>
-#include "exec/ExecutorBase.h"
-#include <misc/polymorphic_downcast.h>
#include "PermuteLayer.h"
+#include "../../../exec/ExecutorBase.h"
+
+#include <misc/polymorphic_downcast.h>
+
+#include <algorithm>
namespace onert
{
@@ -34,14 +35,14 @@ namespace kernel
WhileLayer::WhileLayer(const std::vector<backend::IPortableTensor *> input_tensors,
const std::vector<backend::IPortableTensor *> output_tensors,
const ir::SubgraphIndex &cond_subg_index,
- const ir::SubgraphIndex &body_subg_index, exec::ExecutorMap *executor_map,
+ const ir::SubgraphIndex &body_subg_index, exec::Executors *executors,
basic::DynamicMemoryManager *dyn_memory_manager,
const std::shared_ptr<ExternalContext> &external_context)
: _cond_subg_index{cond_subg_index}, _body_subg_index{body_subg_index},
- _input_tensors{input_tensors}, _output_tensors{output_tensors}, _executor_map{executor_map},
+ _input_tensors{input_tensors}, _output_tensors{output_tensors}, _executors{executors},
_dyn_memory_manager{dyn_memory_manager}, _external_context{external_context}
{
- // At this point, executor_map may not have executors of cond subg and body subg
+ // At this point, executors may not have executors of cond subg and body subg
}
void WhileLayer::run()
@@ -56,8 +57,8 @@ void WhileLayer::run()
// // Run cond subg
// If there is no loop copy "_input_tensors" -> "_dst_tensors", else copy "cond subg inputs" ->
// "_dst_tensors"
- auto cond_exec = _executor_map->at(_cond_subg_index).get();
- auto body_exec = _executor_map->at(_body_subg_index).get();
+ auto cond_exec = _executors->at(_cond_subg_index).get();
+ auto body_exec = _executors->at(_body_subg_index).get();
// Need a temp tensor to hold the cond subgraph output
assert(cond_exec->getOutputTensors().size() == 1);
diff --git a/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.h b/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.h
index 912102781..8551b3d09 100644
--- a/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.h
+++ b/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.h
@@ -18,7 +18,7 @@
#define __ONERT_BACKEND_BUILTIN_KERNEL_WHILE_LAYER_H__
#include <backend/IPortableTensor.h>
-#include <exec/IExecutor.h>
+#include <exec/Executors.h>
#include <exec/IFunction.h>
#include <ir/OperandIndexSequence.h>
#include <ir/Graph.h>
@@ -41,7 +41,7 @@ public:
WhileLayer(const std::vector<backend::IPortableTensor *> input_tensors,
const std::vector<backend::IPortableTensor *> output_tensors,
const ir::SubgraphIndex &cond_subg_index, const ir::SubgraphIndex &body_subg_index,
- exec::ExecutorMap *executor_map, basic::DynamicMemoryManager *dyn_memory_manager,
+ exec::Executors *executors, basic::DynamicMemoryManager *dyn_memory_manager,
const std::shared_ptr<ExternalContext> &external_context);
public:
@@ -52,7 +52,7 @@ private:
const ir::SubgraphIndex _body_subg_index;
const std::vector<backend::IPortableTensor *> _input_tensors;
const std::vector<backend::IPortableTensor *> _output_tensors;
- exec::ExecutorMap *_executor_map;
+ exec::Executors *_executors;
basic::DynamicMemoryManager *_dyn_memory_manager; // For generating temp tensors
const std::shared_ptr<ExternalContext> _external_context;
};
diff --git a/runtime/onert/core/src/compiler/BackendManager.cc b/runtime/onert/core/src/compiler/BackendManager.cc
index 0d6051b21..44442c065 100644
--- a/runtime/onert/core/src/compiler/BackendManager.cc
+++ b/runtime/onert/core/src/compiler/BackendManager.cc
@@ -16,16 +16,11 @@
#include "compiler/BackendManager.h"
-#include <memory>
-#include <dlfcn.h>
+#include "../backend/builtin/Backend.h"
+#include "../backend/builtin/Config.h"
-#include "backend/Backend.h"
-#include "backend/builtin/Backend.h"
-#include "backend/builtin/Config.h"
-#include "backend/IConfig.h"
-#include "util/logging.h"
-#include "util/ConfigSource.h"
-#include "misc/string_helpers.h"
+#include <dlfcn.h>
+#include <memory>
static const char *SHARED_LIB_EXT =
#if defined(__APPLE__) && defined(__MACH__)
@@ -152,7 +147,7 @@ const backend::Backend *BackendManager::get(const std::string &key) const
return nullptr;
}
-const backend::builtin::Backend *BackendManager::getBuiltin() const { return _builtin; }
+const backend::Backend *BackendManager::getBuiltin() const { return _builtin; }
} // namespace compiler
} // namespace onert
diff --git a/runtime/onert/core/src/compiler/Compiler.cc b/runtime/onert/core/src/compiler/Compiler.cc
index 6a1d8fcec..7be9c1e3b 100644
--- a/runtime/onert/core/src/compiler/Compiler.cc
+++ b/runtime/onert/core/src/compiler/Compiler.cc
@@ -18,29 +18,27 @@
#include "ExecutorFactory.h"
#include "ShapeValidator.h"
+#include "pass/ConstantOutputPass.h"
+#include "pass/OddOutputPass.h"
+#include "pass/PassRunner.h"
+#include "pass/UnusedOperandEliminationPass.h"
+#include "../backend/builtin/Config.h"
+#include "../dumper/dot/DotDumper.h"
+#include "../interp/InterpExecutor.h"
+#include "../ir/OperationCloner.h"
+#include "../ir/OperationDumper.h"
+#include "../ir/verifier/Verifier.h"
-#include <backend/builtin/Config.h>
-#include "compiler/BackendManager.h"
-#include "compiler/IScheduler.h"
-#include "compiler/ManualScheduler.h"
-#include "compiler/HEScheduler.h"
#include "compiler/StaticShapeInferer.h"
-#include "compiler/OperationLowerInfo.h"
-#include "compiler/pass/ConstantOutputPass.h"
-#include "compiler/pass/OddOutputPass.h"
-#include "compiler/pass/PassRunner.h"
-#include "compiler/pass/UnusedOperandEliminationPass.h"
-#include "exec/ExecTime.h"
-#include "ir/verifier/Verifier.h"
-#include "dumper/dot/DotDumper.h"
-#include "compiler/Linear.h"
-#include "interp/InterpExecutor.h"
#include "util/ConfigSource.h"
#include "util/logging.h"
-#include "ir/OperationDumper.h"
-#include "ir/OperationCloner.h"
-#include "misc/string_helpers.h"
-#include "json/json.h"
+
+#include <misc/polymorphic_downcast.h>
+#include <misc/string_helpers.h>
+#include <json/json.h>
+
+// TODO Remove using fstream header
+#include <fstream>
namespace
{
@@ -86,8 +84,104 @@ void verboseOptions(compiler::CompilerOptions &options)
<< std::noboolalpha;
}
-void setBackendMap(compiler::ManualSchedulerOptions &ms_options, const ir::Subgraphs &subgs,
- const std::string &str)
+std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::StaticShapeInferer>>
+createStaticShapeInferers(
+ const std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::LoweredGraph>>
+ &lowered_subgs)
+{
+ // Allocate StaticShapeInferer per each subgraph
+ std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::StaticShapeInferer>> inferers;
+ for (auto &pair : lowered_subgs)
+ {
+ const auto &subg_index = pair.first;
+ auto &lowered_subg = pair.second;
+ inferers[subg_index] = std::make_unique<compiler::StaticShapeInferer>(lowered_subg.get());
+ }
+
+ // Append observers in all StaticShapeInferers
+ for (auto &pair : lowered_subgs)
+ {
+ const auto &subg_index = pair.first;
+ auto &lowered_subg = pair.second;
+
+ // TODO: Change this iteration for all to controlflow iteration
+ lowered_subg->graph().operations().iterate([&](const ir::OperationIndex &,
+ const ir::Operation &op) {
+ // A Function to append child inferers. These make it possible for a StaticShapeInferer to
+ // call StaticShapeInferes of child subgraphs recursively
+ auto appendChildInferer = [&](const ir::SubgraphIndex &child_subg_idx) {
+ auto *child_inferer = inferers.at(child_subg_idx).get();
+ inferers.at(subg_index)->appendChildInferer(child_subg_idx, child_inferer);
+ };
+
+ // A Function to appaend subg input observers. This makes it possible for a StaticShapeInferer
+ // to update inputs of child subgraphs
+ auto appendSubgraphInputObserver = [&](const ir::SubgraphIndex &child_subg_idx) {
+ std::vector<ir::Operand *> child_subg_inputs;
+ auto &child_subg = lowered_subgs.at(child_subg_idx)->graph();
+ for (const auto &input_idx : child_subg.getInputs())
+ {
+ auto operand_ptr = child_subg.operands().getRawPtr(input_idx);
+ child_subg_inputs.emplace_back(operand_ptr);
+ }
+ inferers.at(subg_index)
+ ->appendSubgInputObserver(child_subg_idx,
+ std::make_unique<compiler::OperandObserver>(child_subg_inputs));
+ };
+
+ // A Function to set controlflow output observers. This makes it possible for a
+ // StaticShapeInferer to update outputs of parent controlflow opeerations
+ auto setControlFlowOutputObserver = [&](const ir::SubgraphIndex &child_subg_idx) {
+ std::vector<ir::Operand *> cf_outputs;
+ auto &subg = lowered_subg->graph();
+ for (const auto &output_idx : op.getOutputs())
+ {
+ auto operand_ptr = subg.operands().getRawPtr(output_idx);
+ cf_outputs.emplace_back(operand_ptr);
+ }
+ inferers.at(child_subg_idx)
+ ->setControlflowOutputObserver(std::make_unique<compiler::OperandObserver>(cf_outputs));
+ };
+
+ // Append Observers in a StaticShapeInferer
+ if (op.opcode() == ir::OpCode::If)
+ {
+ const auto &if_op = nnfw::misc::polymorphic_downcast<const ir::operation::If &>(op);
+
+ appendChildInferer(if_op.param().then_subg_index);
+ appendChildInferer(if_op.param().else_subg_index);
+
+ appendSubgraphInputObserver(if_op.param().then_subg_index);
+ appendSubgraphInputObserver(if_op.param().else_subg_index);
+
+ setControlFlowOutputObserver(if_op.param().then_subg_index);
+ }
+ else if (op.opcode() == ir::OpCode::While)
+ {
+ const auto &while_op = nnfw::misc::polymorphic_downcast<const ir::operation::While &>(op);
+
+ appendChildInferer(while_op.param().cond_subg_index);
+ appendChildInferer(while_op.param().body_subg_index);
+
+ appendSubgraphInputObserver(while_op.param().cond_subg_index);
+ appendSubgraphInputObserver(while_op.param().body_subg_index);
+
+ setControlFlowOutputObserver(while_op.param().body_subg_index);
+ }
+ });
+ }
+
+ return inferers;
+}
+
+} // namespace
+
+namespace onert
+{
+
+namespace compiler
+{
+void ManualSchedulerOptions::setBackendMap(const std::string &str)
{
// TODO Support multiple subgraphs for manual scheduling
auto key_val_list = nnfw::misc::split(str, ';');
@@ -102,37 +196,24 @@ void setBackendMap(compiler::ManualSchedulerOptions &ms_options, const ir::Subgr
const auto &key_str = key_val.at(0);
const auto &val = key_val.at(1);
auto key = static_cast<uint32_t>(std::stoi(key_str));
-
- subgs.at(ir::SubgraphIndex{0})
- ->operations()
- .at(ir::OperationIndex{key}); // Check if exist, or this wil throw
- ms_options.index_to_backend.emplace(ir::OperationIndex{key}, val);
+ this->index_to_backend.emplace(ir::OperationIndex{key}, val);
}
}
-} // namespace
-
-namespace onert
-{
-
-namespace compiler
+std::unique_ptr<CompilerOptions> CompilerOptions::fromGlobalConfig()
{
-
-CompilerOptions fetchCompilerOptionsFromGlobalConfig(const ir::Subgraphs &subgs)
-{
- CompilerOptions options;
- options.backend_list = nnfw::misc::split(util::getConfigString(util::config::BACKENDS), ';');
- options.trace_filepath = util::getConfigString(util::config::TRACE_FILEPATH);
- options.graph_dump_level = util::getConfigInt(util::config::GRAPH_DOT_DUMP);
- options.executor = util::getConfigString(util::config::EXECUTOR);
- options.he_scheduler = util::getConfigBool(util::config::USE_SCHEDULER);
- options.he_profiling_mode = util::getConfigBool(util::config::PROFILING_MODE);
- options.disable_compile = util::getConfigBool(util::config::DISABLE_COMPILE);
- options.fp16_enable = util::getConfigBool(util::config::FP16_ENABLE);
-
+ auto o = std::make_unique<CompilerOptions>();
+ o->backend_list = nnfw::misc::split(util::getConfigString(util::config::BACKENDS), ';');
+ o->trace_filepath = util::getConfigString(util::config::TRACE_FILEPATH);
+ o->graph_dump_level = util::getConfigInt(util::config::GRAPH_DOT_DUMP);
+ o->executor = util::getConfigString(util::config::EXECUTOR);
+ o->he_scheduler = util::getConfigBool(util::config::USE_SCHEDULER);
+ o->he_profiling_mode = util::getConfigBool(util::config::PROFILING_MODE);
+ o->disable_compile = util::getConfigBool(util::config::DISABLE_COMPILE);
+ o->fp16_enable = util::getConfigBool(util::config::FP16_ENABLE);
{
// Backend for all
- auto &ms_options = options.manual_scheduler_options;
+ auto &ms_options = o->manual_scheduler_options;
// Default value for op_backend_all is first element in the backend list
ms_options.backend_for_all = util::getConfigString(util::config::OP_BACKEND_ALLOPS);
@@ -151,54 +232,67 @@ CompilerOptions fetchCompilerOptionsFromGlobalConfig(const ir::Subgraphs &subgs)
// Index to Backend
auto map_str = util::getConfigString(util::config::OP_BACKEND_MAP);
- setBackendMap(ms_options, subgs, map_str);
+ ms_options.setBackendMap(map_str);
}
- return options;
+ return o;
}
-Compiler::Compiler(const std::shared_ptr<ir::Subgraphs> &subgs, util::TracingCtx *tracing_ctx)
- : _subgraphs{subgs}, _state{State::CREATED}
+Compiler::Compiler(const std::shared_ptr<ir::Model> &model, CompilerOptions &copt)
+ : _nnpkg{std::make_shared<ir::NNPkg>(model)}, _state{State::CREATED}, _voptions{&copt}
{
- // Set default values for CompilerOptions
- // All these default values should not be fetched from Env, when we stop supporting Android NN
- // API.
- _options = fetchCompilerOptionsFromGlobalConfig(*subgs);
-
- _options.tracing_ctx = tracing_ctx;
+ // DO NOTHING
}
-void Compiler::enableToFp16() { _options.fp16_enable = true; }
+Compiler::Compiler(const std::shared_ptr<ir::NNPkg> &nnpkg,
+ std::vector<std::unique_ptr<CompilerOptions>> &copts)
+ : _nnpkg{nnpkg}, _state{State::CREATED}, _voptions{}
+{
+ for (uint32_t i = 0; i < copts.size(); i++)
+ {
+ _voptions.push_back(copts[i].get());
+ }
+}
-void Compiler::set_backend_from_str(const char *backend_settings)
+void Compiler::enableToFp16()
{
- assert(_subgraphs != nullptr);
- // Backend for all
- auto &ms_options = _options.manual_scheduler_options;
- setBackendMap(ms_options, *_subgraphs, std::string{backend_settings});
+ for (auto options : _voptions)
+ options->fp16_enable = true;
}
void Compiler::checkProfilerConditions()
{
- if (!_options.he_scheduler)
+ if (_nnpkg->model_count() != 1)
+ throw std::runtime_error("NYI: Profiling mode for multiple model is not supported yet");
+
+ auto &options = *_voptions[0];
+
+ if (options.he_scheduler)
throw std::runtime_error("Heterogeneous scheduler must be enabled during profiling.");
- if (_options.executor != "Dataflow")
+ if (options.executor != "Dataflow")
throw std::runtime_error("Profiling mode works only with 'Dataflow' executor");
}
bool Compiler::buildPartialGraph(uint32_t num_graphs)
{
- if (_subgraphs->count() > 1)
+ // Use 1st model and options only on partial graph (pipeline) compile
+ assert(_nnpkg->model_count() == 1);
+ assert(_voptions.size() == 1);
+
+ auto model = _nnpkg->primary_model();
+ auto &options = *_voptions[0];
+
+ if (model->subgraphs_count() > 1)
return false;
- auto partialgraphs = std::make_shared<ir::Subgraphs>();
+ auto partialgraphs = std::make_shared<ir::Model>();
for (uint32_t idx = 0; idx < num_graphs; idx++)
{
auto partialgraph = std::make_unique<ir::Graph>();
partialgraphs->push(ir::SubgraphIndex{idx}, std::move(partialgraph));
}
- _subgraphs->primary()->setPartialgraphs(partialgraphs);
+ model->primary_subgraph()->setPartialModel(partialgraphs);
auto partial_graph = primary_subgraph()->partialgraphs();
@@ -208,8 +302,8 @@ bool Compiler::buildPartialGraph(uint32_t num_graphs)
for (auto use_operation : use_operations)
{
- auto graph_index = _options.partial_graph_options.index_to_graph.find(use_operation);
- if (graph_index == _options.partial_graph_options.index_to_graph.end())
+ auto graph_index = options.partial_graph_options.index_to_graph.find(use_operation);
+ if (graph_index == options.partial_graph_options.index_to_graph.end())
{
throw std::runtime_error("Invalid Partition Map");
}
@@ -230,8 +324,8 @@ bool Compiler::buildPartialGraph(uint32_t num_graphs)
primary_subgraph()->operations().iterate(
[&](const ir::OperationIndex &operation_index, const ir::Operation &operation) {
- auto graph_index = _options.partial_graph_options.index_to_graph.find(operation_index);
- if (graph_index == _options.partial_graph_options.index_to_graph.end())
+ auto graph_index = options.partial_graph_options.index_to_graph.find(operation_index);
+ if (graph_index == options.partial_graph_options.index_to_graph.end())
{
throw std::runtime_error("Invalid Partition Map");
}
@@ -259,7 +353,7 @@ bool Compiler::buildPartialGraph(uint32_t num_graphs)
assert(new_operation_index == operation_index);
});
- for (uint32_t idx = 0; idx < partial_graph->count(); idx++)
+ for (uint32_t idx = 0; idx < partial_graph->subgraphs_count(); idx++)
{
auto partition = partial_graph->at(ir::SubgraphIndex{idx});
@@ -282,10 +376,10 @@ bool Compiler::buildPartialGraph(uint32_t num_graphs)
auto use_operations = primary_subgraph()->operands().at(operand_index).getUses();
auto iter = use_operations.begin();
ir::SubgraphIndex graph_index =
- _options.partial_graph_options.index_to_graph.find(*iter++)->second;
+ options.partial_graph_options.index_to_graph.find(*iter++)->second;
while (iter != use_operations.end())
{
- if (graph_index != _options.partial_graph_options.index_to_graph.find(*iter)->second &&
+ if (graph_index != options.partial_graph_options.index_to_graph.find(*iter)->second &&
!partition->getOutputs().contains(operand_index))
{
partition->addOutput(operand_index,
@@ -344,96 +438,157 @@ bool Compiler::buildPartialGraph(uint32_t num_graphs)
return true;
}
-std::shared_ptr<exec::ExecutorMap> Compiler::compile(void)
+std::shared_ptr<CompilerArtifact> Compiler::compile(void)
{
- // Set control flow backend for control flow operators
+ for (auto options : _voptions)
{
+ // Set control flow backend for control flow operators
auto &builtin_id = backend::builtin::Config::ID;
- _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::If] = builtin_id;
- _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::While] = builtin_id;
- _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::Permute] = builtin_id;
- }
+ options->manual_scheduler_options.opcode_to_backend[ir::OpCode::If] = builtin_id;
+ options->manual_scheduler_options.opcode_to_backend[ir::OpCode::While] = builtin_id;
+ options->manual_scheduler_options.opcode_to_backend[ir::OpCode::Permute] = builtin_id;
- // FIXME This is a workaround for bcq operations, should remove it
- {
- _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::BCQFullyConnected] = "bcq";
- _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::BCQGather] = "bcq";
+ // FIXME This is a workaround for bcq operations, should remove it
+ options->manual_scheduler_options.opcode_to_backend[ir::OpCode::BCQFullyConnected] = "bcq";
+ options->manual_scheduler_options.opcode_to_backend[ir::OpCode::BCQGather] = "bcq";
+
+ // FIXME This is a workaround for bulk operations, should remove it
+ options->manual_scheduler_options.opcode_to_backend[ir::OpCode::Bulk] = "trix";
+
+ verboseOptions(*options);
}
- verboseOptions(_options);
+ // NYI: allow one model compilation
+ auto const model_count = _nnpkg->model_count();
+ if (model_count != _voptions.size())
+ throw std::runtime_error{"Model count and option vector size mismatch"};
- _subgraphs->iterate([&](const ir::SubgraphIndex &, ir::Graph &subg) {
- // Mandatory passes
- pass::PassRunner{}
- .append(std::make_unique<pass::ConstantOutputPass>(subg))
- .append(std::make_unique<pass::OddOutputPass>(subg))
- .run();
+ for (uint32_t i = 0; i < model_count; i++)
+ {
+ _nnpkg->model(ir::ModelIndex{i})->iterate([&](const ir::SubgraphIndex &, ir::Graph &subg) {
+ // Mandatory passes
+ pass::PassRunner{}
+ .append(std::make_unique<pass::ConstantOutputPass>(subg))
+ .append(std::make_unique<pass::OddOutputPass>(subg))
+ .run();
- // Optimizations
- pass::PassRunner{}.append(std::make_unique<pass::UnusedOperandEliminationPass>(subg)).run();
- });
+ // Optimizations
+ pass::PassRunner{}.append(std::make_unique<pass::UnusedOperandEliminationPass>(subg)).run();
+ });
+ }
/***************************************************
* Prepare compilation phase
***************************************************/
- auto executors = std::make_shared<exec::ExecutorMap>();
-
// Compilable check
// TODO: Support hybrid execution -
// execution between interpreter and compiled executor (including control flow)
- if (_options.disable_compile)
+ if (_voptions[0]->disable_compile)
{
- _subgraphs->iterate([&](const ir::SubgraphIndex &index, ir::Graph &subg) {
+ if (model_count > 1)
+ throw std::runtime_error{"NYI: Disable compilation for multi model is not supported yet"};
+
+ auto executors = std::make_shared<exec::Executors>();
+
+ _nnpkg->primary_model()->iterate([&](const ir::SubgraphIndex &index, ir::Graph &subg) {
executors->emplace(index, std::make_unique<interp::InterpExecutor>(subg));
});
_state = State::COMPILED;
- return executors;
+ return std::make_shared<CompilerArtifact>(executors, nullptr);
}
// Mode check
- if (_options.he_profiling_mode)
+ // TODO handle option for each model
+ if (_voptions[0]->he_profiling_mode)
checkProfilerConditions();
/***************************************************
* Backend independent analysis & optimization phase
***************************************************/
- auto dump_level = static_cast<dumper::dot::DotDumper::Level>(_options.graph_dump_level);
+ // TODO Handle dump level for each model
+ auto dump_level = static_cast<dumper::dot::DotDumper::Level>(_voptions[0]->graph_dump_level);
+ onert::dumper::dot::DotDumper dot_dumper(dump_level);
+
+ // Tracing context
+ auto tracing_ctx = std::make_unique<util::TracingCtx>();
+
+ // Model edge context
+ std::unique_ptr<ir::ModelEdges> model_edges = nullptr;
// Lower: Assign backend
std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::LoweredGraph>> lowered_subgs;
- _subgraphs->iterate([&](const ir::SubgraphIndex &index, ir::Graph &subg) {
- onert::dumper::dot::DotDumper dot_dumper(subg, dump_level);
- dot_dumper.dump(nnfw::misc::str("before_lower_subg-", index.value()));
- // Lower: Assign backend
- lowered_subgs[index] = std::make_unique<compiler::LoweredGraph>(subg, _options);
+ if (model_count == 1)
+ {
+ _nnpkg->primary_model()->iterate([&](const ir::SubgraphIndex &index, ir::Graph &subg) {
+ dot_dumper.dump(subg, nnfw::misc::str("before_lower_subg-", index.value()));
+ // Lower: Assign backend
+ lowered_subgs[index] = std::make_unique<compiler::LoweredGraph>(subg, *_voptions[0]);
+ // Set tracing_ctx for copied graph
+ tracing_ctx->setSubgraphIndex(&(lowered_subgs[index]->graph()), index.value());
+ });
+ }
+ else
+ {
+ // TODO Support tracing_ctx for multiple model
+ tracing_ctx = nullptr;
+
+ // Copy model edge context
+ model_edges = std::make_unique<ir::ModelEdges>(_nnpkg->model_edges());
- subg.setSubgraphs(nullptr);
- });
+ for (uint32_t i = 0; i < model_count; i++)
+ {
+ auto model = _nnpkg->model(ir::ModelIndex{i});
+ if (model->subgraphs_count() != 1)
+ throw std::runtime_error{"NYI: Lowering subgraphs for multiple model is not supported yet"};
+ auto subg = model->primary_subgraph();
+ dot_dumper.dump(*subg, nnfw::misc::str("before_lower_model-", i));
+
+ // For multimodel, model index is used for lowered graph index in lowered graph map
+ // and index type is SubgraphIndex
+ // TODO Find better way to represent lowered graph index for multimodel's subgraph
+ lowered_subgs[ir::SubgraphIndex{i}] =
+ std::make_unique<compiler::LoweredGraph>(*model->primary_subgraph(), *_voptions[i]);
+ }
+ }
- _subgraphs.reset();
+ _nnpkg.reset();
for (auto &pair : lowered_subgs)
{
const auto &subg_index = pair.first;
auto &lowered_subg = pair.second;
- onert::dumper::dot::DotDumper dot_dumper_lowered(lowered_subg.get(), dump_level);
- dot_dumper_lowered.dump("after_lower_subg-" + std::to_string(subg_index.value()));
+ dot_dumper.dump(*lowered_subg, "after_lower_subg-" + std::to_string(subg_index.value()));
}
// Shape inference.
{
- const auto primary_subg_idx = ir::SubgraphIndex{0};
- StaticShapeInferer inferer(primary_subg_idx, lowered_subgs);
- auto &lowered_subg = lowered_subgs.at(primary_subg_idx);
- auto ordered_ops = lowered_subg->graph().topolSortOperations();
- for (auto op_ind : ordered_ops)
+ // Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called
+ // recursively
+ std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers =
+ createStaticShapeInferers(lowered_subgs);
+
+ if (model_count == 1)
{
- const auto &op = lowered_subg->graph().operations().at(op_ind);
- bool has_dynamic_tensor = inferer.infer(op);
- lowered_subg->setHasDynamicTensor(op_ind, has_dynamic_tensor);
+ const auto primary_subg_idx = ir::SubgraphIndex{0};
+ inferers.at(primary_subg_idx)->infer();
+
+ for (const auto &pair : inferers)
+ {
+ const auto inferer = pair.second.get();
+ inferer->dump();
+ }
+ }
+ else
+ {
+ // Assume multi model has only one subgraph on each model
+ for (const auto &pair : inferers)
+ {
+ const auto inferer = pair.second.get();
+ inferer->infer();
+ inferer->dump();
+ }
}
- inferer.dump();
}
// Shape validation
@@ -452,8 +607,7 @@ std::shared_ptr<exec::ExecutorMap> Compiler::compile(void)
/*************************************************************
* Backend independent analysis & optimization phase finished
*************************************************************/
-
- executors = std::make_shared<exec::ExecutorMap>();
+ auto executors = std::make_shared<exec::Executors>(std::move(model_edges));
for (auto &pair : lowered_subgs)
{
const auto &subg_index = pair.first;
@@ -464,24 +618,31 @@ std::shared_ptr<exec::ExecutorMap> Compiler::compile(void)
std::to_string(subg_index.value()));
lowered_subg->graph().operations().iterate(
[&](const ir::OperationIndex &, const ir::Operation &op) { op.accept(dumper); });
- auto executor = std::unique_ptr<exec::IExecutor>{
- ExecutorFactory::get().create(std::move(lowered_subg), _options, executors)};
+
+ auto &options = (model_count > 1) ? *_voptions[subg_index.value()] : *_voptions[0];
+ auto executor = std::unique_ptr<exec::IExecutor>{ExecutorFactory::get().create(
+ std::move(lowered_subg), tracing_ctx.get(), options, executors)};
executor->setIndexedRanks(indexed_ranks);
- executors->insert(std::make_pair(subg_index, std::move(executor)));
+ executors->emplace(subg_index, std::move(executor));
}
/********************************
* Code generation phase finished
********************************/
_state = State::COMPILED;
- return executors;
+ return std::make_shared<CompilerArtifact>(executors, std::move(tracing_ctx));
}
-std::vector<std::shared_ptr<exec::ExecutorMap>> Compiler::compile(const char *package_file_path,
- const char *map_file_path)
+std::vector<std::shared_ptr<CompilerArtifact>> Compiler::compile(const char *package_file_path,
+ const char *map_file_path)
{
- std::vector<std::shared_ptr<exec::ExecutorMap>> executors;
- auto executor_map = std::make_shared<exec::ExecutorMap>();
+ // Allow one model compilation for pipeline
+ if (_nnpkg->model_count() != 1)
+ throw std::runtime_error{"Multiple models compilation for pipeline is not supported yet."};
+ assert(_voptions.size() == 1);
+
+ auto model = _nnpkg->primary_model();
+ auto &options = *_voptions[0];
std::string package_path(package_file_path);
std::string partition_map_file;
@@ -508,7 +669,7 @@ std::vector<std::shared_ptr<exec::ExecutorMap>> Compiler::compile(const char *pa
num_graphs = np.asUInt();
for (uint32_t i = 0; i < (uint32_t)map.size(); ++i)
{
- _options.partial_graph_options.index_to_graph[ir::OperationIndex{i}] =
+ options.partial_graph_options.index_to_graph[ir::OperationIndex{i}] =
ir::SubgraphIndex{map[i].asUInt()};
}
}
@@ -525,25 +686,25 @@ std::vector<std::shared_ptr<exec::ExecutorMap>> Compiler::compile(const char *pa
// Set control flow backend for control flow operators
{
auto &builtin_id = backend::builtin::Config::ID;
- _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::If] = builtin_id;
- _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::While] = builtin_id;
- _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::Permute] = builtin_id;
+ options.manual_scheduler_options.opcode_to_backend[ir::OpCode::If] = builtin_id;
+ options.manual_scheduler_options.opcode_to_backend[ir::OpCode::While] = builtin_id;
+ options.manual_scheduler_options.opcode_to_backend[ir::OpCode::Permute] = builtin_id;
}
// FIXME This is a workaround for bcq operations, should remove it
{
- _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::BCQFullyConnected] = "bcq";
- _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::BCQGather] = "bcq";
+ options.manual_scheduler_options.opcode_to_backend[ir::OpCode::BCQFullyConnected] = "bcq";
+ options.manual_scheduler_options.opcode_to_backend[ir::OpCode::BCQGather] = "bcq";
}
- // It doesn't support tracing in case of partial graph
+ // FIXME This is a workaround for bulk operations, should remove it
{
- _options.tracing_ctx = nullptr;
+ options.manual_scheduler_options.opcode_to_backend[ir::OpCode::Bulk] = "trix";
}
- verboseOptions(_options);
+ verboseOptions(options);
- _subgraphs->iterate([&](const ir::SubgraphIndex &, ir::Graph &subg) {
+ model->iterate([&](const ir::SubgraphIndex &, ir::Graph &subg) {
// Mandatory passes
auto part = subg.partialgraphs();
part->iterate([&](const ir::SubgraphIndex &, ir::Graph &partialgraph) {
@@ -566,38 +727,41 @@ std::vector<std::shared_ptr<exec::ExecutorMap>> Compiler::compile(const char *pa
// Compilable check
// TODO: Support hybrid execution -
// execution between interpreter and compiled executor (including control flow)
- if (_options.disable_compile)
+ if (options.disable_compile)
{
- _subgraphs->iterate([&](const ir::SubgraphIndex &index, ir::Graph &subg) {
- executor_map->emplace(index, std::make_unique<interp::InterpExecutor>(subg));
- executors.push_back(executor_map);
+ std::vector<std::shared_ptr<CompilerArtifact>> results;
+ auto executors = std::make_shared<exec::Executors>();
+
+ model->iterate([&](const ir::SubgraphIndex &index, ir::Graph &subg) {
+ executors->emplace(index, std::make_unique<interp::InterpExecutor>(subg));
});
+ results.push_back(std::make_shared<CompilerArtifact>(executors, nullptr));
_state = State::COMPILED;
- return executors;
+ return results;
}
// Mode check
- if (_options.he_profiling_mode)
+ if (options.he_profiling_mode)
checkProfilerConditions();
/***************************************************
* Backend independent analysis & optimization phase
***************************************************/
- auto dump_level = static_cast<dumper::dot::DotDumper::Level>(_options.graph_dump_level);
+ auto dump_level = static_cast<dumper::dot::DotDumper::Level>(options.graph_dump_level);
+ onert::dumper::dot::DotDumper dot_dumper_part(dump_level);
// Lower: Assign backend
std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::LoweredGraph>>
lowered_partialgraphs;
- _subgraphs->iterate([&](const ir::SubgraphIndex &, ir::Graph &subg) {
+ model->iterate([&](const ir::SubgraphIndex &, ir::Graph &subg) {
auto part = subg.partialgraphs();
part->iterate([&](const ir::SubgraphIndex &pindex, ir::Graph &partialgraph) {
- onert::dumper::dot::DotDumper dot_dumper_part(partialgraph, dump_level);
- dot_dumper_part.dump(nnfw::misc::str("before_lower_subg_partialgraph-", pindex.value()));
+ dot_dumper_part.dump(partialgraph,
+ nnfw::misc::str("before_lower_subg_partialgraph-", pindex.value()));
// // Lower: Assign backend
lowered_partialgraphs[pindex] =
- std::make_unique<compiler::LoweredGraph>(subg, partialgraph, _options);
- partialgraph.setSubgraphs(nullptr);
+ std::make_unique<compiler::LoweredGraph>(subg, partialgraph, options);
});
});
@@ -606,25 +770,20 @@ std::vector<std::shared_ptr<exec::ExecutorMap>> Compiler::compile(const char *pa
const auto &partialgraph_index = pair.first;
auto &lowered_partialgraph = pair.second;
- onert::dumper::dot::DotDumper dot_dumper_lowered_part(lowered_partialgraph.get(), dump_level);
- dot_dumper_lowered_part.dump("after_lower_subg_partialgraph-" +
- std::to_string(partialgraph_index.value()));
+ dot_dumper_part.dump(*lowered_partialgraph, "after_lower_subg_partialgraph-" +
+ std::to_string(partialgraph_index.value()));
}
// Partial Graph shape inference
+ std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers =
+ createStaticShapeInferers(lowered_partialgraphs);
+ // NOTE If partialgraph has subgraphs StaticShapeInferer may be called multiple times
for (auto &pair : lowered_partialgraphs)
{
const auto &partialgraph_index = pair.first;
- auto &lowered_partialgraph = pair.second;
- StaticShapeInferer partial_inferer(partialgraph_index, lowered_partialgraphs);
- auto ordered_ops = lowered_partialgraph->graph().topolSortOperations();
- for (auto op_ind : ordered_ops)
- {
- const auto &op = lowered_partialgraph->graph().operations().at(op_ind);
- bool has_dynamic_tensor = partial_inferer.infer(op);
- lowered_partialgraph->setHasDynamicTensor(op_ind, has_dynamic_tensor);
- }
- partial_inferer.dump();
+ const auto partial_inferer = inferers.at(partialgraph_index).get();
+ partial_inferer->infer();
+ partial_inferer->dump();
}
// Shape validation
@@ -652,9 +811,11 @@ std::vector<std::shared_ptr<exec::ExecutorMap>> Compiler::compile(const char *pa
ordered.insert(make_pair(pair.first.value(), std::move(lowered_partialgraph)));
}
+ std::vector<std::shared_ptr<CompilerArtifact>> results;
for (auto &pair : ordered)
{
- executor_map = std::make_shared<exec::ExecutorMap>();
+ auto executors = std::make_shared<exec::Executors>();
+
const auto &partialgraph_index = ir::SubgraphIndex(pair.first);
auto &lowered_partialgraph = pair.second;
auto indexed_ranks = lowered_partialgraph->indexed_ranks();
@@ -663,19 +824,21 @@ std::vector<std::shared_ptr<exec::ExecutorMap>> Compiler::compile(const char *pa
lowered_partialgraph->graph().operations().iterate(
[&](const ir::OperationIndex &, const ir::Operation &op) { op.accept(dumper); });
auto executor = std::unique_ptr<exec::IExecutor>{
- ExecutorFactory::get().create(std::move(lowered_partialgraph), _options, executor_map)};
+ ExecutorFactory::get().create(std::move(lowered_partialgraph), nullptr, options, executors)};
executor->setIndexedRanks(indexed_ranks);
- executor_map->insert(std::make_pair(ir::SubgraphIndex{0}, std::move(executor)));
- executors.push_back(executor_map);
+ executors->emplace(ir::SubgraphIndex{0}, std::move(executor));
+
+ // It doesn't support tracing in case of partial graph
+ results.push_back(std::make_shared<CompilerArtifact>(executors, nullptr));
}
- _subgraphs.reset();
+ _nnpkg.reset();
/********************************
* Code generation phase finished
********************************/
_state = State::COMPILED;
- return executors;
+ return results;
}
} // namespace compiler
diff --git a/runtime/onert/core/src/compiler/ExecutorFactory.cc b/runtime/onert/core/src/compiler/ExecutorFactory.cc
index f9db1ca89..024556e7e 100644
--- a/runtime/onert/core/src/compiler/ExecutorFactory.cc
+++ b/runtime/onert/core/src/compiler/ExecutorFactory.cc
@@ -16,23 +16,22 @@
#include "ExecutorFactory.h"
-#include "backend/builtin/Config.h"
-#include "backend/builtin/KernelGenerator.h"
-#include "backend/builtin/TensorBuilder.h"
-#include "backend/builtin/UserTensor.h"
-#include "backend/IPortableTensor.h"
-#include "compiler/BackendManager.h"
-#include "compiler/BackendManager.h"
-#include "compiler/ExecutionBuilder.h"
-#include "compiler/Linear.h"
-#include "dumper/text/GraphDumper.h"
-#include "exec/DataflowExecutor.h"
-#include "exec/ExecTime.h"
-#include "exec/ExecutionObservers.h"
-#include "exec/LinearExecutor.h"
-#include "exec/ParallelExecutor.h"
-#include "ir/OperationCloner.h"
-#include "util/TracingCtx.h"
+#include "Linear.h"
+#include "../backend/builtin/BackendContext.h"
+#include "../backend/builtin/Config.h"
+#include "../backend/builtin/UserTensor.h"
+#include "../dumper/text/GraphDumper.h"
+#include "../exec/DataflowExecutor.h"
+#include "../exec/ExecTime.h"
+#include "../exec/ExecutionObservers.h"
+#include "../exec/LinearExecutor.h"
+#include "../exec/ParallelExecutor.h"
+#include "../ir/OperationCloner.h"
+
+#include <backend/IPortableTensor.h>
+#include <compiler/BackendManager.h>
+#include <compiler/ExecutionBuilder.h>
+#include <util/TracingCtx.h>
#include <functional>
#include <memory>
@@ -242,16 +241,17 @@ ExecutorFactory::ExecutorFactory()
{
_map["Linear"] = createLinearExecutor;
_map["Dataflow"] = std::bind(createDataflowExecutor, std::placeholders::_1, std::placeholders::_2,
- std::placeholders::_3, false);
+ std::placeholders::_3, std::placeholders::_4, false);
_map["Parallel"] = std::bind(createDataflowExecutor, std::placeholders::_1, std::placeholders::_2,
- std::placeholders::_3, true);
+ std::placeholders::_3, std::placeholders::_4, true);
}
exec::IExecutor *ExecutorFactory::create(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
+ const util::TracingCtx *tracing_ctx,
const compiler::CompilerOptions &options,
- const std::shared_ptr<exec::ExecutorMap> &executor_map)
+ const std::shared_ptr<exec::Executors> &executors)
{
- return _map.at(options.executor)(std::move(lowered_graph), options, executor_map);
+ return _map.at(options.executor)(std::move(lowered_graph), tracing_ctx, options, executors);
}
void ExecutorFactory::prepareMigrantTensors(compiler::LoweredGraph &lowered_graph,
@@ -282,7 +282,7 @@ void ExecutorFactory::prepareMigrantTensors(compiler::LoweredGraph &lowered_grap
}
void ExecutorFactory::prepareBuiltinBackend(const TensorRegistries &tensor_regs,
- const std::shared_ptr<exec::ExecutorMap> &executor_map,
+ const std::shared_ptr<exec::Executors> &executors,
const backend::BackendContexts &backend_contexts)
{
for (auto &pair : backend_contexts)
@@ -292,7 +292,7 @@ void ExecutorFactory::prepareBuiltinBackend(const TensorRegistries &tensor_regs,
{
auto builtin_kernel_gen = builtin_context->kernel_gen;
builtin_kernel_gen->setTensorRegistries(tensor_regs);
- builtin_kernel_gen->setExecutorMap(executor_map);
+ builtin_kernel_gen->setExecutors(executors);
}
}
}
@@ -317,12 +317,11 @@ ExecutorFactory::orderBackendContext(const backend::BackendContexts &backend_con
return ordered_contexts;
}
-exec::IExecutor *
-ExecutorFactory::createLinearExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
- const compiler::CompilerOptions &options,
- const std::shared_ptr<exec::ExecutorMap> &executor_map)
+exec::IExecutor *ExecutorFactory::createLinearExecutor(
+ std::unique_ptr<compiler::LoweredGraph> lowered_graph, const util::TracingCtx *tracing_ctx,
+ const compiler::CompilerOptions &options, const std::shared_ptr<exec::Executors> &executors)
{
- auto graph = lowered_graph->graph();
+ auto &graph = lowered_graph->graph();
backend::BackendContexts backend_contexts =
createBackendContexts(*lowered_graph, options.executor == "Linear");
@@ -346,7 +345,7 @@ ExecutorFactory::createLinearExecutor(std::unique_ptr<compiler::LoweredGraph> lo
prepareMigrantTensors(*lowered_graph, backend_contexts);
// Give some runtime objects to builtin KernelGenerator
- prepareBuiltinBackend(tensor_regs, executor_map, backend_contexts);
+ prepareBuiltinBackend(tensor_regs, executors, backend_contexts);
ExecutionBuilder builder;
@@ -426,14 +425,17 @@ ExecutorFactory::createLinearExecutor(std::unique_ptr<compiler::LoweredGraph> lo
auto code_map = builder.releaseCodeMap();
- auto exec = new exec::LinearExecutor{
- std::move(lowered_graph), std::move(backend_contexts), tensor_regs, std::move(code_map), order,
- options.tracing_ctx};
+ auto exec = new exec::LinearExecutor{std::move(lowered_graph),
+ std::move(backend_contexts),
+ tensor_regs,
+ std::move(code_map),
+ order,
+ tracing_ctx};
if (!options.trace_filepath.empty())
{
- std::unique_ptr<exec::IExecutionObserver> ctp = std::make_unique<exec::TracingObserver>(
- options.trace_filepath, exec->graph(), options.tracing_ctx);
+ std::unique_ptr<exec::IExecutionObserver> ctp =
+ std::make_unique<exec::TracingObserver>(options.trace_filepath, exec->graph(), tracing_ctx);
exec->addObserver(std::move(ctp));
}
@@ -441,8 +443,9 @@ ExecutorFactory::createLinearExecutor(std::unique_ptr<compiler::LoweredGraph> lo
}
exec::IExecutor *ExecutorFactory::createDataflowExecutor(
- std::unique_ptr<compiler::LoweredGraph> lowered_graph, const compiler::CompilerOptions &options,
- const std::shared_ptr<exec::ExecutorMap> &executor_map, bool parallel)
+ std::unique_ptr<compiler::LoweredGraph> lowered_graph, const util::TracingCtx *tracing_ctx,
+ const compiler::CompilerOptions &options, const std::shared_ptr<exec::Executors> &executors,
+ bool parallel)
{
backend::BackendContexts backend_contexts =
createBackendContexts(*lowered_graph, options.executor == "Linear");
@@ -462,7 +465,7 @@ exec::IExecutor *ExecutorFactory::createDataflowExecutor(
prepareMigrantTensors(*lowered_graph, backend_contexts);
// Give some runtime objects to builtin KernelGenerator
- prepareBuiltinBackend(tensor_regs, executor_map, backend_contexts);
+ prepareBuiltinBackend(tensor_regs, executors, backend_contexts);
ExecutionBuilder builder;
@@ -491,13 +494,13 @@ exec::IExecutor *ExecutorFactory::createDataflowExecutor(
if (parallel)
{
exec = new exec::ParallelExecutor{std::move(lowered_graph), std::move(backend_contexts),
- tensor_regs, std::move(code_map), options.tracing_ctx};
+ tensor_regs, std::move(code_map), tracing_ctx};
}
else
{
auto dataflow_exec =
new exec::DataflowExecutor{std::move(lowered_graph), std::move(backend_contexts), tensor_regs,
- std::move(code_map), options.tracing_ctx};
+ std::move(code_map), tracing_ctx};
if (options.he_profiling_mode)
{
std::vector<const backend::Backend *> backends;
@@ -515,8 +518,8 @@ exec::IExecutor *ExecutorFactory::createDataflowExecutor(
if (!options.trace_filepath.empty())
{
- std::unique_ptr<exec::IExecutionObserver> ctp = std::make_unique<exec::TracingObserver>(
- options.trace_filepath, exec->graph(), options.tracing_ctx);
+ std::unique_ptr<exec::IExecutionObserver> ctp =
+ std::make_unique<exec::TracingObserver>(options.trace_filepath, exec->graph(), tracing_ctx);
exec->addObserver(std::move(ctp));
}
diff --git a/runtime/onert/core/src/compiler/ExecutorFactory.h b/runtime/onert/core/src/compiler/ExecutorFactory.h
index 2ee05fae3..70c089f8c 100644
--- a/runtime/onert/core/src/compiler/ExecutorFactory.h
+++ b/runtime/onert/core/src/compiler/ExecutorFactory.h
@@ -21,7 +21,7 @@
#include "backend/ITensor.h"
#include "compiler/LoweredGraph.h"
-#include "exec/IExecutor.h"
+#include "exec/Executors.h"
#include <deque>
#include <unordered_map>
@@ -38,8 +38,9 @@ public:
public:
exec::IExecutor *create(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
+ const util::TracingCtx *tracing_ctx,
const compiler::CompilerOptions &options,
- const std::shared_ptr<exec::ExecutorMap> &executor_map);
+ const std::shared_ptr<exec::Executors> &executors);
private:
ExecutorFactory();
@@ -48,25 +49,26 @@ private:
static void prepareMigrantTensors(compiler::LoweredGraph &lowered_graph,
const backend::BackendContexts &backend_contexts);
static void prepareBuiltinBackend(const TensorRegistries &tensor_regs,
- const std::shared_ptr<exec::ExecutorMap> &executor_map,
+ const std::shared_ptr<exec::Executors> &executors,
const backend::BackendContexts &backend_contexts);
static std::deque<std::pair<const backend::Backend *, backend::BackendContext *>>
orderBackendContext(const backend::BackendContexts &backend_contexts);
- static exec::IExecutor *
- createLinearExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
- const compiler::CompilerOptions &options,
- const std::shared_ptr<exec::ExecutorMap> &executor_map);
+ static exec::IExecutor *createLinearExecutor(
+ std::unique_ptr<compiler::LoweredGraph> lowered_graph, const util::TracingCtx *tracing_ctx,
+ const compiler::CompilerOptions &options, const std::shared_ptr<exec::Executors> &executors);
static exec::IExecutor *
createDataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
+ const util::TracingCtx *tracing_ctx,
const compiler::CompilerOptions &options,
- const std::shared_ptr<exec::ExecutorMap> &executor_map, bool parallel);
+ const std::shared_ptr<exec::Executors> &executors, bool parallel);
private:
- std::unordered_map<std::string, std::function<exec::IExecutor *(
- std::unique_ptr<compiler::LoweredGraph>,
- const compiler::CompilerOptions &options,
- const std::shared_ptr<exec::ExecutorMap> &executor_map)>>
+ std::unordered_map<
+ std::string,
+ std::function<exec::IExecutor *(
+ std::unique_ptr<compiler::LoweredGraph>, const util::TracingCtx *tracing_ctx,
+ const compiler::CompilerOptions &options, const std::shared_ptr<exec::Executors> &executors)>>
_map;
};
diff --git a/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc b/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc
index 5c1cef1ab..98dc906e4 100644
--- a/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc
+++ b/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc
@@ -180,7 +180,7 @@ void Fp32ToFp16Converter::appendOpSequences()
{
_lowered_graph.op_seqs().iterate(
[&](const ir::OpSequenceIndex &op_seq_ind, ir::OpSequence &op_seq) {
- const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
+ const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
assert(lower_info != nullptr);
// For now, the only acl_cl supports fully fp16 type
@@ -375,7 +375,7 @@ void Fp32ToFp16Converter::convertOperands()
{
_lowered_graph.op_seqs().iterate(
[&](const ir::OpSequenceIndex &op_seq_ind, ir::OpSequence &op_seq) {
- const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
+ const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
assert(lower_info != nullptr);
// For now, the only acl_cl supports fully fp16
if (lower_info->backend()->config()->id() != kAclClBackendConfigId)
@@ -515,7 +515,7 @@ ir::OperandIndex Fp32ToFp16Converter::newCopiedOperand(const ir::OperandIndex &o
void Fp32ToFp16Converter::setNewOperandLowerInfo(const ir::OpSequenceIndex &op_seq_ind,
const ir::OperandIndex &new_op_ind)
{
- const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
+ const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
assert(lower_info != nullptr);
auto new_lower_info = std::make_unique<compiler::OperandLowerInfo>();
auto permute_factor = compiler::PermuteFactor(lower_info->backend(), lower_info->layout());
@@ -527,7 +527,7 @@ void Fp32ToFp16Converter::setNewOperandLowerInfo(const ir::OpSequenceIndex &op_s
void Fp32ToFp16Converter::setNewOperationLowerInfo(const ir::OpSequenceIndex &op_seq_ind,
const ir::OpSequenceIndex &new_op_seq_ind)
{
- const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
+ const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
assert(lower_info != nullptr);
auto new_lower_info =
@@ -635,7 +635,7 @@ ir::OpSequenceIndex Fp32ToFp16Converter::newOpSequence(const ir::OpSequenceIndex
const ir::OperationIndex &node_index)
{
auto &node = _lowered_graph.graph().operations().at(node_index);
- const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
+ const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
assert(lower_info != nullptr);
auto layout = lower_info->layout();
diff --git a/runtime/onert/core/src/compiler/HEScheduler.cc b/runtime/onert/core/src/compiler/HEScheduler.cc
index 2f996c8e8..c4bfddb8f 100644
--- a/runtime/onert/core/src/compiler/HEScheduler.cc
+++ b/runtime/onert/core/src/compiler/HEScheduler.cc
@@ -14,17 +14,14 @@
* limitations under the License.
*/
-#include "ir/Operand.h"
-#include "compiler/HEScheduler.h"
-#include "ir/Graph.h"
-#include "util/ConfigSource.h"
+#include "HEScheduler.h"
+
#include "compiler/BackendResolver.h"
+#include "ir/Graph.h"
#include "util/logging.h"
-#include "util/Utils.h"
-#include "exec/FunctionSequence.h"
+
#include <cassert>
#include <cmath>
-#include <chrono>
namespace
{
diff --git a/runtime/onert/core/src/compiler/HEScheduler.h b/runtime/onert/core/src/compiler/HEScheduler.h
index 1a95b9881..18ea388fd 100644
--- a/runtime/onert/core/src/compiler/HEScheduler.h
+++ b/runtime/onert/core/src/compiler/HEScheduler.h
@@ -23,14 +23,16 @@
#ifndef __ONERT_COMPILER_H_E_SCHEDULER_H_
#define __ONERT_COMPILER_H_E_SCHEDULER_H_
-#include "compiler/IScheduler.h"
-#include "compiler/BackendManager.h"
-#include "compiler/Compiler.h"
-#include "ir/Graph.h"
-#include "exec/ExecTime.h"
-#include "backend/Backend.h"
-#include <memory>
-#include "ir/OperationIndexMap.h"
+#include "IScheduler.h"
+#include "../backend/builtin/Config.h"
+#include "../exec/ExecTime.h"
+
+#include <backend/Backend.h>
+#include <compiler/BackendManager.h>
+#include <compiler/Compiler.h>
+#include <ir/Graph.h>
+#include <ir/OperationIndexMap.h>
+
#include <map>
#include <memory>
diff --git a/runtime/onert/core/src/compiler/HEScheduler.test.cc b/runtime/onert/core/src/compiler/HEScheduler.test.cc
new file mode 100644
index 000000000..c4a2df025
--- /dev/null
+++ b/runtime/onert/core/src/compiler/HEScheduler.test.cc
@@ -0,0 +1,572 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "HEScheduler.h"
+#include "../exec/ExecTime.h"
+
+#include <ir/DataType.h>
+#include <ir/InternalType.h>
+#include <ir/Shape.h>
+#include <ir/TypeInfo.h>
+#include <ir/operation/BinaryArithmetic.h>
+#include <ir/operation/FullyConnected.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+using namespace onert;
+using namespace ir;
+using namespace backend;
+using namespace operation;
+using namespace exec;
+
+//
+// Mock backends classes
+//
+
+struct MockConfigCPU : public IConfig
+{
+ std::string id() override { return "cpu"; }
+ bool initialize() override { return true; };
+ bool supportPermutation() override { return false; }
+ Layout supportLayout(const Operation &, Layout) override { return Layout::UNKNOWN; }
+ bool supportDynamicTensor() override { return false; }
+ bool supportFP16() override { return false; }
+};
+
+class MockBackendContext : public BackendContext
+{
+public:
+ using BackendContext::BackendContext;
+ ITensorRegistry *genTensors() override { return nullptr; }
+ FunctionMap genKernels() override { return {}; }
+};
+
+struct MockBackendCPU : public Backend
+{
+ std::shared_ptr<IConfig> config() const override { return std::make_shared<MockConfigCPU>(); }
+ std::unique_ptr<BackendContext> newContext(ContextData &&data) const override
+ {
+ return std::make_unique<MockBackendContext>(this, std::move(data), nullptr);
+ }
+};
+
+struct MockConfigGPU : public IConfig
+{
+ std::string id() override { return "gpu"; }
+ bool initialize() override { return true; };
+ bool supportPermutation() override { return false; }
+ ir::Layout supportLayout(const ir::Operation &, ir::Layout) override
+ {
+ return ir::Layout::UNKNOWN;
+ }
+ bool supportDynamicTensor() override { return false; }
+ bool supportFP16() override { return false; }
+};
+
+struct MockBackendGPU : public Backend
+{
+ std::shared_ptr<IConfig> config() const override { return std::make_shared<MockConfigGPU>(); }
+ std::unique_ptr<BackendContext> newContext(ContextData &&data) const override
+ {
+ return std::make_unique<MockBackendContext>(this, std::move(data), nullptr);
+ }
+};
+
+struct MockConfigNPU : public IConfig
+{
+ std::string id() override { return "npu"; }
+ bool initialize() override { return true; };
+ bool supportPermutation() override { return false; }
+ ir::Layout supportLayout(const ir::Operation &, ir::Layout) override
+ {
+ return ir::Layout::UNKNOWN;
+ }
+ bool supportDynamicTensor() override { return false; }
+ bool supportFP16() override { return false; }
+};
+
+struct MockBackendNPU : public Backend
+{
+ std::shared_ptr<IConfig> config() const override { return std::make_shared<MockConfigNPU>(); }
+ std::unique_ptr<BackendContext> newContext(ContextData &&data) const override
+ {
+ return std::make_unique<MockBackendContext>(this, std::move(data), nullptr);
+ }
+};
+
+//
+// Constants
+//
+
+const int OPERAND_ELEMS = 268203;
+const int OPERAND_SIZE = OPERAND_ELEMS * 4;
+const int OPERATION_SIZE = OPERAND_SIZE * 3;
+
+const std::string LINEAR("Linear");
+const std::string DATAFLOW("Dataflow");
+const std::string PARALLEL("Parallel");
+
+//
+// Helper functions
+//
+
+// Set executor through environment variable
+void setExecutor(const std::string &executor) { setenv("EXECUTOR", executor.c_str(), true); }
+
+// Set profiling mode through environment variable
+void setProfilingMode(const bool value) { setenv("PROFILING_MODE", value ? "1" : "0", true); }
+
+// Calculate operation size by addition sizes of all input and output operands
+uint32_t calcOpSize(const std::shared_ptr<Graph> &graph, const OperationIndex &op_idx)
+{
+ uint32_t size = 0;
+ const auto &op = graph->operations().at(op_idx);
+ for (const auto &ind : op.getInputs() + op.getOutputs())
+ size += graph->operands().at(ind).info().total_size();
+ return size;
+}
+
+// Set execution operation time. This method is needed since ExecutionTime has only
+// 'updateOperationExecTime' method.
+void setOperationExecTime(ExecTime &et, const Backend *backend, const std::string &operation,
+ bool quant, uint32_t op_size, int64_t time)
+{
+ // You shouldn't set negative time with this method since nnfw JSON deserializer can't read it
+ assert(time > 0);
+ int64_t prev_time = et.getOperationExecTime(backend, operation, quant, op_size);
+ int64_t time_to_set = prev_time == ExecTime::NOT_FOUND ? time : 2 * time - prev_time;
+ et.updateOperationExecTime(backend, operation, quant, op_size, time_to_set);
+ assert(et.getOperationExecTime(backend, operation, quant, op_size) == time);
+}
+
+// Set same execution time for all given backends/operations
+void setOperationsExecutionTime(const std::vector<const Backend *> &backends,
+ const std::vector<std::string> &op_names,
+ const std::vector<uint32_t> &op_sizes, int64_t exec_time)
+{
+ assert(op_names.size() == op_sizes.size());
+ ExecTime et(backends);
+ for (int i = 0; i < op_names.size(); ++i)
+ {
+ for (auto &backend : backends)
+ setOperationExecTime(et, backend, op_names[i], false, op_sizes[i], exec_time);
+ }
+ et.storeOperationsExecTime();
+}
+
+// Set permute time from one backend to another. This method is needed since ExecutionTime has only
+// 'updatePermuteTime' method.
+void setPermutationTime(ExecTime &et, const Backend *from_backend, const Backend *to_backend,
+ bool quant, uint32_t op_size, int64_t time)
+{
+ // You shouldn't set negative time with this method since nnfw JSON deserializer can't read it
+ assert(time > 0);
+ int64_t prev_time = et.getPermuteTime(from_backend, to_backend, quant, op_size);
+ int64_t time_to_set = prev_time == ExecTime::NOT_FOUND ? time : 2 * time - prev_time;
+ et.updatePermuteTime(from_backend, to_backend, quant, op_size, time_to_set);
+ assert(et.getPermuteTime(from_backend, to_backend, quant, op_size) == time);
+}
+
+// Set same permutation time between all given backends
+void setPermutationsExecutionTime(const std::vector<const Backend *> &backends,
+ const int operand_size, const int64_t exec_time)
+{
+ ExecTime et(backends);
+ for (const auto &backend : backends)
+ {
+ for (auto &other_backend : backends)
+ {
+ if (backend == other_backend)
+ continue;
+ setPermutationTime(et, backend, other_backend, false, operand_size, exec_time);
+ }
+ }
+ et.storeOperationsExecTime();
+}
+
+//
+// Functions for creating graphs
+//
+
+using OIS = OperandIndexSequence;
+
+template <typename NodeT, typename... Types>
+OperationIndex create(std::shared_ptr<Graph> graph, Types &&... args)
+{
+ auto op = std::make_unique<NodeT>(std::forward<Types>(args)...);
+ auto op_idx = graph->addOperation(std::move(op));
+ // For now in scheduler test all operations in tested graphs has same size (for simplicity)
+ assert(calcOpSize(graph, op_idx) == OPERATION_SIZE);
+ return op_idx;
+}
+
+// Create straight graph: Add->Sub->Mul
+std::shared_ptr<Graph> createStraightGraph()
+{
+ auto graph = std::make_shared<Graph>();
+ const TypeInfo float_op(DataType::FLOAT32);
+
+ // Create add node
+ auto add_lhs_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto add_rhs_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto add_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ BinaryArithmetic::Param add_op_params{BinaryArithmetic::ArithmeticType::ADD, Activation::NONE};
+ create<BinaryArithmetic>(graph, OIS{add_lhs_idx, add_rhs_idx}, OIS{add_out_idx}, add_op_params);
+
+ // Create sub node
+ auto sub_const_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto sub_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ BinaryArithmetic::Param sub_op_params{BinaryArithmetic::ArithmeticType::SUB, Activation::NONE};
+ create<BinaryArithmetic>(graph, OIS{add_out_idx, sub_const_idx}, OIS{sub_out_idx}, sub_op_params);
+
+ // Create mul node
+ auto mul_const_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto mul_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ BinaryArithmetic::Param mul_op_params{BinaryArithmetic::ArithmeticType::MUL, Activation::NONE};
+ create<BinaryArithmetic>(graph, OIS{sub_out_idx, mul_const_idx}, OIS{mul_out_idx}, mul_op_params);
+
+ graph->verify();
+ return graph;
+}
+
+/* Create branched graph:
+ * [Add]
+ * // \\
+ * [Mul1] [FC2]
+ * || ||
+ * [Mul2] [FC2]
+ * \\ //
+ * [Sub]
+ */
+std::shared_ptr<Graph> createBranchedGraph()
+{
+ auto graph = std::make_shared<Graph>();
+ const TypeInfo float_op(DataType::FLOAT32);
+
+ // Create add node
+ auto add_lhs_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto add_rhs_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto add_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ BinaryArithmetic::Param add_op_params{BinaryArithmetic::ArithmeticType::ADD, Activation::NONE};
+ create<BinaryArithmetic>(graph, OIS{add_lhs_idx, add_rhs_idx}, OIS{add_out_idx}, add_op_params);
+
+ // Create mul1 node
+ auto mul1_const_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto mul1_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ BinaryArithmetic::Param mul1_op_params{BinaryArithmetic::ArithmeticType::MUL, Activation::NONE};
+ create<BinaryArithmetic>(graph, OIS{add_out_idx, mul1_const_idx}, OIS{mul1_out_idx},
+ mul1_op_params);
+
+ // Create mul2 node
+ auto mul2_const_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto mul2_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ BinaryArithmetic::Param mul2_op_params{BinaryArithmetic::ArithmeticType::MUL, Activation::NONE};
+ create<BinaryArithmetic>(graph, OIS{mul1_out_idx, mul2_const_idx}, OIS{mul2_out_idx},
+ mul2_op_params);
+
+ // Create fc1 node
+ auto fc1_const_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto fc1_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ FullyConnected::Param fc1_op_params{Activation::NONE};
+ create<FullyConnected>(graph, OIS{add_out_idx, fc1_const_idx}, OIS{fc1_out_idx}, fc1_op_params);
+
+ // Create fc2 node
+ auto fc2_const_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto fc2_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ FullyConnected::Param fc2_op_params{Activation::NONE};
+ create<FullyConnected>(graph, OIS{fc1_out_idx, fc2_const_idx}, OIS{fc2_out_idx}, fc2_op_params);
+
+ // Create sub node
+ auto sub_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ BinaryArithmetic::Param sub_op_params{BinaryArithmetic::ArithmeticType::SUB, Activation::NONE};
+ create<BinaryArithmetic>(graph, OIS{mul2_out_idx, fc2_out_idx}, OIS{sub_out_idx}, sub_op_params);
+
+ graph->verify();
+ return graph;
+}
+
+//
+// Tests setup/teardown
+//
+
+// SetUp/TearDown methods runs before/after each test and performs actions common for each test
+class HESchedulerTest : public ::testing::Test
+{
+protected:
+ void SetUp() override
+ {
+ // Initialize mock backends
+ _cpu_backend = new MockBackendCPU();
+ _gpu_backend = new MockBackendGPU();
+ _npu_backend = new MockBackendNPU();
+ _mock_backends = {_cpu_backend, _gpu_backend, _npu_backend};
+
+ // Remove previous profile data if it exists
+ if (!remove("exec_time.json"))
+ {
+ // DO NOTHING (no profile data)
+ }
+
+ // Remember original value of 'EXECUTOR' environment variable
+ char *executor = std::getenv("EXECUTOR");
+ _original_executor = executor == nullptr ? "" : executor;
+
+ // Remember original value of 'PROFILING_MODE' environment variable
+ char *profiling_mode = std::getenv("PROFILING_MODE");
+ _original_profiling_mode = profiling_mode == nullptr ? "" : profiling_mode;
+ }
+
+ void TearDown() override
+ {
+ delete _cpu_backend;
+ delete _gpu_backend;
+ delete _npu_backend;
+ EXPECT_EQ(remove("exec_time.json"), 0);
+ setenv("EXECUTOR", _original_executor.c_str(), true);
+ setenv("PROFILING_MODE", _original_profiling_mode.c_str(), true);
+ }
+
+ const MockBackendCPU *_cpu_backend{nullptr};
+ const MockBackendGPU *_gpu_backend{nullptr};
+ const MockBackendNPU *_npu_backend{nullptr};
+ std::vector<const Backend *> _mock_backends;
+
+ std::string _original_executor;
+ std::string _original_profiling_mode;
+};
+
+//
+// HEScheduler tests
+//
+
+class HESchedulerTestWithExecutorParam : public HESchedulerTest,
+ public testing::WithParamInterface<std::string>
+{
+};
+
+// SchedulerTestWithExecutorParam tests are parameterized with executor name and runs three times -
+// one time for each executor
+INSTANTIATE_TEST_SUITE_P(AllExecutors, HESchedulerTestWithExecutorParam,
+ testing::Values(LINEAR, DATAFLOW, PARALLEL));
+
+// Test scheduler behavior for straight graph with known execution time of all nodes and permutes.
+TEST_P(HESchedulerTestWithExecutorParam, straight_graph_known_exec_time)
+{
+ setExecutor(GetParam());
+
+ // Prepare graph
+ ir::Model model;
+ auto graph(createStraightGraph());
+ model.push(ir::SubgraphIndex{0}, graph);
+ OperationIndex add_op_idx(0), sub_op_idx(1), mul_op_idx(2);
+
+ // Set default execution and transfer time
+ setPermutationsExecutionTime(_mock_backends, OPERAND_SIZE, 1);
+ setOperationsExecutionTime(_mock_backends, {"Add", "Sub", "Mul"},
+ {OPERATION_SIZE, OPERATION_SIZE, OPERATION_SIZE}, 1e4);
+
+ // Test 1
+ // Expected behaviour: scheduler assigns different backend to each node
+ {
+ // For each backend reduce execution time of one node
+ ExecTime et(_mock_backends);
+ setOperationExecTime(et, _cpu_backend, "Add", false, OPERATION_SIZE, 1);
+ setOperationExecTime(et, _gpu_backend, "Sub", false, OPERATION_SIZE, 1);
+ setOperationExecTime(et, _npu_backend, "Mul", false, OPERATION_SIZE, 1);
+ et.storeOperationsExecTime();
+
+ // Test scheduler
+ auto coptions = *onert::compiler::CompilerOptions::fromGlobalConfig();
+ auto scheduler = compiler::HEScheduler(_mock_backends, coptions);
+ const auto br = scheduler.schedule(*graph);
+ ASSERT_EQ(br->getBackend(add_op_idx)->config()->id(), "cpu");
+ ASSERT_EQ(br->getBackend(sub_op_idx)->config()->id(), "gpu");
+ ASSERT_EQ(br->getBackend(mul_op_idx)->config()->id(), "npu");
+ }
+
+ // Test 2
+ // Expected behaviour: scheduler assigns single backend to all nodes because of big transfer time
+ {
+ // Increase transfer time
+ setPermutationsExecutionTime(_mock_backends, OPERAND_SIZE, 1e5);
+
+ // Test scheduler
+ auto coptions = *onert::compiler::CompilerOptions::fromGlobalConfig();
+ auto scheduler = compiler::HEScheduler(_mock_backends, coptions);
+ const auto br = scheduler.schedule(*graph);
+ ASSERT_EQ(br->getBackend(add_op_idx)->config()->id(), "cpu");
+ ASSERT_EQ(br->getBackend(sub_op_idx)->config()->id(), "cpu");
+ ASSERT_EQ(br->getBackend(mul_op_idx)->config()->id(), "cpu");
+ }
+}
+
+// Test scheduler behavior for branched graph with known execution time of all nodes and permutes
+TEST_P(HESchedulerTestWithExecutorParam, branched_graph_known_exec_time)
+{
+ const int64_t NPU_ET = 5000;
+ setExecutor(GetParam());
+
+ // Prepare graph
+ ir::Model model;
+ auto graph(createBranchedGraph());
+ model.push(ir::SubgraphIndex{0}, graph);
+ OperationIndex add_op_idx(0), mul1_op_idx(1), mul2_op_idx(2), fc1_op_idx(3), fc2_op_idx(4),
+ sub_op_idx(5);
+
+ // Set default execution and transfer time
+ setPermutationsExecutionTime(_mock_backends, OPERAND_SIZE, 1000);
+ setOperationsExecutionTime(_mock_backends, {"Add", "Sub", "Mul", "FullyConnected"},
+ {OPERATION_SIZE, OPERATION_SIZE, OPERATION_SIZE, OPERATION_SIZE}, 1e4);
+
+ // Test 1
+ // Expected behaviour: for dataflow and linear executors scheduler assigns fastest backend to all
+ // nodes, in case of parallel executor scheduler assigns different backends to branches.
+ {
+ // Reduce execution time
+ ExecTime et(_mock_backends);
+ setOperationExecTime(et, _npu_backend, "Add", false, OPERATION_SIZE, NPU_ET);
+ setOperationExecTime(et, _npu_backend, "Mul", false, OPERATION_SIZE, NPU_ET);
+ setOperationExecTime(et, _npu_backend, "Sub", false, OPERATION_SIZE, NPU_ET);
+ setOperationExecTime(et, _npu_backend, "FullyConnected", false, OPERATION_SIZE, NPU_ET);
+ setOperationExecTime(et, _gpu_backend, "Mul", false, OPERATION_SIZE, NPU_ET + 1000);
+ setOperationExecTime(et, _gpu_backend, "FullyConnected", false, OPERATION_SIZE, NPU_ET + 1000);
+ et.storeOperationsExecTime();
+
+ // Test scheduler
+ auto coptions = *onert::compiler::CompilerOptions::fromGlobalConfig();
+ auto scheduler = compiler::HEScheduler(_mock_backends, coptions);
+ const auto br = scheduler.schedule(*graph);
+
+ std::string branch1_expected_backend("npu"), branch2_expected_backend("npu");
+ if (GetParam() == PARALLEL)
+ {
+ branch1_expected_backend =
+ br->getBackend(mul1_op_idx)->config()->id() == "npu" ? "npu" : "gpu";
+ branch2_expected_backend = branch1_expected_backend == "npu" ? "gpu" : "npu";
+ }
+
+ ASSERT_EQ(br->getBackend(add_op_idx)->config()->id(), "npu");
+ ASSERT_EQ(br->getBackend(mul1_op_idx)->config()->id(), branch1_expected_backend);
+ ASSERT_EQ(br->getBackend(mul2_op_idx)->config()->id(), branch1_expected_backend);
+ ASSERT_EQ(br->getBackend(fc1_op_idx)->config()->id(), branch2_expected_backend);
+ ASSERT_EQ(br->getBackend(fc2_op_idx)->config()->id(), branch2_expected_backend);
+ ASSERT_EQ(br->getBackend(sub_op_idx)->config()->id(), "npu");
+ }
+
+ // Test 2
+ // Expected behaviour: scheduler assigns single backend to all nodes
+ {
+ // Increase execution time for GPU backend
+ ExecTime et(_mock_backends);
+ /* for parallel executor: set a time, that is larger than sum_of_other_branches_nodes_cnt *
+ * npu_exec_time so that npu is prefered: the ith branch will wait for npu until it finishes the
+ * [0;i-1] branches nodes in DFS order. In each branch it goes deep intul doesn't encounter
+ * branching or scheduler assigns another backend to a node*/
+ setOperationExecTime(et, _gpu_backend, "Mul", false, OPERATION_SIZE, NPU_ET * 3 + 1);
+ setOperationExecTime(et, _gpu_backend, "FullyConnected", false, OPERATION_SIZE, NPU_ET * 3 + 1);
+ et.storeOperationsExecTime();
+
+ // Test scheduler
+ auto coptions = *onert::compiler::CompilerOptions::fromGlobalConfig();
+ auto scheduler = compiler::HEScheduler(_mock_backends, coptions);
+ const auto br = scheduler.schedule(*graph);
+ ASSERT_EQ(br->getBackend(add_op_idx)->config()->id(), "npu");
+ ASSERT_EQ(br->getBackend(mul1_op_idx)->config()->id(), "npu");
+ ASSERT_EQ(br->getBackend(mul2_op_idx)->config()->id(), "npu");
+ ASSERT_EQ(br->getBackend(fc1_op_idx)->config()->id(), "npu");
+ ASSERT_EQ(br->getBackend(fc2_op_idx)->config()->id(), "npu");
+ ASSERT_EQ(br->getBackend(sub_op_idx)->config()->id(), "npu");
+ }
+}
+
+// Test scheduler behavior for branched graph and enabled profiling mode
+TEST_F(HESchedulerTest, branched_graph_profiling_mode)
+{
+ const int ET = 1e5;
+
+ // Turn on profiling mode
+ setProfilingMode(true);
+ setExecutor(DATAFLOW);
+
+ // Prepare graph
+ ir::Model model;
+ auto graph(createBranchedGraph());
+ model.push(ir::SubgraphIndex{0}, graph);
+ OperationIndex add_op_idx(0), mul1_op_idx(1), mul2_op_idx(2), fc1_op_idx(3), fc2_op_idx(4),
+ sub_op_idx(5);
+
+ // Test 1
+ // Expected behaviour: scheduler assigns backends to nodes with unknown execution time
+ {
+ // Set execution time for all backends/nodes except for cpu/Sub, npu/Mul, gpu/FC
+ ExecTime et(_mock_backends);
+ setOperationExecTime(et, _cpu_backend, "Add", false, OPERATION_SIZE, ET);
+ setOperationExecTime(et, _cpu_backend, "Mul", false, OPERATION_SIZE, ET + 1);
+ setOperationExecTime(et, _cpu_backend, "FullyConnected", false, OPERATION_SIZE, ET);
+ setOperationExecTime(et, _npu_backend, "Add", false, OPERATION_SIZE, ET);
+ setOperationExecTime(et, _npu_backend, "FullyConnected", false, OPERATION_SIZE, ET);
+ setOperationExecTime(et, _npu_backend, "Sub", false, OPERATION_SIZE, ET);
+ setOperationExecTime(et, _gpu_backend, "Add", false, OPERATION_SIZE, ET);
+ setOperationExecTime(et, _gpu_backend, "Mul", false, OPERATION_SIZE, ET + 1);
+ setOperationExecTime(et, _gpu_backend, "Sub", false, OPERATION_SIZE, ET);
+ et.storeOperationsExecTime();
+
+ // Test scheduler
+ auto coptions = *onert::compiler::CompilerOptions::fromGlobalConfig();
+ auto scheduler = compiler::HEScheduler(_mock_backends, coptions);
+ const auto br = scheduler.schedule(*graph);
+ ASSERT_EQ(br->getBackend(mul1_op_idx)->config()->id(), "npu");
+ ASSERT_EQ(br->getBackend(mul2_op_idx)->config()->id(), "npu");
+ ASSERT_EQ(br->getBackend(fc1_op_idx)->config()->id(), "gpu");
+ ASSERT_EQ(br->getBackend(fc2_op_idx)->config()->id(), "gpu");
+ ASSERT_EQ(br->getBackend(sub_op_idx)->config()->id(), "cpu");
+ }
+
+ // Test 2
+ // Expected behaviour: scheduler shuffling backends, so different backends are assigned to
+ // neighbor nodes
+ {
+ // Set execution time for rest backends/nodes (cpu/Sub, npu/Mul, gpu/FC)
+ ExecTime et(_mock_backends);
+ setOperationExecTime(et, _cpu_backend, "Sub", false, OPERATION_SIZE, ET);
+ setOperationExecTime(et, _npu_backend, "Mul", false, OPERATION_SIZE, ET + 1);
+ setOperationExecTime(et, _gpu_backend, "FullyConnected", false, OPERATION_SIZE, ET);
+ et.storeOperationsExecTime();
+
+ // Test scheduler
+ auto coptions = *onert::compiler::CompilerOptions::fromGlobalConfig();
+ auto scheduler = compiler::HEScheduler(_mock_backends, coptions);
+ const auto br = scheduler.schedule(*graph);
+ ASSERT_NE(br->getBackend(add_op_idx)->config()->id(),
+ br->getBackend(mul1_op_idx)->config()->id());
+ ASSERT_NE(br->getBackend(add_op_idx)->config()->id(),
+ br->getBackend(fc1_op_idx)->config()->id());
+ ASSERT_NE(br->getBackend(mul1_op_idx)->config()->id(),
+ br->getBackend(mul2_op_idx)->config()->id());
+ ASSERT_NE(br->getBackend(fc1_op_idx)->config()->id(),
+ br->getBackend(fc2_op_idx)->config()->id());
+ ASSERT_NE(br->getBackend(mul2_op_idx)->config()->id(),
+ br->getBackend(sub_op_idx)->config()->id());
+ ASSERT_NE(br->getBackend(fc2_op_idx)->config()->id(),
+ br->getBackend(sub_op_idx)->config()->id());
+ }
+}
+
+// TODO: Add tests with unknown execution and permutation time
+
+} // unnamed namespace
diff --git a/runtime/onert/core/src/compiler/Linear.cc b/runtime/onert/core/src/compiler/Linear.cc
index 73ba96238..f85b8d1bd 100644
--- a/runtime/onert/core/src/compiler/Linear.cc
+++ b/runtime/onert/core/src/compiler/Linear.cc
@@ -14,15 +14,13 @@
* limitations under the License.
*/
-#include <algorithm>
-#include <sstream>
-
#include "Linear.h"
-#include "backend/IConfig.h"
-#include "backend/Backend.h"
+#include "../dumper/text/GraphDumper.h"
+
#include "util/logging.h"
-#include "dumper/text/GraphDumper.h"
+
+#include <sstream>
namespace onert
{
diff --git a/runtime/onert/core/src/compiler/LoweredGraph.cc b/runtime/onert/core/src/compiler/LoweredGraph.cc
index 999bffa7c..9e84753a7 100644
--- a/runtime/onert/core/src/compiler/LoweredGraph.cc
+++ b/runtime/onert/core/src/compiler/LoweredGraph.cc
@@ -16,24 +16,23 @@
#include "compiler/LoweredGraph.h"
-#include <assert.h>
-#include <algorithm>
-#include <sstream>
-#include "util/logging.h"
-#include "compiler/pass/ConstantInsertionPass.h"
-#include "compiler/pass/ConstantLoweringPass.h"
-#include "compiler/pass/PassRunner.h"
-#include "compiler/pass/PermutationOperationPass.h"
-#include "compiler/pass/PermutationInsertionPass.h"
-#include "compiler/pass/PermutationEliminationPass.h"
-#include "dumper/text/GraphDumper.h"
-#include "ir/verifier/Verifier.h"
+#include "HEScheduler.h"
+#include "ManualScheduler.h"
+#include "pass/ConstantInsertionPass.h"
+#include "pass/ConstantLoweringPass.h"
+#include "pass/PassRunner.h"
+#include "pass/PermutationEliminationPass.h"
+#include "pass/PermutationInsertionPass.h"
+#include "pass/PermutationOperationPass.h"
+#include "../dumper/text/GraphDumper.h"
+#include "../ir/verifier/Verifier.h"
+
#include "backend/Backend.h"
-#include "backend/IConfig.h"
#include "compiler/BackendResolver.h"
-#include "compiler/ManualScheduler.h"
-#include "compiler/HEScheduler.h"
-#include "util/TracingCtx.h"
+#include "util/logging.h"
+
+#include <cassert>
+#include <sstream>
namespace onert
{
@@ -42,7 +41,7 @@ namespace compiler
LoweredGraph::LoweredGraph(const ir::Graph &graph, const CompilerOptions &options) : _graph{graph}
{
- lowerGraph(graph, options);
+ lowerGraph(options);
}
// TODO Design better class and constructor to represent parent_graph
@@ -50,18 +49,11 @@ LoweredGraph::LoweredGraph(const ir::Graph &parent_graph, const ir::Graph &graph
const CompilerOptions &options)
: _graph{graph}, _parent_graph{parent_graph}
{
- lowerGraph(graph, options);
+ lowerGraph(options);
}
-void LoweredGraph::lowerGraph(const ir::Graph &graph, const CompilerOptions &options)
+void LoweredGraph::lowerGraph(const CompilerOptions &options)
{
- // set tracing_ctx for copied graph
- if (options.tracing_ctx)
- {
- auto subgraph_index = options.tracing_ctx->getSubgraphIndex(&graph);
- options.tracing_ctx->setSubgraphIndex(&_graph, subgraph_index.value());
- }
-
// Build backend contexts
auto &backend_manager = BackendManager::get();
// Create contexts for other backends
diff --git a/runtime/onert/core/src/compiler/ShapeValidator.cc b/runtime/onert/core/src/compiler/ShapeValidator.cc
index 1c7000986..8c6421744 100644
--- a/runtime/onert/core/src/compiler/ShapeValidator.cc
+++ b/runtime/onert/core/src/compiler/ShapeValidator.cc
@@ -34,77 +34,72 @@ namespace onert
namespace compiler
{
-ShapeValidator::ShapeValidator(const ir::Graph &graph)
- : _graph{graph}, _ctx{graph.operands()}, _current_layout{ir::Layout::UNKNOWN}
-{
-}
+ShapeValidator::ShapeValidator(const ir::Graph &graph) : _graph{graph} {}
void ShapeValidator::checkUnaryOp(const ir::Operation &node)
{
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(0)};
const auto input_index{node.getInputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
// Check if I/O shapes match
- OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
+ OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape());
}
void ShapeValidator::operator()()
{
- // There is no reason for each subgraph to have subgraphs since compiler has subgraphs when
- // creating Compiler
- assert(_graph.subgraphs() == nullptr);
-
- _current_layout = _graph.layout();
-
_graph.operations().iterate(
[&](const ir::OperationIndex &, const ir::Operation &node) { node.accept(*this); });
}
void ShapeValidator::visit(const ir::operation::BatchMatMul &node)
{
+ const auto &operands = _graph.operands();
const auto lhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::LHS));
const auto rhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::RHS));
const auto out_index{node.getOutputs().at(0)};
- if (_ctx.at(out_index).info().isDynamic())
+ if (operands.at(out_index).info().isDynamic())
return;
- OP_REQUIRES(_ctx.at(lhs_index).shape().rank() <= 4);
- OP_REQUIRES(_ctx.at(rhs_index).shape().rank() <= 4);
- OP_REQUIRES(_ctx.at(lhs_index).shape().rank() >= 2);
- OP_REQUIRES(_ctx.at(rhs_index).shape().rank() >= 2);
+ OP_REQUIRES(operands.at(lhs_index).shape().rank() <= 4);
+ OP_REQUIRES(operands.at(rhs_index).shape().rank() <= 4);
+ OP_REQUIRES(operands.at(lhs_index).shape().rank() >= 2);
+ OP_REQUIRES(operands.at(rhs_index).shape().rank() >= 2);
}
void ShapeValidator::visit(const ir::operation::BatchToSpaceND &node)
{
+ const auto &operands = _graph.operands();
const auto ofm_index{node.getOutputs().at(0)};
- if (_ctx.at(ofm_index).info().isDynamic())
+ if (operands.at(ofm_index).info().isDynamic())
return;
const auto ifm_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::INPUT)};
const auto block_size_index{
node.getInputs().at(ir::operation::BatchToSpaceND::Input::BLOCK_SIZE)};
- const auto frontend_layout = _current_layout;
- const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
- const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
+ const auto frontend_layout = _graph.layout();
+ const auto input_shape = operands.at(ifm_index).shape().asFeature(frontend_layout);
+ const auto output_shape = operands.at(ofm_index).shape().asFeature(frontend_layout);
// All requirement as per NNAPI specification.
- OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
- OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
- OP_REQUIRES(_ctx.at(block_size_index).shape().rank() == 1);
+ OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(block_size_index).shape().rank() == 1);
- OP_REQUIRES(_ctx.at(block_size_index).shape().dim(0) == 2);
+ OP_REQUIRES(operands.at(block_size_index).shape().dim(0) == 2);
if (node.getInputs().size() != 2)
{
const auto crops_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::CROPS_DATA)};
- OP_REQUIRES(_ctx.at(crops_index).shape().rank() == 2);
- OP_REQUIRES(_ctx.at(crops_index).shape().dim(0) == (_ctx.at(ifm_index).shape().rank() - 2));
- OP_REQUIRES(_ctx.at(crops_index).shape().dim(1) == 2);
+ OP_REQUIRES(operands.at(crops_index).shape().rank() == 2);
+ OP_REQUIRES(operands.at(crops_index).shape().dim(0) ==
+ (operands.at(ifm_index).shape().rank() - 2));
+ OP_REQUIRES(operands.at(crops_index).shape().dim(1) == 2);
}
OP_REQUIRES(input_shape.C == output_shape.C);
@@ -112,8 +107,9 @@ void ShapeValidator::visit(const ir::operation::BatchToSpaceND &node)
void ShapeValidator::visit(const ir::operation::BCQFullyConnected &node)
{
+ const auto &operands = _graph.operands();
const auto ofm_index{node.getOutputs().at(0)};
- if (_ctx.at(ofm_index).info().isDynamic())
+ if (operands.at(ofm_index).info().isDynamic())
return;
const auto ifm_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)};
@@ -125,16 +121,16 @@ void ShapeValidator::visit(const ir::operation::BCQFullyConnected &node)
node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_CLUSTERS)};
// const auto bias_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::BIAS)};
- OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 2);
- OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 2);
- OP_REQUIRES(_ctx.at(weight_scales_index).shape().rank() == 1);
- OP_REQUIRES(_ctx.at(weight_binary_index).shape().rank() == 2);
- OP_REQUIRES(_ctx.at(weight_cluster_index).shape().rank() == 2);
+ OP_REQUIRES(operands.at(ifm_index).shape().rank() == 2);
+ OP_REQUIRES(operands.at(ofm_index).shape().rank() == 2);
+ OP_REQUIRES(operands.at(weight_scales_index).shape().rank() == 1);
+ OP_REQUIRES(operands.at(weight_binary_index).shape().rank() == 2);
+ OP_REQUIRES(operands.at(weight_cluster_index).shape().rank() == 2);
- OP_REQUIRES(_ctx.at(ifm_index).shape().dim(1) == _ctx.at(ofm_index).shape().dim(1));
+ OP_REQUIRES(operands.at(ifm_index).shape().dim(1) == operands.at(ofm_index).shape().dim(1));
- OP_REQUIRES(_ctx.at(weight_cluster_index).shape().dim(0) > 0);
- OP_REQUIRES(_ctx.at(weight_cluster_index).shape().dim(1) == 2);
+ OP_REQUIRES(operands.at(weight_cluster_index).shape().dim(0) > 0);
+ OP_REQUIRES(operands.at(weight_cluster_index).shape().dim(1) == 2);
// more shape validation will be done inside kernel.
@@ -143,8 +139,9 @@ void ShapeValidator::visit(const ir::operation::BCQFullyConnected &node)
void ShapeValidator::visit(const ir::operation::BCQGather &node)
{
+ const auto &operands = _graph.operands();
const auto ofm_index{node.getOutputs().at(0)};
- if (_ctx.at(ofm_index).info().isDynamic())
+ if (operands.at(ofm_index).info().isDynamic())
return;
const auto indices_index{node.getInputs().at(ir::operation::BCQGather::Input::INDICES)};
@@ -153,13 +150,14 @@ void ShapeValidator::visit(const ir::operation::BCQGather &node)
const auto input_clusters_index{
node.getInputs().at(ir::operation::BCQGather::Input::INPUT_CLUSTERS)};
- OP_REQUIRES(_ctx.at(indices_index).shape().rank() <= 2); // TODO : support rank up to 4 or more
- OP_REQUIRES(_ctx.at(input_binary_index).shape().rank() == 2);
- OP_REQUIRES(_ctx.at(input_scales_index).shape().rank() == 1);
- OP_REQUIRES(_ctx.at(input_clusters_index).shape().rank() == 2);
+ OP_REQUIRES(operands.at(indices_index).shape().rank() <=
+ 2); // TODO : support rank up to 4 or more
+ OP_REQUIRES(operands.at(input_binary_index).shape().rank() == 2);
+ OP_REQUIRES(operands.at(input_scales_index).shape().rank() == 1);
+ OP_REQUIRES(operands.at(input_clusters_index).shape().rank() == 2);
- OP_REQUIRES(_ctx.at(input_clusters_index).shape().dim(0) > 0);
- OP_REQUIRES(_ctx.at(input_clusters_index).shape().dim(1) == 2);
+ OP_REQUIRES(operands.at(input_clusters_index).shape().dim(0) > 0);
+ OP_REQUIRES(operands.at(input_clusters_index).shape().dim(1) == 2);
// more shape validation will be done inside kernel.
}
@@ -171,62 +169,67 @@ void ShapeValidator::visit(const ir::operation::Comparison &)
void ShapeValidator::visit(const ir::operation::Softmax &node)
{
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
const auto input_index{node.getInputs().at(0)};
- OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
+ OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
}
void ShapeValidator::visit(const ir::operation::InstanceNorm &node)
{
+ const auto &operands = _graph.operands();
const auto ofm_index{node.getOutputs().at(0)};
- if (_ctx.at(ofm_index).info().isDynamic())
+ if (operands.at(ofm_index).info().isDynamic())
return;
const auto ifm_index{node.getInputs().at(ir::operation::InstanceNorm::Input::INPUT)};
const auto gamma_index{node.getInputs().at(ir::operation::InstanceNorm::Input::GAMMA)};
const auto beta_index{node.getInputs().at(ir::operation::InstanceNorm::Input::BETA)};
- OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
- OP_REQUIRES(_ctx.at(ifm_index).shape() == _ctx.at(ofm_index).shape());
- OP_REQUIRES(_ctx.at(gamma_index).shape().rank() == 1);
- OP_REQUIRES(_ctx.at(beta_index).shape().rank() == 1);
+ OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(ifm_index).shape() == operands.at(ofm_index).shape());
+ OP_REQUIRES(operands.at(gamma_index).shape().rank() == 1);
+ OP_REQUIRES(operands.at(beta_index).shape().rank() == 1);
}
void ShapeValidator::visit(const ir::operation::Pool2D &node)
{
+ const auto &operands = _graph.operands();
const auto ofm_index{node.getOutputs().at(0)};
- if (_ctx.at(ofm_index).info().isDynamic())
+ if (operands.at(ofm_index).info().isDynamic())
return;
const auto ifm_index{node.getInputs().at(ir::operation::Pool2D::Input::INPUT)};
- OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
}
void ShapeValidator::visit(const ir::operation::Permute &node)
{
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
const auto input_index{node.getInputs().at(0)};
- OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
+ OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
}
void ShapeValidator::visit(const ir::operation::Reduce &node)
{
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
const auto input_index{node.getInputs().at(ir::operation::Reduce::Input::INPUT)};
- const auto input_shape = _ctx.at(input_index).shape();
- const auto output_shape = _ctx.at(output_index).shape();
+ const auto input_shape = operands.at(input_index).shape();
+ const auto output_shape = operands.at(output_index).shape();
OP_REQUIRES(input_shape.rank() <= 4);
OP_REQUIRES(output_shape.rank() <= input_shape.rank());
@@ -266,18 +269,20 @@ void ShapeValidator::visit(const ir::operation::Reduce &node)
void ShapeValidator::visit(const ir::operation::Transpose &node)
{
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
const auto input_index{node.getInputs().at(ir::operation::Transpose::Input::INPUT)};
const auto perm_index{node.getInputs().at(ir::operation::Transpose::Input::PERMUTATION)};
- const auto &output_shape = _ctx.at(output_index).shape();
- const auto &input_shape = _ctx.at(input_index).shape();
+ const auto &output_shape = operands.at(output_index).shape();
+ const auto &input_shape = operands.at(input_index).shape();
- OP_REQUIRES(_ctx.at(perm_index).shape().num_elements() == 0 ||
- input_shape.rank() == static_cast<int>(_ctx.at(perm_index).shape().num_elements()));
+ OP_REQUIRES(operands.at(perm_index).shape().num_elements() == 0 ||
+ input_shape.rank() ==
+ static_cast<int>(operands.at(perm_index).shape().num_elements()));
OP_REQUIRES(input_shape.rank() == output_shape.rank());
}
@@ -285,8 +290,9 @@ void ShapeValidator::visit(const ir::operation::RNN &node)
{
// NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
// TODO Support dynamic rnn
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(ir::operation::RNN::Output::OUTPUT)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
const auto hidden_state_out_index{
@@ -299,35 +305,36 @@ void ShapeValidator::visit(const ir::operation::RNN &node)
const auto bias_index{node.getInputs().at(ir::operation::RNN::Input::BIAS)};
const auto hidden_state_in_index{node.getInputs().at(ir::operation::RNN::Input::HIDDEN_STATE_IN)};
- const auto batch_size = _ctx.at(output_index).shape().dim(0);
- const auto num_units = _ctx.at(output_index).shape().dim(1);
-
- OP_REQUIRES(_ctx.at(output_index).shape().rank() == 2 &&
- _ctx.at(hidden_state_out_index).shape().rank() == 2 &&
- _ctx.at(input_index).shape().rank() == 2 &&
- _ctx.at(weights_index).shape().rank() == 2 &&
- _ctx.at(recurrent_weights_index).shape().rank() == 2 &&
- _ctx.at(hidden_state_in_index).shape().rank() == 2);
- OP_REQUIRES(_ctx.at(bias_index).shape().rank() == 1);
-
- OP_REQUIRES(batch_size == _ctx.at(input_index).shape().dim(0) &&
- batch_size == _ctx.at(hidden_state_in_index).shape().dim(0) &&
- batch_size == _ctx.at(hidden_state_out_index).shape().dim(0));
- OP_REQUIRES(_ctx.at(input_index).shape().dim(1) == _ctx.at(weights_index).shape().dim(1));
-
- OP_REQUIRES(num_units == _ctx.at(weights_index).shape().dim(0) &&
- num_units == _ctx.at(recurrent_weights_index).shape().dim(0) &&
- num_units == _ctx.at(bias_index).shape().dim(0));
- OP_REQUIRES(num_units == _ctx.at(output_index).shape().dim(1) &&
- num_units == _ctx.at(recurrent_weights_index).shape().dim(1) &&
- num_units == _ctx.at(hidden_state_in_index).shape().dim(1) &&
- num_units == _ctx.at(hidden_state_out_index).shape().dim(1));
+ const auto batch_size = operands.at(output_index).shape().dim(0);
+ const auto num_units = operands.at(output_index).shape().dim(1);
+
+ OP_REQUIRES(operands.at(output_index).shape().rank() == 2 &&
+ operands.at(hidden_state_out_index).shape().rank() == 2 &&
+ operands.at(input_index).shape().rank() == 2 &&
+ operands.at(weights_index).shape().rank() == 2 &&
+ operands.at(recurrent_weights_index).shape().rank() == 2 &&
+ operands.at(hidden_state_in_index).shape().rank() == 2);
+ OP_REQUIRES(operands.at(bias_index).shape().rank() == 1);
+
+ OP_REQUIRES(batch_size == operands.at(input_index).shape().dim(0) &&
+ batch_size == operands.at(hidden_state_in_index).shape().dim(0) &&
+ batch_size == operands.at(hidden_state_out_index).shape().dim(0));
+ OP_REQUIRES(operands.at(input_index).shape().dim(1) == operands.at(weights_index).shape().dim(1));
+
+ OP_REQUIRES(num_units == operands.at(weights_index).shape().dim(0) &&
+ num_units == operands.at(recurrent_weights_index).shape().dim(0) &&
+ num_units == operands.at(bias_index).shape().dim(0));
+ OP_REQUIRES(num_units == operands.at(output_index).shape().dim(1) &&
+ num_units == operands.at(recurrent_weights_index).shape().dim(1) &&
+ num_units == operands.at(hidden_state_in_index).shape().dim(1) &&
+ num_units == operands.at(hidden_state_out_index).shape().dim(1));
}
void ShapeValidator::visit(const ir::operation::SpaceToBatchND &node)
{
+ const auto &operands = _graph.operands();
const auto ofm_index{node.getOutputs().at(0)};
- if (_ctx.at(ofm_index).info().isDynamic())
+ if (operands.at(ofm_index).info().isDynamic())
return;
const auto ifm_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
@@ -335,39 +342,40 @@ void ShapeValidator::visit(const ir::operation::SpaceToBatchND &node)
node.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)};
const auto paddings_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
- const auto frontend_layout = _current_layout;
- const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
- const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
+ const auto frontend_layout = _graph.layout();
+ const auto input_shape = operands.at(ifm_index).shape().asFeature(frontend_layout);
+ const auto output_shape = operands.at(ofm_index).shape().asFeature(frontend_layout);
// All requirement as per NNAPI specification.
- OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
- OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
- OP_REQUIRES(_ctx.at(block_size_index).shape().rank() == 1);
- OP_REQUIRES(_ctx.at(paddings_index).shape().rank() == 2);
+ OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(block_size_index).shape().rank() == 1);
+ OP_REQUIRES(operands.at(paddings_index).shape().rank() == 2);
- OP_REQUIRES(_ctx.at(block_size_index).shape().dim(0) == 2);
- OP_REQUIRES(_ctx.at(paddings_index).shape().dim(0) == 2);
- OP_REQUIRES(_ctx.at(paddings_index).shape().dim(1) == 2);
+ OP_REQUIRES(operands.at(block_size_index).shape().dim(0) == 2);
+ OP_REQUIRES(operands.at(paddings_index).shape().dim(0) == 2);
+ OP_REQUIRES(operands.at(paddings_index).shape().dim(1) == 2);
OP_REQUIRES(input_shape.C == output_shape.C);
}
void ShapeValidator::visit(const ir::operation::SpaceToDepth &node)
{
+ const auto &operands = _graph.operands();
const auto ofm_index{node.getOutputs().at(0)};
- if (_ctx.at(ofm_index).info().isDynamic())
+ if (operands.at(ofm_index).info().isDynamic())
return;
const auto ifm_index{node.getInputs().at(ir::operation::SpaceToDepth::Input::INPUT)};
- const auto frontend_layout = _current_layout;
- const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
- const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
+ const auto frontend_layout = _graph.layout();
+ const auto input_shape = operands.at(ifm_index).shape().asFeature(frontend_layout);
+ const auto output_shape = operands.at(ofm_index).shape().asFeature(frontend_layout);
const auto block_size = node.param().block_size;
// All assertions as per NNAPI specification.
- OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
- OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
OP_REQUIRES((input_shape.H % block_size == 0) && (input_shape.W % block_size == 0));
OP_REQUIRES(input_shape.N == output_shape.N);
OP_REQUIRES(input_shape.C * block_size * block_size == output_shape.C);
@@ -382,29 +390,31 @@ void ShapeValidator::visit(const ir::operation::ElementwiseBinary &)
void ShapeValidator::visit(const ir::operation::ElementwiseUnary &node)
{
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(0)};
const auto input_index{node.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
- OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
+ OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape());
}
void ShapeValidator::visit(const ir::operation::EmbeddingLookup &node)
{
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(0)};
const auto lookups_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::LOOKUPS)};
const auto values_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::VALUES)};
- const auto &output_obj = _ctx.at(output_index);
- const auto &lookups_obj = _ctx.at(lookups_index);
- const auto &values_obj = _ctx.at(values_index);
+ const auto &output_obj = operands.at(output_index);
+ const auto &lookups_obj = operands.at(lookups_index);
+ const auto &values_obj = operands.at(values_index);
// Verify operand here, not at SimpleEmbeddingLookup::configure() to avoid acl's modifying
// TensorShape sometimes(Issue: https://github.sec.samsung.net/STAR/nnfw/issues/729)
{
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
const auto &output_shape = output_obj.shape();
@@ -427,26 +437,28 @@ void ShapeValidator::visit(const ir::operation::EmbeddingLookup &node)
void ShapeValidator::visit(const ir::operation::ExpandDims &node)
{
+ const auto &operands = _graph.operands();
const auto axis_index{node.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
- if (_ctx.at(axis_index).info().isDynamic())
+ if (operands.at(axis_index).info().isDynamic())
return;
- OP_REQUIRES(_ctx.at(axis_index).shape().rank() <= 1);
+ OP_REQUIRES(operands.at(axis_index).shape().rank() <= 1);
}
void ShapeValidator::visit(const ir::operation::HashtableLookup &node)
{
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::OUTPUT)};
const auto lookups_index{node.getInputs().at(ir::operation::HashtableLookup::Input::LOOKUPS)};
const auto keys_index{node.getInputs().at(ir::operation::HashtableLookup::Input::KEYS)};
const auto values_index{node.getInputs().at(ir::operation::HashtableLookup::Input::VALUES)};
- const auto &output_obj = _ctx.at(output_index);
- const auto &lookups_obj = _ctx.at(lookups_index);
- const auto &keys_obj = _ctx.at(keys_index);
- const auto &values_obj = _ctx.at(values_index);
+ const auto &output_obj = operands.at(output_index);
+ const auto &lookups_obj = operands.at(lookups_index);
+ const auto &keys_obj = operands.at(keys_index);
+ const auto &values_obj = operands.at(values_index);
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
const auto &output_shape = output_obj.shape();
@@ -464,28 +476,30 @@ void ShapeValidator::visit(const ir::operation::HashtableLookup &node)
void ShapeValidator::visit(const ir::operation::TransposeConv &node)
{
// shape check
+ const auto &operands = _graph.operands();
const auto ofm_index{node.getOutputs().at(0)};
- if (_ctx.at(ofm_index).info().isDynamic())
+
+ if (operands.at(ofm_index).info().isDynamic())
return;
const auto ifm_index{node.getInputs().at(ir::operation::TransposeConv::Input::INPUT)};
const auto ker_index{node.getInputs().at(ir::operation::TransposeConv::Input::KERNEL)};
// Only 4D tensors are supported
- OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
- OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == _ctx.at(ifm_index).shape().rank());
- OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == _ctx.at(ker_index).shape().rank());
+ OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(ofm_index).shape().rank() == operands.at(ifm_index).shape().rank());
+ OP_REQUIRES(operands.at(ofm_index).shape().rank() == operands.at(ker_index).shape().rank());
- const auto frontend_layout = _current_layout;
- const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
- const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
+ const auto frontend_layout = _graph.layout();
+ const auto ofm_shape = operands.at(ofm_index).shape().asFeature(frontend_layout);
+ const auto ifm_shape = operands.at(ifm_index).shape().asFeature(frontend_layout);
// The kernel has only IHWO layout on frontend
// So ker_shape is treated here below
// I -> N
// H -> H
// W -> W
// O -> C
- const auto ker_shape = _ctx.at(ker_index).shape().asFeature(ir::Layout::NHWC);
+ const auto ker_shape = operands.at(ker_index).shape().asFeature(ir::Layout::NHWC);
OP_REQUIRES(ifm_shape.N == ofm_shape.N);
OP_REQUIRES(ifm_shape.C == ker_shape.C);
@@ -494,16 +508,17 @@ void ShapeValidator::visit(const ir::operation::TransposeConv &node)
void ShapeValidator::visit(const ir::operation::Gather &node)
{
+ const auto &operands = _graph.operands();
const auto ofm_index{node.getOutputs().at(0)};
- if (_ctx.at(ofm_index).info().isDynamic())
+ if (operands.at(ofm_index).info().isDynamic())
return;
const auto ifm_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)};
const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)};
- const auto ifm_shape = _ctx.at(ifm_index).shape();
- const auto indices_shape = _ctx.at(indices_index).shape();
- const auto ofm_shape = _ctx.at(ofm_index).shape();
+ const auto ifm_shape = operands.at(ifm_index).shape();
+ const auto indices_shape = operands.at(indices_index).shape();
+ const auto ofm_shape = operands.at(ofm_index).shape();
OP_REQUIRES(ifm_shape.rank() <= 4);
OP_REQUIRES(indices_shape.rank() <= 3);
@@ -512,21 +527,22 @@ void ShapeValidator::visit(const ir::operation::Gather &node)
void ShapeValidator::visit(const ir::operation::DepthToSpace &node)
{
+ const auto &operands = _graph.operands();
int32_t block_size = node.param().block_size;
// shape check
const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
const auto input_index{node.getInputs().at(ir::operation::DepthToSpace::Input::INPUT)};
- const auto frontend_layout = _current_layout;
- const auto output_shape = _ctx.at(output_index).shape().asFeature(frontend_layout);
- const auto input_shape = _ctx.at(input_index).shape().asFeature(frontend_layout);
+ const auto frontend_layout = _graph.layout();
+ const auto output_shape = operands.at(output_index).shape().asFeature(frontend_layout);
+ const auto input_shape = operands.at(input_index).shape().asFeature(frontend_layout);
- OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4);
- OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(input_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(output_index).shape().rank() == 4);
{
OP_REQUIRES(output_shape.N == input_shape.N);
@@ -539,22 +555,23 @@ void ShapeValidator::visit(const ir::operation::DepthToSpace &node)
void ShapeValidator::visit(const ir::operation::Pack &node)
{
+ const auto &operands = _graph.operands();
const auto axis{node.param().axis};
const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
// shape check
- const auto &output_shape = _ctx.at(output_index).shape();
+ const auto &output_shape = operands.at(output_index).shape();
const auto output_rank = static_cast<int32_t>(output_shape.rank());
const auto input1_index{node.getInputs().at(0)};
- const auto input_shape = _ctx.at(input1_index).shape();
+ const auto input_shape = operands.at(input1_index).shape();
OP_REQUIRES(axis >= -output_rank && axis < output_rank);
for (const auto &index : node.getInputs())
{
- OP_REQUIRES(input_shape == _ctx.at(index).shape());
+ OP_REQUIRES(input_shape == operands.at(index).shape());
}
}
@@ -562,8 +579,9 @@ void ShapeValidator::visit(const ir::operation::LSTM &node)
{
// NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
// TODO Support dynamic rnn
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
const auto scratch_buffer_index{
@@ -611,91 +629,96 @@ void ShapeValidator::visit(const ir::operation::LSTM &node)
node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_STATE_IN)};
const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)};
- OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank());
- for (int i = 0; i < _ctx.at(input_index).shape().rank() - 1; ++i)
+ OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
+ for (int i = 0; i < operands.at(input_index).shape().rank() - 1; ++i)
{
- OP_REQUIRES(_ctx.at(input_index).shape().dim(i) == _ctx.at(output_index).shape().dim(i));
+ OP_REQUIRES(operands.at(input_index).shape().dim(i) ==
+ operands.at(output_index).shape().dim(i));
}
- OP_REQUIRES(
- (_ctx.at(output_index).shape().rank() == 2 || _ctx.at(output_index).shape().rank() == 3) &&
- (_ctx.at(input_index).shape().rank() == 2 || _ctx.at(input_index).shape().rank() == 3) &&
- (!_ctx.exist(input_to_input_weights_index) ||
- _ctx.at(input_to_input_weights_index).shape().rank() == 2) &&
- _ctx.at(input_to_forget_weights_index).shape().rank() == 2 &&
- _ctx.at(input_to_cell_weights_index).shape().rank() == 2 &&
- _ctx.at(input_to_output_weights_index).shape().rank() == 2 &&
- (!_ctx.exist(recurrent_to_input_weights_index) ||
- _ctx.at(recurrent_to_input_weights_index).shape().rank() == 2) &&
- _ctx.at(recurrent_to_forget_weights_index).shape().rank() == 2 &&
- _ctx.at(recurrent_to_cell_weights_index).shape().rank() == 2 &&
- _ctx.at(recurrent_to_output_weights_index).shape().rank() == 2 &&
- (!_ctx.exist(projection_weights_index) ||
- _ctx.at(projection_weights_index).shape().rank() == 2) &&
- _ctx.at(output_state_in_index).shape().rank() == 2 &&
- _ctx.at(cell_state_in_index).shape().rank() == 2);
-
- OP_REQUIRES(
- (!_ctx.exist(cell_to_input_weights_index) ||
- _ctx.at(cell_to_input_weights_index).shape().rank() == 1) &&
- (!_ctx.exist(cell_to_forget_weights_index) ||
- _ctx.at(cell_to_forget_weights_index).shape().rank() == 1) &&
- (!_ctx.exist(cell_to_output_weights_index) ||
- _ctx.at(cell_to_output_weights_index).shape().rank() == 1) &&
- (!_ctx.exist(input_gate_bias_index) || _ctx.at(input_gate_bias_index).shape().rank() == 1) &&
- _ctx.at(forget_gate_bias_index).shape().rank() == 1 &&
- _ctx.at(cell_bias_index).shape().rank() == 1 &&
- _ctx.at(output_gate_bias_index).shape().rank() == 1 &&
- (!_ctx.exist(projection_bias_index) || _ctx.at(projection_bias_index).shape().rank() == 1));
+ OP_REQUIRES((operands.at(output_index).shape().rank() == 2 ||
+ operands.at(output_index).shape().rank() == 3) &&
+ (operands.at(input_index).shape().rank() == 2 ||
+ operands.at(input_index).shape().rank() == 3) &&
+ (!operands.exist(input_to_input_weights_index) ||
+ operands.at(input_to_input_weights_index).shape().rank() == 2) &&
+ operands.at(input_to_forget_weights_index).shape().rank() == 2 &&
+ operands.at(input_to_cell_weights_index).shape().rank() == 2 &&
+ operands.at(input_to_output_weights_index).shape().rank() == 2 &&
+ (!operands.exist(recurrent_to_input_weights_index) ||
+ operands.at(recurrent_to_input_weights_index).shape().rank() == 2) &&
+ operands.at(recurrent_to_forget_weights_index).shape().rank() == 2 &&
+ operands.at(recurrent_to_cell_weights_index).shape().rank() == 2 &&
+ operands.at(recurrent_to_output_weights_index).shape().rank() == 2 &&
+ (!operands.exist(projection_weights_index) ||
+ operands.at(projection_weights_index).shape().rank() == 2) &&
+ operands.at(output_state_in_index).shape().rank() == 2 &&
+ operands.at(cell_state_in_index).shape().rank() == 2);
+
+ OP_REQUIRES((!operands.exist(cell_to_input_weights_index) ||
+ operands.at(cell_to_input_weights_index).shape().rank() == 1) &&
+ (!operands.exist(cell_to_forget_weights_index) ||
+ operands.at(cell_to_forget_weights_index).shape().rank() == 1) &&
+ (!operands.exist(cell_to_output_weights_index) ||
+ operands.at(cell_to_output_weights_index).shape().rank() == 1) &&
+ (!operands.exist(input_gate_bias_index) ||
+ operands.at(input_gate_bias_index).shape().rank() == 1) &&
+ operands.at(forget_gate_bias_index).shape().rank() == 1 &&
+ operands.at(cell_bias_index).shape().rank() == 1 &&
+ operands.at(output_gate_bias_index).shape().rank() == 1 &&
+ (!operands.exist(projection_bias_index) ||
+ operands.at(projection_bias_index).shape().rank() == 1));
// CIFG assertion
- OP_REQUIRES(
- ((!_ctx.exist(input_to_input_weights_index) ||
- (_ctx.at(input_to_input_weights_index).shape().dim(0) == 0 &&
- _ctx.at(input_to_input_weights_index).shape().dim(1) == 0)) &&
- (!_ctx.exist(recurrent_to_input_weights_index) ||
- (_ctx.at(recurrent_to_input_weights_index).shape().dim(0) == 0 &&
- _ctx.at(recurrent_to_input_weights_index).shape().dim(1) == 0)) &&
- (!_ctx.exist(input_gate_bias_index) || _ctx.at(input_gate_bias_index).shape().dim(0) == 0) &&
- (!_ctx.exist(cell_to_input_weights_index) ||
- _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0)) ||
- ((_ctx.exist(input_to_input_weights_index) &&
- (_ctx.at(input_to_input_weights_index).shape().dim(0) != 0 &&
- _ctx.at(input_to_input_weights_index).shape().dim(1) != 0)) &&
- (_ctx.exist(recurrent_to_input_weights_index) &&
- (_ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
- _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0)) &&
- (_ctx.exist(input_gate_bias_index) && _ctx.at(input_gate_bias_index).shape().dim(0) != 0)));
+ OP_REQUIRES(((!operands.exist(input_to_input_weights_index) ||
+ (operands.at(input_to_input_weights_index).shape().dim(0) == 0 &&
+ operands.at(input_to_input_weights_index).shape().dim(1) == 0)) &&
+ (!operands.exist(recurrent_to_input_weights_index) ||
+ (operands.at(recurrent_to_input_weights_index).shape().dim(0) == 0 &&
+ operands.at(recurrent_to_input_weights_index).shape().dim(1) == 0)) &&
+ (!operands.exist(input_gate_bias_index) ||
+ operands.at(input_gate_bias_index).shape().dim(0) == 0) &&
+ (!operands.exist(cell_to_input_weights_index) ||
+ operands.at(cell_to_input_weights_index).shape().dim(0) == 0)) ||
+ ((operands.exist(input_to_input_weights_index) &&
+ (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
+ operands.at(input_to_input_weights_index).shape().dim(1) != 0)) &&
+ (operands.exist(recurrent_to_input_weights_index) &&
+ (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
+ operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0)) &&
+ (operands.exist(input_gate_bias_index) &&
+ operands.at(input_gate_bias_index).shape().dim(0) != 0)));
// Peephole assertion
- OP_REQUIRES(((!_ctx.exist(cell_to_forget_weights_index) ||
- _ctx.at(cell_to_forget_weights_index).shape().dim(0) == 0) &&
- (!_ctx.exist(cell_to_output_weights_index) ||
- _ctx.at(cell_to_output_weights_index).shape().dim(0) == 0)) ||
- ((_ctx.exist(cell_to_forget_weights_index) &&
- _ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0) &&
- (_ctx.exist(cell_to_output_weights_index) &&
- _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0)));
-
- bool has_input_to_input_weights = _ctx.exist(input_to_input_weights_index) &&
- (_ctx.at(input_to_input_weights_index).shape().dim(0) != 0 &&
- _ctx.at(input_to_input_weights_index).shape().dim(1) != 0);
+ OP_REQUIRES(((!operands.exist(cell_to_forget_weights_index) ||
+ operands.at(cell_to_forget_weights_index).shape().dim(0) == 0) &&
+ (!operands.exist(cell_to_output_weights_index) ||
+ operands.at(cell_to_output_weights_index).shape().dim(0) == 0)) ||
+ ((operands.exist(cell_to_forget_weights_index) &&
+ operands.at(cell_to_forget_weights_index).shape().dim(0) != 0) &&
+ (operands.exist(cell_to_output_weights_index) &&
+ operands.at(cell_to_output_weights_index).shape().dim(0) != 0)));
+
+ bool has_input_to_input_weights =
+ operands.exist(input_to_input_weights_index) &&
+ (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
+ operands.at(input_to_input_weights_index).shape().dim(1) != 0);
bool has_recurrent_to_input_weights =
- _ctx.exist(recurrent_to_input_weights_index) &&
- (_ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
- _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0);
+ operands.exist(recurrent_to_input_weights_index) &&
+ (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
+ operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0);
bool has_input_gate_bias =
- _ctx.exist(input_gate_bias_index) && _ctx.at(input_gate_bias_index).shape().dim(0) != 0;
- bool has_cell_to_input_weights = _ctx.exist(cell_to_input_weights_index) &&
- _ctx.at(cell_to_input_weights_index).shape().dim(0) != 0;
- bool has_cell_to_forget_weights = _ctx.exist(cell_to_forget_weights_index) &&
- _ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0;
- bool has_cell_to_output_weights = _ctx.exist(cell_to_output_weights_index) &&
- _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0;
- bool has_projection_weights = _ctx.exist(projection_weights_index) &&
- (_ctx.at(projection_weights_index).shape().dim(0) != 0 &&
- _ctx.at(projection_weights_index).shape().dim(1) != 0);
+ operands.exist(input_gate_bias_index) && operands.at(input_gate_bias_index).shape().dim(0) != 0;
+ bool has_cell_to_input_weights = operands.exist(cell_to_input_weights_index) &&
+ operands.at(cell_to_input_weights_index).shape().dim(0) != 0;
+ bool has_cell_to_forget_weights = operands.exist(cell_to_forget_weights_index) &&
+ operands.at(cell_to_forget_weights_index).shape().dim(0) != 0;
+ bool has_cell_to_output_weights = operands.exist(cell_to_output_weights_index) &&
+ operands.at(cell_to_output_weights_index).shape().dim(0) != 0;
+ bool has_projection_weights = operands.exist(projection_weights_index) &&
+ (operands.at(projection_weights_index).shape().dim(0) != 0 &&
+ operands.at(projection_weights_index).shape().dim(1) != 0);
bool has_projection_bias =
- _ctx.exist(projection_bias_index) && _ctx.at(projection_bias_index).shape().dim(0) != 0;
+ operands.exist(projection_bias_index) && operands.at(projection_bias_index).shape().dim(0) != 0;
// NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
// true: no CIFG
@@ -710,46 +733,48 @@ void ShapeValidator::visit(const ir::operation::LSTM &node)
// NOTE The projection weights may have data but the projection bias may not.
bool has_projection_param = has_projection_weights;
- const auto batch_size = (_ctx.at(input_index).shape().rank() == 3 && node.param().time_major)
- ? _ctx.at(input_index).shape().dim(1)
- : _ctx.at(input_index).shape().dim(0);
- OP_REQUIRES(batch_size == _ctx.at(output_state_in_index).shape().dim(0) &&
- batch_size == _ctx.at(cell_state_in_index).shape().dim(0));
-
- const auto input_size = _ctx.at(input_index).shape().dim(_ctx.at(input_index).shape().rank() - 1);
- OP_REQUIRES(input_size == _ctx.at(input_to_forget_weights_index).shape().dim(1) &&
- input_size == _ctx.at(input_to_cell_weights_index).shape().dim(1) &&
- input_size == _ctx.at(input_to_output_weights_index).shape().dim(1));
-
- const auto num_units = _ctx.at(input_to_output_weights_index).shape().dim(0);
- OP_REQUIRES(num_units == _ctx.at(input_to_cell_weights_index).shape().dim(0) &&
- num_units == _ctx.at(input_to_output_weights_index).shape().dim(0) &&
- num_units == _ctx.at(recurrent_to_forget_weights_index).shape().dim(0) &&
- num_units == _ctx.at(recurrent_to_cell_weights_index).shape().dim(0) &&
- num_units == _ctx.at(recurrent_to_output_weights_index).shape().dim(0) &&
- num_units == _ctx.at(forget_gate_bias_index).shape().dim(0) &&
- num_units == _ctx.at(cell_bias_index).shape().dim(0) &&
- num_units == _ctx.at(output_gate_bias_index).shape().dim(0) &&
- num_units == _ctx.at(cell_state_in_index).shape().dim(1));
+ const auto batch_size = (operands.at(input_index).shape().rank() == 3 && node.param().time_major)
+ ? operands.at(input_index).shape().dim(1)
+ : operands.at(input_index).shape().dim(0);
+ OP_REQUIRES(batch_size == operands.at(output_state_in_index).shape().dim(0) &&
+ batch_size == operands.at(cell_state_in_index).shape().dim(0));
+
+ const auto input_size =
+ operands.at(input_index).shape().dim(operands.at(input_index).shape().rank() - 1);
+ OP_REQUIRES(input_size == operands.at(input_to_forget_weights_index).shape().dim(1) &&
+ input_size == operands.at(input_to_cell_weights_index).shape().dim(1) &&
+ input_size == operands.at(input_to_output_weights_index).shape().dim(1));
+
+ const auto num_units = operands.at(input_to_output_weights_index).shape().dim(0);
+ OP_REQUIRES(num_units == operands.at(input_to_cell_weights_index).shape().dim(0) &&
+ num_units == operands.at(input_to_output_weights_index).shape().dim(0) &&
+ num_units == operands.at(recurrent_to_forget_weights_index).shape().dim(0) &&
+ num_units == operands.at(recurrent_to_cell_weights_index).shape().dim(0) &&
+ num_units == operands.at(recurrent_to_output_weights_index).shape().dim(0) &&
+ num_units == operands.at(forget_gate_bias_index).shape().dim(0) &&
+ num_units == operands.at(cell_bias_index).shape().dim(0) &&
+ num_units == operands.at(output_gate_bias_index).shape().dim(0) &&
+ num_units == operands.at(cell_state_in_index).shape().dim(1));
const auto output_size =
- _ctx.at(output_index).shape().dim(_ctx.at(output_index).shape().rank() - 1);
- OP_REQUIRES(output_size == _ctx.at(recurrent_to_forget_weights_index).shape().dim(1) &&
- output_size == _ctx.at(recurrent_to_cell_weights_index).shape().dim(1) &&
- output_size == _ctx.at(recurrent_to_output_weights_index).shape().dim(1) &&
- output_size == _ctx.at(output_state_in_index).shape().dim(1));
+ operands.at(output_index).shape().dim(operands.at(output_index).shape().rank() - 1);
+ OP_REQUIRES(output_size == operands.at(recurrent_to_forget_weights_index).shape().dim(1) &&
+ output_size == operands.at(recurrent_to_cell_weights_index).shape().dim(1) &&
+ output_size == operands.at(recurrent_to_output_weights_index).shape().dim(1) &&
+ output_size == operands.at(output_state_in_index).shape().dim(1));
if (has_cifg_param)
{
- OP_REQUIRES(input_size == _ctx.at(input_to_input_weights_index).shape().dim(1));
- OP_REQUIRES(num_units == _ctx.at(input_to_input_weights_index).shape().dim(0) &&
- num_units == _ctx.at(recurrent_to_input_weights_index).shape().dim(0) &&
- ((_ctx.exist(cell_to_input_weights_index) &&
- num_units == _ctx.at(cell_to_input_weights_index).shape().dim(0)) ||
- (!_ctx.exist(cell_to_input_weights_index) ||
- _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0) /* non-peephole */) &&
- num_units == _ctx.at(input_gate_bias_index).shape().dim(0));
- OP_REQUIRES(output_size == _ctx.at(recurrent_to_input_weights_index).shape().dim(1));
+ OP_REQUIRES(input_size == operands.at(input_to_input_weights_index).shape().dim(1));
+ OP_REQUIRES(
+ num_units == operands.at(input_to_input_weights_index).shape().dim(0) &&
+ num_units == operands.at(recurrent_to_input_weights_index).shape().dim(0) &&
+ ((operands.exist(cell_to_input_weights_index) &&
+ num_units == operands.at(cell_to_input_weights_index).shape().dim(0)) ||
+ (!operands.exist(cell_to_input_weights_index) ||
+ operands.at(cell_to_input_weights_index).shape().dim(0) == 0) /* non-peephole */) &&
+ num_units == operands.at(input_gate_bias_index).shape().dim(0));
+ OP_REQUIRES(output_size == operands.at(recurrent_to_input_weights_index).shape().dim(1));
OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights &&
has_input_gate_bias);
if (has_cell_to_input_weights)
@@ -757,64 +782,65 @@ void ShapeValidator::visit(const ir::operation::LSTM &node)
// NOTE The cell_to_input_weights exist only in case of non-CIFG and peephole.
OP_REQUIRES(has_peephole_param);
}
- if (_ctx.exist(scratch_buffer_index))
- OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().dim(1) == num_units * 4);
+ if (operands.exist(scratch_buffer_index))
+ OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 4);
}
else
{
- if (_ctx.exist(scratch_buffer_index))
- OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().dim(1) == num_units * 3);
+ if (operands.exist(scratch_buffer_index))
+ OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 3);
}
if (has_peephole_param)
{
- OP_REQUIRES(num_units == _ctx.at(cell_to_forget_weights_index).shape().dim(0) &&
- num_units == _ctx.at(cell_to_output_weights_index).shape().dim(0) &&
- (num_units == _ctx.at(cell_to_input_weights_index).shape().dim(0) ||
- _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0 /* CIFG */));
+ OP_REQUIRES(num_units == operands.at(cell_to_forget_weights_index).shape().dim(0) &&
+ num_units == operands.at(cell_to_output_weights_index).shape().dim(0) &&
+ (num_units == operands.at(cell_to_input_weights_index).shape().dim(0) ||
+ operands.at(cell_to_input_weights_index).shape().dim(0) == 0 /* CIFG */));
}
if (has_projection_param)
{
- OP_REQUIRES(num_units == _ctx.at(projection_weights_index).shape().dim(1));
- OP_REQUIRES(output_size == _ctx.at(projection_weights_index).shape().dim(0));
+ OP_REQUIRES(num_units == operands.at(projection_weights_index).shape().dim(1));
+ OP_REQUIRES(output_size == operands.at(projection_weights_index).shape().dim(0));
if (has_projection_bias)
{
- OP_REQUIRES(output_size == _ctx.at(projection_bias_index).shape().dim(0));
+ OP_REQUIRES(output_size == operands.at(projection_bias_index).shape().dim(0));
}
}
- if (_ctx.exist(scratch_buffer_index))
+ if (operands.exist(scratch_buffer_index))
{
- OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().rank() == 2);
- OP_REQUIRES(batch_size == _ctx.at(scratch_buffer_index).shape().dim(0));
+ OP_REQUIRES(operands.at(scratch_buffer_index).shape().rank() == 2);
+ OP_REQUIRES(batch_size == operands.at(scratch_buffer_index).shape().dim(0));
}
- if (_ctx.exist(output_state_out_index))
+ if (operands.exist(output_state_out_index))
{
- OP_REQUIRES(_ctx.at(output_state_out_index).shape().rank() == 2);
- OP_REQUIRES(batch_size == _ctx.at(output_state_out_index).shape().dim(0));
- OP_REQUIRES(output_size == _ctx.at(output_state_out_index).shape().dim(1));
+ OP_REQUIRES(operands.at(output_state_out_index).shape().rank() == 2);
+ OP_REQUIRES(batch_size == operands.at(output_state_out_index).shape().dim(0));
+ OP_REQUIRES(output_size == operands.at(output_state_out_index).shape().dim(1));
}
- if (_ctx.exist(cell_state_out_index))
+ if (operands.exist(cell_state_out_index))
{
- OP_REQUIRES(_ctx.at(cell_state_out_index).shape().rank() == 2);
- OP_REQUIRES(batch_size == _ctx.at(cell_state_out_index).shape().dim(0));
- OP_REQUIRES(num_units == _ctx.at(cell_state_out_index).shape().dim(1));
+ OP_REQUIRES(operands.at(cell_state_out_index).shape().rank() == 2);
+ OP_REQUIRES(batch_size == operands.at(cell_state_out_index).shape().dim(0));
+ OP_REQUIRES(num_units == operands.at(cell_state_out_index).shape().dim(1));
}
}
void ShapeValidator::visit(const ir::operation::L2Normalization &node)
{
+ const auto &operands = _graph.operands();
const auto ofm_index{node.getOutputs().at(0)};
- if (_ctx.at(ofm_index).info().isDynamic())
+ if (operands.at(ofm_index).info().isDynamic())
return;
const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)};
- auto ifm_shape = _ctx.at(ifm_index).shape();
- auto ofm_shape = _ctx.at(ofm_index).shape();
+ auto ifm_shape = operands.at(ifm_index).shape();
+ auto ofm_shape = operands.at(ofm_index).shape();
OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank());
@@ -826,14 +852,15 @@ void ShapeValidator::visit(const ir::operation::L2Normalization &node)
void ShapeValidator::visit(const ir::operation::Unpack &node)
{
+ const auto &operands = _graph.operands();
const auto axis{node.param().axis};
const auto output_index{node.getInputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
const auto input_index{node.getInputs().at(ir::operation::Unpack::Input::INPUT)};
- const auto &input_shape = _ctx.at(input_index).shape();
+ const auto &input_shape = operands.at(input_index).shape();
const auto input_rank = static_cast<int32_t>(input_shape.rank());
OP_REQUIRES(axis >= -input_rank && axis < input_rank);
@@ -841,22 +868,23 @@ void ShapeValidator::visit(const ir::operation::Unpack &node)
void ShapeValidator::visit(const ir::operation::Pad &node)
{
+ const auto &operands = _graph.operands();
const auto pad_index{node.getInputs().at(ir::operation::Pad::Input::PAD)};
- OP_REQUIRES(_ctx.at(pad_index).typeInfo().type() == ir::DataType::INT32);
+ OP_REQUIRES(operands.at(pad_index).typeInfo().type() == ir::DataType::INT32);
const auto output_index{node.getInputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
const auto input_index{node.getInputs().at(ir::operation::Pad::Input::INPUT)};
- const auto &pad_shape = _ctx.at(pad_index).shape();
- const auto input_rank = static_cast<int32_t>(_ctx.at(input_index).shape().rank());
+ const auto &pad_shape = operands.at(pad_index).shape();
+ const auto input_rank = static_cast<int32_t>(operands.at(input_index).shape().rank());
OP_REQUIRES(pad_shape.rank() == 2);
OP_REQUIRES(pad_shape.dim(0) == input_rank);
OP_REQUIRES(pad_shape.dim(1) == 2);
- OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank());
+ OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
}
void ShapeValidator::visit(const ir::operation::Select &)
@@ -866,65 +894,70 @@ void ShapeValidator::visit(const ir::operation::Select &)
void ShapeValidator::visit(const ir::operation::StridedSlice &node)
{
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(0)};
const auto input_index{node.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
- OP_REQUIRES(_ctx.at(input_index).shape().rank() <= 4);
+ OP_REQUIRES(operands.at(input_index).shape().rank() <= 4);
}
void ShapeValidator::visit(const ir::operation::Split &node)
{
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
const auto input_index{node.getInputs().at(ir::operation::Split::Input::INPUT)};
const auto axis_index{node.getInputs().at(ir::operation::Split::Input::AXIS)};
const auto num_splits = node.param().num_splits;
- const auto input_rank = _ctx.at(input_index).shape().rank();
- auto axis = *reinterpret_cast<const int32_t *>(_ctx.at(axis_index).data()->base());
+ const auto input_rank = operands.at(input_index).shape().rank();
+ auto axis = *reinterpret_cast<const int32_t *>(operands.at(axis_index).data()->base());
axis = axis < 0 ? axis + input_rank : axis;
OP_REQUIRES(axis >= 0 && axis < input_rank);
- OP_REQUIRES(_ctx.at(input_index).shape().dim(axis) % num_splits == 0);
+ OP_REQUIRES(operands.at(input_index).shape().dim(axis) % num_splits == 0);
}
void ShapeValidator::visit(const ir::operation::Shape &node)
{
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
const auto input_index{node.getInputs().at(0)};
UNUSED_RELEASE(input_index);
- OP_REQUIRES(_ctx.at(output_index).shape().rank() == 1);
+ OP_REQUIRES(operands.at(output_index).shape().rank() == 1);
}
void ShapeValidator::visit(const ir::operation::ResizeBilinear &node)
{
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(0)};
const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
{
return;
}
- OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4);
- OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(input_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(output_index).shape().rank() == 4);
}
void ShapeValidator::visit(const ir::operation::Reverse &node)
{
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(0)};
const auto input_index{node.getInputs().at(ir::operation::Reverse::Input::INPUT)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
- OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
+ OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape());
}
void ShapeValidator::visit(const ir::operation::If &)
@@ -940,17 +973,18 @@ void ShapeValidator::visit(const ir::operation::While &)
void ShapeValidator::visit(const ir::operation::SquaredDifference &node)
{
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(0)};
const auto lhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::LHS)};
const auto rhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::RHS)};
// Check for dimension constraints
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
- auto output_shape = _ctx.at(output_index).shape();
- auto lhs_shape = _ctx.at(lhs_index).shape();
- auto rhs_shape = _ctx.at(rhs_index).shape();
+ auto output_shape = operands.at(output_index).shape();
+ auto lhs_shape = operands.at(lhs_index).shape();
+ auto rhs_shape = operands.at(rhs_index).shape();
// Check for output rank
OP_REQUIRES(output_shape.rank() == std::max(lhs_shape.rank(), rhs_shape.rank()));
auto min_rank = std::min(lhs_shape.rank(), rhs_shape.rank());
@@ -982,36 +1016,40 @@ void ShapeValidator::visit(const ir::operation::SquaredDifference &node)
}
void ShapeValidator::visit(const ir::operation::Tile &node)
{
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
const auto input_index{node.getInputs().at(0)};
const auto multiple_index{node.getInputs().at(1)};
- OP_REQUIRES(_ctx.at(multiple_index).shape().rank() == 1);
- OP_REQUIRES(_ctx.at(multiple_index).shape().dim(0) == _ctx.at(input_index).shape().rank());
- OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank());
+ OP_REQUIRES(operands.at(multiple_index).shape().rank() == 1);
+ OP_REQUIRES(operands.at(multiple_index).shape().dim(0) ==
+ operands.at(input_index).shape().rank());
+ OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
}
void ShapeValidator::visit(const ir::operation::Range &node)
{
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(0)};
const auto start_index{node.getInputs().at(ir::operation::Range::Input::START)};
const auto limit_index{node.getInputs().at(ir::operation::Range::Input::LIMIT)};
const auto delta_index{node.getInputs().at(ir::operation::Range::Input::DELTA)};
// Check for dimension constraints
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
- OP_REQUIRES(_ctx.at(start_index).shape().rank() == 0);
- OP_REQUIRES(_ctx.at(limit_index).shape().rank() == 0);
- OP_REQUIRES(_ctx.at(delta_index).shape().rank() == 0);
+ OP_REQUIRES(operands.at(start_index).shape().rank() == 0);
+ OP_REQUIRES(operands.at(limit_index).shape().rank() == 0);
+ OP_REQUIRES(operands.at(delta_index).shape().rank() == 0);
}
void ShapeValidator::visit(const ir::operation::MatrixBandPart &node)
{
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(0)};
const auto input_index{node.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT)};
const auto num_lower_index{
@@ -1020,23 +1058,24 @@ void ShapeValidator::visit(const ir::operation::MatrixBandPart &node)
node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_UPPER_DIAG)};
// Check for dimension constraints
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
- OP_REQUIRES(_ctx.at(input_index).shape().rank() >= 2); // input must be more than 2 dim matrix
- OP_REQUIRES(_ctx.at(num_upper_index).shape().rank() == 0); // num_lower must be scalar
- OP_REQUIRES(_ctx.at(num_lower_index).shape().rank() == 0); // num_upper must be scalar
+ OP_REQUIRES(operands.at(input_index).shape().rank() >= 2); // input must be more than 2 dim matrix
+ OP_REQUIRES(operands.at(num_upper_index).shape().rank() == 0); // num_lower must be scalar
+ OP_REQUIRES(operands.at(num_lower_index).shape().rank() == 0); // num_upper must be scalar
}
void ShapeValidator::visit(const ir::operation::LogSoftmax &node)
{
+ const auto &operands = _graph.operands();
const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
+ if (operands.at(output_index).info().isDynamic())
return;
const auto input_index{node.getInputs().at(0)};
- OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
+ OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
}
} // namespace compiler
diff --git a/runtime/onert/core/src/compiler/ShapeValidator.h b/runtime/onert/core/src/compiler/ShapeValidator.h
index 763cf7ce3..a51e8adc0 100644
--- a/runtime/onert/core/src/compiler/ShapeValidator.h
+++ b/runtime/onert/core/src/compiler/ShapeValidator.h
@@ -39,8 +39,13 @@ class ShapeValidator : public ir::OperationVisitor
public:
ShapeValidator(void) = delete;
ShapeValidator(const ir::Graph &graph);
+ ShapeValidator(const ShapeValidator &) = delete;
+ ShapeValidator(ShapeValidator &&) = delete;
+ ~ShapeValidator() = default;
public:
+ ShapeValidator &operator=(const ShapeValidator &) = delete;
+ ShapeValidator &operator=(ShapeValidator &&) = delete;
void operator()();
public:
@@ -90,10 +95,7 @@ private:
void checkUnaryOp(const ir::Operation &node);
private:
- // TODO Remove _ctx field
const ir::Graph &_graph;
- const ir::Operands &_ctx;
- ir::Layout _current_layout;
};
} // namespace compiler
diff --git a/runtime/onert/core/src/compiler/StaticShapeInferer.cc b/runtime/onert/core/src/compiler/StaticShapeInferer.cc
index f2fee2c3c..485450560 100644
--- a/runtime/onert/core/src/compiler/StaticShapeInferer.cc
+++ b/runtime/onert/core/src/compiler/StaticShapeInferer.cc
@@ -19,62 +19,90 @@
#include "util/logging.h"
#include <sstream>
+#include <stdexcept>
namespace onert
{
namespace compiler
{
-
-void StaticShapeInferer::inferSubgraph(ir::SubgraphIndex subg_ind)
+void OperandObserver::updateShapes(const std::vector<ir::OperandInfo> &changed_operands_info,
+ bool unpredictable)
{
- StaticShapeInferer inferer(subg_ind, _lowered_subgs);
- auto &lgraph = _lowered_subgs.at(subg_ind);
- for (auto op_ind : lgraph->graph().topolSortOperations())
+ assert(changed_operands_info.size() == _operands.size());
+ for (size_t i = 0; i < changed_operands_info.size(); ++i)
{
- auto &op = lgraph->graph().operations().at(op_ind);
- bool has_dynamic_tensor = inferer.infer(op);
- lgraph->setHasDynamicTensor(op_ind, has_dynamic_tensor);
+ const auto &changed_operand_info = changed_operands_info.at(i);
+ auto &operand = _operands.at(i);
+ // assert(changed_operand_info.typeInfo() == operand->typeInfo());
+ // assert(changed_operand_info.typeInfo() == operand->typeInfo());
+ // This error check may by replaced by an assertion if this function is called after the
+ // validation of models are completed.
+ if (changed_operand_info.typeInfo() != operand->typeInfo())
+ {
+ throw std::runtime_error("OperandObserver: The types of operands are mismatched");
+ }
+ if (!operand->info().isConstant() && (changed_operand_info.isDynamic() || unpredictable))
+ {
+ operand->info().setDynamic();
+ }
+ else
+ {
+ const auto &new_shape = changed_operands_info.at(i).shape();
+ operand->info().shape(new_shape);
+ }
}
}
-bool StaticShapeInferer::infer(const ir::Operation &op)
+void StaticShapeInferer::infer()
{
- bool has_dynamic_tensor = false;
-
- auto opcode = op.opcode();
-
- _return_has_dynamic_tensor = false; // this is used as a return value inside operation's visit()
-
- // IF: need shape inference for then, else
- // While: need shape inference for condition, body
- if (opcode == ir::OpCode::If || opcode == ir::OpCode::While)
- {
- op.accept(*this);
- }
- else
+ for (const auto &op_idx : _lowered_subg->graph().topolSortOperations())
{
- _return_has_dynamic_tensor = checkDynamicInput(op);
-
- if (_return_has_dynamic_tensor)
+ const auto &op = _lowered_subg->graph().operations().at(op_idx);
+ bool has_dynamic_tensor = false;
+ const auto opcode = op.opcode();
+ // IF: requires shape inference for then, else
+ // While: requires shape inference for condition, body
+ if (opcode == ir::OpCode::If || opcode == ir::OpCode::While)
{
- setDynamicOutput(op);
+ op.accept(*this);
}
else
{
- op.accept(*this);
+ has_dynamic_tensor = checkDynamicInput(op);
+ if (has_dynamic_tensor)
+ {
+ setDynamicOutput(op);
+ }
+ else
+ {
+ op.accept(*this);
+ }
}
+ has_dynamic_tensor = has_dynamic_tensor || checkDynamicOutput(op);
+ _lowered_subg->setHasDynamicTensor(op_idx, has_dynamic_tensor);
}
- has_dynamic_tensor = has_dynamic_tensor || _return_has_dynamic_tensor;
-
- return has_dynamic_tensor;
+ if (_controlflow_output_observer != nullptr)
+ {
+ // re-sizing output shapes of the controflow operation branching to this subgraph
+ std::vector<ir::OperandInfo> outputs_info;
+ const auto &graph = _lowered_subg->graph();
+ const auto &outputs = graph.getOutputs();
+ for (size_t i = 0; i < outputs.size(); ++i)
+ {
+ const auto &operand_info = graph.operands().at(outputs.at(i)).info();
+ outputs_info.emplace_back(operand_info);
+ }
+ _controlflow_output_observer->updateShapes(outputs_info);
+ }
}
bool StaticShapeInferer::checkDynamicInput(const ir::Operation &op)
{
+ const auto &operands = _lowered_subg->graph().operands();
for (auto input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
{
- if (_operands.at(input_idx).info().isDynamic())
+ if (operands.at(input_idx).info().isDynamic())
{
return true;
}
@@ -83,11 +111,25 @@ bool StaticShapeInferer::checkDynamicInput(const ir::Operation &op)
return false;
}
+bool StaticShapeInferer::checkDynamicOutput(const ir::Operation &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+ for (auto output_idx : op.getOutputs() | ir::Remove::UNDEFINED)
+ {
+ if (operands.at(output_idx).info().isDynamic())
+ {
+ return true;
+ }
+ }
+ return false;
+}
+
void StaticShapeInferer::setDynamicOutput(const ir::Operation &op)
{
+ auto &operands = _lowered_subg->graph().operands();
for (auto output_idx : op.getOutputs() | ir::Remove::UNDEFINED)
{
- _operands.at(output_idx).info().setDynamic();
+ operands.at(output_idx).info().setDynamic();
}
}
@@ -95,11 +137,12 @@ void StaticShapeInferer::handleBinaryArithmeticOp(const ir::Operation &op,
const ir::OperandIndex lhs_idx,
const ir::OperandIndex rhs_idx)
{
- const auto &lhs = _operands.at(lhs_idx);
- const auto &rhs = _operands.at(rhs_idx);
+ auto &operands = _lowered_subg->graph().operands();
+ const auto &lhs = operands.at(lhs_idx);
+ const auto &rhs = operands.at(rhs_idx);
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
// re-sizing output shape
ir::Shape new_shape = shape_inference::inferEltwiseShape(lhs.info().shape(), rhs.info().shape());
@@ -109,11 +152,12 @@ void StaticShapeInferer::handleBinaryArithmeticOp(const ir::Operation &op,
void StaticShapeInferer::handleSimpleUnaryOp(const ir::Operation &op,
const ir::OperandIndex input_idx)
{
- const auto &input = _operands.at(input_idx);
+ auto &operands = _lowered_subg->graph().operands();
+ const auto &input = operands.at(input_idx);
// get mutable output operand
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
// re-sizing output shape
ir::Shape new_shape = input.info().shape();
@@ -136,36 +180,31 @@ void StaticShapeInferer::dump()
return sstream.str();
};
- for (const auto &pair : _lowered_subgs)
- {
- const auto index = pair.first;
- const auto &lowered_subg = pair.second;
- VERBOSE(StaticShapeInferer) << index << std::endl;
- lowered_subg->graph().operands().iterate(
- [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
- VERBOSE(StaticShapeInferer)
- << " " << ind << ", " << (operand.info().isDynamic() ? "Dynamic" : "Static") << ", "
- << get_shape_str(operand.info().shape()) << std::endl;
- });
- }
+ _lowered_subg->graph().operands().iterate(
+ [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
+ VERBOSE(StaticShapeInferer) << " " << ind << ", "
+ << (operand.info().isDynamic() ? "Dynamic" : "Static") << ", "
+ << get_shape_str(operand.info().shape()) << std::endl;
+ });
}
void StaticShapeInferer::visit(const ir::operation::ArgMinMax &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_idx{op.getInputs().at(ir::operation::ArgMinMax::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
const auto axis_idx{op.getInputs().at(ir::operation::ArgMinMax::Input::AXIS)};
- const auto &axis = _operands.at(axis_idx);
+ const auto &axis = operands.at(axis_idx);
// get mutable output operand
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
if (!axis.isConstant())
{
output.info().setDynamic();
- _return_has_dynamic_tensor = true;
return;
}
@@ -181,27 +220,31 @@ void StaticShapeInferer::visit(const ir::operation::ArgMinMax &op)
void StaticShapeInferer::visit(const ir::operation::BatchMatMul &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto lhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::LHS);
const auto rhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::RHS);
const auto output_index = op.getOutputs().at(0);
- const auto &lhs = _operands.at(lhs_index);
- const auto &rhs = _operands.at(rhs_index);
- auto &output = _operands.at(output_index);
+ const auto &lhs = operands.at(lhs_index);
+ const auto &rhs = operands.at(rhs_index);
+ auto &output = operands.at(output_index);
auto new_shape = shape_inference::inferBatchMatMulShape(lhs.shape(), rhs.shape(), op.param());
output.info().shape(new_shape);
}
void StaticShapeInferer::visit(const ir::operation::BCQFullyConnected &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_idx{op.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
const auto cluster_idx{
op.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_CLUSTERS)};
- const auto &cluster = _operands.at(cluster_idx);
+ const auto &cluster = operands.at(cluster_idx);
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
auto cluster_buf = reinterpret_cast<const int32_t *>(cluster.data()->base());
assert(cluster_buf);
@@ -214,17 +257,19 @@ void StaticShapeInferer::visit(const ir::operation::BCQFullyConnected &op)
void StaticShapeInferer::visit(const ir::operation::BCQGather &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto indices_idx{op.getInputs().at(ir::operation::BCQGather::Input::INDICES)};
- const auto &indices = _operands.at(indices_idx);
+ const auto &indices = operands.at(indices_idx);
const auto input_binary_idx{op.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)};
- const auto &input_binary = _operands.at(input_binary_idx);
+ const auto &input_binary = operands.at(input_binary_idx);
const auto cluster_idx{op.getInputs().at(ir::operation::BCQGather::Input::INPUT_CLUSTERS)};
- const auto &cluster = _operands.at(cluster_idx);
+ const auto &cluster = operands.at(cluster_idx);
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
auto cluster_buf = reinterpret_cast<const int32_t *>(cluster.data()->base());
assert(cluster_buf);
@@ -247,16 +292,16 @@ void StaticShapeInferer::visit(const ir::operation::BinaryArithmetic &op)
void StaticShapeInferer::visit(const ir::operation::BroadcastTo &op)
{
// get mutable output operand
+ auto &operands = _lowered_subg->graph().operands();
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
const auto shape_idx{op.getInputs().at(ir::operation::BroadcastTo::Input::SHAPE)};
- const auto &shape = _operands.at(shape_idx);
+ const auto &shape = operands.at(shape_idx);
if (!shape.isConstant())
{
output.info().setDynamic();
- _return_has_dynamic_tensor = true;
return;
}
@@ -276,16 +321,18 @@ void StaticShapeInferer::visit(const ir::operation::Comparison &op)
void StaticShapeInferer::visit(const ir::operation::Concat &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_count = op.getInputs().size();
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
shape_inference::Shapes input_shapes;
for (uint32_t i = 0; i < input_count; i++)
{
const auto input_idx{op.getInputs().at(i)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
input_shapes.emplace_back(input.shape());
}
@@ -297,12 +344,14 @@ void StaticShapeInferer::visit(const ir::operation::Concat &op)
void StaticShapeInferer::visit(const ir::operation::Conv2D &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_idx{op.getInputs().at(ir::operation::Conv2D::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
const auto ker_idx{op.getInputs().at(ir::operation::Conv2D::Input::KERNEL)};
- const auto &ker = _operands.at(ker_idx);
+ const auto &ker = operands.at(ker_idx);
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
// re-sizing output shape
ir::Shape new_shape =
@@ -328,17 +377,18 @@ void StaticShapeInferer::visit(const ir::operation::ElementwiseUnary &op)
void StaticShapeInferer::visit(const ir::operation::ExpandDims &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_idx{op.getInputs().at(ir::operation::ExpandDims::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
const auto axis_idx{op.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
- const auto &axis = _operands.at(axis_idx);
+ const auto &axis = operands.at(axis_idx);
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
if (!axis.isConstant())
{
output.info().setDynamic();
- _return_has_dynamic_tensor = true;
return;
}
@@ -360,15 +410,16 @@ void StaticShapeInferer::visit(const ir::operation::ExpandDims &op)
void StaticShapeInferer::visit(const ir::operation::Fill &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto shape_idx{op.getInputs().at(ir::operation::Fill::Input::SHAPE)};
- const auto &shape = _operands.at(shape_idx);
+ const auto &shape = operands.at(shape_idx);
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
if (!shape.isConstant())
{
output.info().setDynamic();
- _return_has_dynamic_tensor = true;
return;
}
@@ -390,15 +441,17 @@ void StaticShapeInferer::visit(const ir::operation::Fill &op)
void StaticShapeInferer::visit(const ir::operation::FullyConnected &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_idx{op.getInputs().at(ir::operation::FullyConnected::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
const auto ker_idx{op.getInputs().at(ir::operation::FullyConnected::Input::WEIGHT)};
- const auto &ker = _operands.at(ker_idx);
+ const auto &ker = operands.at(ker_idx);
// get mutable output operand
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
// re-sizing output shape
ir::Shape new_shape =
shape_inference::inferFullyConnectedShape(input.info().shape(), ker.info().shape());
@@ -412,15 +465,17 @@ void StaticShapeInferer::visit(const ir::operation::FusedBatchNorm &op)
void StaticShapeInferer::visit(const ir::operation::Gather &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_idx{op.getInputs().at(ir::operation::Gather::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
// get mutable output operand
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
const auto indices_idx{op.getInputs().at(ir::operation::Gather::Input::INDICES)};
- const auto &indices = _operands.at(indices_idx);
+ const auto &indices = operands.at(indices_idx);
const auto rank = input.info().shape().rank();
const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
@@ -434,70 +489,21 @@ void StaticShapeInferer::visit(const ir::operation::Gather &op)
void StaticShapeInferer::visit(const ir::operation::If &op)
{
- auto &then_graph = _lowered_subgs.at(op.param().then_subg_index)->graph();
- auto &else_graph = _lowered_subgs.at(op.param().else_subg_index)->graph();
+ // re-sizing input shapes of then/else subgraph
const std::vector<ir::OperandIndex> inputs{op.getInputs().begin() + 1, op.getInputs().end()};
- const auto &outputs = op.getOutputs();
- // re-sizing input shapes of then subgraph
- const auto &then_inputs = then_graph.getInputs();
- assert(inputs.size() == then_inputs.size());
+ std::vector<ir::OperandInfo> inputs_info;
+ const auto &graph = _lowered_subg->graph();
for (size_t i = 0; i < inputs.size(); ++i)
{
- auto &then_input = then_graph.operands().at(then_inputs.at(i));
- if (_operands.at(inputs.at(i)).info().isDynamic())
- {
- then_input.info().setDynamic();
- }
- else
- {
- auto new_shape = _operands.at(inputs.at(i)).info().shape();
- then_input.info().shape(new_shape);
- }
+ const auto &operand_info = graph.operands().at(inputs.at(i)).info();
+ inputs_info.emplace_back(operand_info);
}
+ _subg_input_observers.at(op.param().then_subg_index)->updateShapes(inputs_info);
+ _child_inferers.at(op.param().then_subg_index)->infer();
- // re-sizing input shapes of else subgraph
- const auto &else_inputs = else_graph.getInputs();
- assert(inputs.size() == else_inputs.size());
- for (size_t i = 0; i < inputs.size(); ++i)
- {
- auto &else_input = else_graph.operands().at(else_inputs.at(i));
- if (_operands.at(inputs.at(i)).info().isDynamic())
- {
- else_input.info().setDynamic();
- }
- else
- {
- const auto &new_shape = _operands.at(inputs.at(i)).info().shape();
- else_input.info().shape(new_shape);
- }
- }
-
- inferSubgraph(op.param().then_subg_index);
- inferSubgraph(op.param().else_subg_index);
-
- // re-sizing output shapes
- // TODO use then_graph / else_graph instead
- const auto &then_outputs = _lowered_subgs.at(op.param().then_subg_index)->graph().getOutputs();
- const auto &else_outputs = _lowered_subgs.at(op.param().else_subg_index)->graph().getOutputs();
- assert(outputs.size() == then_outputs.size());
- assert(outputs.size() == else_outputs.size());
- for (size_t i = 0; i < outputs.size(); ++i)
- {
- const auto &then_output = then_graph.operands().at(then_outputs.at(i));
- const auto &else_output = else_graph.operands().at(else_outputs.at(i));
- auto &output = _operands.at(outputs.at(i));
- if (!then_output.info().isDynamic() && !else_output.info().isDynamic() &&
- then_output.shape() == else_output.shape())
- {
- output.info().shape(then_output.shape());
- }
- else
- {
- output.info().setDynamic();
- _return_has_dynamic_tensor = true;
- }
- }
+ _subg_input_observers.at(op.param().else_subg_index)->updateShapes(inputs_info);
+ _child_inferers.at(op.param().else_subg_index)->infer();
}
void StaticShapeInferer::visit(const ir::operation::L2Normalization &op)
@@ -507,8 +513,10 @@ void StaticShapeInferer::visit(const ir::operation::L2Normalization &op)
void StaticShapeInferer::visit(const ir::operation::LSTM &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto output_index{op.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
- auto &output = _operands.at(output_index);
+ auto &output = operands.at(output_index);
const auto output_state_out_index{
op.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)};
@@ -518,24 +526,24 @@ void StaticShapeInferer::visit(const ir::operation::LSTM &op)
const auto scratch_buffer_index{op.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)};
if (output.info().isDynamic() ||
- (_operands.exist(output_state_out_index) &&
- _operands.at(output_state_out_index).info().isDynamic()) ||
- (_operands.exist(cell_state_out_index) &&
- _operands.at(cell_state_out_index).info().isDynamic()) ||
- (_operands.exist(scratch_buffer_index) &&
- _operands.at(scratch_buffer_index).info().isDynamic()))
+ (operands.exist(output_state_out_index) &&
+ operands.at(output_state_out_index).info().isDynamic()) ||
+ (operands.exist(cell_state_out_index) &&
+ operands.at(cell_state_out_index).info().isDynamic()) ||
+ (operands.exist(scratch_buffer_index) &&
+ operands.at(scratch_buffer_index).info().isDynamic()))
return;
const auto input_index{op.getInputs().at(ir::operation::LSTM::Input::INPUT)};
- const auto &input = _operands.at(input_index);
+ const auto &input = operands.at(input_index);
const auto input_to_output_weights_index{
op.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)};
- const auto &input_to_output_weights = _operands.at(input_to_output_weights_index);
+ const auto &input_to_output_weights = operands.at(input_to_output_weights_index);
const auto recurrent_to_output_weights_index{
op.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)};
- const auto &recurrent_to_output_weights = _operands.at(recurrent_to_output_weights_index);
+ const auto &recurrent_to_output_weights = operands.at(recurrent_to_output_weights_index);
// re-sizing outputs
const int n_batch = (input.shape().rank() == 3 && op.param().time_major) ? input.shape().dim(1)
@@ -555,21 +563,21 @@ void StaticShapeInferer::visit(const ir::operation::LSTM &op)
output.info().shape(ir::Shape{n_batch, n_output});
}
- if (_operands.exist(output_state_out_index))
+ if (operands.exist(output_state_out_index))
{
- auto &output_state_out = _operands.at(output_state_out_index);
+ auto &output_state_out = operands.at(output_state_out_index);
output_state_out.info().shape(ir::Shape{n_batch, n_output});
}
- if (_operands.exist(cell_state_out_index))
+ if (operands.exist(cell_state_out_index))
{
- auto &cell_state_out = _operands.at(cell_state_out_index);
+ auto &cell_state_out = operands.at(cell_state_out_index);
cell_state_out.info().shape(ir::Shape{n_batch, n_cell});
}
- if (_operands.exist(scratch_buffer_index))
+ if (operands.exist(scratch_buffer_index))
{
- auto &scratch_buffer = _operands.at(scratch_buffer_index);
+ auto &scratch_buffer = operands.at(scratch_buffer_index);
const auto input_to_input_weights_index{
op.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)};
@@ -577,11 +585,11 @@ void StaticShapeInferer::visit(const ir::operation::LSTM &op)
op.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)};
bool has_input_to_input_weights =
- _operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
- _operands.at(input_to_input_weights_index).shape().dim(1) != 0;
+ operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
+ operands.at(input_to_input_weights_index).shape().dim(1) != 0;
bool has_recurrent_to_input_weights =
- _operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
- _operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0;
+ operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
+ operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0;
// NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
// true: no CIFG
@@ -605,20 +613,21 @@ void StaticShapeInferer::visit(const ir::operation::MatrixBandPart &op)
void StaticShapeInferer::visit(const ir::operation::OneHot &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto indice_idx{op.getInputs().at(ir::operation::OneHot::Input::INDICES)};
- const auto &indice = _operands.at(indice_idx);
+ const auto &indice = operands.at(indice_idx);
const auto depth_idx{op.getInputs().at(ir::operation::OneHot::Input::DEPTH)};
- const auto &depth = _operands.at(depth_idx);
+ const auto &depth = operands.at(depth_idx);
const auto axis = op.param().axis;
auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
if (!depth.isConstant())
{
output.info().setDynamic();
- _return_has_dynamic_tensor = true;
return;
}
@@ -631,12 +640,14 @@ void StaticShapeInferer::visit(const ir::operation::OneHot &op)
void StaticShapeInferer::visit(const ir::operation::Pack &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_idx{op.getInputs().at(0)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
// get mutable output operand
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
const auto rank = input.shape().rank() + 1;
const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
@@ -651,21 +662,22 @@ void StaticShapeInferer::visit(const ir::operation::Pack &op)
void StaticShapeInferer::visit(const ir::operation::Pad &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_idx{op.getInputs().at(ir::operation::Pad::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
const auto pad_idx{op.getInputs().at(ir::operation::Pad::Input::PAD)};
- const auto &pad = _operands.at(pad_idx);
+ const auto &pad = operands.at(pad_idx);
// get mutable output operand
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
// if pad is not constant, output also becomes dynamic
if (!pad.isConstant())
{
output.info().setDynamic();
- _return_has_dynamic_tensor = true;
return;
}
@@ -678,10 +690,12 @@ void StaticShapeInferer::visit(const ir::operation::Pad &op)
void StaticShapeInferer::visit(const ir::operation::Permute &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_idx{op.getInputs().at(0)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
// re-sizing output shape
// Permute is a special operation that layouts of input/output may be different on backend
@@ -700,16 +714,18 @@ void StaticShapeInferer::visit(const ir::operation::Pow &op)
void StaticShapeInferer::visit(const ir::operation::Range &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto start_idx{op.getInputs().at(ir::operation::Range::Input::START)};
const auto limit_idx{op.getInputs().at(ir::operation::Range::Input::LIMIT)};
const auto delta_idx{op.getInputs().at(ir::operation::Range::Input::DELTA)};
- const auto &start_op = _operands.at(start_idx);
- const auto &limit_op = _operands.at(limit_idx);
- const auto &delta_op = _operands.at(delta_idx);
+ const auto &start_op = operands.at(start_idx);
+ const auto &limit_op = operands.at(limit_idx);
+ const auto &delta_op = operands.at(delta_idx);
// get mutable output operand
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
ir::Shape new_shape;
if (start_op.isConstant() && limit_op.isConstant() && delta_op.isConstant())
@@ -731,21 +747,22 @@ void StaticShapeInferer::visit(const ir::operation::Range &op)
else
{
output.info().setDynamic();
- _return_has_dynamic_tensor = true;
}
}
void StaticShapeInferer::visit(const ir::operation::Reduce &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_idx{op.getInputs().at(ir::operation::Reduce::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
const auto axes_idx{op.getInputs().at(ir::operation::Reduce::Input::AXES)};
- const auto &axes = _operands.at(axes_idx);
+ const auto &axes = operands.at(axes_idx);
// get mutable output operand
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
std::vector<int32_t> axes_vec;
for (size_t i = 0; i < axes.shape().num_elements(); ++i)
@@ -777,19 +794,21 @@ void StaticShapeInferer::visit(const ir::operation::Reduce &op)
void StaticShapeInferer::visit(const ir::operation::Reshape &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_idx{op.getInputs().at(ir::operation::Reshape::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
// get mutable output operand
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
// New shape is given by second input tensor
if (op.getInputs().size() == 2)
{
// Let's check the second input
const auto shape_idx{op.getInputs().at(ir::operation::Reshape::Input::SHAPE)};
- const auto &shape = _operands.at(shape_idx);
+ const auto &shape = operands.at(shape_idx);
if (shape.isConstant())
{
@@ -810,7 +829,6 @@ void StaticShapeInferer::visit(const ir::operation::Reshape &op)
{
// if shape is NOT Const, set output shape to be dynamic_
output.info().setDynamic();
- _return_has_dynamic_tensor = true;
}
}
// New shape is given by option
@@ -835,21 +853,22 @@ void StaticShapeInferer::visit(const ir::operation::Reshape &op)
void StaticShapeInferer::visit(const ir::operation::ResizeBilinear &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_idx{op.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
// get mutable output operand
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
int32_t height_out, width_out;
if (op.getInputs().size() == 2)
{
- auto &size = _operands.at(op.getInputs().at(ir::operation::ResizeBilinear::Input::SIZE));
+ auto &size = operands.at(op.getInputs().at(ir::operation::ResizeBilinear::Input::SIZE));
if (!size.isConstant())
{
output.info().setDynamic();
- _return_has_dynamic_tensor = true;
return;
}
const auto size_v = size.asVector<std::int32_t>();
@@ -881,17 +900,19 @@ void StaticShapeInferer::visit(const ir::operation::Reverse &op)
void StaticShapeInferer::visit(const ir::operation::Select &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_cond_idx{op.getInputs().at(ir::operation::Select::Input::CONDITION)};
- const auto &input_cond = _operands.at(input_cond_idx);
+ const auto &input_cond = operands.at(input_cond_idx);
const auto input_true_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_TRUE)};
- const auto &input_true = _operands.at(input_true_idx);
+ const auto &input_true = operands.at(input_true_idx);
const auto input_false_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_FALSE)};
- const auto &input_false = _operands.at(input_false_idx);
+ const auto &input_false = operands.at(input_false_idx);
auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
// Select output shpae
ir::Shape new_shape = shape_inference::inferSelectShape(
@@ -901,12 +922,14 @@ void StaticShapeInferer::visit(const ir::operation::Select &op)
void StaticShapeInferer::visit(const ir::operation::Shape &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_idx{op.getInputs().at(0)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
// get mutable output operand
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
// re-sizing output shape
ir::Shape output_shape;
@@ -917,20 +940,21 @@ void StaticShapeInferer::visit(const ir::operation::Shape &op)
void StaticShapeInferer::visit(const ir::operation::Slice &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_index{op.getInputs().at(ir::operation::Slice::Input::INPUT)};
- const auto &input = _operands.at(input_index);
+ const auto &input = operands.at(input_index);
const auto begins_index{op.getInputs().at(ir::operation::Slice::Input::BEGINS)};
- const auto &begins = _operands.at(begins_index);
+ const auto &begins = operands.at(begins_index);
const auto sizes_index{op.getInputs().at(ir::operation::Slice::Input::SIZES)};
- const auto &sizes = _operands.at(sizes_index);
+ const auto &sizes = operands.at(sizes_index);
const auto output_index = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_index);
+ ir::Operand &output = operands.at(output_index);
// Whether input is constant or not does not affect whether output is dynamic or not
if (!(begins.isConstant() && sizes.isConstant()))
{
output.info().setDynamic();
- _return_has_dynamic_tensor = true;
return;
}
@@ -959,21 +983,22 @@ void StaticShapeInferer::visit(const ir::operation::Softmax &op)
void StaticShapeInferer::visit(const ir::operation::SpaceToBatchND &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto output_index = op.getOutputs().at(0);
const auto input_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
const auto block_shape_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)};
const auto padding_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
- ir::Operand &output = _operands.at(output_index);
- const auto &input = _operands.at(input_idx);
- const auto &block_shape = _operands.at(block_shape_idx);
- const auto &padding = _operands.at(padding_idx);
+ ir::Operand &output = operands.at(output_index);
+ const auto &input = operands.at(input_idx);
+ const auto &block_shape = operands.at(block_shape_idx);
+ const auto &padding = operands.at(padding_idx);
// Whether input is constant or not does not affect whether output is dynamic or not
if (!(block_shape.isConstant() && padding.isConstant()))
{
output.info().setDynamic();
- _return_has_dynamic_tensor = true;
return;
}
@@ -992,21 +1017,22 @@ void StaticShapeInferer::visit(const ir::operation::SpaceToBatchND &op)
void StaticShapeInferer::visit(const ir::operation::Split &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_idx{op.getInputs().at(ir::operation::Split::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
const auto axis_idx{op.getInputs().at(ir::operation::Split::Input::AXIS)};
- const auto &axis = _operands.at(axis_idx);
+ const auto &axis = operands.at(axis_idx);
auto outputs = op.getOutputs();
if (!axis.isConstant())
{
for (auto output_idx : outputs)
{
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
output.info().setDynamic();
}
- _return_has_dynamic_tensor = true;
return;
}
@@ -1022,7 +1048,7 @@ void StaticShapeInferer::visit(const ir::operation::Split &op)
shape_inference::inferSplitShape(input.info().shape(), axis_value, num_splits);
for (auto output_idx : outputs)
{
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
output.info().shape(new_shape);
}
}
@@ -1035,11 +1061,13 @@ void StaticShapeInferer::visit(const ir::operation::SquaredDifference &op)
void StaticShapeInferer::visit(const ir::operation::Squeeze &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_idx{op.getInputs().at(ir::operation::Squeeze::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
// Squeeze output shpae
ir::Shape new_shape = shape_inference::inferSqueezeShape(input.info().shape(), op.param());
@@ -1048,21 +1076,22 @@ void StaticShapeInferer::visit(const ir::operation::Squeeze &op)
void StaticShapeInferer::visit(const ir::operation::StridedSlice &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_index{op.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
- const auto &input = _operands.at(input_index);
+ const auto &input = operands.at(input_index);
const auto starts_index{op.getInputs().at(ir::operation::StridedSlice::Input::STARTS)};
- const auto &starts = _operands.at(starts_index);
+ const auto &starts = operands.at(starts_index);
const auto ends_index{op.getInputs().at(ir::operation::StridedSlice::Input::ENDS)};
- const auto &ends = _operands.at(ends_index);
+ const auto &ends = operands.at(ends_index);
const auto strides_index{op.getInputs().at(ir::operation::StridedSlice::Input::STRIDES)};
- const auto &strides = _operands.at(strides_index);
+ const auto &strides = operands.at(strides_index);
const auto output_index = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_index);
+ ir::Operand &output = operands.at(output_index);
if (!(starts.isConstant() && ends.isConstant() && strides.isConstant()))
{
output.info().setDynamic();
- _return_has_dynamic_tensor = true;
return;
}
@@ -1085,19 +1114,20 @@ void StaticShapeInferer::visit(const ir::operation::StridedSlice &op)
void StaticShapeInferer::visit(const ir::operation::Tile &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_idx{op.getInputs().at(ir::operation::Tile::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
const auto multiplier_idx{op.getInputs().at(ir::operation::Tile::Input::MULTIPLES)};
- const auto &multiplier = _operands.at(multiplier_idx);
+ const auto &multiplier = operands.at(multiplier_idx);
const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
if (!multiplier.isConstant())
{
output.info().setDynamic();
- _return_has_dynamic_tensor = true;
return;
}
@@ -1112,11 +1142,13 @@ void StaticShapeInferer::visit(const ir::operation::Tile &op)
void StaticShapeInferer::visit(const ir::operation::Transpose &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_idx{op.getInputs().at(ir::operation::Transpose::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
const auto perm_idx{op.getInputs().at(ir::operation::Transpose::Input::PERMUTATION)};
- const auto &perm = _operands.at(perm_idx);
+ const auto &perm = operands.at(perm_idx);
// perm.shape() != ir::Shape{0} means that perm is (n-1...0)
// TODO This condition changes to perm.num_elements() == 0
@@ -1124,11 +1156,10 @@ void StaticShapeInferer::visit(const ir::operation::Transpose &op)
// get mutable output operand
const auto output_idx = op.getOutputs().at(0);
- auto &output = _operands.at(output_idx);
+ auto &output = operands.at(output_idx);
if (!perm.isConstant() && !is_regular_transpose)
{
output.info().setDynamic();
- _return_has_dynamic_tensor = true;
return;
}
@@ -1157,8 +1188,10 @@ void StaticShapeInferer::visit(const ir::operation::Transpose &op)
void StaticShapeInferer::visit(const ir::operation::Unpack &op)
{
+ auto &operands = _lowered_subg->graph().operands();
+
const auto input_idx{op.getInputs().at(0)};
- const auto &input = _operands.at(input_idx);
+ const auto &input = operands.at(input_idx);
const auto num = op.param().num;
const auto rank = input.shape().rank();
const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
@@ -1169,10 +1202,9 @@ void StaticShapeInferer::visit(const ir::operation::Unpack &op)
for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
{
const auto output_idx = op.getOutputs().at(out_tensor_idx);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
output.info().setDynamic();
}
- _return_has_dynamic_tensor = true;
return;
}
@@ -1182,69 +1214,43 @@ void StaticShapeInferer::visit(const ir::operation::Unpack &op)
for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
{
const auto output_idx = op.getOutputs().at(out_tensor_idx);
- ir::Operand &output = _operands.at(output_idx);
+ ir::Operand &output = operands.at(output_idx);
output.info().shape(new_shape);
}
}
void StaticShapeInferer::visit(const ir::operation::While &op)
{
- auto &cond_graph = _lowered_subgs.at(op.param().cond_subg_index)->graph();
- auto &body_graph = _lowered_subgs.at(op.param().body_subg_index)->graph();
+ auto body_input_observer = _subg_input_observers.at(op.param().body_subg_index).get();
+ auto cond_input_observer = _subg_input_observers.at(op.param().cond_subg_index).get();
+ // re-sizing input shapes of body subgraph
const auto inputs = op.getInputs();
- const auto &outputs = op.getOutputs();
-
- // re-sizing input shapes of then subgraph
- const auto &cond_inputs = cond_graph.getInputs();
- assert(inputs.size() == cond_inputs.size());
+ std::vector<ir::OperandInfo> inputs_info;
+ const auto &graph = _lowered_subg->graph();
for (size_t i = 0; i < inputs.size(); ++i)
{
- const auto &input = _operands.at(inputs.at(i));
- auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
- if (input.info().isDynamic())
- {
- cond_input.info().setDynamic();
- }
- else
- {
- auto new_shape = input.info().shape();
- cond_input.info().shape(new_shape);
- }
+ const auto &operand_info = graph.operands().at(inputs.at(i)).info();
+ inputs_info.emplace_back(operand_info);
}
- // re-sizing input shapes of body subgraph
- const auto &body_inputs = body_graph.getInputs();
- assert(cond_inputs.size() == body_inputs.size());
- for (size_t i = 0; i < cond_inputs.size(); ++i)
- {
- const auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
- auto &body_input = body_graph.operands().at(body_inputs.at(i));
- if (cond_input.info().isDynamic())
- {
- body_input.info().setDynamic();
- }
- else
- {
- const auto &new_shape = cond_input.info().shape();
- body_input.info().shape(new_shape);
- }
- }
-
- // re-sizing operands of body subgraph
- inferSubgraph(op.param().body_subg_index);
+ body_input_observer->updateShapes(inputs_info);
+ _child_inferers.at(op.param().body_subg_index)->infer();
// Check whether while operation's shapes are predictable
- // If any of shape of body outputs and cond inputs are different, non-constant operands would be
- // set to dynamic
+ // This while op's outputs are also updated in the above function
+ // "_child_inferers.at(op.param().body_subg_index)->update()". That means that body's outputs and
+ // thils op's outputs must have the same shape. So we can predict whether body subgraphs will
+ // change at every step by comparing the shapes of inputs/outputs. If any of shape of body outputs
+ // and inputs are different Non-constant operands will be set to dynamic.
bool check_unpredictable_dynamic = false;
- const auto &body_outputs = body_graph.getOutputs();
- assert(body_outputs.size() == cond_inputs.size());
- for (size_t i = 0; i < body_outputs.size(); ++i)
+ const auto &updated_outputs = op.getOutputs();
+ assert(inputs_info.size() == updated_outputs.size());
+ for (size_t i = 0; i < updated_outputs.size(); ++i)
{
- const auto &body_output = body_graph.operands().at(body_outputs.at(i));
- auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
- if ((cond_input.info().isDynamic() != body_output.info().isDynamic()) ||
- (cond_input.shape() != body_output.shape()))
+ const auto &input_info = inputs_info.at(i);
+ const auto &output_info = graph.operands().at(updated_outputs.at(i)).info();
+ if (input_info.isDynamic() != output_info.isDynamic() ||
+ input_info.shape() != output_info.shape())
{
check_unpredictable_dynamic = true;
break;
@@ -1253,53 +1259,11 @@ void StaticShapeInferer::visit(const ir::operation::While &op)
if (check_unpredictable_dynamic)
{
- // Set inputs of body subgraph
- for (const auto &input_index : body_inputs)
- {
- auto &input = body_graph.operands().at(input_index);
- if (!input.isConstant())
- {
- input.info().setDynamic();
- }
- }
-
- // Set inputs of cond subgraph
- for (const auto &input_index : cond_inputs)
- {
- auto &input = cond_graph.operands().at(input_index);
- if (!input.isConstant())
- {
- input.info().setDynamic();
- }
- }
-
- // Set non-constant operands of body subgraph to dynamic
- inferSubgraph(op.param().body_subg_index);
- }
-
- // re-sizing operands of cond subgraph
- // If check_unpredictable_dynamic is true, non-constant operands of cond subgraph would be set to
- // dynamic
- inferSubgraph(op.param().cond_subg_index);
-
- // re-sizing outputs of while operation
- // If check_unpredictable_dynamic is true, outputs of while operation would be set to dynamic
- assert(cond_inputs.size() == outputs.size());
- for (size_t i = 0; i < cond_inputs.size(); ++i)
- {
- const auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
- auto &output = _operands.at(outputs.at(i));
- if (cond_input.info().isDynamic())
- {
- output.info().setDynamic();
- _return_has_dynamic_tensor = true;
- }
- else
- {
- const auto new_shape = cond_input.info().shape();
- output.info().shape(new_shape);
- }
+ body_input_observer->updateShapes(inputs_info, check_unpredictable_dynamic);
+ _child_inferers.at(op.param().body_subg_index)->infer();
}
+ cond_input_observer->updateShapes(inputs_info, check_unpredictable_dynamic);
+ _child_inferers.at(op.param().cond_subg_index)->infer();
}
void StaticShapeInferer::visit(const ir::operation::DetectionPostProcess &op)
@@ -1307,24 +1271,52 @@ void StaticShapeInferer::visit(const ir::operation::DetectionPostProcess &op)
// TODO: NMS supports very limited input/output size.
ir::operation::DetectionPostProcess::Param param = op.param();
+ auto &operands = _lowered_subg->graph().operands();
const int num_detected_boxes = param.max_detections * param.max_classes_per_detection;
const auto output_idx1 = op.getOutputs().at(0);
- auto &output1 = _operands.at(output_idx1);
+ auto &output1 = operands.at(output_idx1);
output1.info().shape({1, num_detected_boxes, 4});
const auto output_idx2 = op.getOutputs().at(1);
- auto &output2 = _operands.at(output_idx2);
+ auto &output2 = operands.at(output_idx2);
output2.info().shape({1, num_detected_boxes});
const auto output_idx3 = op.getOutputs().at(2);
- auto &output3 = _operands.at(output_idx3);
+ auto &output3 = operands.at(output_idx3);
output3.info().shape({1, num_detected_boxes});
const auto output_idx4 = op.getOutputs().at(3);
- auto &output4 = _operands.at(output_idx4);
+ auto &output4 = operands.at(output_idx4);
output4.info().shape({1});
}
+void StaticShapeInferer::visit(const ir::operation::Bulk &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ // TODO: support multiple inputs/outputs
+ const auto input_idx{op.getInputs().at(0)};
+ const auto &input = operands.at(input_idx);
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ auto cur_input_shape = input.info().shape();
+ auto origin_input_shape = op.param().origin_input_shapes[0];
+ auto cur_output_shape = output.info().shape();
+ auto origin_output_shape = op.param().origin_output_shapes[0];
+
+ // TODO: more check for valid batch request
+ assert(cur_input_shape.dim(0) >= origin_output_shape.dim(0));
+ assert(cur_input_shape.dim(0) % origin_output_shape.dim(0) == 0);
+ size_t batch_multiplier = cur_input_shape.dim(0) / origin_output_shape.dim(0);
+
+ ir::Shape new_shape;
+ new_shape.append(origin_output_shape.dim(0) * batch_multiplier);
+ for (int32_t d = 1; d < origin_output_shape.rank(); ++d)
+ new_shape.append(origin_output_shape.dim(d));
+
+ output.info().shape(new_shape);
+}
} // namespace compiler
diff --git a/runtime/onert/core/src/compiler/TensorRegistries.h b/runtime/onert/core/src/compiler/TensorRegistries.h
index 2a99db781..b3cc0bbe3 100644
--- a/runtime/onert/core/src/compiler/TensorRegistries.h
+++ b/runtime/onert/core/src/compiler/TensorRegistries.h
@@ -17,13 +17,14 @@
#ifndef __ONERT_COMPILER_TENSOR_REGISTRIES_H__
#define __ONERT_COMPILER_TENSOR_REGISTRIES_H__
-#include <unordered_set>
-#include <memory>
-#include "backend/BackendContext.h"
+#include "../backend/builtin/Config.h"
+#include "../backend/builtin/TensorRegistry.h"
+
#include "backend/Backend.h"
-#include "backend/builtin/Config.h"
-#include "backend/builtin/TensorBuilder.h"
-#include "backend/builtin/TensorRegistry.h"
+#include "backend/BackendContext.h"
+
+#include <memory>
+#include <unordered_set>
namespace onert
{
diff --git a/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc b/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc
index 181f388de..c27ce3d09 100644
--- a/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc
+++ b/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc
@@ -15,7 +15,6 @@
*/
#include "PermutationEliminationPass.h"
-#include "backend/builtin/Config.h"
#include "util/logging.h"
diff --git a/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc b/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc
index 6f9899114..71efa1bb5 100644
--- a/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc
+++ b/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc
@@ -17,18 +17,16 @@
#include "PermutationInsertionPass.h"
-#include <cassert>
-#include <utility>
-#include <unordered_map>
+#include "../../backend/builtin/Config.h"
-#include "backend/builtin/Config.h"
-#include "ir/Operand.h"
#include "compiler/OperationLowerInfo.h"
-#include "ir/Graph.h"
-#include "backend/IConfig.h"
+#include "ir/operation/Permute.h"
#include "util/logging.h"
+
+#include <cassert>
#include <memory>
-#include "ir/operation/Permute.h"
+#include <unordered_map>
+#include <utility>
namespace onert
{
@@ -125,6 +123,8 @@ ir::OperationIndex PermutationInsertionPass::insertPermute(const ir::OperandInde
// backend
auto &model_outputs = _graph.getOutputs();
const backend::Backend *builtin_backend = compiler::BackendManager::get().getBuiltin();
+ assert(builtin_backend->config()->id() == onert::backend::builtin::Config::ID);
+
if (model_outputs.contains(operand_index) && factor.backend() == builtin_backend)
{
model_outputs.replace(operand_index, out_operand_index);
@@ -141,6 +141,8 @@ ir::OperationIndex PermutationInsertionPass::insertPermute(const ir::OperandInde
const auto permute_node_layout = ir::Layout::UNKNOWN;
// NOTE If one backend supports several layout, the backend must support Permute operation
const backend::Backend *permute_node_backend = compiler::BackendManager::get().getBuiltin();
+ assert(permute_node_backend->config()->id() == onert::backend::builtin::Config::ID);
+
if (input_backend == output_backend)
{
permute_node_backend = input_backend;
diff --git a/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.test.cc b/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.test.cc
new file mode 100644
index 000000000..572b4df24
--- /dev/null
+++ b/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.test.cc
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "UnusedOperandEliminationPass.h"
+
+#include "ir/Graph.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert::ir;
+using namespace onert::compiler::pass;
+
+TEST(UnusedOperandEliminationPass, Simple)
+{
+ Graph graph;
+
+ // Add tensors
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::FLOAT32};
+ auto in = graph.addOperand(shape, type);
+ auto out = graph.addOperand(shape, type);
+
+ auto unused = graph.addOperand(shape, type);
+
+ // Set model inputs/outputs
+ graph.addInput(in);
+ graph.addOutput(out);
+
+ UnusedOperandEliminationPass{graph}.run();
+
+ ASSERT_TRUE(graph.operands().exist(in));
+ ASSERT_TRUE(graph.operands().exist(out));
+ ASSERT_FALSE(graph.operands().exist(unused));
+}
diff --git a/runtime/onert/core/src/dumper/dot/DotDumper.cc b/runtime/onert/core/src/dumper/dot/DotDumper.cc
index 714fb6fda..0bb2fa11f 100644
--- a/runtime/onert/core/src/dumper/dot/DotDumper.cc
+++ b/runtime/onert/core/src/dumper/dot/DotDumper.cc
@@ -19,6 +19,7 @@
#include "DotDumper.h"
#include "DotBuilder.h"
+#include "ir/OperandIndexMap.h"
#include "ir/OperationIndexMap.h"
#include "backend/Backend.h"
#include "backend/IConfig.h"
@@ -31,97 +32,72 @@ namespace dumper
namespace dot
{
-void DotDumper::dump(const std::string &tag)
+namespace
{
- if (_level == Level::OFF)
- {
- return;
- }
-
- onert::dumper::dot::DotBuilder dot_builder;
-
- auto &operations = _graph.operations();
- auto &operands = _graph.operands();
-
- ir::OperationIndexMap<std::unique_ptr<Operation>> operation_nodes;
- std::unordered_map<ir::OperandIndex, std::unique_ptr<Operand>> operand_nodes;
-
- auto backend_to_fillcolor = [](const backend::Backend *backend) {
- static const auto map = []() {
- std::unordered_map<const backend::Backend *, std::string> ret;
- uint32_t index = 1; // Start from 1 to avoid 0(red) which is too dark :(
- for (const auto backend : compiler::BackendManager::get().getAll())
- {
- ret.emplace(backend, Node::BG_COLORS[index]);
- index = (index + 1) % (sizeof(Node::BG_COLORS) / sizeof(Node::BG_COLORS[0]));
- }
- return ret;
- }();
-
- auto itr = map.find(backend);
- if (itr == map.end())
- {
- return Node::DEFAULT_FILLCOLOR;
- }
- else
+std::string backend_to_fillcolor(const backend::Backend *backend)
+{
+ static const auto map = []() {
+ std::unordered_map<const backend::Backend *, std::string> ret;
+ uint32_t index = 1; // Start from 1 to avoid 0(red) which is too dark :(
+ for (const auto backend : compiler::BackendManager::get().getAll())
{
- return itr->second;
+ ret.emplace(backend, Node::BG_COLORS[index]);
+ index = (index + 1) % (sizeof(Node::BG_COLORS) / sizeof(Node::BG_COLORS[0]));
}
- };
+ return ret;
+ }();
+ auto itr = map.find(backend);
+ if (itr == map.end())
+ {
+ return Node::DEFAULT_FILLCOLOR;
+ }
+ else
+ {
+ return itr->second;
+ }
+}
- util::Set<ir::OperandIndex> shown_operand_set;
+std::unordered_map<ir::OperandIndex, std::unique_ptr<Operand>>
+generate_dot_operands(const ir::Graph &graph, const DotDumper::Level level)
+{
+ std::unordered_map<ir::OperandIndex, std::unique_ptr<Operand>> dot_operands;
+ const auto &operands = graph.operands();
operands.iterate([&](const ir::OperandIndex &index, const ir::Operand &object) {
- bool showing_cond = false;
- if (_level == Level::ALL)
- {
- showing_cond = true;
- }
- else
- {
- showing_cond =
- !object.isConstant() || (_graph.getInputs() + _graph.getOutputs()).contains(index);
- }
+ bool showing_cond =
+ level == DotDumper::Level::ALL
+ ? true
+ : !object.isConstant() || (graph.getInputs() + graph.getOutputs()).contains(index);
if (showing_cond)
{
- shown_operand_set.add(index);
-
auto type = [&]() {
using onert::dumper::dot::Operand;
- if (_graph.getInputs().contains(index))
+ if (graph.getInputs().contains(index))
return Operand::Type::MODEL_INPUT;
- if (_graph.getOutputs().contains(index))
+ if (graph.getOutputs().contains(index))
return Operand::Type::MODEL_OUTPUT;
return Operand::Type::INTERNAL;
}();
auto node = std::make_unique<Operand>(index, type);
+ std::string label = std::to_string(index.value());
+ std::string fillcolor = "";
+ node->setAttribute("label", label);
+ node->setAttribute("fillcolor", fillcolor);
- {
- // Display LowerInfo attributes
- std::string label = std::to_string(index.value());
- std::string fillcolor = "";
- if (_lowered_graph)
- {
- auto lower_info = _lowered_graph->lower_info().operand.getRawPtr(index);
- const auto &def_factors = lower_info->def_factors();
- if (def_factors.size() > 0)
- {
- label += "\\n[";
- label += def_factors.getOnlyElement().backend()->config()->id();
- label += "]";
-
- fillcolor = backend_to_fillcolor(lower_info->def_factors().getOnlyElement().backend());
- }
- }
- node->setAttribute("label", label);
- node->setAttribute("fillcolor", fillcolor);
- }
-
- operand_nodes.emplace(index, std::move(node));
+ dot_operands.emplace(index, std::move(node));
}
});
+ return dot_operands;
+}
+
+ir::OperationIndexMap<std::unique_ptr<Operation>>
+generate_dot_operations(const ir::Graph &graph,
+ const ir::OperandIndexMap<std::unique_ptr<Operand>> &dot_operands)
+{
+ ir::OperationIndexMap<std::unique_ptr<Operation>> dot_operations;
+ const auto &operations = graph.operations();
operations.iterate([&](const ir::OperationIndex &index, const ir::Operation &op) {
auto node = std::make_unique<Operation>(index, op);
@@ -130,42 +106,79 @@ void DotDumper::dump(const std::string &tag)
using onert::dumper::dot::Operand;
// Constant input and dump level is ALL_BUT_CONSTANTS
- if (operand_nodes.find(input) == operand_nodes.end())
+ if (dot_operands.find(input) == dot_operands.end())
continue;
- auto &input_node = operand_nodes.at(input);
+ auto &input_node = dot_operands.at(input);
input_node->addOutEdge(node.get());
}
for (auto output : op.getOutputs() | ir::Remove::UNDEFINED)
{
using onert::dumper::dot::Operand;
- auto &output_node = operand_nodes.at(output);
+ auto &output_node = dot_operands.at(output);
node->addOutEdge(output_node.get());
}
- operation_nodes.emplace(index, std::move(node));
+ dot_operations.emplace(index, std::move(node));
});
- if (_lowered_graph)
- {
- _graph.operations().iterate([&](const ir::OperationIndex &index, const ir::Operation &) {
- const auto lower_info = _lowered_graph->lower_info().operation.getRawPtr(index);
- if (lower_info)
+ return dot_operations;
+}
+
+void update_lower_info(const compiler::LoweredGraph &lowered_graph,
+ ir::OperandIndexMap<std::unique_ptr<Operand>> *dot_operands)
+{
+ const auto &operands = lowered_graph.graph().operands();
+ operands.iterate([&](const ir::OperandIndex &index, const ir::Operand &) {
+ auto itr = dot_operands->find(index);
+ if (itr != dot_operands->end())
+ {
+ auto &node = itr->second;
+ // Display LowerInfo attributes
+ std::string label = node->getAttribute("label");
+ std::string fillcolor = node->getAttribute("fillcolor");
+ auto lower_info = lowered_graph.lower_info().operand.getRawPtr(index);
+ const auto &def_factors = lower_info->def_factors();
+ if (def_factors.size() > 0)
{
- auto fillcolor = backend_to_fillcolor(lower_info->backend());
- std::string backend_label = "[" + lower_info->backend()->config()->id() + "]";
- auto itr = operation_nodes.find(index);
- if (itr != operation_nodes.end())
- {
- auto &node = itr->second;
- node->setAttribute("label", node->getAttribute("label") + "\n" + backend_label);
- node->setAttribute("fillcolor", fillcolor);
- }
+ label += "\\n[";
+ label += def_factors.getOnlyElement().backend()->config()->id();
+ label += "]";
+ fillcolor = backend_to_fillcolor(lower_info->def_factors().getOnlyElement().backend());
}
- });
- }
+ node->setAttribute("label", label);
+ node->setAttribute("fillcolor", fillcolor);
+ }
+ });
+}
+void update_lower_info(const compiler::LoweredGraph &lowered_graph,
+ ir::OperationIndexMap<std::unique_ptr<Operation>> *dot_operations)
+{
+ const auto &operations = lowered_graph.graph().operations();
+ operations.iterate([&](const ir::OperationIndex &index, const ir::Operation &) {
+ const auto lower_info = lowered_graph.lower_info().operation.getRawPtr(index);
+ if (lower_info)
+ {
+ auto fillcolor = backend_to_fillcolor(lower_info->backend());
+ std::string backend_label = "[" + lower_info->backend()->config()->id() + "]";
+ auto itr = dot_operations->find(index);
+ if (itr != dot_operations->end())
+ {
+ auto &node = itr->second;
+ node->setAttribute("label", node->getAttribute("label") + "\n" + backend_label);
+ node->setAttribute("fillcolor", fillcolor);
+ }
+ }
+ });
+}
+
+void dump_to_file(const ir::OperandIndexMap<std::unique_ptr<Operand>> &operand_nodes,
+ const ir::OperationIndexMap<std::unique_ptr<Operation>> &operation_nodes,
+ const std::string &tag)
+{
+ onert::dumper::dot::DotBuilder dot_builder;
for (const auto &e : operation_nodes)
dot_builder.update(*e.second);
for (const auto &e : operand_nodes)
@@ -186,6 +199,33 @@ void DotDumper::dump(const std::string &tag)
fb.close();
}
}
+} // namespace
+
+void DotDumper::dump(const ir::Graph &graph, const std::string &tag)
+{
+ if (_level == Level::OFF)
+ {
+ return;
+ }
+
+ const auto dot_operands = generate_dot_operands(graph, _level);
+ const auto dot_operations = generate_dot_operations(graph, dot_operands);
+ dump_to_file(dot_operands, dot_operations, tag);
+}
+
+void DotDumper::dump(const compiler::LoweredGraph &lowered_graph, const std::string &tag)
+{
+ if (_level == Level::OFF)
+ {
+ return;
+ }
+
+ auto dot_operands = generate_dot_operands(lowered_graph.graph(), _level);
+ auto dot_operations = generate_dot_operations(lowered_graph.graph(), dot_operands);
+ update_lower_info(lowered_graph, &dot_operands);
+ update_lower_info(lowered_graph, &dot_operations);
+ dump_to_file(dot_operands, dot_operations, tag);
+}
} // namespace dot
} // namespace dumper
diff --git a/runtime/onert/core/src/dumper/dot/DotDumper.h b/runtime/onert/core/src/dumper/dot/DotDumper.h
index f300c3432..6249010d3 100644
--- a/runtime/onert/core/src/dumper/dot/DotDumper.h
+++ b/runtime/onert/core/src/dumper/dot/DotDumper.h
@@ -38,27 +38,28 @@ public:
};
public:
- DotDumper(const ir::Graph &graph, Level level)
- : _lowered_graph{nullptr}, _graph(graph), _level{level}
- {
- }
- DotDumper(const compiler::LoweredGraph *lowered_graph, Level level)
- : _lowered_graph{lowered_graph}, _graph(_lowered_graph->graph()), _level{level}
- {
- }
+ DotDumper(Level level) : _level{level} {}
public:
/**
- * @brief Dump to dot file as tag name if "GRAPH_DOT_DUMP" is set
+ * @brief Dump graph information to dot file as tag name if "GRAPH_DOT_DUMP" is set
+ *
+ * @param[in] graph The graph that would be used to get operations and operands
+ * @param[in] tag The name of dot file that would be created
+ * @return N/A
+ */
+ void dump(const ir::Graph &graph, const std::string &tag);
+
+ /**
+ * @brief Dump lowered graph information to dot file as tag name if "GRAPH_DOT_DUMP" is set
*
+ * @param[in] graph The graph that would be used to get operations and operands
* @param[in] tag The name of dot file that would be created
* @return N/A
*/
- void dump(const std::string &tag);
+ void dump(const compiler::LoweredGraph &lowered_graph, const std::string &tag);
private:
- const compiler::LoweredGraph *_lowered_graph;
- const ir::Graph &_graph;
Level _level;
};
diff --git a/runtime/onert/core/src/exec/DataflowExecutor.h b/runtime/onert/core/src/exec/DataflowExecutor.h
index bcac19d2e..1649be733 100644
--- a/runtime/onert/core/src/exec/DataflowExecutor.h
+++ b/runtime/onert/core/src/exec/DataflowExecutor.h
@@ -17,19 +17,18 @@
#ifndef __ONERT_EXEC_DATAFLOW_EXECUTOR_H__
#define __ONERT_EXEC_DATAFLOW_EXECUTOR_H__
-#include <list>
-#include <map>
-#include <unordered_map>
-
-#include "exec/FunctionSequence.h"
+#include "ExecutorBase.h"
#include "Job.h"
-#include "ir/OperandIndexSequence.h"
-#include "ir/Index.h"
-#include <memory>
-#include "exec/ExecutorBase.h"
+
#include "compiler/CodeMap.h"
+#include "ir/OperandIndexSequence.h"
#include "util/TracingCtx.h"
+#include <list>
+#include <map>
+#include <memory>
+#include <unordered_map>
+
namespace onert
{
namespace exec
diff --git a/runtime/onert/core/src/exec/ExecTime.cc b/runtime/onert/core/src/exec/ExecTime.cc
index 6bf2744a9..4b82655b9 100644
--- a/runtime/onert/core/src/exec/ExecTime.cc
+++ b/runtime/onert/core/src/exec/ExecTime.cc
@@ -14,12 +14,10 @@
* limitations under the License.
*/
-#include "exec/ExecTime.h"
+#include "ExecTime.h"
-#include <fstream>
-#include <cassert>
-#include <limits>
#include <algorithm>
+#include <cassert>
namespace onert
{
diff --git a/runtime/onert/core/src/exec/ExecTime.test.cc b/runtime/onert/core/src/exec/ExecTime.test.cc
new file mode 100644
index 000000000..1f7152e7b
--- /dev/null
+++ b/runtime/onert/core/src/exec/ExecTime.test.cc
@@ -0,0 +1,106 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ExecTime.h"
+
+#include "backend/IConfig.h"
+#include "backend/Backend.h"
+
+#include <gtest/gtest.h>
+
+#include <string>
+
+namespace
+{
+using namespace onert;
+using namespace exec;
+using namespace backend;
+
+struct MockConfig : public IConfig
+{
+ std::string id() override { return "b1"; }
+ bool initialize() override { return true; };
+ bool supportPermutation() override { return false; }
+ ir::Layout supportLayout(const ir::Operation &, ir::Layout) override
+ {
+ return ir::Layout::UNKNOWN;
+ }
+ bool supportDynamicTensor() override { return false; }
+ bool supportFP16() override { return false; }
+};
+
+struct MockBackend : public ::onert::backend::Backend
+{
+ std::shared_ptr<onert::backend::IConfig> config() const override
+ {
+ return std::make_shared<MockConfig>();
+ }
+ std::unique_ptr<onert::backend::BackendContext> newContext(ContextData &&) const override
+ {
+ return nullptr;
+ }
+};
+
+TEST(ExecTime, roundtrip_ok)
+{
+ const auto *b = new MockBackend();
+ std::vector<const Backend *> bs = {b};
+ {
+ ExecTime et(bs);
+ et.updateOperationExecTime(b, "op1", true, 100, 100);
+ et.updateOperationExecTime(b, "op1", true, 200, 200);
+ et.updateOperationExecTime(b, "op1", false, 100, 888);
+ et.storeOperationsExecTime();
+ }
+ {
+ ExecTime et(bs);
+ auto time = et.getOperationExecTime(b, "op1", true, 100);
+ ASSERT_EQ(time, 100);
+ // Check interpolation
+ time = et.getOperationExecTime(b, "op1", true, 150);
+ ASSERT_EQ(time, 150);
+ time = et.getOperationExecTime(b, "op1", false, 100);
+ ASSERT_EQ(time, 888);
+ et.storeOperationsExecTime();
+ }
+ // clean up
+ EXPECT_EQ(remove("exec_time.json"), 0);
+}
+
+TEST(ExecTime, structure)
+{
+
+ const auto *b = new MockBackend();
+ std::vector<const Backend *> bs = {b};
+ {
+ ExecTime et(bs);
+ et.updateOperationExecTime(b, "op1", true, 100, 100);
+ et.updateOperationExecTime(b, "op1", true, 200, 200);
+ et.storeOperationsExecTime();
+ }
+ {
+ ExecTime et(bs);
+ auto time = et.getOperationExecTime(b, "op1", true, 100);
+ ASSERT_EQ(time, 100);
+ // Check interpolation
+ time = et.getOperationExecTime(b, "op1", true, 200);
+ ASSERT_EQ(time, 200);
+ et.storeOperationsExecTime();
+ }
+ // clean up
+ EXPECT_EQ(remove("exec_time.json"), 0);
+}
+} // unnamed namespace
diff --git a/runtime/onert/core/src/exec/Execution.cc b/runtime/onert/core/src/exec/Execution.cc
index 8eff73bac..9d1e06d6c 100644
--- a/runtime/onert/core/src/exec/Execution.cc
+++ b/runtime/onert/core/src/exec/Execution.cc
@@ -23,13 +23,12 @@ namespace onert
namespace exec
{
-Execution::Execution(const std::shared_ptr<ExecutorMap> &executors) : _executors{executors}
+Execution::Execution(const std::shared_ptr<Executors> &executors) : _executors{executors}
{
assert(executors != nullptr);
assert(executors->at(ir::SubgraphIndex{0}) != nullptr);
- const auto &primary_subg = primary_subgraph();
- _io_desc.inputs.resize(primary_subg.getInputs().size());
- _io_desc.outputs.resize(primary_subg.getOutputs().size());
+ _io_desc.inputs.resize(_executors->inputSize());
+ _io_desc.outputs.resize(_executors->outputSize());
sem_init(&_async_io_descs_sem, 0, 1);
}
@@ -48,8 +47,7 @@ void Execution::changeInputShape(const ir::IOIndex &index, const ir::Shape &new_
void Execution::setInput(const ir::IOIndex &index, const void *buffer, size_t length,
ir::Layout layout)
{
- const auto input_index = primary_subgraph().getInputs().at(index);
- const auto info = primary_subgraph().operands().at(input_index).info();
+ const auto info = _executors->inputInfo(index);
// TODO handle when (!buffer && length != 0) : setting the input as an optional tensor
@@ -105,8 +103,7 @@ bool Execution::isEmptyQueue()
void Execution::executeAsyncInput(const ir::IOIndex &index, const void *buffer, size_t length,
ir::Layout layout)
{
- const auto input_index = primary_subgraph().getInputs().at(index);
- const auto info = primary_subgraph().operands().at(input_index).info();
+ const auto info = _executors->inputInfo(index);
IODescription *_async_io_desc = _async_io_descs.back().first;
{
@@ -135,8 +132,7 @@ void Execution::executeAsyncInput(const ir::IOIndex &index, const void *buffer,
void Execution::executeAsyncOutput(const ir::IOIndex &index, void *buffer, size_t length,
ir::Layout layout)
{
- const auto output_index = primary_subgraph().getOutputs().at(index);
- const auto info = primary_subgraph().operands().at(output_index).info();
+ const auto info = _executors->outputInfo(index);
IODescription *_async_io_desc = _async_io_descs.front().first;
if (length < info.total_size())
@@ -165,8 +161,7 @@ void Execution::setInput(const ir::IOIndex &index, const ir::TypeInfo &type, con
// TODO Remove default parameter
void Execution::setOutput(const ir::IOIndex &index, void *buffer, size_t length, ir::Layout layout)
{
- const auto output_index = primary_subgraph().getOutputs().at(index);
- const auto info = primary_subgraph().operands().at(output_index).info();
+ const auto info = _executors->outputInfo(index);
if (length < info.total_size())
{
@@ -208,7 +203,7 @@ void Execution::execute()
{
VERBOSE(Execution) << "Start execution" << std::endl;
- primary_executor()->execute(_io_desc);
+ _executors->execute(_io_desc);
finished = true;
VERBOSE(Execution) << "Execution finished" << std::endl;
@@ -248,8 +243,7 @@ ir::Shape Execution::getInputShape(ir::IOIndex ind) const
auto itr = _io_desc.dynamic_input_shapes.find(ind);
if (itr == _io_desc.dynamic_input_shapes.end())
{
- auto operand_idx = primary_subgraph().getInputs().at(ind);
- return primary_subgraph().operands().at(operand_idx).shape();
+ return _executors->inputInfo(ind).shape();
}
else
{
diff --git a/runtime/onert/core/src/exec/Execution.test.cc b/runtime/onert/core/src/exec/Execution.test.cc
new file mode 100644
index 000000000..e3ea49470
--- /dev/null
+++ b/runtime/onert/core/src/exec/Execution.test.cc
@@ -0,0 +1,302 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "exec/Execution.h"
+
+#include "compiler/Compiler.h"
+#include "ir/Graph.h"
+#include "ir/operation/BinaryArithmetic.h"
+#include "util/TracingCtx.h"
+
+#include <gtest/gtest.h>
+#include <thread>
+
+namespace
+{
+
+using namespace onert::ir;
+
+class CompiledMockUpModel
+{
+public:
+ CompiledMockUpModel()
+ {
+ // Model: two elementwise add operation
+ // model input: lhs, rhs1
+ // model output: second add result (result2)
+ // constant: rhs2
+ // result1 <= (lhs + rhs)
+ // result2 <= (result1 + rhs2)
+ // lhs, rhs1, rh2, result1, result2 shape: {1, 2, 2, 1}
+ // activation: none (constant)
+ graph = std::make_shared<Graph>();
+ // 1st add operands (result1 <= lhs + rhs1)
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::FLOAT32};
+ static float rhs2_data[4] = {3, 1, -1, 5};
+ auto operand_lhs = graph->addOperand(shape, type);
+ auto operand_rhs1 = graph->addOperand(shape, type);
+ auto operand_result1 = graph->addOperand(shape, type);
+ auto operand_rhs2 = graph->addOperand(shape, type);
+ auto operand_result2 = graph->addOperand(shape, type);
+ graph->operands()
+ .at(operand_rhs2)
+ .data(std::make_unique<CachedData>(reinterpret_cast<const uint8_t *>(&rhs2_data), 16));
+ // 2nd add operations (result2 <= result1 + rhs2)
+ operation::BinaryArithmetic::Param param1;
+ param1.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD;
+ param1.activation = Activation::NONE;
+ auto input_set1 = OperandIndexSequence{operand_lhs, operand_rhs1};
+ auto output_set1 = OperandIndexSequence{operand_result1};
+ graph->addOperation(
+ std::make_unique<operation::BinaryArithmetic>(input_set1, output_set1, param1));
+ operation::BinaryArithmetic::Param param2;
+ param2.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD;
+ param2.activation = Activation::NONE;
+ auto input_set2 = OperandIndexSequence{operand_result1, operand_rhs2};
+ auto output_set2 = OperandIndexSequence{operand_result2};
+ graph->addOperation(
+ std::make_unique<operation::BinaryArithmetic>(input_set2, output_set2, param2));
+ // Identify model inputs and outputs
+ graph->addInput(operand_lhs);
+ graph->addInput(operand_rhs1);
+ graph->addOutput(operand_result2);
+ graph->verify();
+
+ // Compile
+ auto model = std::make_shared<onert::ir::Model>();
+ model->push(onert::ir::SubgraphIndex{0}, graph);
+ coptions = onert::compiler::CompilerOptions::fromGlobalConfig();
+ onert::compiler::Compiler compiler{model, *coptions};
+ artifact = compiler.compile();
+ }
+
+public:
+ std::shared_ptr<Graph> graph;
+ std::unique_ptr<onert::compiler::CompilerOptions> coptions;
+ std::shared_ptr<onert::compiler::CompilerArtifact> artifact;
+};
+
+TEST(ExecInstance, simple)
+{
+ auto mockup = CompiledMockUpModel();
+ auto graph = mockup.graph;
+ auto executors = mockup.artifact->_executors;
+
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto output = IOIndex{0};
+
+ const float input1_buffer[4] = {1, 0, -1, -2};
+ const float input2_buffer[4] = {1, -3, 2, -4};
+ float output_buffer[4] = {};
+ const float output_expected[4] = {5, -2, 0, -1};
+
+ onert::exec::Execution execution{executors};
+
+ execution.setInput(input1, reinterpret_cast<const void *>(input1_buffer), 16);
+ execution.setInput(input2, reinterpret_cast<const void *>(input2_buffer), 16);
+ execution.setOutput(output, reinterpret_cast<void *>(output_buffer), 16);
+ execution.execute();
+
+ for (auto i = 0; i < 4; i++)
+ {
+ EXPECT_EQ(output_buffer[i], output_expected[i]);
+ }
+}
+
+TEST(ExecInstance, twoCompile)
+{
+ auto mockup = CompiledMockUpModel();
+ auto graph = mockup.graph;
+ auto executors1 = mockup.artifact->_executors;
+ onert::exec::Execution execution1{executors1};
+
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto output = IOIndex{0};
+
+ const float exe1_input1_buffer[4] = {1, 0, -1, -2};
+ const float exe1_input2_buffer[4] = {1, -3, 2, -4};
+ float exe1_output_buffer[4] = {};
+ const float exe1_output_expected[4] = {5, -2, 0, -1};
+
+ execution1.setInput(input1, reinterpret_cast<const void *>(exe1_input1_buffer), 16);
+ execution1.setInput(input2, reinterpret_cast<const void *>(exe1_input2_buffer), 16);
+ execution1.setOutput(output, reinterpret_cast<void *>(exe1_output_buffer), 16);
+
+ // Make new executor: compile again
+ auto model = std::make_shared<onert::ir::Model>();
+ model->push(onert::ir::SubgraphIndex{0}, graph);
+ auto coptions = onert::compiler::CompilerOptions::fromGlobalConfig();
+ onert::compiler::Compiler compiler{model, *coptions};
+ std::shared_ptr<onert::compiler::CompilerArtifact> artifact = compiler.compile();
+ onert::exec::Execution execution2{artifact->_executors};
+
+ const float exe2_input1_buffer[4] = {2, 1, -2, 0};
+ const float exe2_input2_buffer[4] = {-3, 3, 1, 2};
+ float exe2_output_buffer[4] = {};
+ const float exe2_output_expected[4] = {2, 5, -2, 7};
+
+ execution2.setInput(input1, reinterpret_cast<const void *>(exe2_input1_buffer), 16);
+ execution2.setInput(input2, reinterpret_cast<const void *>(exe2_input2_buffer), 16);
+ execution2.setOutput(output, reinterpret_cast<void *>(exe2_output_buffer), 16);
+
+ execution1.execute();
+ execution2.execute();
+
+ for (auto i = 0; i < 4; i++)
+ {
+ EXPECT_EQ(exe1_output_buffer[i], exe1_output_expected[i]);
+ EXPECT_EQ(exe2_output_buffer[i], exe2_output_expected[i]);
+ }
+}
+
+// Support two initialized execution instance then ordered execution
+TEST(ExecInstance, twoExecution)
+{
+ auto mockup = CompiledMockUpModel();
+ auto executors = mockup.artifact->_executors;
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto output1 = IOIndex{0};
+
+ const float exe1_input1_buffer[4] = {1, 0, -1, -2};
+ const float exe1_input2_buffer[4] = {1, -3, 2, -4};
+ float exe1_output_buffer[4] = {};
+ const float exe1_output_expected[4] = {5, -2, 0, -1};
+ const float exe2_output_expected[4] = {2, 5, -2, 7};
+
+ onert::exec::Execution execution1{executors};
+ execution1.setInput(input1, reinterpret_cast<const void *>(exe1_input1_buffer), 16);
+ execution1.setInput(input2, reinterpret_cast<const void *>(exe1_input2_buffer), 16);
+ execution1.setOutput(output1, reinterpret_cast<void *>(exe1_output_buffer), 16);
+
+ const float exe2_input1_buffer[4] = {2, 1, -2, 0};
+ const float exe2_input2_buffer[4] = {-3, 3, 1, 2};
+ float exe2_output_buffer[4] = {};
+
+ // Make new execution
+ onert::exec::Execution execution2{executors};
+ execution2.setInput(input1, reinterpret_cast<const void *>(exe2_input1_buffer), 16);
+ execution2.setInput(input2, reinterpret_cast<const void *>(exe2_input2_buffer), 16);
+ execution2.setOutput(output1, reinterpret_cast<void *>(exe2_output_buffer), 16);
+
+ execution1.execute();
+ execution2.execute();
+
+ for (auto i = 0; i < 4; i++)
+ {
+ EXPECT_EQ(exe1_output_buffer[i], exe1_output_expected[i]);
+ EXPECT_EQ(exe2_output_buffer[i], exe2_output_expected[i]);
+ }
+}
+
+class Inference
+{
+public:
+ Inference(const float (&input1)[4], const float (&input2)[4], float (&output)[4],
+ std::shared_ptr<onert::exec::Executors> &executors)
+ : _input1{input1}, _input2{input2}, _output{output}, _executors{executors}
+ {
+ // DO NOTHING
+ }
+
+ void inference(void)
+ {
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto output1 = IOIndex{0};
+
+ onert::exec::Execution execution{_executors};
+ execution.setInput(input1, reinterpret_cast<const void *>(_input1), 16);
+ execution.setInput(input2, reinterpret_cast<const void *>(_input2), 16);
+ execution.setOutput(output1, reinterpret_cast<void *>(_output), 16);
+
+ execution.execute();
+ }
+
+private:
+ const float (&_input1)[4];
+ const float (&_input2)[4];
+ float (&_output)[4];
+ std::shared_ptr<onert::exec::Executors> &_executors;
+};
+
+// Support multi-thread execution
+TEST(ExecInstance, twoThreads)
+{
+ auto mockup = CompiledMockUpModel();
+ auto executors = mockup.artifact->_executors;
+
+ const float exe1_input1_buffer[4] = {1, 0, -1, -2};
+ const float exe1_input2_buffer[4] = {1, -3, 2, -4};
+ float exe1_output_buffer[4] = {};
+ const float exe1_output_expected[4] = {5, -2, 0, -1};
+
+ Inference execution1{exe1_input1_buffer, exe1_input2_buffer, exe1_output_buffer, executors};
+
+ const float exe2_input1_buffer[4] = {2, 1, -2, 0};
+ const float exe2_input2_buffer[4] = {-3, 3, 1, 2};
+ float exe2_output_buffer[4] = {};
+ const float exe2_output_expected[4] = {2, 5, -2, 7};
+
+ Inference execution2{exe2_input1_buffer, exe2_input2_buffer, exe2_output_buffer, executors};
+
+ std::thread t1{&Inference::inference, &execution1};
+ std::thread t2{&Inference::inference, &execution2};
+
+ t1.join();
+ t2.join();
+
+ for (auto i = 0; i < 4; i++)
+ {
+ EXPECT_EQ(exe1_output_buffer[i], exe1_output_expected[i]);
+ EXPECT_EQ(exe2_output_buffer[i], exe2_output_expected[i]);
+ }
+}
+
+// Support asynchronous execution
+TEST(ExecInstance, async)
+{
+ auto mockup = CompiledMockUpModel();
+ auto graph = mockup.graph;
+ auto executors = mockup.artifact->_executors;
+
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto output = IOIndex{0};
+
+ const float input1_buffer[4] = {1, 0, -1, -2};
+ const float input2_buffer[4] = {1, -3, 2, -4};
+ float output_buffer[4] = {};
+ const float output_expected[4] = {5, -2, 0, -1};
+
+ onert::exec::Execution execution{executors};
+
+ execution.setInput(input1, reinterpret_cast<const void *>(input1_buffer), 16);
+ execution.setInput(input2, reinterpret_cast<const void *>(input2_buffer), 16);
+ execution.setOutput(output, reinterpret_cast<void *>(output_buffer), 16);
+ execution.startExecute();
+ execution.waitFinish();
+
+ for (auto i = 0; i < 4; i++)
+ {
+ EXPECT_EQ(output_buffer[i], output_expected[i]);
+ }
+}
+
+} // namespace
diff --git a/runtime/onert/core/src/exec/ExecutionObservee.h b/runtime/onert/core/src/exec/ExecutionObservee.h
index 423b5026b..3ee1754c9 100644
--- a/runtime/onert/core/src/exec/ExecutionObservee.h
+++ b/runtime/onert/core/src/exec/ExecutionObservee.h
@@ -17,11 +17,12 @@
#ifndef __ONERT_EXEC_EXECUTION_OBSERVEE_H__
#define __ONERT_EXEC_EXECUTION_OBSERVEE_H__
-#include <list>
+#include "ExecutionObservers.h"
-#include "exec/ExecutionObservers.h"
#include "ir/Index.h"
+#include <list>
+
namespace onert
{
namespace exec
diff --git a/runtime/onert/core/src/exec/ExecutionObservers.cc b/runtime/onert/core/src/exec/ExecutionObservers.cc
index 386178ae6..9abde7ba4 100644
--- a/runtime/onert/core/src/exec/ExecutionObservers.cc
+++ b/runtime/onert/core/src/exec/ExecutionObservers.cc
@@ -14,16 +14,16 @@
* limitations under the License.
*/
-#include "exec/ExecutionObservers.h"
+#include "ExecutionObservers.h"
-#include <string>
-#include <sstream>
+#include "../util/EventWriter.h"
#include "util/logging.h"
-#include "exec/IExecutor.h"
-#include "misc/polymorphic_downcast.h"
-#include "ir/Operation.h"
-#include "util/EventWriter.h"
+
+#include <misc/polymorphic_downcast.h>
+
+#include <string>
+#include <sstream>
namespace
{
diff --git a/runtime/onert/core/src/exec/ExecutionObservers.h b/runtime/onert/core/src/exec/ExecutionObservers.h
index 4c6c7b18e..1aadac2f5 100644
--- a/runtime/onert/core/src/exec/ExecutionObservers.h
+++ b/runtime/onert/core/src/exec/ExecutionObservers.h
@@ -17,17 +17,16 @@
#ifndef __ONERT_EXEC_OBSREVERS_H__
#define __ONERT_EXEC_OBSREVERS_H__
-#include "exec/IFunction.h"
+#include "ExecTime.h"
+#include "../util/EventCollector.h"
+#include "../util/EventRecorder.h"
+#include "../util/EventWriter.h"
+
+#include "exec/Executors.h"
#include "ir/Index.h"
#include "ir/Operation.h"
-#include "ExecTime.h"
#include "util/ITimer.h"
-#include "exec/IExecutor.h"
-#include "util/EventCollector.h"
-#include "util/EventRecorder.h"
-#include "util/EventWriter.h"
#include "util/TracingCtx.h"
-#include "util/EventWriter.h"
namespace onert
{
diff --git a/runtime/onert/core/src/exec/ExecutorBase.cc b/runtime/onert/core/src/exec/ExecutorBase.cc
index efc22cfa5..d2d204a0b 100644
--- a/runtime/onert/core/src/exec/ExecutorBase.cc
+++ b/runtime/onert/core/src/exec/ExecutorBase.cc
@@ -15,11 +15,10 @@
*/
#include "ExecutorBase.h"
+
#include "ShapeConverter.h"
-#include "backend/builtin/UserTensor.h"
-#include "util/logging.h"
-#include "misc/polymorphic_downcast.h"
+#include <misc/polymorphic_downcast.h>
namespace onert
{
diff --git a/runtime/onert/core/src/exec/ExecutorBase.h b/runtime/onert/core/src/exec/ExecutorBase.h
index c0f609d11..e4f914546 100644
--- a/runtime/onert/core/src/exec/ExecutorBase.h
+++ b/runtime/onert/core/src/exec/ExecutorBase.h
@@ -17,22 +17,17 @@
#ifndef __ONERT_EXEC_EXECUTOR_BASE_H__
#define __ONERT_EXEC_EXECUTOR_BASE_H__
-#include "IPermuteFunction.h"
+#include "ExecutionObservee.h"
+#include "../backend/builtin/IOTensor.h"
+#include "../compiler/TensorRegistries.h"
+
+#include "compiler/LoweredGraph.h"
#include "exec/IExecutor.h"
-#include "exec/ExecTime.h"
-#include "exec/ExecutionObservee.h"
-#include "exec/IFunction.h"
#include "exec/IODescription.h"
#include "ir/Graph.h"
-#include "ir/Index.h"
-#include "compiler/GraphLowerInfo.h"
#include "ir/OperationIndexMap.h"
-#include "compiler/LoweredGraph.h"
-#include "compiler/TensorRegistries.h"
-#include "backend/builtin/IOTensor.h"
#include "util/TracingCtx.h"
-#include <cstdint>
#include <memory>
#include <mutex>
#include <vector>
diff --git a/runtime/onert/core/src/exec/Executors.cc b/runtime/onert/core/src/exec/Executors.cc
new file mode 100644
index 000000000..e0ee24fea
--- /dev/null
+++ b/runtime/onert/core/src/exec/Executors.cc
@@ -0,0 +1,183 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "exec/Executors.h"
+
+namespace onert
+{
+namespace exec
+{
+
+uint32_t Executors::inputSize() const
+{
+ return _model_edges ? _model_edges->pkg_inputs.size()
+ : _executors.at(ir::SubgraphIndex{0})->graph().getInputs().size();
+}
+
+uint32_t Executors::outputSize() const
+{
+ return _model_edges ? _model_edges->pkg_outputs.size()
+ : _executors.at(ir::SubgraphIndex{0})->graph().getOutputs().size();
+}
+
+const ir::OperandInfo Executors::inputInfo(const ir::IOIndex &index)
+{
+ if (_model_edges)
+ {
+ // Assume that each model may have only one subgraph
+ // TODO handle general case
+ const auto desc = _model_edges->pkg_inputs[index.value()];
+ const auto model_idx = std::get<0>(desc);
+ const auto executor_idx = ir::SubgraphIndex{model_idx.value()};
+ const auto input_index = _executors.at(executor_idx)->graph().getInputs().at(std::get<2>(desc));
+ return _executors.at(executor_idx)->graph().operands().at(input_index).info();
+ }
+
+ const auto input_index = _executors.at(ir::SubgraphIndex{0})->graph().getInputs().at(index);
+ return _executors.at(ir::SubgraphIndex{0})->graph().operands().at(input_index).info();
+}
+
+const ir::OperandInfo Executors::outputInfo(const ir::IOIndex &index)
+{
+ if (_model_edges)
+ {
+ // Assume that each model may have only one subgraph
+ // TODO handle general case
+ auto desc = _model_edges->pkg_outputs[index.value()];
+ auto model_idx = std::get<0>(desc);
+ auto executor_idx = ir::SubgraphIndex{model_idx.value()};
+ auto output_index = _executors.at(executor_idx)->graph().getOutputs().at(std::get<2>(desc));
+ return _executors.at(executor_idx)->graph().operands().at(output_index).info();
+ }
+
+ auto output_index = _executors.at(ir::SubgraphIndex{0})->graph().getOutputs().at(index);
+ return _executors.at(ir::SubgraphIndex{0})->graph().operands().at(output_index).info();
+}
+
+void Executors::execute(const IODescription &desc)
+{
+ if (_model_edges)
+ return executeEntries(desc);
+
+ _executors.at(ir::SubgraphIndex{0})->execute(desc);
+}
+
+void Executors::executeEntries(const IODescription &desc)
+{
+ // Assume 2 executors only
+ // Assume that each model may have only one subgraph
+ // TODO Support general case
+ if (_executors.size() != 2)
+ throw std::runtime_error{"NYI: Multi model execution for this package is not supported yet"};
+
+ // Assume all edges are 0:0:x -> 1:0:x
+ for (auto edge : _model_edges->edges)
+ {
+ if ((std::get<ir::ModelIndex>(edge.from) != ir::ModelIndex{0}) ||
+ (std::get<ir::ModelIndex>(edge.to) != ir::ModelIndex{1}) ||
+ (std::get<ir::SubgraphIndex>(edge.from) != ir::SubgraphIndex{0}) ||
+ (std::get<ir::SubgraphIndex>(edge.to) != ir::SubgraphIndex{0}) ||
+ (std::get<ir::IOIndex>(edge.from) != std::get<ir::IOIndex>(edge.to)))
+ throw std::runtime_error{"NYI: Multi model execution for this edge is not supported yet"};
+ }
+
+ // Assume all package inputs are 0:0:x
+ for (uint32_t i = 0; i < _model_edges->pkg_inputs.size(); i++)
+ {
+ auto input = _model_edges->pkg_inputs[i];
+ if ((std::get<ir::ModelIndex>(input) != ir::ModelIndex{0}) ||
+ (std::get<ir::SubgraphIndex>(input) != ir::SubgraphIndex{0}) ||
+ (std::get<ir::IOIndex>(input) != ir::IOIndex{i}))
+ {
+ throw std::runtime_error{"NYI: Support package input to 1st model with same order"};
+ }
+ }
+
+ // Assume all package outputs are 1:0:x
+ for (uint32_t i = 0; i < _model_edges->pkg_outputs.size(); i++)
+ {
+ auto output = _model_edges->pkg_outputs[i];
+ if ((std::get<ir::ModelIndex>(output) != ir::ModelIndex{1}) ||
+ (std::get<ir::SubgraphIndex>(output) != ir::SubgraphIndex{0}) ||
+ (std::get<ir::IOIndex>(output) != ir::IOIndex{i}))
+ {
+ throw std::runtime_error{"NYI: Support package output from 2nd model with same order"};
+ }
+ }
+
+ const auto &executor1 = _executors.at(ir::SubgraphIndex{0});
+ const auto &graph1 = executor1->graph();
+ const auto &executor2 = _executors.at(ir::SubgraphIndex{1});
+ const auto &graph2 = executor2->graph();
+
+ if ((graph1.getInputs().size() != _model_edges->pkg_inputs.size()) ||
+ (graph2.getOutputs().size() != _model_edges->pkg_outputs.size()) ||
+ (graph1.getOutputs().size() != graph2.getInputs().size()) ||
+ (graph1.getOutputs().size() != _model_edges->edges.size()))
+ {
+ throw std::runtime_error{"NYI: Unsupported model edge pattern"};
+ }
+
+ // Prepare buffer
+ // Assume buffer layout is NHWC
+ std::vector<std::unique_ptr<uint8_t[]>> bufs(_model_edges->edges.size());
+ std::vector<const ir::OperandInfo *> buf_infos(_model_edges->edges.size());
+ const auto layout = ir::Layout::NHWC;
+
+ for (uint32_t i = 0; i < graph1.getOutputs().size(); i++)
+ {
+ const auto buf_index =
+ _executors.at(ir::SubgraphIndex{0})->graph().getOutputs().at(ir::IOIndex{i});
+ buf_infos[i] = &_executors.at(ir::SubgraphIndex{0})->graph().operands().at(buf_index).info();
+ const auto buf_size = buf_infos[i]->total_size();
+ bufs[i] = std::make_unique<uint8_t[]>(buf_size);
+ }
+
+ // 1st executor
+ {
+ IODescription desc1;
+ const auto input_size = graph1.getInputs().size();
+ const auto output_size = graph1.getOutputs().size();
+ desc1.inputs.resize(input_size);
+ desc1.outputs.resize(output_size);
+ for (uint32_t i = 0; i < input_size; i++)
+ desc1.inputs[i] = std::make_unique<InputDesc>(*desc.inputs[i].get());
+ for (uint32_t i = 0; i < output_size; i++)
+ desc1.outputs[i] = std::make_unique<OutputDesc>(*buf_infos[i], bufs[i].get(),
+ buf_infos[i]->total_size(), layout);
+
+ executor1->execute(desc1);
+ }
+
+ // 2nd executor
+ {
+ IODescription desc2;
+ const auto input_size = graph2.getInputs().size();
+ const auto output_size = graph2.getOutputs().size();
+ desc2.inputs.resize(input_size);
+ desc2.outputs.resize(output_size);
+ for (uint32_t i = 0; i < input_size; i++)
+ desc2.inputs[i] = std::make_unique<InputDesc>(*buf_infos[i], bufs[i].get(),
+ buf_infos[i]->total_size(), layout);
+ for (uint32_t i = 0; i < output_size; i++)
+ desc2.outputs[i] = std::make_unique<OutputDesc>(*desc.outputs[i].get());
+
+ executor2->execute(desc2);
+ }
+}
+
+} // namespace exec
+} // namespace onert
diff --git a/runtime/onert/core/src/exec/FunctionSequence.cc b/runtime/onert/core/src/exec/FunctionSequence.cc
index df68b1b64..f87c271f7 100644
--- a/runtime/onert/core/src/exec/FunctionSequence.cc
+++ b/runtime/onert/core/src/exec/FunctionSequence.cc
@@ -34,9 +34,7 @@ void FunctionSequence::run()
// Thus, those two bakends cannot reach here.
// Do dynamic shape inference
- auto op_ind = _dynamic_tensor_ctx->op_ind;
- auto &op = _dynamic_tensor_ctx->operations->at(op_ind);
- op.accept(*_dynamic_tensor_ctx->dynamic_shape_inferer);
+ _dynamic_tensor_ctx->op->accept(*_dynamic_tensor_ctx->dynamic_shape_inferer);
for (const auto &function : _functions)
{
diff --git a/runtime/onert/core/src/exec/JSONExecTime.cc b/runtime/onert/core/src/exec/JSONExecTime.cc
index b29216a2f..d149345fd 100644
--- a/runtime/onert/core/src/exec/JSONExecTime.cc
+++ b/runtime/onert/core/src/exec/JSONExecTime.cc
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#include "exec/JSONExecTime.h"
-#include "backend/IConfig.h"
+#include "JSONExecTime.h"
+
#include <fstream>
namespace onert
diff --git a/runtime/onert/core/src/exec/LinearExecutor.h b/runtime/onert/core/src/exec/LinearExecutor.h
index 39d653154..a833466da 100644
--- a/runtime/onert/core/src/exec/LinearExecutor.h
+++ b/runtime/onert/core/src/exec/LinearExecutor.h
@@ -22,11 +22,10 @@
#ifndef __ONERT_EXEC_EXECUTOR_H_
#define __ONERT_EXEC_EXECUTOR_H_
-#include "ir/Index.h"
#include "ExecutorBase.h"
-#include "compiler/Linear.h"
-#include "exec/FunctionSequence.h"
+
#include "compiler/CodeMap.h"
+#include "ir/Index.h"
#include "util/TracingCtx.h"
namespace onert
diff --git a/runtime/onert/core/src/exec/ParallelExecutor.h b/runtime/onert/core/src/exec/ParallelExecutor.h
index 7f107fa22..7d459b0b4 100644
--- a/runtime/onert/core/src/exec/ParallelExecutor.h
+++ b/runtime/onert/core/src/exec/ParallelExecutor.h
@@ -17,19 +17,13 @@
#ifndef __ONERT_EXEC_PARALLEL_EXECUTOR_H__
#define __ONERT_EXEC_PARALLEL_EXECUTOR_H__
-#include <list>
-#include <queue>
-#include <unordered_map>
-
-#include "exec/FunctionSequence.h"
-#include "Job.h"
-#include "ir/OperandIndexSequence.h"
-#include "ir/Index.h"
-#include <memory>
-#include "exec/DataflowExecutor.h"
+#include "DataflowExecutor.h"
#include "ParallelScheduler.h"
+
#include "util/TracingCtx.h"
+#include <memory>
+
namespace onert
{
namespace exec
diff --git a/runtime/onert/core/src/exec/feature/MockTensor.h b/runtime/onert/core/src/exec/feature/MockTensor.h
new file mode 100644
index 000000000..1d2d375e2
--- /dev/null
+++ b/runtime/onert/core/src/exec/feature/MockTensor.h
@@ -0,0 +1,66 @@
+
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "backend/ITensor.h"
+
+template <typename T> class MockTensor : public onert::backend::ITensor
+{
+public:
+ MockTensor<T>(onert::ir::Shape &shape, T *buf, onert::ir::Layout layout)
+ : _buf(reinterpret_cast<uint8_t *>(buf)), _shape(shape), _layout(layout)
+ {
+ }
+
+public:
+ uint8_t *buffer() const override { return _buf; }
+
+ size_t calcOffset(const onert::ir::Coordinates &coords) const override
+ {
+ size_t rank = _shape.rank();
+ rank = rank == 0 ? 1 : rank;
+ size_t offset = 0;
+ for (size_t i = 0; i < rank; ++i)
+ {
+ auto dim = _shape.rank() == 0 ? 1 : _shape.dim(i);
+ offset = offset * dim + coords[i];
+ }
+ offset *= sizeof(T);
+
+ return offset;
+ }
+
+ onert::ir::Shape getShape() const override { return _shape; }
+
+public: // DUMMY methods
+ size_t total_size() const override { return 0; }
+ onert::ir::Layout layout() const override { return _layout; }
+ onert::ir::DataType data_type() const override { return onert::ir::DataType::UINT8; }
+ float data_scale() const override { return 0; }
+ int32_t data_zero_point() const override { return 0; }
+ const std::vector<float> &data_scales() const override { return _dummy_scales; }
+ const std::vector<int32_t> &data_zero_points() const override { return _dummy_zerops; }
+ bool has_padding() const override { return false; }
+ void access(const std::function<void(ITensor &tensor)> &fn) override {}
+ bool is_dynamic() const override { return false; }
+
+private:
+ uint8_t *_buf = nullptr;
+ onert::ir::Shape _shape;
+ onert::ir::Layout _layout = onert::ir::Layout::UNKNOWN;
+ std::vector<float> _dummy_scales;
+ std::vector<int32_t> _dummy_zerops;
+};
diff --git a/runtime/onert/core/src/exec/feature/nchw/Reader.test.cc b/runtime/onert/core/src/exec/feature/nchw/Reader.test.cc
new file mode 100644
index 000000000..f439cafb5
--- /dev/null
+++ b/runtime/onert/core/src/exec/feature/nchw/Reader.test.cc
@@ -0,0 +1,85 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Reader.h"
+
+#include "../MockTensor.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert::exec::feature;
+
+template <typename T> class Reader_nchw : public testing::Test
+{
+public:
+ void setData(std::initializer_list<T> list) { _data = std::make_shared<std::vector<T>>(list); }
+
+ void setShape(int32_t batch, int32_t depth, int32_t height, int32_t width)
+ {
+ _shape = onert::ir::FeatureShape(batch, depth, height, width);
+ }
+
+ void setStride(int32_t batch, int32_t depth, int32_t height, int32_t width)
+ {
+ auto elem_size = sizeof(T);
+ _stride = onert::ir::FeatureShape(batch * elem_size, depth * elem_size, height * elem_size,
+ width * elem_size);
+ }
+
+ void createReader()
+ {
+ _reader =
+ std::make_shared<nchw::Reader<T>>(_shape, _stride, _data->data(), _data->size() * sizeof(T));
+ }
+
+ void createUsingMockTensor()
+ {
+ onert::ir::Shape shape = {_shape.N, _shape.H, _shape.W, _shape.C};
+ _tensor = std::make_shared<MockTensor<T>>(shape, _data->data(), onert::ir::Layout::NCHW);
+ _reader = std::make_shared<nchw::Reader<T>>(_tensor.get());
+ }
+
+ std::shared_ptr<Reader<T>> _reader = nullptr;
+
+private:
+ std::shared_ptr<std::vector<T>> _data = nullptr;
+ onert::ir::FeatureShape _shape;
+ onert::ir::FeatureShape _stride;
+ std::shared_ptr<MockTensor<T>> _tensor = nullptr;
+};
+
+using ReaderTypes = ::testing::Types<float, int32_t, uint8_t, int8_t, int16_t>;
+TYPED_TEST_SUITE(Reader_nchw, ReaderTypes);
+
+TYPED_TEST(Reader_nchw, basic_reader)
+{
+ this->setData({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ this->setShape(1, 2, 3, 2);
+ this->setStride(12, 6, 2, 1);
+ this->createReader();
+
+ // Data: NCHW
+ // Shape: NCHW
+ ASSERT_EQ(this->_reader->at(0, 1, 1, 0), 8);
+ ASSERT_EQ(this->_reader->at(1, 1, 0), 8);
+
+ // Data: NCHW
+ // Shape: NCHW
+ this->createUsingMockTensor();
+
+ ASSERT_EQ(this->_reader->at(0, 1, 1, 0), 6);
+ ASSERT_EQ(this->_reader->at(1, 1, 0), 6);
+}
diff --git a/runtime/onert/core/src/exec/feature/nchw/View.test.cc b/runtime/onert/core/src/exec/feature/nchw/View.test.cc
new file mode 100644
index 000000000..c6dcda710
--- /dev/null
+++ b/runtime/onert/core/src/exec/feature/nchw/View.test.cc
@@ -0,0 +1,85 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "View.h"
+
+#include "../MockTensor.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert::exec::feature;
+
+template <typename T> class View_nchw : public testing::Test
+{
+public:
+ void setData(std::initializer_list<T> list) { _data = std::make_shared<std::vector<T>>(list); }
+
+ void setShape(int32_t batch, int32_t depth, int32_t height, int32_t width)
+ {
+ _shape = onert::ir::FeatureShape(batch, depth, height, width);
+ }
+
+ void setStride(int32_t batch, int32_t depth, int32_t height, int32_t width)
+ {
+ auto elem_size = sizeof(T);
+ _stride = onert::ir::FeatureShape(batch * elem_size, depth * elem_size, height * elem_size,
+ width * elem_size);
+ }
+
+ void createView()
+ {
+ _view =
+ std::make_shared<nchw::View<T>>(_shape, _stride, _data->data(), _data->size() * sizeof(T));
+ }
+
+ void createUsingMockTensor()
+ {
+ onert::ir::Shape shape = {_shape.N, _shape.H, _shape.W, _shape.C};
+ _tensor = std::make_shared<MockTensor<T>>(shape, _data->data(), onert::ir::Layout::NCHW);
+ _view = std::make_shared<nchw::View<T>>(_tensor.get());
+ }
+
+ std::shared_ptr<nchw::View<T>> _view = nullptr;
+
+private:
+ std::shared_ptr<std::vector<T>> _data = nullptr;
+ onert::ir::FeatureShape _shape;
+ onert::ir::FeatureShape _stride;
+ std::shared_ptr<MockTensor<T>> _tensor = nullptr;
+};
+
+using ViewTypes = ::testing::Types<float, int32_t, uint8_t, int8_t, int16_t>;
+TYPED_TEST_SUITE(View_nchw, ViewTypes);
+
+TYPED_TEST(View_nchw, basic_view)
+{
+ this->setData({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ this->setShape(1, 2, 3, 2);
+ this->setStride(12, 6, 2, 1);
+ this->createView();
+
+ // Data: NCHW
+ // Shape: NCHW
+ ASSERT_EQ(this->_view->at(0, 1, 1, 0), 8);
+ ASSERT_EQ(this->_view->at(1, 1, 0), 8);
+
+ // Data: NCHW
+ // Shape: NCHW
+ this->createUsingMockTensor();
+
+ ASSERT_EQ(this->_view->at(0, 1, 1, 0), 6);
+ ASSERT_EQ(this->_view->at(1, 1, 0), 6);
+}
diff --git a/runtime/onert/core/src/exec/feature/nhwc/Reader.test.cc b/runtime/onert/core/src/exec/feature/nhwc/Reader.test.cc
new file mode 100644
index 000000000..773199042
--- /dev/null
+++ b/runtime/onert/core/src/exec/feature/nhwc/Reader.test.cc
@@ -0,0 +1,86 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Reader.h"
+
+#include "../MockTensor.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert::exec::feature;
+
+template <typename T> class Reader_nhwc : public testing::Test
+{
+public:
+ void setData(std::initializer_list<T> list) { _data = std::make_shared<std::vector<T>>(list); }
+
+ void setShape(int32_t batch, int32_t depth, int32_t height, int32_t width)
+ {
+ _shape = onert::ir::FeatureShape(batch, depth, height, width);
+ }
+
+ void setStride(int32_t batch, int32_t depth, int32_t height, int32_t width)
+ {
+ auto elem_size = sizeof(T);
+ _stride = onert::ir::FeatureShape(batch * elem_size, depth * elem_size, height * elem_size,
+ width * elem_size);
+ }
+
+ void createReader()
+ {
+ _reader =
+ std::make_shared<nhwc::Reader<T>>(_shape, _stride, _data->data(), _data->size() * sizeof(T));
+ }
+
+ void createUsingMockTensor()
+ {
+ onert::ir::Shape shape = {_shape.N, _shape.H, _shape.W, _shape.C};
+ _tensor = std::make_shared<MockTensor<T>>(shape, _data->data(), onert::ir::Layout::NHWC);
+ _reader = std::make_shared<nhwc::Reader<T>>(_tensor.get());
+ }
+
+ std::shared_ptr<nhwc::Reader<T>> _reader = nullptr;
+
+private:
+ std::shared_ptr<std::vector<T>> _data = nullptr;
+ onert::ir::FeatureShape _shape;
+ onert::ir::FeatureShape _stride;
+ std::shared_ptr<MockTensor<T>> _tensor = nullptr;
+};
+
+using ReaderTypes = ::testing::Types<float, int32_t, uint8_t, int8_t, int16_t>;
+TYPED_TEST_SUITE(Reader_nhwc, ReaderTypes);
+TYPED_TEST_SUITE(MockTensorReader_nhwc, ReaderTypes);
+
+TYPED_TEST(Reader_nhwc, basic_reader)
+{
+ this->setData({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ this->setShape(1, 2, 3, 2);
+ this->setStride(12, 1, 6, 2);
+ this->createReader();
+
+ // Data: NCHW
+ // Shape: NHWC
+ ASSERT_EQ(this->_reader->at(0, 1, 1, 0), 8);
+ ASSERT_EQ(this->_reader->at(1, 1, 0), 8);
+
+ // Data: NHWC
+ // Shape: NHWC
+ this->createUsingMockTensor();
+
+ ASSERT_EQ(this->_reader->at(0, 1, 1, 0), 6);
+ ASSERT_EQ(this->_reader->at(1, 1, 0), 6);
+}
diff --git a/runtime/onert/core/src/exec/feature/nhwc/View.h b/runtime/onert/core/src/exec/feature/nhwc/View.h
index 40d1d237c..c98d050c3 100644
--- a/runtime/onert/core/src/exec/feature/nhwc/View.h
+++ b/runtime/onert/core/src/exec/feature/nhwc/View.h
@@ -17,7 +17,7 @@
#ifndef __ONERT_EXEC_FEATURE_NHWC_VIEW_H__
#define __ONERT_EXEC_FEATURE_NHWC_VIEW_H__
-#include "../Reader.h"
+#include "Reader.h"
#include <cassert>
#include <cstddef>
diff --git a/runtime/onert/core/src/exec/feature/nhwc/View.test.cc b/runtime/onert/core/src/exec/feature/nhwc/View.test.cc
new file mode 100644
index 000000000..bdd73d5a7
--- /dev/null
+++ b/runtime/onert/core/src/exec/feature/nhwc/View.test.cc
@@ -0,0 +1,86 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "View.h"
+
+#include "../MockTensor.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert::exec::feature;
+
+template <typename T> class View_nhwc : public testing::Test
+{
+public:
+ void setData(std::initializer_list<T> list) { _data = std::make_shared<std::vector<T>>(list); }
+
+ void setShape(int32_t batch, int32_t depth, int32_t height, int32_t width)
+ {
+ _shape = onert::ir::FeatureShape(batch, depth, height, width);
+ }
+
+ void setStride(int32_t batch, int32_t depth, int32_t height, int32_t width)
+ {
+ auto elem_size = sizeof(T);
+ _stride = onert::ir::FeatureShape(batch * elem_size, depth * elem_size, height * elem_size,
+ width * elem_size);
+ }
+
+ void createView()
+ {
+ _view =
+ std::make_shared<nhwc::View<T>>(_shape, _stride, _data->data(), _data->size() * sizeof(T));
+ }
+
+ void createUsingMockTensor()
+ {
+ onert::ir::Shape shape = {_shape.N, _shape.H, _shape.W, _shape.C};
+ _tensor = std::make_shared<MockTensor<T>>(shape, _data->data(), onert::ir::Layout::NHWC);
+ _view = std::make_shared<nhwc::View<T>>(_tensor.get());
+ }
+
+ std::shared_ptr<nhwc::View<T>> _view = nullptr;
+
+private:
+ std::shared_ptr<std::vector<T>> _data = nullptr;
+ onert::ir::FeatureShape _shape;
+ onert::ir::FeatureShape _stride;
+ std::shared_ptr<MockTensor<T>> _tensor = nullptr;
+};
+
+using ViewTypes = ::testing::Types<float, int32_t, uint8_t, int8_t, int16_t>;
+TYPED_TEST_SUITE(View_nhwc, ViewTypes);
+TYPED_TEST_SUITE(MockTensorView_nhwc, ViewTypes);
+
+TYPED_TEST(View_nhwc, basic_view)
+{
+ this->setData({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ this->setShape(1, 2, 3, 2);
+ this->setStride(12, 1, 6, 2);
+ this->createView();
+
+ // Data: NCHW
+ // Shape: NHWC
+ ASSERT_EQ(this->_view->at(0, 1, 1, 0), 8);
+ ASSERT_EQ(this->_view->at(1, 1, 0), 8);
+
+ // Data: NHWC
+ // Shape: NHWC
+ this->createUsingMockTensor();
+
+ ASSERT_EQ(this->_view->at(0, 1, 1, 0), 6);
+ ASSERT_EQ(this->_view->at(1, 1, 0), 6);
+}
diff --git a/runtime/onert/core/src/interp/InterpExecutor.cc b/runtime/onert/core/src/interp/InterpExecutor.cc
index 44d1575d7..f04777174 100644
--- a/runtime/onert/core/src/interp/InterpExecutor.cc
+++ b/runtime/onert/core/src/interp/InterpExecutor.cc
@@ -14,9 +14,10 @@
* limitations under the License.
*/
-#include "interp/InterpExecutor.h"
-#include "interp/ExecEnv.h"
-#include "interp/Interpreter.h"
+#include "InterpExecutor.h"
+
+#include "ExecEnv.h"
+#include "Interpreter.h"
#include "util/logging.h"
diff --git a/runtime/onert/core/src/interp/InterpExecutor.h b/runtime/onert/core/src/interp/InterpExecutor.h
index df6153d09..d6d5dd0a3 100644
--- a/runtime/onert/core/src/interp/InterpExecutor.h
+++ b/runtime/onert/core/src/interp/InterpExecutor.h
@@ -74,7 +74,12 @@ public:
}
private:
- const ir::Graph &_graph;
+ /**
+ * @brief Copy of target graph for lowering
+ * @note It uses copy of graph, not reference.
+ * Original graph may be deallocated by frontend.
+ */
+ const ir::Graph _graph;
ir::OperandIndexMap<std::shared_ptr<ITensor>> _tensor_map;
};
diff --git a/runtime/onert/core/src/interp/InterpExecutor.test.cc b/runtime/onert/core/src/interp/InterpExecutor.test.cc
new file mode 100644
index 000000000..9f95ffee0
--- /dev/null
+++ b/runtime/onert/core/src/interp/InterpExecutor.test.cc
@@ -0,0 +1,355 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "InterpExecutor.h"
+
+#include "exec/Execution.h"
+#include "ir/Graph.h"
+#include "ir/operation/BinaryArithmetic.h"
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+namespace
+{
+
+using namespace onert::ir;
+using InterpExecutor = onert::interp::InterpExecutor;
+using Execution = onert::exec::Execution;
+using Executors = onert::exec::Executors;
+
+class InterpExecutorTest : public ::testing::Test
+{
+protected:
+ virtual void SetUp() {}
+ void CreateSimpleModel()
+ {
+ // Model: one elementwise add operation
+ // model input: lhs, rhs
+ // model output: add result
+ // lhs, rhs, result shape: {1, 2, 2, 1}
+ // activation: none (constant)
+ _graph = std::make_unique<Graph>();
+
+ // Add operands
+
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::INT32};
+ Shape shape_scalar(0);
+ TypeInfo type_scalar{DataType::INT32};
+
+ auto operand_lhs = _graph->addOperand(shape, type);
+ auto operand_rhs = _graph->addOperand(shape, type);
+ auto operand_result = _graph->addOperand(shape, type);
+
+ // Add operations
+
+ operation::BinaryArithmetic::Param param;
+ param.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD;
+ param.activation = Activation::NONE;
+ auto input_set = OperandIndexSequence{operand_lhs, operand_rhs};
+ auto output_set = OperandIndexSequence{operand_result};
+ _graph->addOperation(
+ std::make_unique<operation::BinaryArithmetic>(input_set, output_set, param));
+
+ // Identify model inputs and outputs
+
+ _graph->getInputs().append(operand_lhs);
+ _graph->getInputs().append(operand_rhs);
+ _graph->getOutputs().append(operand_result);
+
+ _graph->verify();
+
+ auto model = std::make_shared<onert::ir::Model>();
+ model->push(onert::ir::SubgraphIndex{0}, _graph);
+
+ _executors = std::make_shared<Executors>();
+ _executors->emplace(onert::ir::SubgraphIndex{0}, std::make_unique<InterpExecutor>(*_graph));
+ }
+
+ void CreateTwoStepModel()
+ {
+ // Model: two elementwise add operation
+ // model input: lhs, rhs1
+ // model output: second add result (result2)
+ // constant: rhs2
+ // result1 <= (lhs + rhs)
+ // result2 <= (result1 + rhs2)
+ // lhs, rhs1, rh2, result1, result2 shape: {1, 2, 2, 1}
+ // activation: none (constant)
+ _graph = std::make_unique<Graph>();
+
+ // 1st add operands (result1 <= lhs + rhs1)
+
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::INT32};
+ Shape shape_scalar(0);
+ TypeInfo type_scalar{DataType::INT32};
+
+ static int32_t rhs2_data[4] = {3, 1, -1, 5};
+
+ auto operand_lhs = _graph->addOperand(shape, type);
+ auto operand_rhs1 = _graph->addOperand(shape, type);
+ auto operand_result1 = _graph->addOperand(shape, type);
+ auto operand_rhs2 = _graph->addOperand(shape, type);
+ auto operand_result2 = _graph->addOperand(shape, type);
+ _graph->operands()
+ .at(operand_rhs2)
+ .data(std::make_unique<CachedData>(reinterpret_cast<const uint8_t *>(&rhs2_data), 16));
+
+ // 2nd add operations (result2 <= result1 + rhs2)
+
+ operation::BinaryArithmetic::Param param1;
+ param1.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD;
+ param1.activation = Activation::NONE;
+ auto input_set1 = OperandIndexSequence{operand_lhs, operand_rhs1};
+ auto output_set1 = OperandIndexSequence{operand_result1};
+ _graph->addOperation(
+ std::make_unique<operation::BinaryArithmetic>(input_set1, output_set1, param1));
+
+ operation::BinaryArithmetic::Param param2;
+ param2.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD;
+ param2.activation = Activation::NONE;
+ auto input_set2 = OperandIndexSequence{operand_result1, operand_rhs2};
+ auto output_set2 = OperandIndexSequence{operand_result2};
+ _graph->addOperation(
+ std::make_unique<operation::BinaryArithmetic>(input_set2, output_set2, param2));
+
+ // Identify model inputs and outputs
+
+ _graph->getInputs().append(operand_lhs);
+ _graph->getInputs().append(operand_rhs1);
+ _graph->getOutputs().append(operand_result2);
+
+ _graph->verify();
+
+ auto model = std::make_shared<onert::ir::Model>();
+ model->push(onert::ir::SubgraphIndex{0}, _graph);
+
+ _executors = std::make_shared<Executors>();
+ _executors->emplace(onert::ir::SubgraphIndex{0}, std::make_unique<InterpExecutor>(*_graph));
+ }
+
+ void CreateUnspecifiedDimensionsModel()
+ {
+ // Model: one elementwise add operation
+ // model input: lhs, rhs
+ // model output: add result
+ // lhs, rhs, result shape: {1, unknown, 2, 1}
+ // activation: none (constant)
+ _graph = std::make_unique<Graph>();
+
+ // Add operands
+
+ Shape shape{1, 0, 2, 1};
+ TypeInfo type{DataType::INT32};
+ Shape shape_scalar(0);
+ TypeInfo type_scalar{DataType::INT32};
+
+ auto operand_lhs = _graph->addOperand(shape, type);
+ auto operand_rhs = _graph->addOperand(shape, type);
+
+ auto operand_activation = _graph->addOperand(shape_scalar, type_scalar);
+ _graph->operands()
+ .at(operand_activation)
+ .data(std::make_unique<CachedData>(reinterpret_cast<const uint8_t *>(&_activation_value), 4));
+
+ auto operand_result = _graph->addOperand(shape, type);
+
+ // Add operations
+
+ operation::BinaryArithmetic::Param param;
+ param.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD;
+ param.activation = Activation::NONE;
+ auto input_set = OperandIndexSequence{operand_lhs, operand_rhs};
+ auto output_set = OperandIndexSequence{operand_result};
+ _graph->addOperation(
+ std::make_unique<operation::BinaryArithmetic>(input_set, output_set, param));
+
+ // Identify model inputs and outputs
+
+ _graph->getInputs().append(operand_lhs);
+ _graph->getInputs().append(operand_rhs);
+ _graph->getOutputs().append(operand_result);
+
+ _graph->verify();
+
+ auto model = std::make_shared<onert::ir::Model>();
+ model->push(onert::ir::SubgraphIndex{0}, _graph);
+
+ _executors = std::make_shared<Executors>();
+ _executors->emplace(onert::ir::SubgraphIndex{0}, std::make_unique<InterpExecutor>(*_graph));
+ }
+
+ void createExecution() { _execution = std::make_unique<Execution>(_executors); }
+
+ virtual void TearDown() { _executors = nullptr; }
+
+ std::shared_ptr<Graph> _graph{nullptr};
+ std::shared_ptr<Executors> _executors{nullptr};
+ std::unique_ptr<Execution> _execution{nullptr};
+ const int32_t _activation_value{0};
+};
+
+TEST_F(InterpExecutorTest, create_empty)
+{
+ Graph graph;
+ graph.verify();
+ auto executor = std::make_unique<InterpExecutor>(graph);
+ ASSERT_NE(executor, nullptr);
+}
+
+TEST_F(InterpExecutorTest, create_simple)
+{
+ CreateSimpleModel();
+ ASSERT_NE(_executors, nullptr);
+ ASSERT_NE(_executors->at(onert::ir::SubgraphIndex{0}), nullptr);
+}
+
+TEST_F(InterpExecutorTest, neg_setInput)
+{
+ CreateSimpleModel();
+ createExecution();
+
+ auto input1 = IOIndex{0};
+ const int32_t input1_buffer[4] = {1, 0, -1, -2};
+
+ EXPECT_THROW(_execution->setInput(input1, reinterpret_cast<const void *>(input1_buffer), 4),
+ std::runtime_error);
+ EXPECT_THROW(_execution->setInput(input1, reinterpret_cast<const void *>(input1_buffer), 12),
+ std::runtime_error);
+ EXPECT_NO_THROW(_execution->setInput(input1, reinterpret_cast<const void *>(input1_buffer), 16));
+}
+
+TEST_F(InterpExecutorTest, neg_setOutput)
+{
+ CreateSimpleModel();
+ createExecution();
+
+ auto output = IOIndex{0};
+ auto output_idx = _graph->getOutputs().at(output);
+
+ int32_t output_buffer[4] = {};
+
+ EXPECT_THROW(_execution->setOutput(output, reinterpret_cast<void *>(output_buffer), 4),
+ std::runtime_error);
+ EXPECT_THROW(_execution->setOutput(output, reinterpret_cast<void *>(output_buffer), 12),
+ std::runtime_error);
+ EXPECT_NO_THROW(_execution->setOutput(output, reinterpret_cast<void *>(output_buffer), 16));
+}
+
+TEST_F(InterpExecutorTest, neg_setInputForUnspecifiedDimensions)
+{
+ CreateUnspecifiedDimensionsModel();
+ createExecution();
+
+ auto input1 = IOIndex{0};
+ const int32_t input1_buffer[4] = {1, 0, -1, -2};
+
+ TypeInfo operand_type{DataType::INT32};
+ Shape operand_shape{1, 2, 2, 1};
+
+ EXPECT_THROW(_execution->setInput(input1, operand_type, operand_shape,
+ reinterpret_cast<const void *>(input1_buffer), 4),
+ std::runtime_error);
+ EXPECT_THROW(_execution->setInput(input1, operand_type, operand_shape,
+ reinterpret_cast<const void *>(input1_buffer), 12),
+ std::runtime_error);
+ EXPECT_NO_THROW(_execution->setInput(input1, operand_type, operand_shape,
+ reinterpret_cast<const void *>(input1_buffer), 16));
+}
+
+TEST_F(InterpExecutorTest, neg_setOutputForUnspecifiedDimensions)
+{
+ CreateUnspecifiedDimensionsModel();
+ createExecution();
+
+ auto output = IOIndex{0};
+ auto output_idx = _graph->getOutputs().at(output);
+
+ TypeInfo operand_type{DataType::INT32};
+ Shape operand_shape{1, 2, 2, 1};
+
+ int32_t output_buffer[4] = {};
+
+ EXPECT_THROW(_execution->setOutput(output, operand_type, operand_shape,
+ reinterpret_cast<void *>(output_buffer), 4),
+ std::runtime_error);
+ EXPECT_THROW(_execution->setOutput(output, operand_type, operand_shape,
+ reinterpret_cast<void *>(output_buffer), 12),
+ std::runtime_error);
+ EXPECT_NO_THROW(_execution->setOutput(output, operand_type, operand_shape,
+ reinterpret_cast<void *>(output_buffer), 16));
+}
+
+TEST_F(InterpExecutorTest, execute)
+{
+ CreateSimpleModel();
+ createExecution();
+
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto input1_idx = _graph->getInputs().at(input1);
+ auto input2_idx = _graph->getInputs().at(input2);
+
+ const int32_t input1_buffer[4] = {1, 0, -1, -2};
+ const int32_t input2_buffer[4] = {1, -3, 2, -4};
+
+ auto output = IOIndex{0};
+ auto output_idx = _graph->getOutputs().at(output);
+
+ int32_t output_buffer[4] = {};
+
+ EXPECT_NO_THROW(_execution->setInput(input1, reinterpret_cast<const void *>(input1_buffer), 16));
+ EXPECT_NO_THROW(_execution->setInput(input2, reinterpret_cast<const void *>(input2_buffer), 16));
+ EXPECT_NO_THROW(_execution->setOutput(output, reinterpret_cast<void *>(output_buffer), 16));
+ EXPECT_NO_THROW(_execution->execute());
+ EXPECT_EQ(output_buffer[0], 2);
+ EXPECT_EQ(output_buffer[1], -3);
+ EXPECT_EQ(output_buffer[2], 1);
+ EXPECT_EQ(output_buffer[3], -6);
+}
+
+TEST_F(InterpExecutorTest, executeTwoStep)
+{
+ CreateTwoStepModel();
+ createExecution();
+
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto input1_idx = _graph->getInputs().at(input1);
+ auto input2_idx = _graph->getInputs().at(input2);
+
+ const int32_t input1_buffer[4] = {1, 0, -1, -2};
+ const int32_t input2_buffer[4] = {1, -3, 2, -4};
+
+ auto output = IOIndex{0};
+ auto output_idx = _graph->getOutputs().at(output);
+
+ int32_t output_buffer[4] = {};
+
+ EXPECT_NO_THROW(_execution->setInput(input1, reinterpret_cast<const void *>(input1_buffer), 16));
+ EXPECT_NO_THROW(_execution->setInput(input2, reinterpret_cast<const void *>(input2_buffer), 16));
+ EXPECT_NO_THROW(_execution->setOutput(output, reinterpret_cast<void *>(output_buffer), 16));
+ EXPECT_NO_THROW(_execution->execute());
+ EXPECT_EQ(output_buffer[0], 5);
+ EXPECT_EQ(output_buffer[1], -2);
+ EXPECT_EQ(output_buffer[2], 0);
+ EXPECT_EQ(output_buffer[3], -1);
+}
+
+} // namespace
diff --git a/runtime/onert/core/src/interp/operations/BinaryArithmeticOps.cc b/runtime/onert/core/src/interp/operations/BinaryArithmeticOps.cc
index 804e9fb51..fe4acd309 100644
--- a/runtime/onert/core/src/interp/operations/BinaryArithmeticOps.cc
+++ b/runtime/onert/core/src/interp/operations/BinaryArithmeticOps.cc
@@ -14,14 +14,14 @@
* limitations under the License.
*/
-#include <cker/operation/BinaryArithmeticOps.h>
-
#include "OperationUtil.h"
+#include "../Registration.h"
-#include "interp/Registration.h"
#include "ir/operation/BinaryArithmetic.h"
-#include "misc/polymorphic_downcast.h"
-#include "cker/Types.h"
+
+#include <cker/operation/BinaryArithmeticOps.h>
+#include <cker/Types.h>
+#include <misc/polymorphic_downcast.h>
namespace onert
{
diff --git a/runtime/onert/core/src/interp/operations/Concat.cc b/runtime/onert/core/src/interp/operations/Concat.cc
index a063ab14a..103604631 100644
--- a/runtime/onert/core/src/interp/operations/Concat.cc
+++ b/runtime/onert/core/src/interp/operations/Concat.cc
@@ -14,13 +14,13 @@
* limitations under the License.
*/
-#include <cker/operation/Concatenation.h>
-
#include "OperationUtil.h"
+#include "../Registration.h"
-#include "interp/Registration.h"
#include "ir/operation/Concat.h"
-#include "misc/polymorphic_downcast.h"
+
+#include <cker/operation/Concatenation.h>
+#include <misc/polymorphic_downcast.h>
namespace onert
{
diff --git a/runtime/onert/core/src/interp/operations/Conv2D.cc b/runtime/onert/core/src/interp/operations/Conv2D.cc
index 0b43a4799..72c2057c2 100644
--- a/runtime/onert/core/src/interp/operations/Conv2D.cc
+++ b/runtime/onert/core/src/interp/operations/Conv2D.cc
@@ -14,15 +14,15 @@
* limitations under the License.
*/
-#include <cker/operation/Conv.h>
-
#include "OperationUtil.h"
+#include "../Registration.h"
-#include "interp/Registration.h"
#include "ir/operation/Conv2D.h"
-#include "util/Utils.h"
#include "util/ShapeInference.h"
-#include "misc/polymorphic_downcast.h"
+#include "util/Utils.h"
+
+#include <cker/operation/Conv.h>
+#include <misc/polymorphic_downcast.h>
namespace onert
{
diff --git a/runtime/onert/core/src/interp/operations/DepthwiseConv2D.cc b/runtime/onert/core/src/interp/operations/DepthwiseConv2D.cc
index d1c62d73f..9f527440e 100644
--- a/runtime/onert/core/src/interp/operations/DepthwiseConv2D.cc
+++ b/runtime/onert/core/src/interp/operations/DepthwiseConv2D.cc
@@ -14,15 +14,15 @@
* limitations under the License.
*/
-#include <cker/operation/DepthwiseConv.h>
-#include <misc/polymorphic_downcast.h>
-
#include "OperationUtil.h"
+#include "../Registration.h"
-#include "interp/Registration.h"
#include "ir/operation/DepthwiseConv2D.h"
-#include "util/Utils.h"
#include "util/ShapeInference.h"
+#include "util/Utils.h"
+
+#include <cker/operation/DepthwiseConv.h>
+#include <misc/polymorphic_downcast.h>
namespace onert
{
diff --git a/runtime/onert/core/src/interp/operations/ElementwiseActivations.cc b/runtime/onert/core/src/interp/operations/ElementwiseActivations.cc
index 197855ff4..e13080e76 100644
--- a/runtime/onert/core/src/interp/operations/ElementwiseActivations.cc
+++ b/runtime/onert/core/src/interp/operations/ElementwiseActivations.cc
@@ -14,17 +14,16 @@
* limitations under the License.
*/
-#include <cmath>
-
#include "OperationUtil.h"
-
-#include "interp/Registration.h"
+#include "../Registration.h"
#include "ir/operation/ElementwiseActivation.h"
-#include <misc/polymorphic_downcast.h>
#include <cker/operation/Logistic.h>
#include <cker/operation/Tanh.h>
+#include <misc/polymorphic_downcast.h>
+
+#include <cmath>
namespace onert
{
diff --git a/runtime/onert/core/src/interp/operations/FullyConnected.cc b/runtime/onert/core/src/interp/operations/FullyConnected.cc
index ef827605b..2bc9f517f 100644
--- a/runtime/onert/core/src/interp/operations/FullyConnected.cc
+++ b/runtime/onert/core/src/interp/operations/FullyConnected.cc
@@ -14,13 +14,13 @@
* limitations under the License.
*/
-#include <cker/operation/FullyConnected.h>
-
#include "OperationUtil.h"
+#include "../Registration.h"
-#include "interp/Registration.h"
#include "ir/operation/FullyConnected.h"
-#include "misc/polymorphic_downcast.h"
+
+#include <cker/operation/FullyConnected.h>
+#include <misc/polymorphic_downcast.h>
namespace onert
{
diff --git a/runtime/onert/core/src/interp/operations/Gather.cc b/runtime/onert/core/src/interp/operations/Gather.cc
index 0ea60875c..d686cfcf6 100644
--- a/runtime/onert/core/src/interp/operations/Gather.cc
+++ b/runtime/onert/core/src/interp/operations/Gather.cc
@@ -14,13 +14,13 @@
* limitations under the License.
*/
-#include <cker/operation/Gather.h>
-
#include "OperationUtil.h"
+#include "../Registration.h"
-#include "interp/Registration.h"
#include "ir/operation/Gather.h"
-#include "misc/polymorphic_downcast.h"
+
+#include <cker/operation/Gather.h>
+#include <misc/polymorphic_downcast.h>
namespace onert
{
diff --git a/runtime/onert/core/src/interp/operations/InstanceNorm.cc b/runtime/onert/core/src/interp/operations/InstanceNorm.cc
index b5c38819d..318088457 100644
--- a/runtime/onert/core/src/interp/operations/InstanceNorm.cc
+++ b/runtime/onert/core/src/interp/operations/InstanceNorm.cc
@@ -14,13 +14,13 @@
* limitations under the License.
*/
-#include <cker/operation/InstanceNorm.h>
-
#include "OperationUtil.h"
+#include "../Registration.h"
-#include "interp/Registration.h"
#include "ir/operation/InstanceNorm.h"
-#include "misc/polymorphic_downcast.h"
+
+#include <cker/operation/InstanceNorm.h>
+#include <misc/polymorphic_downcast.h>
namespace onert
{
diff --git a/runtime/onert/core/src/interp/operations/Pad.cc b/runtime/onert/core/src/interp/operations/Pad.cc
index 0eec7fe9a..3db0828eb 100644
--- a/runtime/onert/core/src/interp/operations/Pad.cc
+++ b/runtime/onert/core/src/interp/operations/Pad.cc
@@ -14,13 +14,13 @@
* limitations under the License.
*/
-#include <cker/operation/Pad.h>
-
#include "OperationUtil.h"
+#include "../Registration.h"
-#include "interp/Registration.h"
#include "ir/operation/Pad.h"
+#include <cker/operation/Pad.h>
+
namespace onert
{
namespace interp
diff --git a/runtime/onert/core/src/interp/operations/Pool2D.cc b/runtime/onert/core/src/interp/operations/Pool2D.cc
index 2f3b71655..3935d4756 100644
--- a/runtime/onert/core/src/interp/operations/Pool2D.cc
+++ b/runtime/onert/core/src/interp/operations/Pool2D.cc
@@ -14,16 +14,16 @@
* limitations under the License.
*/
-#include <cker/operation/AveragePool.h>
-#include <cker/operation/MaxPool.h>
-
#include "OperationUtil.h"
+#include "../Registration.h"
-#include "interp/Registration.h"
#include "ir/operation/Pool2D.h"
-#include "util/Utils.h"
#include "util/ShapeInference.h"
-#include "misc/polymorphic_downcast.h"
+#include "util/Utils.h"
+
+#include <cker/operation/AveragePool.h>
+#include <cker/operation/MaxPool.h>
+#include <misc/polymorphic_downcast.h>
namespace onert
{
diff --git a/runtime/onert/core/src/interp/operations/Reshape.cc b/runtime/onert/core/src/interp/operations/Reshape.cc
index 3a118456b..1de5a5762 100644
--- a/runtime/onert/core/src/interp/operations/Reshape.cc
+++ b/runtime/onert/core/src/interp/operations/Reshape.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "interp/Registration.h"
+#include "../Registration.h"
namespace onert
{
diff --git a/runtime/onert/core/src/interp/operations/Softmax.cc b/runtime/onert/core/src/interp/operations/Softmax.cc
index 1fc303117..8be2f2210 100644
--- a/runtime/onert/core/src/interp/operations/Softmax.cc
+++ b/runtime/onert/core/src/interp/operations/Softmax.cc
@@ -14,13 +14,13 @@
* limitations under the License.
*/
-#include <cker/operation/SoftMax.h>
-
#include "OperationUtil.h"
+#include "../Registration.h"
-#include "interp/Registration.h"
#include "ir/operation/Softmax.h"
-#include "misc/polymorphic_downcast.h"
+
+#include <cker/operation/SoftMax.h>
+#include <misc/polymorphic_downcast.h>
namespace onert
{
diff --git a/runtime/onert/core/src/interp/operations/TransposeConv.cc b/runtime/onert/core/src/interp/operations/TransposeConv.cc
index 755103dc2..59c8e8cdf 100644
--- a/runtime/onert/core/src/interp/operations/TransposeConv.cc
+++ b/runtime/onert/core/src/interp/operations/TransposeConv.cc
@@ -14,14 +14,14 @@
* limitations under the License.
*/
-#include <cker/operation/TransposeConv.h>
-#include <misc/polymorphic_downcast.h>
-
#include "OperationUtil.h"
+#include "../Registration.h"
-#include "interp/Registration.h"
#include "ir/operation/TransposeConv.h"
+#include <cker/operation/TransposeConv.h>
+#include <misc/polymorphic_downcast.h>
+
namespace onert
{
namespace interp
diff --git a/runtime/onert/core/src/ir/Graph.cc b/runtime/onert/core/src/ir/Graph.cc
index df30bbdbe..28cf4137d 100644
--- a/runtime/onert/core/src/ir/Graph.cc
+++ b/runtime/onert/core/src/ir/Graph.cc
@@ -17,19 +17,9 @@
#include "ir/Graph.h"
#include "OperationValidator.h"
+#include "verifier/Verifier.h"
-#include <algorithm>
-
-#include <bitset>
-#include <sstream>
-
-#include "util/logging.h"
#include "util/Set.h"
-#include "verifier/Verifier.h"
-#include "ir/OperandIndexMap.h"
-#include "ir/OperationIndexMap.h"
-#include "dumper/text/GraphDumper.h"
-#include "backend/IConfig.h"
namespace onert
{
@@ -38,6 +28,8 @@ namespace ir
Graph::Graph() = default;
+Graph::Graph(const Graph &) = default;
+
Graph::~Graph(void) = default;
OperandIndex Graph::addOperand(const Shape &shape, const TypeInfo &type)
diff --git a/runtime/onert/core/src/ir/Graph.test.cc b/runtime/onert/core/src/ir/Graph.test.cc
new file mode 100644
index 000000000..144500745
--- /dev/null
+++ b/runtime/onert/core/src/ir/Graph.test.cc
@@ -0,0 +1,147 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/Graph.h"
+#include "ir/operation/BinaryArithmetic.h"
+
+#include <gtest/gtest.h>
+
+TEST(Graph, neg_inputs_and_outputs)
+{
+ onert::ir::Graph graph;
+
+ onert::ir::OperandIndex index0{0u};
+ onert::ir::OperandIndex index1{1u};
+
+ graph.addInput({index0});
+ graph.addInput({index1});
+
+ onert::ir::OperandIndex index10{10u};
+ onert::ir::OperandIndex index11{11u};
+ onert::ir::OperandIndex index12{12u};
+
+ graph.addOutput({index10});
+ graph.addOutput({index11});
+ graph.addOutput({index12});
+
+ ASSERT_EQ(graph.getInputs().size(), 2);
+ ASSERT_EQ(graph.getOutputs().size(), 3);
+
+ onert::ir::IOIndex io_index0{0};
+ onert::ir::IOIndex io_index1{1};
+ onert::ir::IOIndex io_index2{2};
+
+ ASSERT_EQ(graph.getInputs().at(io_index0), 0);
+ ASSERT_EQ(graph.getInputs().at(io_index1), 1);
+
+ ASSERT_EQ(graph.getOutputs().at(io_index0), 10);
+ ASSERT_EQ(graph.getOutputs().at(io_index1), 11);
+ ASSERT_EQ(graph.getOutputs().at(io_index2), 12);
+
+ EXPECT_THROW(graph.getOutputs().at(onert::ir::IOIndex{3}), std::out_of_range);
+}
+
+using namespace onert::ir;
+
+OperationIndex addAddOperation(Graph &graph, const OperandIndexSequence inputs,
+ const OperandIndexSequence outputs)
+{
+ // Add "ADD" operation
+ operation::BinaryArithmetic::Param param;
+ param.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD;
+ param.activation = Activation::NONE;
+ return graph.addOperation(std::make_unique<operation::BinaryArithmetic>(inputs, outputs, param));
+}
+
+TEST(Graph, OneOpGraphSimpleValid)
+{
+ // Simple Graph with just one Add operation
+
+ Graph graph;
+
+ // Add tensors
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::FLOAT32};
+ auto lhs = graph.addOperand(shape, type);
+ auto rhs = graph.addOperand(shape, type);
+ auto res = graph.addOperand(shape, type);
+
+ addAddOperation(graph, {lhs, rhs}, {res});
+
+ // Set model inputs/outputs
+ graph.addInput(lhs);
+ graph.addInput(rhs);
+ graph.addOutput(res);
+
+ graph.verify();
+
+ SUCCEED();
+}
+
+TEST(Graph, neg_InvalidGraph_BadInput)
+{
+ Graph graph;
+
+ // Add tensors
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::FLOAT32};
+ auto in = graph.addOperand(shape, type);
+ auto out = graph.addOperand(shape, type);
+
+ // Set model inputs/outputs
+ graph.addInput(in);
+ graph.addOutput(out);
+ graph.addInput(OperandIndex{89}); // Non-exisiting operand!
+
+ EXPECT_ANY_THROW(graph.verify());
+}
+
+TEST(Graph, neg_InvalidGraph_BadOutput)
+{
+ Graph graph;
+
+ // Add tensors
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::FLOAT32};
+ auto in = graph.addOperand(shape, type);
+ auto out = graph.addOperand(shape, type);
+
+ // Set model inputs/outputs
+ graph.addInput(in);
+ graph.addOutput(out);
+ graph.addOutput(OperandIndex{12}); // Non-exisiting operand!
+
+ EXPECT_ANY_THROW(graph.verify());
+}
+
+TEST(Graph, neg_InvalidAddOperation_BadInputIndex)
+{
+ Graph graph;
+
+ // Add tensors
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::FLOAT32};
+ auto lhs = graph.addOperand(shape, type);
+ auto rhs = graph.addOperand(shape, type);
+ auto res = graph.addOperand(shape, type);
+
+ // Set model inputs/outputs
+ graph.addInput(lhs);
+ graph.addInput(rhs);
+ graph.addOutput(res);
+
+ ASSERT_FALSE(addAddOperation(graph, {lhs, OperandIndex{99}}, {res}).valid());
+}
diff --git a/runtime/onert/core/src/ir/LayoutSet.test.cc b/runtime/onert/core/src/ir/LayoutSet.test.cc
new file mode 100644
index 000000000..fc956abe8
--- /dev/null
+++ b/runtime/onert/core/src/ir/LayoutSet.test.cc
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "LayoutSet.h"
+
+#include <gtest/gtest.h>
+
+using onert::ir::Layout;
+using onert::ir::LayoutSet;
+
+TEST(ir_LayoutSet, neg_add_remove)
+{
+ LayoutSet set{Layout::NCHW};
+ set.remove(Layout::NHWC);
+ ASSERT_EQ(set.size(), 1);
+ set.add(Layout::NHWC);
+ ASSERT_EQ(set.size(), 2);
+ set.remove(Layout::NHWC);
+ ASSERT_EQ(set.size(), 1);
+ set.remove(Layout::NCHW);
+ ASSERT_EQ(set.size(), 0);
+ set.remove(Layout::NCHW);
+ ASSERT_EQ(set.size(), 0);
+}
+
+TEST(ir_LayoutSet, neg_add_twice)
+{
+ LayoutSet set;
+ set.add(Layout::NHWC);
+ ASSERT_EQ(set.size(), 1);
+ set.add(Layout::NHWC);
+ ASSERT_EQ(set.size(), 1);
+}
+
+TEST(ir_LayoutSet, set_operators)
+{
+ LayoutSet set1{Layout::NCHW};
+ LayoutSet set2{Layout::NHWC};
+ LayoutSet set3 = set1 | set2;
+
+ ASSERT_EQ(set3.size(), 2);
+
+ ASSERT_EQ((set3 - set1).size(), 1);
+ ASSERT_EQ((set3 - set1).contains(Layout::NHWC), true);
+ ASSERT_EQ((set3 - set2).size(), 1);
+ ASSERT_EQ((set3 - set2).contains(Layout::NCHW), true);
+ ASSERT_EQ((set3 - set3).size(), 0);
+
+ ASSERT_EQ((set3 & set1).size(), 1);
+ ASSERT_EQ((set3 & set1).contains(Layout::NCHW), true);
+ ASSERT_EQ((set3 & set2).size(), 1);
+ ASSERT_EQ((set3 & set2).contains(Layout::NHWC), true);
+ ASSERT_EQ((set1 & set2).size(), 0);
+}
diff --git a/runtime/onert/core/src/ir/MockNode.h b/runtime/onert/core/src/ir/MockNode.h
new file mode 100644
index 000000000..0e7ed977b
--- /dev/null
+++ b/runtime/onert/core/src/ir/MockNode.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TEST_GRAPH_MOCK_NODE_H__
+#define __ONERT_TEST_GRAPH_MOCK_NODE_H__
+
+#include "ir/Operation.h"
+#include "ir/OperandIndexSequence.h"
+
+namespace onert_test
+{
+namespace ir
+{
+
+class SimpleMock : public onert::ir::Operation
+{
+public:
+ SimpleMock(const onert::ir::OperandIndexSequence &inputs,
+ const onert::ir::OperandIndexSequence &outputs)
+ : Operation{onert::ir::OperandConstraint::createAny()}
+ {
+ setInputs(inputs);
+ setOutputs(outputs);
+ }
+
+public:
+ void accept(onert::ir::OperationVisitor &) const override {}
+ onert::ir::OpCode opcode() const final { return onert::ir::OpCode::Invalid; }
+};
+
+} // namespace ir
+} // namespace onert_test
+
+#endif // __ONERT_TEST_GRAPH_MOCK_NODE_H__
diff --git a/runtime/onert/core/src/ir/Operand.test.cc b/runtime/onert/core/src/ir/Operand.test.cc
new file mode 100644
index 000000000..0b858792a
--- /dev/null
+++ b/runtime/onert/core/src/ir/Operand.test.cc
@@ -0,0 +1,86 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/Graph.h"
+
+#include "MockNode.h"
+#include "verifier/Verifier.h"
+
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <typeindex>
+
+namespace
+{
+
+using IndexSet = onert::ir::OperandIndexSequence;
+using Mock = onert_test::ir::SimpleMock;
+
+} // namespace
+
+TEST(ir_Operand, neg_usedef)
+{
+ onert::ir::Graph graph;
+ onert::ir::verifier::DAGChecker verifier;
+
+ onert::ir::Shape shape(3);
+ onert::ir::TypeInfo type{onert::ir::DataType::INT32};
+
+ // Model Input/Output
+ auto input_operand = graph.addOperand(shape, type);
+ auto output_operand = graph.addOperand(shape, type);
+
+ graph.addInput(input_operand);
+ graph.addOutput(output_operand);
+
+ // MockNode1
+ auto operand_index1 = graph.addOperand(shape, type);
+ auto mocknode_index1 =
+ graph.addOperation(std::make_unique<Mock>(IndexSet{input_operand}, IndexSet{operand_index1}));
+
+ // MockNode2
+ auto operand_index2 = graph.addOperand(shape, type);
+ auto mocknode_index2 =
+ graph.addOperation(std::make_unique<Mock>(IndexSet{input_operand}, IndexSet{operand_index2}));
+
+ // MockNode3(two input)
+ auto multiinput_index = graph.addOperation(
+ std::make_unique<Mock>(IndexSet{operand_index1, operand_index2}, IndexSet{output_operand}));
+
+ graph.verify();
+
+ ASSERT_TRUE(verifier.verify(graph));
+
+ // Check def
+ ASSERT_EQ(graph.operands().at(operand_index1).getDef(), mocknode_index1);
+ ASSERT_EQ(graph.operands().at(operand_index2).getDef(), mocknode_index2);
+ ASSERT_EQ(graph.operands().at(output_operand).getDef(), multiinput_index);
+
+ ASSERT_NE(graph.operands().at(operand_index1).getDef(), mocknode_index2);
+ ASSERT_NE(graph.operands().at(operand_index1).getDef(), multiinput_index);
+
+ // Check use
+ ASSERT_EQ(graph.operands().at(input_operand).getUses().contains(mocknode_index1), true);
+ ASSERT_EQ(graph.operands().at(input_operand).getUses().contains(mocknode_index2), true);
+ ASSERT_EQ(graph.operands().at(input_operand).getUses().contains(multiinput_index), false);
+ ASSERT_EQ(graph.operands().at(operand_index1).getUses().contains(multiinput_index), true);
+ ASSERT_EQ(graph.operands().at(operand_index2).getUses().contains(multiinput_index), true);
+
+ ASSERT_EQ(graph.operands().at(input_operand).getUses().size(), 2);
+ ASSERT_EQ(graph.operands().at(operand_index1).getUses().size(), 1);
+ ASSERT_EQ(graph.operands().at(output_operand).getUses().size(), 0);
+}
diff --git a/runtime/onert/core/src/ir/OperandIndexSequence.test.cc b/runtime/onert/core/src/ir/OperandIndexSequence.test.cc
new file mode 100644
index 000000000..588c4e419
--- /dev/null
+++ b/runtime/onert/core/src/ir/OperandIndexSequence.test.cc
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/OperandIndexSequence.h"
+
+#include <gtest/gtest.h>
+
+using onert::ir::OperandIndex;
+using onert::ir::OperandIndexSequence;
+
+TEST(ir_OperandIndexSequence, neg_append)
+{
+ OperandIndexSequence iset{0, 2, 4, 8};
+
+ ASSERT_EQ(iset.size(), 4);
+
+ iset.append(OperandIndex{10});
+
+ ASSERT_EQ(iset.size(), 5);
+
+ onert::ir::IOIndex index1{1};
+ onert::ir::IOIndex index2{4};
+
+ ASSERT_EQ(iset.at(index1), 2);
+ ASSERT_EQ(iset.at(index2), 10);
+
+ ASSERT_TRUE(iset.contains(OperandIndex{2}));
+ ASSERT_TRUE(iset.contains(OperandIndex{10}));
+ ASSERT_FALSE(iset.contains(OperandIndex{11}));
+}
+
+TEST(graph_OperandIndexSequence, neg_replace)
+{
+ OperandIndexSequence iset{0, 1, 2, 3};
+
+ iset.replace(OperandIndex{1}, OperandIndex{9});
+ ASSERT_FALSE(iset.contains(OperandIndex{1}));
+ ASSERT_TRUE(iset.contains(OperandIndex{9}));
+}
diff --git a/runtime/onert/core/src/ir/Operands.test.cc b/runtime/onert/core/src/ir/Operands.test.cc
new file mode 100644
index 000000000..aff228b10
--- /dev/null
+++ b/runtime/onert/core/src/ir/Operands.test.cc
@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/Operands.h"
+
+#include <gtest/gtest.h>
+
+TEST(ir_Operands, neg_set_test)
+{
+ onert::ir::Operands set;
+
+ onert::ir::Shape shape0{1, 2, 3};
+
+ onert::ir::Shape shape1(4);
+ shape1.dim(0) = 10;
+ shape1.dim(1) = 20;
+ shape1.dim(2) = 30;
+ shape1.dim(3) = 40;
+
+ onert::ir::TypeInfo type{onert::ir::DataType::INT32};
+
+ set.emplace(shape0, type);
+ set.emplace(shape1, type);
+
+ ASSERT_EQ(set.exist(onert::ir::OperandIndex{0u}), true);
+ ASSERT_EQ(set.exist(onert::ir::OperandIndex{1u}), true);
+ ASSERT_EQ(set.exist(onert::ir::OperandIndex{2u}), false);
+
+ ASSERT_EQ(set.at(onert::ir::OperandIndex{0u}).shape().dim(0), 1);
+ ASSERT_EQ(set.at(onert::ir::OperandIndex{0u}).shape().dim(1), 2);
+ ASSERT_EQ(set.at(onert::ir::OperandIndex{0u}).shape().dim(2), 3);
+}
diff --git a/runtime/onert/core/src/ir/Operation.test.cc b/runtime/onert/core/src/ir/Operation.test.cc
new file mode 100644
index 000000000..b3c4e852d
--- /dev/null
+++ b/runtime/onert/core/src/ir/Operation.test.cc
@@ -0,0 +1,98 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/Graph.h"
+#include "ir/Index.h"
+#include "ir/OperandIndexSequence.h"
+#include "ir/operation/Concat.h"
+#include "ir/operation/Conv2D.h"
+
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <stdexcept>
+
+using Index = onert::ir::IOIndex;
+using IndexSet = onert::ir::OperandIndexSequence;
+
+TEST(ir_Operation_setIO, operation_setIO_conv)
+{
+ onert::ir::Graph graph;
+
+ onert::ir::Shape shape{3};
+ onert::ir::TypeInfo type{onert::ir::DataType::INT32};
+
+ // Add Conv
+ using Graph = onert::ir::operation::Conv2D;
+
+ auto input_operand = graph.addOperand(shape, type);
+ auto kernel_operand = graph.addOperand(shape, type);
+ auto bias_operand = graph.addOperand(shape, type);
+ IndexSet inputs{input_operand, kernel_operand, bias_operand};
+
+ Graph::Param conv_params;
+ conv_params.padding.type = onert::ir::PaddingType::SAME;
+ conv_params.stride.horizontal = 1;
+ conv_params.stride.vertical = 1;
+ conv_params.activation = onert::ir::Activation::NONE;
+
+ auto output_operand = graph.addOperand(shape, type).value();
+ IndexSet outputs{output_operand};
+
+ auto conv = std::make_unique<Graph>(inputs, outputs, conv_params);
+
+ ASSERT_NE(conv, nullptr);
+ ASSERT_EQ(conv->getInputs().at(Index{0}).value(), inputs.at(0).value());
+ conv->setInputs({8, 9, 10});
+ ASSERT_NE(conv->getInputs().at(Index{0}).value(), inputs.at(0).value());
+ ASSERT_EQ(conv->getInputs().at(Index{0}).value(), 8);
+}
+
+TEST(ir_Operation_setIO, neg_operation_setIO_concat)
+{
+ onert::ir::Graph graph;
+
+ onert::ir::Shape shape{3};
+
+ onert::ir::TypeInfo type{onert::ir::DataType::INT32};
+
+ using Graph = onert::ir::operation::Concat;
+
+ // Add Concat
+ IndexSet inputs;
+ for (int i = 0; i < 6; ++i)
+ {
+ inputs.append(graph.addOperand(shape, type));
+ }
+
+ Graph::Param concat_params{0};
+
+ auto output_operand = graph.addOperand(shape, type).value();
+ IndexSet outputs{output_operand};
+
+ auto concat = std::make_unique<Graph>(inputs, outputs, concat_params);
+
+ ASSERT_NE(concat, nullptr);
+ ASSERT_EQ(concat->getInputs().size(), 6);
+ ASSERT_EQ(concat->getInputs().at(Index{0}).value(), inputs.at(0).value());
+
+ concat->setInputs({80, 6, 9, 11});
+ ASSERT_EQ(concat->getInputs().size(), 4);
+ ASSERT_NE(concat->getInputs().at(Index{0}).value(), inputs.at(0).value());
+ ASSERT_EQ(concat->getInputs().at(Index{0}).value(), 80);
+ ASSERT_EQ(concat->getInputs().at(Index{2}).value(), 9);
+ ASSERT_THROW(concat->getInputs().at(Index{5}), std::out_of_range);
+}
diff --git a/runtime/onert/core/src/ir/Operations.test.cc b/runtime/onert/core/src/ir/Operations.test.cc
new file mode 100644
index 000000000..e57872689
--- /dev/null
+++ b/runtime/onert/core/src/ir/Operations.test.cc
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/Operations.h"
+
+#include "MockNode.h"
+
+#include <gtest/gtest.h>
+
+using onert::ir::Operation;
+using onert::ir::OperationIndex;
+using onert::ir::Operations;
+
+TEST(ir_Operations, basic)
+{
+ Operations ops;
+ ops.push(std::unique_ptr<Operation>(new onert_test::ir::SimpleMock({1, 2, 3, 4}, {5, 6, 7})));
+ OperationIndex idx{0u};
+ ASSERT_EQ(ops.at(idx).getInputs().size(), 4);
+ ASSERT_EQ(ops.at(idx).getOutputs().size(), 3);
+}
+
+TEST(ir_Operations, neg_at)
+{
+ Operations ops;
+ ops.push(std::unique_ptr<Operation>(new onert_test::ir::SimpleMock({1, 2, 3, 4}, {5, 6, 7})));
+ OperationIndex idx{99u};
+ EXPECT_THROW(ops.at(idx), std::out_of_range);
+}
diff --git a/runtime/onert/core/src/ir/Shape.test.cc b/runtime/onert/core/src/ir/Shape.test.cc
new file mode 100644
index 000000000..afdb29254
--- /dev/null
+++ b/runtime/onert/core/src/ir/Shape.test.cc
@@ -0,0 +1,58 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/Shape.h"
+
+#include <gtest/gtest.h>
+
+TEST(ShapeTest, basic_test)
+{
+ {
+ onert::ir::Shape shape(3);
+
+ shape.dim(0) = 1;
+ shape.dim(1) = 2;
+ shape.dim(2) = 3;
+
+ ASSERT_EQ(shape.rank(), 3);
+ ASSERT_EQ(shape.num_elements(), 6);
+ ASSERT_EQ(onert::ir::rankMaybeUnspecified(shape), false);
+ ASSERT_EQ(shape.hasUnspecifiedDims(), false);
+ }
+ {
+ onert::ir::Shape shape; // scalar or rank is unspecified
+
+ ASSERT_EQ(shape.rank(), 0);
+ ASSERT_EQ(shape.num_elements(), 1);
+ ASSERT_EQ(onert::ir::rankMaybeUnspecified(shape), true);
+ ASSERT_EQ(shape.hasUnspecifiedDims(), false);
+ }
+}
+
+TEST(ShapeTest, neg_basic_test)
+{
+ {
+ onert::ir::Shape shape(2);
+
+ shape.dim(0) = 1;
+ shape.dim(1) = onert::ir::Shape::UNSPECIFIED_DIM;
+
+ ASSERT_EQ(shape.rank(), 2);
+ ASSERT_EQ(onert::ir::rankMaybeUnspecified(shape), false);
+ ASSERT_EQ(shape.hasUnspecifiedDims(), true);
+ EXPECT_ANY_THROW(shape.num_elements());
+ }
+}
diff --git a/runtime/onert/core/src/ir/verifier/Verifier.test.cc b/runtime/onert/core/src/ir/verifier/Verifier.test.cc
new file mode 100644
index 000000000..1ec71cd55
--- /dev/null
+++ b/runtime/onert/core/src/ir/verifier/Verifier.test.cc
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Verifier.h"
+
+#include "../MockNode.h"
+
+#include "ir/Graph.h"
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+using IndexSet = onert::ir::OperandIndexSequence;
+using Mock = onert_test::ir::SimpleMock;
+
+TEST(Verifier, dag_checker)
+{
+ onert::ir::Graph graph;
+
+ onert::ir::Shape shape{3};
+ onert::ir::TypeInfo type{onert::ir::DataType::INT32};
+
+ auto operand1 = graph.addOperand(shape, type);
+ auto operand2 = graph.addOperand(shape, type);
+
+ graph.addInput(operand1);
+ graph.addOutput(operand2);
+
+ graph.addOperation(std::make_unique<Mock>(IndexSet{operand1}, IndexSet{operand2}));
+
+ onert::ir::verifier::DAGChecker verifier;
+
+ ASSERT_TRUE(verifier.verify(graph));
+}
+
+TEST(Verifier, neg_edge_consistency_checker_1)
+{
+ onert::ir::Graph graph;
+
+ onert::ir::Shape shape{3};
+ onert::ir::TypeInfo type{onert::ir::DataType::INT32};
+
+ auto operand1 = graph.addOperand(shape, type);
+ auto operand2 = graph.addOperand(shape, type);
+
+ graph.addInput(operand1);
+ graph.addOutput(operand2);
+
+ auto mock_op = std::make_unique<Mock>(IndexSet{operand1}, IndexSet{operand2});
+ auto op_ind = graph.addOperation(std::move(mock_op));
+
+ graph.operands().at(operand1).removeUse(op_ind); // Manipulate the operand alone
+
+ onert::ir::verifier::EdgeChecker verifier;
+ ASSERT_FALSE(verifier.verify(graph));
+}
+
+TEST(Verifier, neg_edge_consistency_checker_2)
+{
+ onert::ir::Graph graph;
+
+ onert::ir::Shape shape{3};
+ onert::ir::TypeInfo type{onert::ir::DataType::INT32};
+
+ auto operand1 = graph.addOperand(shape, type);
+ auto operand2 = graph.addOperand(shape, type);
+
+ graph.addInput(operand1);
+ graph.addOutput(operand2);
+
+ auto mock_op = std::make_unique<Mock>(IndexSet{operand1}, IndexSet{operand2});
+ auto mock_op_ptr = mock_op.get();
+ auto op_ind = graph.addOperation(std::move(mock_op));
+
+ mock_op_ptr->setInputs({operand2}); // Manipulate the operation alone
+
+ onert::ir::verifier::EdgeChecker verifier;
+ ASSERT_FALSE(verifier.verify(graph));
+}
diff --git a/runtime/onert/core/src/util/ChromeTracingEventWriter.cc b/runtime/onert/core/src/util/ChromeTracingEventWriter.cc
index 3fc0c8ece..d868efedf 100644
--- a/runtime/onert/core/src/util/ChromeTracingEventWriter.cc
+++ b/runtime/onert/core/src/util/ChromeTracingEventWriter.cc
@@ -14,12 +14,12 @@
* limitations under the License.
*/
-#include "util/EventWriter.h"
+#include "EventWriter.h"
-#include <sstream>
-#include <vector>
#include <cassert>
+#include <sstream>
#include <utility>
+#include <vector>
// json type for ChromeTracingWriter
namespace
diff --git a/runtime/onert/core/src/util/ConfigSource.cc b/runtime/onert/core/src/util/ConfigSource.cc
index 9da93f68a..b7fcefc7a 100644
--- a/runtime/onert/core/src/util/ConfigSource.cc
+++ b/runtime/onert/core/src/util/ConfigSource.cc
@@ -15,13 +15,15 @@
*/
#include "util/ConfigSource.h"
-#include "util/GeneralConfigSource.h"
-#include "util/EnvConfigSource.h"
+#include "util/logging.h"
+
+#include <misc/EnvConfigSource.h>
+#include <misc/GeneralConfigSource.h>
+#include <misc/IConfigSource.h>
-#include <array>
#include <algorithm>
+#include <array>
#include <cassert>
-
#include <memory>
namespace onert
@@ -29,12 +31,27 @@ namespace onert
namespace util
{
+using namespace nnfw::misc;
+
static std::unique_ptr<IConfigSource> _source;
static std::unique_ptr<IConfigSource> _source_ext;
void config_source(std::unique_ptr<IConfigSource> &&source) { _source = std::move(source); }
void config_source_ext(std::unique_ptr<IConfigSource> &&source) { _source_ext = std::move(source); }
+void setConfigKeyValues(const CfgKeyValues &keyValues)
+{
+ auto configsrc = std::make_unique<GeneralConfigSource>();
+
+ for (auto it = keyValues.begin(); it != keyValues.end(); ++it)
+ {
+ VERBOSE(NNPKG_CONFIGS) << "(" << it->first << ") = (" << it->second << ")" << std::endl;
+ configsrc->set(it->first, it->second);
+ }
+
+ onert::util::config_source_ext(std::move(configsrc));
+}
+
static IConfigSource *config_source()
{
if (!_source)
diff --git a/runtime/onert/core/src/util/EventCollector.cc b/runtime/onert/core/src/util/EventCollector.cc
index 83c2649d1..c1b9c4315 100644
--- a/runtime/onert/core/src/util/EventCollector.cc
+++ b/runtime/onert/core/src/util/EventCollector.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "util/EventCollector.h"
+#include "EventCollector.h"
// C++ standard libraries
#include <chrono>
diff --git a/runtime/onert/core/src/util/EventCollector.h b/runtime/onert/core/src/util/EventCollector.h
index 774fe05ef..effb72373 100644
--- a/runtime/onert/core/src/util/EventCollector.h
+++ b/runtime/onert/core/src/util/EventCollector.h
@@ -17,12 +17,13 @@
#ifndef __ONERT_UTIL_EVENT_COLLECTOR_H__
#define __ONERT_UTIL_EVENT_COLLECTOR_H__
-#include "util/EventRecorder.h"
+#include "EventRecorder.h"
+
#include "util/TracingCtx.h"
-#include <vector>
-#include <utility>
#include <string>
+#include <utility>
+#include <vector>
class EventCollector
{
diff --git a/runtime/onert/core/src/util/EventRecorder.cc b/runtime/onert/core/src/util/EventRecorder.cc
index 5d3d5f5c6..85a588d38 100644
--- a/runtime/onert/core/src/util/EventRecorder.cc
+++ b/runtime/onert/core/src/util/EventRecorder.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "util/EventRecorder.h"
+#include "EventRecorder.h"
void EventRecorder::emit(std::unique_ptr<DurationEvent> &&evt)
{
diff --git a/runtime/onert/core/src/util/EventWriter.cc b/runtime/onert/core/src/util/EventWriter.cc
index c42c53730..ca4bd302e 100644
--- a/runtime/onert/core/src/util/EventWriter.cc
+++ b/runtime/onert/core/src/util/EventWriter.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "util/EventWriter.h"
+#include "EventWriter.h"
#include <cassert>
diff --git a/runtime/onert/core/src/util/GeneralConfigSource.cc b/runtime/onert/core/src/util/GeneralConfigSource.cc
deleted file mode 100644
index 7d2757e58..000000000
--- a/runtime/onert/core/src/util/GeneralConfigSource.cc
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "util/GeneralConfigSource.h"
-#include "util/logging.h"
-
-namespace onert
-{
-namespace util
-{
-
-std::string GeneralConfigSource::get(const std::string &key) const
-{
- auto itr = _map.find(key);
- if (itr == _map.end())
- {
- return "";
- }
- else
- {
- return itr->second;
- }
-}
-
-void GeneralConfigSource::set(const std::string &key, const std::string &val)
-{
- VERBOSE(GeneralConfigSource) << key << " : " << val << std::endl;
- _map[key] = val;
-}
-
-} // namespace util
-} // namespace onert
diff --git a/runtime/onert/core/src/util/EnvConfigSource.cc b/runtime/onert/core/src/util/Index.test.cc
index 0d25b7353..ff73e5e59 100644
--- a/runtime/onert/core/src/util/EnvConfigSource.cc
+++ b/runtime/onert/core/src/util/Index.test.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,27 +14,21 @@
* limitations under the License.
*/
-#include "util/EnvConfigSource.h"
+#include "util/Index.h"
-#include <cstdlib>
+#include <gtest/gtest.h>
-namespace onert
-{
-namespace util
-{
+using Index = ::onert::util::Index<uint32_t, struct TestTag>;
-std::string EnvConfigSource::get(const std::string &key) const
+TEST(Index, neg_index_test)
{
- const char *value = std::getenv(key.c_str());
- if (value != nullptr)
- {
- return value;
- }
- else
- {
- return GeneralConfigSource::get(key);
- }
-}
+ Index idx1{1u};
+ Index idx2{2u};
+ Index idx3{idx1};
-} // namespace util
-} // namespace onert
+ ASSERT_EQ(idx1, 1);
+ ASSERT_EQ(idx1, 1u);
+ ASSERT_EQ(idx1.value(), 1u);
+ ASSERT_NE(idx1, idx2);
+ ASSERT_EQ(idx1, idx3);
+}
diff --git a/runtime/onert/core/src/util/MDTableEventWriter.cc b/runtime/onert/core/src/util/MDTableEventWriter.cc
index b7fbac5e2..7a8b9f234 100644
--- a/runtime/onert/core/src/util/MDTableEventWriter.cc
+++ b/runtime/onert/core/src/util/MDTableEventWriter.cc
@@ -14,16 +14,16 @@
* limitations under the License.
*/
-#include "util/EventWriter.h"
+#include "EventWriter.h"
-#include <sstream>
-#include <vector>
-#include <unordered_map>
#include <cassert>
-#include <utility>
#include <map>
#include <set>
+#include <sstream>
#include <stdint.h>
+#include <unordered_map>
+#include <utility>
+#include <vector>
// md table type
namespace
diff --git a/runtime/onert/core/src/util/ObjectManager.test.cc b/runtime/onert/core/src/util/ObjectManager.test.cc
new file mode 100644
index 000000000..3fe735732
--- /dev/null
+++ b/runtime/onert/core/src/util/ObjectManager.test.cc
@@ -0,0 +1,211 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/Index.h"
+#include "util/ObjectManager.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert;
+
+struct TestTag;
+using Index = typename util::Index<uint32_t, TestTag>;
+
+TEST(ObjectManager, emplace)
+{
+ util::ObjectManager<Index, int> man;
+
+ auto index = man.emplace(100);
+ ASSERT_EQ(man.at(index), 100);
+}
+
+TEST(ObjectManager, neg_remove_1)
+{
+ util::ObjectManager<Index, int> man;
+
+ Index index = man.emplace(100);
+ ASSERT_TRUE(man.exist(index));
+ ASSERT_EQ(man.at(index), 100);
+
+ man.remove(index);
+ ASSERT_FALSE(man.exist(index));
+}
+
+TEST(ObjectManager, neg_remove_2)
+{
+ util::ObjectManager<Index, int> man;
+
+ auto index0 = man.emplace(100);
+ auto index1 = man.emplace(200);
+ ASSERT_TRUE(man.exist(index0));
+ ASSERT_EQ(man.at(index0), 100);
+ ASSERT_TRUE(man.exist(index1));
+ ASSERT_EQ(man.at(index1), 200);
+
+ man.remove(index0);
+ ASSERT_FALSE(man.exist(index0));
+ ASSERT_TRUE(man.exist(index1));
+ ASSERT_EQ(man.at(index1), 200);
+}
+
+TEST(ObjectManager, push)
+{
+ util::ObjectManager<Index, int> man;
+
+ // Not specify index
+ auto index = man.push(std::make_unique<int>(100));
+ ASSERT_EQ(man.at(index), 100);
+
+ // Specify index
+ auto index2 = man.push(std::make_unique<int>(200), Index{33});
+ ASSERT_EQ(index2.value(), 33);
+ ASSERT_EQ(man.at(index2), 200);
+
+ auto index3 = man.push(std::make_unique<int>(300));
+ // NOTE auto-generated index number is always (biggest index in the ObjectManager + 1)
+ ASSERT_EQ(index3.value(), 34);
+ ASSERT_EQ(man.at(index3), 300);
+
+ auto index4 = man.push(std::make_unique<int>(400), Index{22});
+ ASSERT_EQ(index4.value(), 22);
+ ASSERT_EQ(man.at(index4), 400);
+
+ auto index5 = man.push(std::make_unique<int>(500));
+ // NOTE auto-generated index number is always (biggest index in the ObjectManager + 1)
+ ASSERT_EQ(index5.value(), 35);
+ ASSERT_EQ(man.at(index5), 500);
+}
+
+TEST(ObjectManager, neg_push)
+{
+ util::ObjectManager<Index, int> man;
+
+ // Specify index
+ auto index = man.push(std::make_unique<int>(100), Index{55});
+ ASSERT_EQ(index.value(), 55);
+ ASSERT_EQ(man.at(index), 100);
+
+ // Specify the same index
+ auto index2 = man.push(std::make_unique<int>(200), Index{55});
+ ASSERT_FALSE(index2.valid());
+}
+
+static const uint32_t kMaxUInt32 = std::numeric_limits<uint32_t>::max();
+
+TEST(ObjectManager, neg_push_undefined_index)
+{
+ util::ObjectManager<Index, int> man;
+
+ // Try inserting invalid(undefined) index
+ auto index = man.push(std::make_unique<int>(100), Index{kMaxUInt32});
+ ASSERT_FALSE(index.valid());
+ ASSERT_EQ(man.size(), 0);
+}
+
+TEST(ObjectManager, neg_push_max_index)
+{
+ util::ObjectManager<Index, int> man;
+
+ // Insert an object with maximum valid index
+ auto index = man.push(std::make_unique<int>(100), Index{kMaxUInt32 - 1});
+ ASSERT_EQ(index.value(), kMaxUInt32 - 1);
+ ASSERT_EQ(man.at(index), 100);
+ ASSERT_EQ(man.size(), 1);
+
+ // Reached to the final index so next push/emplace must fail
+ auto index2 = man.push(std::make_unique<int>(200));
+ ASSERT_EQ(man.size(), 1);
+ ASSERT_FALSE(index2.valid());
+}
+
+TEST(ObjectManager, neg_emplace_max_index)
+{
+ util::ObjectManager<Index, int> man;
+
+ // Insert an object with maximum valid index
+ auto index = man.push(std::make_unique<int>(100), Index{kMaxUInt32 - 1});
+ ASSERT_EQ(index.value(), kMaxUInt32 - 1);
+ ASSERT_EQ(man.at(index), 100);
+ ASSERT_EQ(man.size(), 1);
+
+ // Reached to the final index so next push/emplace must fail
+ auto index3 = man.emplace(200);
+ ASSERT_EQ(man.size(), 1);
+ ASSERT_FALSE(index3.valid());
+}
+
+TEST(ObjectManager, const_iterate)
+{
+ util::ObjectManager<Index, int> man;
+
+ auto index0 = man.emplace(100);
+ auto index1 = man.emplace(200);
+ auto index2 = man.emplace(300);
+
+ int sum = 0;
+ man.iterate([&](const Index &index, const int &val) { sum += val; });
+ ASSERT_EQ(sum, 600);
+}
+
+TEST(ObjectManager, non_const_iterate)
+{
+ util::ObjectManager<Index, int> man;
+
+ auto index0 = man.emplace(100);
+ auto index1 = man.emplace(200);
+ auto index2 = man.emplace(300);
+
+ man.iterate([&](const Index &index, int &val) { val += 1; });
+ ASSERT_EQ(man.at(index0), 101);
+ ASSERT_EQ(man.at(index1), 201);
+ ASSERT_EQ(man.at(index2), 301);
+}
+
+TEST(ObjectManager, set)
+{
+ util::ObjectManager<Index, int> man;
+ auto index = man.set(Index{1}, std::make_unique<int>(100)); // Insert
+ ASSERT_EQ(index, Index{1});
+ auto index2 = man.set(index, std::make_unique<int>(200)); // Overwrite
+ ASSERT_EQ(index2, index);
+ ASSERT_EQ(man.at(index2), 200);
+}
+
+TEST(ObjectManager, neg_set)
+{
+ auto v = std::make_unique<int>(100);
+ util::ObjectManager<Index, int> man;
+ auto index = man.set(Index{}, std::move(v)); // Try set with an invalid index
+ ASSERT_EQ(index, Index{});
+ ASSERT_FALSE(index.valid());
+ ASSERT_NE(v, nullptr); // v must be kept when failure
+}
+
+TEST(ObjectManager, getRawPtr)
+{
+ auto v = std::make_unique<int>(100);
+ auto v_ptr = v.get();
+ util::ObjectManager<Index, int> man;
+ auto index = man.push(std::move(v));
+ ASSERT_EQ(v_ptr, man.getRawPtr(index));
+}
+
+TEST(ObjectManager, neg_getRawPtr)
+{
+ util::ObjectManager<Index, int> man;
+ auto ptr = man.getRawPtr(Index{1});
+ ASSERT_EQ(ptr, nullptr);
+}
diff --git a/runtime/onert/core/src/util/SNPEEventWriter.cc b/runtime/onert/core/src/util/SNPEEventWriter.cc
index 6f03cfccf..4dea6d16c 100644
--- a/runtime/onert/core/src/util/SNPEEventWriter.cc
+++ b/runtime/onert/core/src/util/SNPEEventWriter.cc
@@ -14,11 +14,12 @@
* limitations under the License.
*/
-#include "util/EventWriter.h"
+#include "EventWriter.h"
-#include <unordered_map>
#include <json/json.h>
+
#include <cassert>
+#include <unordered_map>
#include <utility>
/**
diff --git a/runtime/onert/core/src/util/ShapeInference.test.cc b/runtime/onert/core/src/util/ShapeInference.test.cc
new file mode 100644
index 000000000..96579bfa2
--- /dev/null
+++ b/runtime/onert/core/src/util/ShapeInference.test.cc
@@ -0,0 +1,544 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/ShapeInference.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert::ir;
+
+TEST(ShapeInference, Elementwise)
+{
+ Shape lhs_shape{1, 299, 299, 3};
+ Shape rhs_shape{3};
+ auto infered_out_shape = onert::shape_inference::inferEltwiseShape(lhs_shape, rhs_shape);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.dim(0), 1);
+ ASSERT_EQ(infered_out_shape.dim(1), 299);
+ ASSERT_EQ(infered_out_shape.dim(2), 299);
+ ASSERT_EQ(infered_out_shape.dim(3), 3);
+}
+
+TEST(ShapeInference, neg_Elementwise)
+{
+ Shape lhs_shape{1, 299, 299, 3};
+ Shape rhs_shape{5, 3};
+ ASSERT_THROW(onert::shape_inference::inferEltwiseShape(lhs_shape, rhs_shape), std::runtime_error);
+}
+
+TEST(ShapeInference, Pool2DNodeSame)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Stride stride{3, 7};
+ Padding padding{PaddingType::SAME};
+
+ operation::Pool2D::Param avg_pool_param{
+ operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
+ auto infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, avg_pool_param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
+
+ operation::Pool2D::Param max_pool_param{
+ operation::Pool2D::PoolType::MAX, 3, 6, stride, padding, Activation::NONE};
+ infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, max_pool_param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
+}
+
+TEST(ShapeInference, Pool2DNodeValid)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Stride stride{3, 7};
+ Padding padding{PaddingType::VALID};
+
+ operation::Pool2D::Param avg_pool_param{
+ operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
+ auto infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, avg_pool_param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
+
+ operation::Pool2D::Param max_pool_param{
+ operation::Pool2D::PoolType::MAX, 3, 6, stride, padding, Activation::NONE};
+ infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, max_pool_param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
+}
+
+TEST(ShapeInference, Pool2DNodeExplicit)
+{
+ Shape in_shape{10, 3, 5, 20};
+
+ Stride stride{3, 7};
+ Padding padding{4, 3, 2, 1};
+
+ operation::Pool2D::Param avg_pool_param{
+ operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
+ auto infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, avg_pool_param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
+
+ operation::Pool2D::Param max_pool_param{
+ operation::Pool2D::PoolType::MAX, 3, 6, stride, padding, Activation::NONE};
+ infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, max_pool_param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
+}
+
+TEST(ShapeInference, neg_Pool2DNode_InvalidStride)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Stride stride{0, 7};
+ Padding padding{PaddingType::SAME};
+
+ operation::Pool2D::Param avg_pool_param{
+ operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
+ ASSERT_THROW(onert::shape_inference::inferPoolShape(in_shape, avg_pool_param),
+ std::runtime_error);
+}
+
+TEST(ShapeInference, Conv2D)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Shape ker_shape{30, 3, 6, 20};
+
+ operation::Conv2D::Param param{Stride{3, 7}, Padding{PaddingType::VALID}, Activation::NONE,
+ Dilation{1, 1}};
+ auto infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
+
+ param = operation::Conv2D::Param{Stride{3, 7}, Padding{PaddingType::SAME}, Activation::NONE,
+ Dilation{1, 1}};
+ infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
+
+ param =
+ operation::Conv2D::Param{Stride{3, 7}, Padding{4, 3, 2, 1}, Activation::NONE, Dilation{1, 1}};
+ infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 3);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
+}
+
+TEST(ShapeInference, neg_Conv2D_InvalidStride)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Shape ker_shape{30, 3, 6, 20};
+
+ operation::Conv2D::Param param{Stride{0, 0}, Padding{PaddingType::VALID}, Activation::NONE,
+ Dilation{1, 1}};
+ ASSERT_THROW(onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param),
+ std::runtime_error);
+}
+
+TEST(ShapeInference, DepthwiseConv2D)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Shape ker_shape{1, 3, 6, 60};
+
+ operation::DepthwiseConv2D::Param param{Stride{3, 7}, Padding{PaddingType::VALID}, 3,
+ Activation::NONE, Dilation{1, 1}};
+ auto infered_out_shape =
+ onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
+
+ param = operation::DepthwiseConv2D::Param{Stride{3, 7}, Padding{PaddingType::SAME}, 3,
+ Activation::NONE, Dilation{1, 1}};
+ infered_out_shape = onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
+
+ param = operation::DepthwiseConv2D::Param{Stride{3, 7}, Padding{4, 3, 2, 1}, 3, Activation::NONE,
+ Dilation{1, 1}};
+ infered_out_shape = onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 3);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
+}
+
+TEST(ShapeInference, neg_DepthwiseConv2D_InvalidSride)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Shape ker_shape{1, 3, 6, 60};
+
+ operation::DepthwiseConv2D::Param param{Stride{3, 0}, Padding{PaddingType::VALID}, 3,
+ Activation::NONE, Dilation{1, 1}};
+ ASSERT_THROW(onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param),
+ std::runtime_error);
+}
+
+TEST(ShapeInference, Concat)
+{
+ {
+ Shape in1{10, 20, 30, 3, 50};
+ Shape in2{10, 20, 30, 2, 50};
+ Shape in3{10, 20, 30, 2, 50};
+
+ operation::Concat::Param param{3};
+ auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2, in3}, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 5);
+ ASSERT_EQ(infered_out_shape.dim(0), 10);
+ ASSERT_EQ(infered_out_shape.dim(1), 20);
+ ASSERT_EQ(infered_out_shape.dim(2), 30);
+ ASSERT_EQ(infered_out_shape.dim(3), 7);
+ ASSERT_EQ(infered_out_shape.dim(4), 50);
+ }
+ {
+ // case 1. when axis < 0
+ Shape in1{10, 20, 2};
+ Shape in2{10, 20, 3};
+
+ operation::Concat::Param param{-1};
+ auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2}, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 3);
+ ASSERT_EQ(infered_out_shape.dim(0), 10);
+ ASSERT_EQ(infered_out_shape.dim(1), 20);
+ ASSERT_EQ(infered_out_shape.dim(2), 5);
+ }
+ {
+ // case 2. when axis < 0
+ Shape in1{2, 20, 2};
+ Shape in2{3, 20, 2};
+
+ operation::Concat::Param param{-3};
+ auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2}, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 3);
+ ASSERT_EQ(infered_out_shape.dim(0), 5);
+ ASSERT_EQ(infered_out_shape.dim(1), 20);
+ ASSERT_EQ(infered_out_shape.dim(2), 2);
+ }
+}
+
+TEST(ShapeInference, neg_Concat)
+{
+ {
+ operation::Concat::Param param{2};
+ Shape in1{10, 1, 3};
+ Shape in2{10, 2, 4}; // dim[1] should be 1 but 2
+
+ EXPECT_ANY_THROW(onert::shape_inference::inferConcatShape({in1, in2}, param));
+ }
+ { // wrong rank
+ operation::Concat::Param param{2};
+ Shape in1{10, 2, 3, 4};
+ Shape in2{10, 2, 4}; // rank should be 4
+
+ EXPECT_ANY_THROW(onert::shape_inference::inferConcatShape({in1, in2}, param));
+ }
+}
+
+TEST(ShapeInference, ExpandDims)
+{
+ Shape in_shape{30, 40};
+
+ auto check = [&](int32_t axis, Shape &expected) {
+ auto actual = onert::shape_inference::inferExpandDimsShape(in_shape, axis);
+
+ ASSERT_EQ(actual.rank(), 3);
+ for (int32_t dim = 0; dim < expected.rank(); dim++)
+ ASSERT_EQ(actual.dim(dim), expected.dim(dim));
+ };
+
+ { // boundary
+ int32_t axis = 0;
+ Shape expected{1, 30, 40};
+ check(axis, expected);
+ }
+ { // boundary
+ int32_t axis = 2;
+ Shape expected{30, 40, 1};
+ check(axis, expected);
+ }
+ { // inside
+ int32_t axis = 1;
+ Shape expected{30, 1, 40};
+ check(axis, expected);
+ }
+ { // negative boundary
+ int32_t axis = -1;
+ Shape expected{30, 40, 1};
+ check(axis, expected);
+ }
+ { // negative boundary
+ int32_t axis = -3;
+ Shape expected{1, 30, 40};
+ check(axis, expected);
+ }
+}
+
+TEST(ShapeInference, neg_ExpandDims)
+{
+ Shape in_shape{30, 40};
+
+ { // over boundary
+ int32_t axis = 3;
+ ASSERT_THROW(onert::shape_inference::inferExpandDimsShape(in_shape, axis), std::runtime_error);
+ }
+ { // over boundary
+ int32_t axis = -4;
+ ASSERT_THROW(onert::shape_inference::inferExpandDimsShape(in_shape, axis), std::runtime_error);
+ }
+}
+
+TEST(ShapeInference, FullyConnected)
+{
+ Shape in_shape{3, 4, 5, 6};
+ Shape ker_shape{3, 10};
+ auto infered_out_shape = onert::shape_inference::inferFullyConnectedShape(in_shape, ker_shape);
+
+ ASSERT_EQ(infered_out_shape.rank(), 2);
+ ASSERT_EQ(infered_out_shape.dim(0), 36);
+ ASSERT_EQ(infered_out_shape.dim(1), 3);
+}
+
+TEST(ShapeInference, Transpose)
+{
+ auto check = [&](Shape &in_shape, std::vector<int> perm, Shape &expected) {
+ // pre-conditions
+ ASSERT_EQ(in_shape.rank(), perm.size());
+ ASSERT_EQ(expected.rank(), perm.size());
+ auto inferred_out_shape =
+ onert::shape_inference::inferTransposeShape(in_shape, perm.data(), perm.size());
+ // post-conditions
+ ASSERT_EQ(inferred_out_shape.rank(), perm.size());
+ for (int32_t dim = 0; dim < expected.rank(); dim++)
+ {
+ ASSERT_EQ(inferred_out_shape.dim(dim), expected.dim(dim));
+ }
+ };
+ // check for 2-D
+ {
+ Shape in_shape{2, 3};
+ std::vector<int> perm = {1, 0};
+ Shape expected{3, 2};
+ // int32_t rank = 2;
+ check(in_shape, perm, expected);
+ }
+ // check for 3-D
+ {
+ Shape in_shape{1, 2, 3};
+ std::vector<int> perm = {2, 0, 1};
+ Shape expected{3, 1, 2};
+ // int32_t rank = 3;
+ check(in_shape, perm, expected);
+ }
+ // check for 4-D
+ {
+ Shape in_shape{1, 2, 3, 4};
+ std::vector<int> perm = {1, 3, 0, 2};
+ Shape expected{2, 4, 1, 3};
+ // int32_t rank = 4;
+ check(in_shape, perm, expected);
+ }
+}
+
+TEST(ShapeInference, neg_Transpose)
+{
+ Shape in_shape{1, 2, 3};
+ // Invalid parameter size
+ {
+ std::vector<int> perm = {2, 0, 1, 0};
+ // int32_t rank = 3;
+ ASSERT_THROW(onert::shape_inference::inferTransposeShape(in_shape, perm.data(), perm.size()),
+ std::runtime_error);
+ }
+ // Invalid parameter value
+ {
+ std::vector<int> perm = {2, 0, 3};
+ // int32_t rank = 3;
+ ASSERT_THROW(onert::shape_inference::inferTransposeShape(in_shape, perm.data(), perm.size()),
+ std::runtime_error);
+ }
+}
+
+TEST(ShapeInference, Gather)
+{
+ auto check = [&](Shape &input, Shape &indices, Shape &expected, int32_t axis) {
+ int rank = input.rank();
+ auto actual = onert::shape_inference::inferGatherShape(input, indices, axis, rank);
+
+ ASSERT_EQ(actual.rank(), expected.rank());
+
+ for (int32_t dim = 0; dim < expected.rank(); dim++)
+ ASSERT_EQ(actual.dim(dim), expected.dim(dim));
+ };
+
+ // check for 2-D, 3-D, axis 0
+ {
+ Shape input{3, 4};
+ Shape indices{1, 1, 2};
+ int32_t axis = 0;
+ Shape expected{1, 1, 2, 4};
+ check(input, indices, expected, axis);
+ }
+
+ // check for 2-D, 3-D, axis 1
+ {
+ Shape input{3, 4};
+ Shape indices{1, 2, 1};
+ int32_t axis = 1;
+ Shape expected{3, 1, 2, 1};
+ check(input, indices, expected, axis);
+ }
+
+ // check for 3-D, 2-D, axis 0
+ {
+ Shape input{2, 3, 4};
+ Shape indices{1, 2};
+ int32_t axis = 0;
+ Shape expected{1, 2, 3, 4};
+ check(input, indices, expected, axis);
+ }
+
+ // check for 3-D, 2-D, axis 2
+ {
+ Shape input{2, 3, 4};
+ Shape indices{2, 1};
+ int32_t axis = 2;
+ Shape expected{2, 3, 2, 1};
+ check(input, indices, expected, axis);
+ }
+
+ // check for 4D, axis 0
+ {
+ Shape input{1, 2, 3, 4};
+ Shape indices{2};
+ int32_t axis = 0;
+ Shape expected{2, 2, 3, 4};
+ check(input, indices, expected, axis);
+ }
+}
+
+TEST(ShapeInference, BCQFullyConnected)
+{
+ auto check = [&](Shape &in_shape, Shape &cluster_shape, std::vector<int> cluster,
+ Shape &expected) {
+ auto actual =
+ onert::shape_inference::inferBCQFullyConnectedShape(in_shape, cluster_shape, cluster.data());
+ ASSERT_EQ(actual.rank(), expected.rank());
+
+ for (int32_t dim = 0; dim < expected.rank(); dim++)
+ ASSERT_EQ(actual.dim(dim), expected.dim(dim));
+ };
+
+ {
+ Shape in_shape{10, 1};
+ Shape cluster_shape{3, 2};
+ std::vector<int> cluster = {1, 10, 2, 10, 3, 10};
+
+ Shape expected{30, 1};
+ check(in_shape, cluster_shape, cluster, expected);
+ }
+
+ {
+ Shape in_shape{1, 1};
+ Shape cluster_shape{1, 2};
+ std::vector<int> cluster = {3, 50};
+
+ Shape expected{50, 1};
+ check(in_shape, cluster_shape, cluster, expected);
+ }
+}
+
+TEST(ShapeInference, BCQGather)
+{
+ auto check = [&](Shape &indices_shape, Shape &cluster_shape, std::vector<int> cluster,
+ uint32_t hidden_size, uint32_t axis, int rank, Shape &expected) {
+ operation::BCQGather::Param param{hidden_size, axis};
+ auto actual = onert::shape_inference::inferBCQGatherShape(indices_shape, cluster_shape,
+ cluster.data(), rank, param);
+ ASSERT_EQ(actual.rank(), expected.rank());
+
+ for (int32_t dim = 0; dim < expected.rank(); dim++)
+ ASSERT_EQ(actual.dim(dim), expected.dim(dim));
+ };
+
+ {
+ Shape indices_shape{5, 1};
+ Shape cluster_shape{3, 2};
+ std::vector<int> cluster = {1, 10, 2, 10, 3, 10};
+ uint32_t hidden_size = 10;
+ uint32_t axis = 0;
+ int rank = 2;
+
+ Shape expected{5, 1, 10};
+ check(indices_shape, cluster_shape, cluster, hidden_size, axis, rank, expected);
+ }
+
+ {
+ Shape indices_shape{5, 1};
+ Shape cluster_shape{3, 2};
+ std::vector<int> cluster = {1, 10, 2, 10, 3, 10};
+ uint32_t hidden_size = 10;
+ uint32_t axis = 1;
+ int rank = 2;
+
+ Shape expected{30, 5, 1};
+ check(indices_shape, cluster_shape, cluster, hidden_size, axis, rank, expected);
+ }
+}